/home/liu/actions-runner/_work/ccv/ccv/lib/nnc/cmd/sort/ccv_nnc_sort_cpu_ref.c
Line | Count | Source |
1 | | #include "ccv.h" |
2 | | #include "ccv_internal.h" |
3 | | #include "nnc/ccv_nnc.h" |
4 | | #include "nnc/ccv_nnc_easy.h" |
5 | | #include "nnc/ccv_nnc_internal.h" |
6 | | #ifdef USE_OPENMP |
7 | | #include <omp.h> |
8 | | #endif |
9 | | #ifdef USE_DISPATCH |
10 | | #include <dispatch/dispatch.h> |
11 | | #endif |
12 | | |
13 | 397k | #define less_than(a, b, aux) ((a) < (b)) |
14 | 405k | #define greater_than(a, b, aux) ((a) > (b)) |
15 | 160k | #define swap_func(a, b, array, aux, t) do { \ |
16 | 146k | (t) = (a); \ |
17 | 146k | (a) = (b); \ |
18 | 146k | (b) = (t); \ |
19 | 146k | int _t = aux[&(a) - array]; \ |
20 | 146k | aux[&(a) - array] = aux[&(b) - array]; \ |
21 | 146k | aux[&(b) - array] = _t; \ |
22 | 146k | } while (0) |
23 | 200k | static CCV_IMPLEMENT_QSORT_EX(_ccv_nnc_sort_less_than_f32, float, less_than, swap_func37.3k , int*) |
24 | 196k | static CCV_IMPLEMENT_QSORT_EX(_ccv_nnc_sort_less_than_i32, int, less_than, swap_func37.1k , int*) |
25 | 203k | static CCV_IMPLEMENT_QSORT_EX(_ccv_nnc_sort_greater_than_f32, float, greater_than, swap_func34.3k , int*) |
26 | 201k | static CCV_IMPLEMENT_QSORT_EX(_ccv_nnc_sort_greater_than_i32, int, greater_than, swap_func37.2k , int*) |
27 | | #undef less_than |
28 | | #undef greater_than |
29 | | #undef swap_func |
30 | | typedef struct { |
31 | | float* array; |
32 | | int stride; |
33 | | int* idx; |
34 | | } ccv_nnc_sort_aux_f32_t; |
35 | | typedef struct { |
36 | | int* array; |
37 | | int stride; |
38 | | int* idx; |
39 | | } ccv_nnc_sort_aux_i32_t; |
40 | 16 | #define less_than(a, b, aux) (aux.array[(&(a) - aux.array) * aux.stride] < aux.array[(&(b) - aux.array) * aux.stride]) |
41 | 21 | #define greater_than(a, b, aux) (aux.array[(&(a) - aux.array) * aux.stride] > aux.array[(&(b) - aux.array) * aux.stride]) |
42 | 20 | #define swap_func(a, b, array, aux, t) do { \ |
43 | 20 | (t) = aux.array[(&(a) - array) * aux.stride]; \ |
44 | 20 | aux.array[(&(a) - array) * aux.stride] = aux.array[(&(b) - array) * aux.stride]; \ |
45 | 20 | aux.array[(&(b) - array) * aux.stride] = (t); \ |
46 | 20 | int _t = aux.idx[(&(a) - array) * aux.stride]; \ |
47 | 20 | aux.idx[(&(a) - array) * aux.stride] = aux.idx[(&(b) - array) * aux.stride]; \ |
48 | 20 | aux.idx[(&(b) - array) * aux.stride] = _t; \ |
49 | 20 | } while (0) |
50 | 8 | static CCV_IMPLEMENT_QSORT_EX(_ccv_nnc_sort_with_stride_less_than_f32, float, less_than, swap_func2 , ccv_nnc_sort_aux_f32_t) |
51 | 8 | static CCV_IMPLEMENT_QSORT_EX(_ccv_nnc_sort_with_stride_less_than_i32, int, less_than, swap_func2 , ccv_nnc_sort_aux_i32_t) |
52 | 9 | static CCV_IMPLEMENT_QSORT_EX(_ccv_nnc_sort_with_stride_greater_than_f32, float, greater_than, swap_func7 , ccv_nnc_sort_aux_f32_t) |
53 | 12 | static CCV_IMPLEMENT_QSORT_EX(_ccv_nnc_sort_with_stride_greater_than_i32, int, greater_than, swap_func9 , ccv_nnc_sort_aux_i32_t) |
54 | | #undef less_than |
55 | | #undef swap_func |
56 | | |
57 | | static int _ccv_nnc_sort_forw(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint, const int flags, ccv_nnc_tensor_t* const* const inputs, const int input_size, ccv_nnc_tensor_t* const* const outputs, const int output_size, ccv_nnc_stream_context_t* const stream_context) |
58 | 17 | { |
59 | 17 | assert(input_size == 1); |
60 | 17 | assert(output_size == 2); |
61 | 17 | const ccv_nnc_tensor_view_t* const a = (ccv_nnc_tensor_view_t*)inputs[0]; |
62 | 17 | const int a_nd = ccv_nnc_tensor_nd(a->info.dim); |
63 | 17 | ccv_nnc_tensor_view_t* const b = (ccv_nnc_tensor_view_t*)outputs[0]; |
64 | 17 | assert(ccv_nnc_tensor_nd(b->info.dim) == a_nd); |
65 | 17 | ccv_nnc_tensor_view_t* const indices = (ccv_nnc_tensor_view_t*)outputs[1]; |
66 | 17 | assert(ccv_nnc_tensor_nd(indices->info.dim) == a_nd); |
67 | 17 | assert(indices->info.datatype == CCV_32S); |
68 | 17 | assert(CCV_IS_TENSOR_CONTIGUOUS(a)); |
69 | 17 | assert(CCV_IS_TENSOR_CONTIGUOUS(b)); |
70 | 17 | assert(CCV_IS_TENSOR_CONTIGUOUS(indices)); |
71 | 17 | const int count = ccv_nnc_tensor_count(a->info); |
72 | 17 | int i; |
73 | 17 | if (a_nd == 1) |
74 | 8 | { |
75 | | // This is the fast path, we just do a regular sort and extract the index. |
76 | 8 | assert(ccv_nnc_tensor_count(b->info) == count); |
77 | 8 | assert(ccv_nnc_tensor_count(indices->info) == count); |
78 | 40.0k | for (i = 0; 8 i < count; i++40.0k ) |
79 | 40.0k | indices->data.i32[i] = i; |
80 | 8 | if (a->info.datatype == CCV_32F) |
81 | 4 | { |
82 | 4 | memcpy(b->data.f32, a->data.f32, sizeof(float) * count); |
83 | 4 | if (cmd.info.sort.descending) |
84 | 2 | _ccv_nnc_sort_greater_than_f32(b->data.f32, count, indices->data.i32); |
85 | 2 | else |
86 | 2 | _ccv_nnc_sort_less_than_f32(b->data.f32, count, indices->data.i32); |
87 | 4 | } else { |
88 | 4 | assert(a->info.datatype == CCV_32S); |
89 | 4 | memcpy(b->data.i32, a->data.i32, sizeof(int) * count); |
90 | 4 | if (cmd.info.sort.descending) |
91 | 2 | _ccv_nnc_sort_greater_than_i32(b->data.i32, count, indices->data.i32); |
92 | 2 | else |
93 | 2 | _ccv_nnc_sort_less_than_i32(b->data.i32, count, indices->data.i32); |
94 | 4 | } |
95 | 9 | } else { |
96 | 9 | const int count = ccv_nnc_tensor_count(a->info); |
97 | 9 | assert(ccv_nnc_tensor_count(b->info) == count); |
98 | 9 | assert(ccv_nnc_tensor_count(indices->info) == count); |
99 | 9 | if (a->info.datatype == CCV_32F) |
100 | 4 | memcpy(b->data.f32, a->data.f32, sizeof(float) * count); |
101 | 5 | else |
102 | 5 | memcpy(b->data.i32, a->data.i32, sizeof(float) * count); |
103 | 9 | int sort_runs = 1; |
104 | 9 | int sort_stride = 1; |
105 | 27 | for (i = 0; i < a_nd; i++18 ) |
106 | 18 | { |
107 | 18 | if (i < cmd.info.sort.along_axis) // Skip this. |
108 | 5 | sort_runs *= a->info.dim[i]; |
109 | 13 | else if (i > cmd.info.sort.along_axis) |
110 | 4 | sort_stride *= a->info.dim[i]; |
111 | 18 | } |
112 | 9 | const int skip_stride = sort_stride * a->info.dim[cmd.info.sort.along_axis]; |
113 | 9 | int j, k; |
114 | 9 | const int dim = a->info.dim[cmd.info.sort.along_axis]; |
115 | 9 | if (a->info.datatype == CCV_32F) |
116 | 4 | { |
117 | 4 | ccv_nnc_sort_aux_f32_t aux = { |
118 | 4 | .array = b->data.f32, |
119 | 4 | .stride = sort_stride, |
120 | 4 | .idx = indices->data.i32, |
121 | 4 | }; |
122 | 4 | if (cmd.info.sort.descending) |
123 | 6 | for (i = 0; 2 i < sort_runs; i++4 ) |
124 | 9 | for (j = 0; 4 j < sort_stride; j++5 ) |
125 | 5 | { |
126 | 5 | aux.array = b->data.f32 + skip_stride * i + j; |
127 | 5 | aux.idx = indices->data.i32 + skip_stride * i + j; |
128 | 17 | for (k = 0; k < dim; k++12 ) |
129 | 12 | indices->data.i32[i * skip_stride + k * sort_stride + j] = k; |
130 | 5 | _ccv_nnc_sort_with_stride_greater_than_f32(aux.array, dim, aux); |
131 | 5 | } |
132 | 2 | else |
133 | 6 | for (i = 0; 2 i < sort_runs; i++4 ) |
134 | 9 | for (j = 0; 4 j < sort_stride; j++5 ) |
135 | 5 | { |
136 | 5 | aux.array = b->data.f32 + skip_stride * i + j; |
137 | 5 | aux.idx = indices->data.i32 + skip_stride * i + j; |
138 | 17 | for (k = 0; k < dim; k++12 ) |
139 | 12 | indices->data.i32[i * skip_stride + k * sort_stride + j] = k; |
140 | 5 | _ccv_nnc_sort_with_stride_less_than_f32(aux.array, dim, aux); |
141 | 5 | } |
142 | 5 | } else { |
143 | 5 | assert(a->info.datatype == CCV_32S); |
144 | 5 | ccv_nnc_sort_aux_i32_t aux = { |
145 | 5 | .array = b->data.i32, |
146 | 5 | .stride = sort_stride, |
147 | 5 | .idx = indices->data.i32, |
148 | 5 | }; |
149 | 5 | if (cmd.info.sort.descending) |
150 | 10 | for (i = 0; 3 i < sort_runs; i++7 ) |
151 | 15 | for (j = 0; 7 j < sort_stride; j++8 ) |
152 | 8 | { |
153 | 8 | aux.array = b->data.i32 + skip_stride * i + j; |
154 | 8 | aux.idx = indices->data.i32 + skip_stride * i + j; |
155 | 26 | for (k = 0; k < dim; k++18 ) |
156 | 18 | indices->data.i32[i * skip_stride + k * sort_stride + j] = k; |
157 | 8 | _ccv_nnc_sort_with_stride_greater_than_i32(aux.array, dim, aux); |
158 | 8 | } |
159 | 2 | else |
160 | 6 | for (i = 0; 2 i < sort_runs; i++4 ) |
161 | 9 | for (j = 0; 4 j < sort_stride; j++5 ) |
162 | 5 | { |
163 | 5 | aux.array = b->data.i32 + skip_stride * i + j; |
164 | 5 | aux.idx = indices->data.i32 + skip_stride * i + j; |
165 | 17 | for (k = 0; k < dim; k++12 ) |
166 | 12 | indices->data.i32[i * skip_stride + k * sort_stride + j] = k; |
167 | 5 | _ccv_nnc_sort_with_stride_less_than_i32(aux.array, dim, aux); |
168 | 5 | } |
169 | 5 | } |
170 | 9 | } |
171 | 17 | return CCV_NNC_EXEC_SUCCESS; |
172 | 17 | } |
173 | | |
174 | | static int _ccv_nnc_sort_back(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint, const int flags, ccv_nnc_tensor_t* const* const inputs, const int input_size, ccv_nnc_tensor_t* const* const outputs, const int output_size, ccv_nnc_stream_context_t* const stream_context) |
175 | 0 | { |
176 | 0 | return CCV_NNC_EXEC_INVALID; |
177 | 0 | } |
178 | | |
179 | | REGISTER_COMMAND_BACKEND(CCV_NNC_SORT_FORWARD, CCV_NNC_BACKEND_CPU_REF)(ccv_nnc_cmd_backend_registry_t* const registry) |
180 | 1 | { |
181 | 1 | registry->tensor_formats = CCV_TENSOR_FORMAT_NHWC | CCV_TENSOR_FORMAT_NCHW; |
182 | 1 | registry->tensor_datatypes = CCV_32F | CCV_32S; |
183 | 1 | registry->tensor_memory = CCV_TENSOR_CPU_MEMORY; |
184 | 1 | registry->algorithms = 1; |
185 | 1 | registry->exec = _ccv_nnc_sort_forw; |
186 | 1 | } |
187 | | |
188 | | REGISTER_COMMAND_BACKEND(CCV_NNC_SORT_BACKWARD, CCV_NNC_BACKEND_CPU_REF)(ccv_nnc_cmd_backend_registry_t* const registry) |
189 | 1 | { |
190 | 1 | registry->tensor_formats = CCV_TENSOR_FORMAT_NHWC | CCV_TENSOR_FORMAT_NCHW; |
191 | 1 | registry->tensor_datatypes = CCV_32F | CCV_32S; |
192 | 1 | registry->tensor_memory = CCV_TENSOR_CPU_MEMORY; |
193 | 1 | registry->algorithms = 1; |
194 | 1 | registry->exec = _ccv_nnc_sort_back; |
195 | 1 | } |