/home/liu/actions-runner/_work/ccv/ccv/lib/nnc/ccv_nnc_micro.c
Line | Count | Source (jump to first uncovered line) |
1 | | #include "ccv_nnc.h" |
2 | | #include "ccv_nnc_easy.h" |
3 | | #include "ccv_nnc_internal.h" |
4 | | #include "ccv_internal.h" |
5 | | #include "_ccv_nnc_micro.h" |
6 | | #include "3rdparty/khash/khash.h" |
7 | | |
8 | | // MARK - Level-1 API |
9 | | |
10 | | KHASH_MAP_INIT_STR(ccv_nnc_micro_bind_scalar, uint32_t) |
11 | | |
12 | | static uint32_t _scalars_lookup(const void* const context, const char* const name) |
13 | 10 | { |
14 | 10 | const khash_t(ccv_nnc_micro_bind_scalar)* const bind_scalars = (const khash_t(ccv_nnc_micro_bind_scalar)*)context; |
15 | 10 | khiter_t k = kh_get(ccv_nnc_micro_bind_scalar, bind_scalars, name); |
16 | 10 | assert(k != kh_end(bind_scalars)); |
17 | 10 | return kh_val(bind_scalars, k); |
18 | 10 | } |
19 | | |
20 | | KHASH_SET_INIT_INT64(ccv_nnc_ids) |
21 | | |
22 | | CCV_WARN_UNUSED(ccv_nnc_micro_combine_t*) ccv_nnc_micro_combine_new(const ccv_nnc_micro_io_t* const inputs, const int input_size, const char* const* const parameters, const int parameter_size, const ccv_nnc_micro_io_t* const outputs, const int output_size, const ccv_nnc_micro_io_t* const ingrads, const int ingrad_size, const ccv_nnc_micro_io_t* const outgrads, const int outgrad_size) |
23 | 3 | { |
24 | 3 | assert(output_size > 0); |
25 | 3 | assert(input_size > 0); |
26 | 3 | int i, j, k; |
27 | | // First, do reverse topological sort (from output and then reverse the order). |
28 | | // We can do this simple thing because there is no overlaps of the outputs, thus, no cases where |
29 | | // output[0] is the input for output[1]. Otherwise we need to detect this, see ccv_cnnp_model_new |
30 | | // for more details on why. |
31 | 3 | for (i = 0; i < output_size - 1; i++0 ) |
32 | 0 | for (j = i + 1; j < output_size; j++) |
33 | 0 | { assert(outputs[i] != outputs[j]); } |
34 | 3 | uint64_t input_bitmask[((input_size - 1) >> 6) + 1]; |
35 | 3 | memset(input_bitmask, 0, sizeof(uint64_t) * (((input_size - 1) >> 6) + 1)); |
36 | 3 | ccv_array_t* const reverse_top = ccv_array_new(sizeof(ccv_nnc_micro_io_t), output_size + input_size, 0); |
37 | 3 | ccv_array_resize(reverse_top, output_size); |
38 | 3 | memcpy(ccv_array_get(reverse_top, 0), outputs, sizeof(ccv_nnc_micro_io_t) * output_size); |
39 | 3 | khash_t(ccv_nnc_ids)* const ids = kh_init(ccv_nnc_ids); |
40 | 15 | for (i = 0; i < reverse_top->rnum; i++12 ) |
41 | 12 | { |
42 | 12 | const ccv_nnc_micro_io_t output = *(ccv_nnc_micro_io_t*)ccv_array_get(reverse_top, i); |
43 | 37 | for (j = 0; j < output->input_size; j++25 ) |
44 | 25 | if (!CCV_NNC_IS_MICRO_IO_INPUT(output->inputs[j])) |
45 | 9 | { |
46 | 9 | int ret; |
47 | 9 | kh_put(ccv_nnc_ids, ids, (int64_t)(intptr_t)output->inputs[j], &ret); |
48 | 9 | if (ret != 0) |
49 | 9 | ccv_array_push(reverse_top, &output->inputs[j]); |
50 | 16 | } else { |
51 | | // This is an input, it must be represented in inputs, try to find it. |
52 | 23 | for (k = 0; k < input_size; k++7 ) |
53 | 23 | if (inputs[k] == output->inputs[j]) |
54 | 16 | break; |
55 | 16 | assert(k < input_size); // Cannot find the inputs, error! |
56 | 16 | input_bitmask[k >> 6] |= ((uint64_t)1 << (k & 63)); |
57 | 16 | } |
58 | 12 | } |
59 | 3 | kh_destroy(ccv_nnc_ids, ids); |
60 | 9 | for (i = 0; i < input_size; i++6 ) |
61 | 6 | { assert((input_bitmask[i >> 6] & ((uint64_t)1 << (i & 63)))); } // Assuming they all match. |
62 | | // Second, binding parameters (bounded scalars). |
63 | 3 | khash_t(ccv_nnc_micro_bind_scalar)* const bind_scalars = kh_init(ccv_nnc_micro_bind_scalar); |
64 | 6 | for (i = 0; i < parameter_size; i++3 ) |
65 | 3 | { |
66 | 3 | int ret; |
67 | 3 | khiter_t k = kh_put(ccv_nnc_micro_bind_scalar, bind_scalars, parameters[i], &ret); |
68 | 3 | assert(ret != 0); |
69 | 3 | kh_val(bind_scalars, k) = i; |
70 | 3 | } |
71 | 15 | for (i = 0; 3 i < reverse_top->rnum; i++12 ) |
72 | 12 | { |
73 | 12 | const ccv_nnc_micro_io_t output = *(ccv_nnc_micro_io_t*)ccv_array_get(reverse_top, reverse_top->rnum - 1 - i); |
74 | 12 | ccv_nnc_micro_bind_scalars(output, _scalars_lookup, bind_scalars); |
75 | 12 | } |
76 | 3 | kh_destroy(ccv_nnc_micro_bind_scalar, bind_scalars); |
77 | 3 | const int var_count = reverse_top->rnum + input_size; |
78 | | // Applying numbering for the inputs. Note that our variables are numbered in reverse topological order. |
79 | 9 | for (i = 0; i < input_size; i++6 ) |
80 | 6 | ccv_nnc_micro_numbering(inputs[i], i, var_count); |
81 | 3 | ccv_array_t* const equal_assertions = ccv_array_new(sizeof(ccv_nnc_micro_id_equal_assertion_t), 0, 0); |
82 | | // Applying numbering for the outputs and collect equal assertions. |
83 | 15 | for (i = reverse_top->rnum - 1; i >= 0; i--12 ) |
84 | 12 | { |
85 | 12 | const ccv_nnc_micro_io_t output = *(ccv_nnc_micro_io_t*)ccv_array_get(reverse_top, reverse_top->rnum - 1 - i); |
86 | 12 | ccv_nnc_micro_numbering(output, i + input_size, var_count); |
87 | 12 | ccv_nnc_micro_equal_assertions(output, equal_assertions); |
88 | 12 | } |
89 | 12 | for (i = 0; i < ingrad_size; i++9 ) |
90 | 9 | ccv_nnc_micro_numbering(ingrads[i], -1, var_count); |
91 | 9 | for (i = 0; i < outgrad_size; i++6 ) |
92 | 6 | ccv_nnc_micro_numbering(outgrads[i], -1, var_count); |
93 | | // Fill in shapes for variables. |
94 | 3 | ccv_nnc_micro_tensor_t* const vars = (ccv_nnc_micro_tensor_t*)cccalloc(var_count * 2, sizeof(ccv_nnc_micro_tensor_t)); |
95 | 9 | for (i = 0; i < input_size; i++6 ) |
96 | 6 | { |
97 | 6 | vars[i].dimensions = inputs[i]->dimensions; |
98 | 6 | vars[i].input = -1; |
99 | 6 | } |
100 | 15 | for (i = 0; i < reverse_top->rnum; i++12 ) |
101 | 12 | { |
102 | 12 | const ccv_nnc_micro_io_t output = *(ccv_nnc_micro_io_t*)ccv_array_get(reverse_top, reverse_top->rnum - 1 - i); |
103 | 12 | vars[i + input_size] = ccv_nnc_micro_return_shape(output); |
104 | 12 | } |
105 | 21 | for (i = var_count; i < 2 * var_count; i++18 ) |
106 | 18 | { |
107 | 18 | vars[i].dimensions = vars[2 * var_count - 1 - i].dimensions; |
108 | 18 | vars[i].input = 2 * var_count - 1 - i; |
109 | 18 | } |
110 | | // Lower each ccv_nnc_micro_io_t (except the input) op into nested loops such that we can |
111 | | // apply optimizations later. |
112 | 3 | int function_count = reverse_top->rnum; |
113 | 3 | ccv_nnc_micro_function_t* functions = (ccv_nnc_micro_function_t*)ccmalloc(sizeof(ccv_nnc_micro_function_t) * function_count); |
114 | 15 | for (i = 0; i < function_count; i++12 ) |
115 | 12 | { |
116 | 12 | const ccv_nnc_micro_io_t output = *(ccv_nnc_micro_io_t*)ccv_array_get(reverse_top, function_count - 1 - i); |
117 | 12 | functions[i] = ccv_nnc_micro_emit(output); |
118 | 12 | } |
119 | 3 | ccv_nnc_micro_combine_t* const combine = (ccv_nnc_micro_combine_t*)ccmalloc(sizeof(ccv_nnc_micro_combine_t)); |
120 | 3 | combine->parameter_size = parameter_size; |
121 | 3 | combine->forward.input_size = input_size; |
122 | 3 | combine->forward.inputs = (int*)ccmalloc(sizeof(int) * (input_size + output_size)); |
123 | 9 | for (i = 0; i < input_size; i++6 ) |
124 | 6 | combine->forward.inputs[i] = inputs[i]->id; |
125 | 3 | combine->forward.output_size = output_size; |
126 | 3 | combine->forward.outputs = combine->forward.inputs + input_size; |
127 | 6 | for (i = 0; i < output_size; i++3 ) |
128 | 3 | combine->forward.outputs[i] = outputs[i]->id; |
129 | 3 | combine->forward.var_count = var_count; |
130 | | // We copied forward.vars so backward.vars and forward.vars can maintain separate states. |
131 | | // However, shape and related allocations are shared because these are not going to be mutated. |
132 | 3 | combine->forward.vars = (ccv_nnc_micro_tensor_t*)ccmalloc(sizeof(ccv_nnc_micro_tensor_t) * var_count); |
133 | 3 | memcpy(combine->forward.vars, vars, sizeof(ccv_nnc_micro_tensor_t) * var_count); |
134 | 3 | combine->forward.function_count = function_count; |
135 | 3 | combine->forward.functions = functions; |
136 | 3 | ccv_nnc_micro_program_simplify(&combine->forward, inputs, input_size, outputs, output_size, equal_assertions); |
137 | 3 | function_count = reverse_top->rnum * 2; |
138 | 3 | functions = (ccv_nnc_micro_function_t*)ccmalloc(sizeof(ccv_nnc_micro_function_t) * function_count); |
139 | 15 | for (i = 0; i < reverse_top->rnum; i++12 ) |
140 | 12 | { |
141 | 12 | const ccv_nnc_micro_io_t output = *(ccv_nnc_micro_io_t*)ccv_array_get(reverse_top, reverse_top->rnum - 1 - i); |
142 | 12 | functions[i] = ccv_nnc_micro_emit(output); |
143 | 12 | } |
144 | 15 | for (i = reverse_top->rnum; i < function_count; i++12 ) |
145 | 12 | { |
146 | 12 | const ccv_nnc_micro_io_t output = *(ccv_nnc_micro_io_t*)ccv_array_get(reverse_top, i - reverse_top->rnum); |
147 | 12 | functions[i] = ccv_nnc_micro_emit_grad(output, var_count); |
148 | 12 | } |
149 | 3 | combine->backward.input_size = ingrad_size; |
150 | 3 | combine->backward.inputs = ingrad_size + outgrad_size > 0 ? (int*)ccmalloc(sizeof(int) * (ingrad_size + outgrad_size)) : 00 ; |
151 | 12 | for (i = 0; i < ingrad_size; i++9 ) |
152 | 9 | combine->backward.inputs[i] = ingrads[i]->id; |
153 | 3 | combine->backward.output_size = outgrad_size; |
154 | 3 | combine->backward.outputs = outgrad_size > 0 ? combine->backward.inputs + ingrad_size : 00 ; |
155 | 9 | for (i = 0; i < outgrad_size; i++6 ) |
156 | 6 | combine->backward.outputs[i] = outgrads[i]->id; |
157 | 3 | combine->backward.var_count = var_count * 2; |
158 | 3 | combine->backward.vars = vars; |
159 | 3 | combine->backward.function_count = function_count; |
160 | 3 | combine->backward.functions = functions; |
161 | 3 | ccv_nnc_micro_program_simplify(&combine->backward, ingrads, ingrad_size, outgrads, outgrad_size, equal_assertions); |
162 | 3 | combine->equal_assertions = equal_assertions; |
163 | 15 | for (i = 0; i < reverse_top->rnum; i++12 ) |
164 | 12 | { |
165 | 12 | const ccv_nnc_micro_io_t output = *(ccv_nnc_micro_io_t*)ccv_array_get(reverse_top, i); |
166 | 12 | ccv_nnc_micro_deinit(output); |
167 | 12 | ccfree(output); |
168 | 12 | } |
169 | 3 | ccv_array_free(reverse_top); |
170 | | // It may overlap with inputs, in that case, skip. |
171 | 12 | for (i = 0; i < ingrad_size; i++9 ) |
172 | 9 | { |
173 | 9 | int flag = 0; |
174 | 24 | for (j = 0; !flag && j < input_size18 ; j++15 ) |
175 | 15 | flag = (inputs[j] == ingrads[i]); |
176 | 9 | if (!flag) |
177 | 3 | { |
178 | 3 | ccv_nnc_micro_deinit(ingrads[i]); |
179 | 3 | ccfree(ingrads[i]); |
180 | 3 | } |
181 | 9 | } |
182 | 9 | for (i = 0; i < input_size; i++6 ) |
183 | 6 | { |
184 | 6 | ccv_nnc_micro_deinit(inputs[i]); |
185 | 6 | ccfree(inputs[i]); |
186 | 6 | } |
187 | 9 | for (i = 0; i < outgrad_size; i++6 ) // Should be no overlap on outgrads. |
188 | 6 | { |
189 | 6 | ccv_nnc_micro_deinit(outgrads[i]); |
190 | 6 | ccfree(outgrads[i]); |
191 | 6 | } |
192 | 3 | return combine; |
193 | 3 | } |
194 | | |
195 | | void ccv_nnc_micro_loop_index_free(ccv_nnc_micro_loop_index_term_t* const term) |
196 | 760 | { |
197 | 760 | if (term->type == CCV_NNC_MICRO_LOOP_INDEX_TYPE_BINARY) |
198 | 80 | { |
199 | 80 | ccv_nnc_micro_loop_index_free(&term->binary->left); |
200 | 80 | ccv_nnc_micro_loop_index_free(&term->binary->right); |
201 | 80 | ccfree(term->binary); |
202 | 80 | } |
203 | 760 | } |
204 | | |
205 | | void ccv_nnc_micro_loop_variable_free(ccv_nnc_micro_loop_variable_t* const var) |
206 | 99 | { |
207 | 99 | int i; |
208 | 604 | for (i = 0; i < var->index_count; i++505 ) |
209 | 505 | ccv_nnc_micro_loop_index_free(&var->index[i]); |
210 | 99 | } |
211 | | |
212 | | static void _ccv_nnc_micro_loop_expression_free(ccv_nnc_micro_loop_expression_t* const expr) |
213 | 54 | { |
214 | 54 | switch (expr->type) |
215 | 54 | { |
216 | 30 | case CCV_NNC_MICRO_LOOP_EXPR_TYPE_VAR: { |
217 | 30 | ccv_nnc_micro_loop_variable_free(&expr->variable); |
218 | 30 | break; |
219 | 0 | } |
220 | 0 | case CCV_NNC_MICRO_LOOP_EXPR_TYPE_UNARY: { |
221 | 0 | _ccv_nnc_micro_loop_expression_free(expr->unary.x); |
222 | 0 | ccfree(expr->unary.x); |
223 | 0 | break; |
224 | 0 | } |
225 | 12 | case CCV_NNC_MICRO_LOOP_EXPR_TYPE_BINARY: { |
226 | 12 | _ccv_nnc_micro_loop_expression_free(expr->binary.left); |
227 | 12 | ccfree(expr->binary.left); |
228 | 12 | _ccv_nnc_micro_loop_expression_free(expr->binary.right); |
229 | 12 | ccfree(expr->binary.right); |
230 | 12 | break; |
231 | 0 | } |
232 | 0 | case CCV_NNC_MICRO_LOOP_EXPR_TYPE_TERNAY: { |
233 | 0 | _ccv_nnc_micro_loop_expression_free(expr->ternary.pivot); |
234 | 0 | ccfree(expr->ternary.pivot); |
235 | 0 | _ccv_nnc_micro_loop_expression_free(expr->ternary.left); |
236 | 0 | ccfree(expr->ternary.left); |
237 | 0 | _ccv_nnc_micro_loop_expression_free(expr->ternary.right); |
238 | 0 | ccfree(expr->ternary.right); |
239 | 0 | break; |
240 | 0 | } |
241 | 54 | } |
242 | 54 | } |
243 | | |
244 | | void ccv_nnc_micro_loop_statement_lvalue_free(ccv_nnc_micro_loop_statement_t* const statement) |
245 | 21 | { |
246 | 21 | switch (statement->type) |
247 | 21 | { |
248 | 0 | case CCV_NNC_MICRO_LOOP_STATEMENT_TYPE_COMPOUND_ASSIGNMENT: { |
249 | 0 | if (statement->compound_assignment.lvalue.type == CCV_NNC_MICRO_LOOP_EXPR_TYPE_VAR) |
250 | 0 | ccv_nnc_micro_loop_variable_free(&statement->compound_assignment.lvalue.variable); |
251 | 0 | break; |
252 | 0 | } |
253 | 21 | case CCV_NNC_MICRO_LOOP_STATEMENT_TYPE_ASSIGNMENT: { |
254 | 21 | ccv_nnc_micro_loop_variable_free(&statement->assignment.lvalue); |
255 | 21 | break; |
256 | 0 | } |
257 | 21 | } |
258 | 21 | } |
259 | | |
260 | | void ccv_nnc_micro_loop_statement_free(ccv_nnc_micro_loop_statement_t* const statement) |
261 | 30 | { |
262 | 30 | switch (statement->type) |
263 | 30 | { |
264 | 12 | case CCV_NNC_MICRO_LOOP_STATEMENT_TYPE_COMPOUND_ASSIGNMENT: { |
265 | 12 | if (statement->compound_assignment.lvalue.type == CCV_NNC_MICRO_LOOP_EXPR_TYPE_VAR) |
266 | 6 | ccv_nnc_micro_loop_variable_free(&statement->compound_assignment.lvalue.variable); |
267 | 12 | _ccv_nnc_micro_loop_expression_free(&statement->compound_assignment.rvalue); |
268 | 12 | break; |
269 | 0 | } |
270 | 18 | case CCV_NNC_MICRO_LOOP_STATEMENT_TYPE_ASSIGNMENT: { |
271 | 18 | ccv_nnc_micro_loop_variable_free(&statement->assignment.lvalue); |
272 | 18 | _ccv_nnc_micro_loop_expression_free(&statement->assignment.rvalue); |
273 | 18 | break; |
274 | 0 | } |
275 | 30 | } |
276 | 30 | } |
277 | | |
278 | | void ccv_nnc_micro_loops_free(ccv_nnc_micro_loop_t* const loops, const int loop_count) |
279 | 27 | { |
280 | 27 | int i, j; |
281 | 124 | for (i = 0; i < loop_count; i++97 ) |
282 | 97 | { |
283 | 127 | for (j = 0; j < loops[i].statement_count; j++30 ) |
284 | 30 | ccv_nnc_micro_loop_statement_free(&loops[i].statements[j]); |
285 | 97 | if (loops[i].statements) |
286 | 27 | ccfree(loops[i].statements); |
287 | 97 | if (loops[i].carrieds) |
288 | 6 | ccfree(loops[i].carrieds); |
289 | 97 | } |
290 | 27 | } |
291 | | |
292 | | void ccv_nnc_micro_combine_free(ccv_nnc_micro_combine_t* const combine) |
293 | 3 | { |
294 | 3 | int i, j; |
295 | 3 | const int var_count = combine->forward.var_count; |
296 | 21 | for (i = 0; i < var_count; i++18 ) |
297 | 18 | if (combine->forward.vars[i].shape) |
298 | 9 | { |
299 | 60 | for (j = 0; j < combine->forward.vars[i].dimensions; j++51 ) |
300 | 51 | ccv_nnc_micro_loop_index_free(&combine->forward.vars[i].shape[j]); |
301 | 9 | ccfree(combine->forward.vars[i].shape); |
302 | 9 | } |
303 | 3 | ccfree(combine->forward.vars); |
304 | 3 | ccfree(combine->backward.vars); |
305 | 3 | int function_count = combine->forward.function_count; |
306 | 6 | for (i = 0; i < function_count; i++3 ) |
307 | 3 | { |
308 | 3 | const int block_count = combine->forward.functions[i].block_count; |
309 | 3 | ccv_nnc_micro_loop_block_t* const blocks = (block_count == 1) ? &combine->forward.functions[i].one_block2 : combine->forward.functions[i].blocks1 ; |
310 | 7 | for (j = 0; j < block_count; j++4 ) |
311 | 4 | { |
312 | 4 | ccv_nnc_micro_loop_block_t block = blocks[j]; |
313 | 4 | ccv_nnc_micro_loops_free(block.loops, block.loop_count); |
314 | 4 | ccfree(block.loops); |
315 | 4 | } |
316 | 3 | if (block_count > 1) |
317 | 1 | ccfree(combine->forward.functions[i].blocks); |
318 | 3 | } |
319 | 3 | ccfree(combine->forward.functions); |
320 | 3 | ccfree(combine->forward.inputs); |
321 | | // Backward and forward share the same vars. |
322 | 3 | function_count = combine->backward.function_count; |
323 | 6 | for (i = 0; i < function_count; i++3 ) |
324 | 3 | { |
325 | 3 | const int block_count = combine->backward.functions[i].block_count; |
326 | 3 | ccv_nnc_micro_loop_block_t* const blocks = (block_count == 1) ? &combine->backward.functions[i].one_block0 : combine->backward.functions[i].blocks; |
327 | 14 | for (j = 0; j < block_count; j++11 ) |
328 | 11 | { |
329 | 11 | ccv_nnc_micro_loop_block_t block = blocks[j]; |
330 | 11 | ccv_nnc_micro_loops_free(block.loops, block.loop_count); |
331 | 11 | ccfree(block.loops); |
332 | 11 | } |
333 | 3 | if (block_count > 1) |
334 | 3 | ccfree(combine->backward.functions[i].blocks); |
335 | 3 | } |
336 | 3 | ccfree(combine->backward.functions); |
337 | 3 | if (combine->backward.inputs) |
338 | 3 | ccfree(combine->backward.inputs); |
339 | 3 | ccv_array_free(combine->equal_assertions); |
340 | 3 | ccfree(combine); |
341 | 3 | } |
342 | | |
343 | | char* ccv_nnc_micro_combine_c(ccv_nnc_micro_combine_t* const combine) |
344 | 0 | { |
345 | 0 | return 0; |
346 | 0 | } |