/home/liu/actions-runner/_work/ccv/ccv/lib/nnc/cmd/partition/ccv_nnc_partition_cpu_ref.c
Line | Count | Source |
1 | | #include "ccv.h" |
2 | | #include "ccv_internal.h" |
3 | | #include "nnc/ccv_nnc.h" |
4 | | #include "nnc/ccv_nnc_easy.h" |
5 | | #include "nnc/ccv_nnc_internal.h" |
6 | | #ifdef USE_OPENMP |
7 | | #include <omp.h> |
8 | | #endif |
9 | | #ifdef USE_DISPATCH |
10 | | #include <dispatch/dispatch.h> |
11 | | #endif |
12 | | |
13 | | static int _ccv_nnc_partition_forw(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint, const int flags, ccv_nnc_tensor_t* const* const inputs, const int input_size, ccv_nnc_tensor_t* const* const outputs, const int output_size, ccv_nnc_stream_context_t* const stream_context) |
14 | 39 | { |
15 | 39 | assert(input_size == 1); |
16 | 39 | assert(output_size == 2); |
17 | 39 | const ccv_nnc_tensor_view_t* const a = (ccv_nnc_tensor_view_t*)inputs[0]; |
18 | 39 | const int a_nd = ccv_nnc_tensor_nd(a->info.dim); |
19 | 39 | ccv_nnc_tensor_view_t* const b = (ccv_nnc_tensor_view_t*)outputs[0]; |
20 | 39 | assert(ccv_nnc_tensor_nd(b->info.dim) == a_nd); |
21 | 39 | ccv_nnc_tensor_view_t* const indices = (ccv_nnc_tensor_view_t*)outputs[1]; |
22 | 39 | assert(ccv_nnc_tensor_nd(indices->info.dim) == a_nd); |
23 | 39 | assert(indices->info.datatype == CCV_32S); |
24 | 39 | assert(CCV_IS_TENSOR_CONTIGUOUS(a)); |
25 | 39 | assert(CCV_IS_TENSOR_CONTIGUOUS(b)); |
26 | 39 | assert(CCV_IS_TENSOR_CONTIGUOUS(indices)); |
27 | 39 | assert(a->info.datatype == b->info.datatype); |
28 | 39 | const int count = ccv_nnc_tensor_count(a->info); |
29 | 39 | if (a_nd == 1) |
30 | 6 | { |
31 | 6 | int i, j; |
32 | 6 | void* workmem = ccv_nnc_stream_context_get_workspace(stream_context, ((a->info.datatype == CCV_32F) ? sizeof(float)4 : sizeof(int)2 ) * count + sizeof(int) * count, CCV_TENSOR_CPU_MEMORY); |
33 | | // This is the fast path, we just do a regular sort and extract the index. |
34 | 6 | assert(ccv_nnc_tensor_count(b->info) == cmd.info.partition.kth); |
35 | 6 | assert(ccv_nnc_tensor_count(indices->info) == cmd.info.partition.kth); |
36 | 6 | if (a->info.datatype == CCV_32F) |
37 | 4 | { |
38 | 4 | memcpy(workmem, a->data.f32, sizeof(float) * count); |
39 | 4 | float* const a_ptr = (float*)workmem; |
40 | 4 | int* const idx_ptr = (int*)((char*)workmem + sizeof(float) * count); |
41 | 20.0k | for (i = 0; i < count; i++20.0k ) |
42 | 20.0k | idx_ptr[i] = i; |
43 | 4 | if (cmd.info.partition.descending) |
44 | 2 | { |
45 | 7 | for (i = 0; i < cmd.info.partition.kth; i++5 ) |
46 | 5 | { |
47 | 5 | float k = a_ptr[i]; |
48 | 5 | int v = i; |
49 | 10.0k | for (j = i + 1; j < count; j++10.0k ) |
50 | 10.0k | if (a_ptr[j] > k) |
51 | 12 | k = a_ptr[j], v = j; |
52 | 5 | b->data.f32[i] = k; |
53 | 5 | indices->data.i32[i] = idx_ptr[v]; |
54 | 5 | if (i != v) |
55 | 4 | a_ptr[v] = a_ptr[i], idx_ptr[v] = idx_ptr[i]; |
56 | 5 | } |
57 | 2 | } else { |
58 | 7 | for (i = 0; i < cmd.info.partition.kth; i++5 ) |
59 | 5 | { |
60 | 5 | float k = a_ptr[i]; |
61 | 5 | int v = i; |
62 | 10.0k | for (j = i + 1; j < count; j++10.0k ) |
63 | 10.0k | if (a_ptr[j] < k) |
64 | 11 | k = a_ptr[j], v = j; |
65 | 5 | b->data.f32[i] = k; |
66 | 5 | indices->data.i32[i] = idx_ptr[v]; |
67 | 5 | if (i != v) |
68 | 2 | a_ptr[v] = a_ptr[i], idx_ptr[v] = idx_ptr[i]; |
69 | 5 | } |
70 | 2 | } |
71 | 4 | } else { |
72 | 2 | assert(a->info.datatype == CCV_32S); |
73 | 2 | memcpy(workmem, a->data.i32, sizeof(int) * count); |
74 | 2 | int* const a_ptr = (int*)workmem; |
75 | 2 | int* const idx_ptr = (int*)((char*)workmem + sizeof(float) * count); |
76 | 12 | for (i = 0; i < count; i++10 ) |
77 | 10 | idx_ptr[i] = i; |
78 | 2 | if (cmd.info.partition.descending) |
79 | 1 | { |
80 | 5 | for (i = 0; i < cmd.info.partition.kth; i++4 ) |
81 | 4 | { |
82 | 4 | int k = a_ptr[i]; |
83 | 4 | int v = i; |
84 | 14 | for (j = i + 1; j < count; j++10 ) |
85 | 10 | if (a_ptr[j] > k) |
86 | 5 | k = a_ptr[j], v = j; |
87 | 4 | b->data.i32[i] = k; |
88 | 4 | indices->data.i32[i] = idx_ptr[v]; |
89 | 4 | if (i != v) |
90 | 3 | a_ptr[v] = a_ptr[i], idx_ptr[v] = idx_ptr[i]; |
91 | 4 | } |
92 | 1 | } else { |
93 | 5 | for (i = 0; i < cmd.info.partition.kth; i++4 ) |
94 | 4 | { |
95 | 4 | int k = a_ptr[i]; |
96 | 4 | int v = i; |
97 | 14 | for (j = i + 1; j < count; j++10 ) |
98 | 10 | if (a_ptr[j] < k) |
99 | 1 | k = a_ptr[j], v = j; |
100 | 4 | b->data.i32[i] = k; |
101 | 4 | indices->data.i32[i] = idx_ptr[v]; |
102 | 4 | if (i != v) |
103 | 1 | a_ptr[v] = a_ptr[i], idx_ptr[v] = idx_ptr[i]; |
104 | 4 | } |
105 | 1 | } |
106 | 2 | } |
107 | 33 | } else { |
108 | 33 | int i, j, k, f; |
109 | 33 | int sort_runs = 1; |
110 | 33 | int sort_stride = 1; |
111 | 111 | for (i = 0; i < a_nd; i++78 ) |
112 | 78 | { |
113 | 78 | if (i < cmd.info.partition.along_axis) // Skip this. |
114 | 29 | sort_runs *= a->info.dim[i]; |
115 | 49 | else if (i > cmd.info.partition.along_axis) |
116 | 16 | sort_stride *= a->info.dim[i]; |
117 | 78 | } |
118 | 33 | const int skip_stride = sort_stride * a->info.dim[cmd.info.partition.along_axis]; |
119 | 33 | const int dim = a->info.dim[cmd.info.partition.along_axis]; |
120 | 33 | void* workmem = ccv_nnc_stream_context_get_workspace(stream_context, ((a->info.datatype == CCV_32F) ? sizeof(float)20 : sizeof(int)13 ) * count + sizeof(int) * dim, CCV_TENSOR_CPU_MEMORY); |
121 | 33 | if (a->info.datatype == CCV_32F) |
122 | 20 | { |
123 | 20 | memcpy(workmem, a->data.f32, sizeof(float) * count); |
124 | 20 | float* const a_ptr = (float*)workmem; |
125 | 20 | int* const idx_ptr = (int*)((char*)workmem + sizeof(float) * count); |
126 | 20 | if (cmd.info.partition.descending) |
127 | 814 | for (i = 0; 10 i < sort_runs; i++804 ) |
128 | 5.20k | for (j = 0; 804 j < sort_stride; j++4.40k ) |
129 | 4.40k | { |
130 | 4.40k | float* const a_ptr_0 = a_ptr + skip_stride * i + j; |
131 | 61.4k | for (k = 0; k < dim; k++57.0k ) |
132 | 57.0k | idx_ptr[k] = k; |
133 | 4.40k | float* const b_ptr = b->data.f32 + sort_stride * cmd.info.partition.kth * i + j; |
134 | 4.40k | int* const indices_ptr = indices->data.i32 + sort_stride * cmd.info.partition.kth * i + j; |
135 | 11.0k | for (k = 0; k < cmd.info.partition.kth; k++6.60k ) |
136 | 6.60k | { |
137 | 6.60k | float key = a_ptr_0[k * sort_stride]; |
138 | 6.60k | int val = k; |
139 | 78.8k | for (f = k + 1; f < dim; f++72.2k ) |
140 | 72.2k | if (a_ptr_0[f * sort_stride] > key) |
141 | 9.88k | key = a_ptr_0[f * sort_stride], val = f; |
142 | 6.60k | b_ptr[k * sort_stride] = key; |
143 | 6.60k | indices_ptr[k * sort_stride] = idx_ptr[val]; |
144 | 6.60k | if (k != val) |
145 | 5.13k | a_ptr_0[val * sort_stride] = a_ptr_0[k * sort_stride], idx_ptr[val] = idx_ptr[k]; |
146 | 6.60k | } |
147 | 4.40k | } |
148 | 10 | else |
149 | 814 | for (i = 0; 10 i < sort_runs; i++804 ) |
150 | 5.20k | for (j = 0; 804 j < sort_stride; j++4.40k ) |
151 | 4.40k | { |
152 | 4.40k | float* const a_ptr_0 = a_ptr + skip_stride * i + j; |
153 | 84.4k | for (k = 0; k < dim; k++80.0k ) |
154 | 80.0k | idx_ptr[k] = k; |
155 | 4.40k | float* const b_ptr = b->data.f32 + sort_stride * cmd.info.partition.kth * i + j; |
156 | 4.40k | int* const indices_ptr = indices->data.i32 + sort_stride * cmd.info.partition.kth * i + j; |
157 | 11.0k | for (k = 0; k < cmd.info.partition.kth; k++6.60k ) |
158 | 6.60k | { |
159 | 6.60k | float key = a_ptr_0[k * sort_stride]; |
160 | 6.60k | int val = k; |
161 | 117k | for (f = k + 1; f < dim; f++111k ) |
162 | 111k | if (a_ptr_0[f * sort_stride] < key) |
163 | 14.2k | key = a_ptr_0[f * sort_stride], val = f; |
164 | 6.60k | b_ptr[k * sort_stride] = key; |
165 | 6.60k | indices_ptr[k * sort_stride] = idx_ptr[val]; |
166 | 6.60k | if (k != val) |
167 | 5.99k | a_ptr_0[val * sort_stride] = a_ptr_0[k * sort_stride], idx_ptr[val] = idx_ptr[k]; |
168 | 6.60k | } |
169 | 4.40k | } |
170 | 20 | } else { |
171 | 13 | assert(a->info.datatype == CCV_32S); |
172 | 13 | memcpy(workmem, a->data.f32, sizeof(int) * count); |
173 | 13 | int* const a_ptr = (int*)workmem; |
174 | 13 | int* const idx_ptr = (int*)((char*)workmem + sizeof(int) * count); |
175 | 13 | if (cmd.info.partition.descending) |
176 | 414 | for (i = 0; 7 i < sort_runs; i++407 ) |
177 | 2.61k | for (j = 0; 407 j < sort_stride; j++2.20k ) |
178 | 2.20k | { |
179 | 2.20k | int* const a_ptr_0 = a_ptr + skip_stride * i + j; |
180 | 42.2k | for (k = 0; k < dim; k++40.0k ) |
181 | 40.0k | idx_ptr[k] = k; |
182 | 2.20k | int* const b_ptr = b->data.i32 + sort_stride * cmd.info.partition.kth * i + j; |
183 | 2.20k | int* const indices_ptr = indices->data.i32 + sort_stride * cmd.info.partition.kth * i + j; |
184 | 5.51k | for (k = 0; k < cmd.info.partition.kth; k++3.31k ) |
185 | 3.31k | { |
186 | 3.31k | int key = a_ptr_0[k * sort_stride]; |
187 | 3.31k | int val = k; |
188 | 58.9k | for (f = k + 1; f < dim; f++55.6k ) |
189 | 55.6k | if (a_ptr_0[f * sort_stride] > key) |
190 | 7.15k | key = a_ptr_0[f * sort_stride], val = f; |
191 | 3.31k | b_ptr[k * sort_stride] = key; |
192 | 3.31k | indices_ptr[k * sort_stride] = idx_ptr[val]; |
193 | 3.31k | if (k != val) |
194 | 3.02k | a_ptr_0[val * sort_stride] = a_ptr_0[k * sort_stride], idx_ptr[val] = idx_ptr[k]; |
195 | 3.31k | } |
196 | 2.20k | } |
197 | 6 | else |
198 | 410 | for (i = 0; 6 i < sort_runs; i++404 ) |
199 | 2.60k | for (j = 0; 404 j < sort_stride; j++2.20k ) |
200 | 2.20k | { |
201 | 2.20k | int* const a_ptr_0 = a_ptr + skip_stride * i + j; |
202 | 42.2k | for (k = 0; k < dim; k++40.0k ) |
203 | 40.0k | idx_ptr[k] = k; |
204 | 2.20k | int* const b_ptr = b->data.i32 + sort_stride * cmd.info.partition.kth * i + j; |
205 | 2.20k | int* const indices_ptr = indices->data.i32 + sort_stride * cmd.info.partition.kth * i + j; |
206 | 5.51k | for (k = 0; k < cmd.info.partition.kth; k++3.30k ) |
207 | 3.30k | { |
208 | 3.30k | int key = a_ptr_0[k * sort_stride]; |
209 | 3.30k | int val = k; |
210 | 58.9k | for (f = k + 1; f < dim; f++55.6k ) |
211 | 55.6k | if (a_ptr_0[f * sort_stride] < key) |
212 | 6.77k | key = a_ptr_0[f * sort_stride], val = f; |
213 | 3.30k | b_ptr[k * sort_stride] = key; |
214 | 3.30k | indices_ptr[k * sort_stride] = idx_ptr[val]; |
215 | 3.30k | if (k != val) |
216 | 2.98k | a_ptr_0[val * sort_stride] = a_ptr_0[k * sort_stride], idx_ptr[val] = idx_ptr[k]; |
217 | 3.30k | } |
218 | 2.20k | } |
219 | 13 | } |
220 | 33 | } |
221 | 39 | return CCV_NNC_EXEC_SUCCESS; |
222 | 39 | } |
223 | | |
224 | | static int _ccv_nnc_partition_back(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint, const int flags, ccv_nnc_tensor_t* const* const inputs, const int input_size, ccv_nnc_tensor_t* const* const outputs, const int output_size, ccv_nnc_stream_context_t* const stream_context) |
225 | 0 | { |
226 | 0 | return CCV_NNC_EXEC_INVALID; |
227 | 0 | } |
228 | | |
229 | | REGISTER_COMMAND_BACKEND(CCV_NNC_PARTITION_FORWARD, CCV_NNC_BACKEND_CPU_REF)(ccv_nnc_cmd_backend_registry_t* const registry) |
230 | 1 | { |
231 | 1 | registry->tensor_formats = CCV_TENSOR_FORMAT_NHWC | CCV_TENSOR_FORMAT_NCHW; |
232 | 1 | registry->tensor_datatypes = CCV_32F | CCV_32S; |
233 | 1 | registry->tensor_memory = CCV_TENSOR_CPU_MEMORY; |
234 | 1 | registry->algorithms = 1; |
235 | 1 | registry->exec = _ccv_nnc_partition_forw; |
236 | 1 | } |
237 | | |
238 | | REGISTER_COMMAND_BACKEND(CCV_NNC_PARTITION_BACKWARD, CCV_NNC_BACKEND_CPU_REF)(ccv_nnc_cmd_backend_registry_t* const registry) |
239 | 1 | { |
240 | 1 | registry->tensor_formats = CCV_TENSOR_FORMAT_NHWC | CCV_TENSOR_FORMAT_NCHW; |
241 | 1 | registry->tensor_datatypes = CCV_32F | CCV_32S; |
242 | 1 | registry->tensor_memory = CCV_TENSOR_CPU_MEMORY; |
243 | 1 | registry->algorithms = 1; |
244 | 1 | registry->exec = _ccv_nnc_partition_back; |
245 | 1 | } |