/home/liu/actions-runner/_work/ccv/ccv/lib/nnc/ccv_nnc_micro_simplify.c
Line | Count | Source (jump to first uncovered line) |
1 | | #include "ccv_nnc.h" |
2 | | #include "ccv_nnc_easy.h" |
3 | | #include "ccv_nnc_internal.h" |
4 | | #include "ccv_internal.h" |
5 | | #include "_ccv_nnc_micro.h" |
6 | | #include "3rdparty/khash/khash.h" |
7 | | |
8 | 1.33k | #define MICRO_ID_TO_INT(x) (((x).id << 8) | ((x).d)) |
9 | | KHASH_MAP_INIT_INT(ccv_nnc_axis_id_group, int) |
10 | | |
11 | | static int _ccv_nnc_same_index_term(const ccv_nnc_micro_loop_index_term_t a_index, const ccv_nnc_micro_loop_index_term_t b_index, const int* const groups, khash_t(ccv_nnc_axis_id_group)* const axis_id_groups) |
12 | 1.57k | { |
13 | 1.57k | if (a_index.type != b_index.type) |
14 | 108 | return 0; |
15 | 1.46k | const int type = a_index.type; |
16 | 1.46k | switch (type) |
17 | 1.46k | { |
18 | 539 | case CCV_NNC_MICRO_LOOP_INDEX_TYPE_VAL: |
19 | 539 | return a_index.immediate_value == b_index.immediate_value; |
20 | 912 | case CCV_NNC_MICRO_LOOP_INDEX_TYPE_ID: |
21 | 912 | if (a_index.id.type != b_index.id.type) |
22 | 33 | return 0; |
23 | | // Check within the axis_id_groups to see if there is a match, if there is no match, we can proceed (to use the group table again to check). |
24 | 879 | if (axis_id_groups && a_index.id.type == CCV_NNC_MICRO_AXIS_SIZE_ID827 ) |
25 | 659 | { |
26 | 659 | ccv_nnc_micro_id_t a_id = a_index.id; |
27 | 1.08k | while (groups && groups[a_id.id] != a_id.id) |
28 | 429 | a_id.id = groups[a_id.id]; |
29 | 659 | int a_root = MICRO_ID_TO_INT(a_id); |
30 | 659 | khiter_t k; |
31 | 677 | for (;;) { |
32 | 677 | k = kh_get(ccv_nnc_axis_id_group, axis_id_groups, a_root); |
33 | 677 | if (k == kh_end(axis_id_groups)) |
34 | 659 | break; |
35 | 18 | a_root = kh_val(axis_id_groups, k); |
36 | 18 | } |
37 | 659 | ccv_nnc_micro_id_t b_id = b_index.id; |
38 | 1.16k | while (groups && groups[b_id.id] != b_id.id) |
39 | 506 | b_id.id = groups[b_id.id]; |
40 | 659 | int b_root = MICRO_ID_TO_INT(b_id); |
41 | 692 | for (;;) { |
42 | 692 | k = kh_get(ccv_nnc_axis_id_group, axis_id_groups, b_root); |
43 | 692 | if (k == kh_end(axis_id_groups)) |
44 | 659 | break; |
45 | 33 | b_root = kh_val(axis_id_groups, k); |
46 | 33 | } |
47 | 659 | if (a_root == b_root) |
48 | 271 | return 1; |
49 | 659 | } |
50 | 608 | if (groups && (556 a_index.id.type == CCV_NNC_MICRO_AXIS_SIZE_ID556 || a_index.id.type == CCV_NNC_MICRO_TENSOR_ID168 )) |
51 | 388 | { |
52 | 388 | if (a_index.id.d != b_index.id.d) |
53 | 291 | return 0; |
54 | 97 | switch (a_index.id.type) |
55 | 97 | { |
56 | 0 | case CCV_NNC_MICRO_TENSOR_ID: |
57 | 97 | case CCV_NNC_MICRO_AXIS_SIZE_ID: { |
58 | | // Find their group identifier and then compare. |
59 | 97 | int a_root = groups[a_index.id.id]; |
60 | 97 | while (groups[a_root] != a_root) |
61 | 0 | a_root = groups[a_root]; |
62 | 97 | int b_root = groups[b_index.id.id]; |
63 | 97 | while (groups[b_root] != b_root) |
64 | 0 | b_root = groups[b_root]; |
65 | 97 | return a_root == b_root; |
66 | 0 | } |
67 | 97 | } |
68 | 0 | return a_index.id.id == b_index.id.id; |
69 | 97 | } else |
70 | 220 | return (a_index.id.d == b_index.id.d && a_index.id.id == b_index.id.id218 ); |
71 | 16 | case CCV_NNC_MICRO_LOOP_INDEX_TYPE_BINARY: { |
72 | 16 | return a_index.binary->op == b_index.binary->op && _ccv_nnc_same_index_term(a_index.binary->left, b_index.binary->left, groups, axis_id_groups) && _ccv_nnc_same_index_term(a_index.binary->right, b_index.binary->right, groups, axis_id_groups); |
73 | 608 | } |
74 | 1.46k | } |
75 | 0 | return 0; |
76 | 1.46k | } |
77 | | |
78 | | static int _ccv_nnc_same_shape(const ccv_nnc_micro_loop_index_term_t* const a_shape, const ccv_nnc_micro_loop_index_term_t* const b_shape, const int dimensions) |
79 | 18 | { |
80 | 18 | int i; |
81 | 48 | for (i = 0; i < dimensions; i++30 ) |
82 | 44 | if (!_ccv_nnc_same_index_term(a_shape[i], b_shape[i], 0, 0)) |
83 | 14 | return 0; |
84 | 4 | return 1; |
85 | 18 | } |
86 | | |
87 | | static int _ccv_nnc_same_loop(const ccv_nnc_micro_loop_block_t* const left_block, const ccv_nnc_micro_loop_block_t* const right_block, const int* const groups, khash_t(ccv_nnc_axis_id_group)* const axis_id_groups, int* const left_loop_idx, int* const right_loop_idx) |
88 | 47 | { |
89 | 47 | assert(left_block->loop_count > 0); |
90 | 47 | assert(right_block->loop_count > 0); |
91 | 47 | int i, j; |
92 | 47 | int left_right_link[left_block->loop_count]; |
93 | 47 | int right_left_link[right_block->loop_count]; |
94 | 47 | enum { |
95 | 47 | ONE = -2, |
96 | 47 | UNASSIGNED = -1, |
97 | 47 | }; |
98 | 279 | for (i = 0; i < left_block->loop_count; i++232 ) |
99 | 232 | if (left_block->loops[i].start_index.type == CCV_NNC_MICRO_LOOP_INDEX_TYPE_VAL && left_block->loops[i].start_index.immediate_value == 0 && |
100 | 232 | left_block->loops[i].end_index.type == CCV_NNC_MICRO_LOOP_INDEX_TYPE_VAL && left_block->loops[i].end_index.immediate_value == 10 ) |
101 | 0 | left_right_link[i] = ONE; |
102 | 232 | else |
103 | 232 | left_right_link[i] = UNASSIGNED; |
104 | 277 | for (i = 0; i < right_block->loop_count; i++230 ) |
105 | 230 | if (right_block->loops[i].start_index.type == CCV_NNC_MICRO_LOOP_INDEX_TYPE_VAL && right_block->loops[i].start_index.immediate_value == 0 && |
106 | 230 | right_block->loops[i].end_index.type == CCV_NNC_MICRO_LOOP_INDEX_TYPE_VAL && right_block->loops[i].end_index.immediate_value == 10 ) |
107 | 0 | right_left_link[i] = ONE; |
108 | 230 | else |
109 | 230 | right_left_link[i] = UNASSIGNED; |
110 | 279 | for (i = 0; i < left_block->loop_count; i++232 ) // Find corresponding loop on the right. |
111 | 232 | { |
112 | 232 | if (left_right_link[i] != UNASSIGNED) |
113 | 0 | continue; |
114 | 232 | int flag = UNASSIGNED; |
115 | 1.12k | for (j = 0; j < right_block->loop_count && flag == UNASSIGNED1.01k ; j++889 ) |
116 | 889 | { |
117 | 889 | if (right_left_link[j] != UNASSIGNED) |
118 | 382 | continue; |
119 | 507 | if (_ccv_nnc_same_index_term(left_block->loops[i].start_index, right_block->loops[j].start_index, groups, axis_id_groups) && |
120 | 507 | _ccv_nnc_same_index_term(left_block->loops[i].end_index, right_block->loops[j].end_index, groups, axis_id_groups)) |
121 | 146 | flag = j; |
122 | 507 | } |
123 | 232 | if (flag != UNASSIGNED) |
124 | 146 | { |
125 | 146 | left_right_link[i] = flag; |
126 | 146 | right_left_link[flag] = i; |
127 | 146 | } |
128 | 232 | } |
129 | | // If still have unmatched, they don't share the same loop. |
130 | 191 | for (i = 0; i < left_block->loop_count; i++144 ) |
131 | 167 | if (left_right_link[i] == UNASSIGNED) |
132 | 23 | return 0; |
133 | 168 | for (i = 0; 24 i < right_block->loop_count; i++144 ) |
134 | 144 | if (right_left_link[i] == UNASSIGNED) |
135 | 0 | return 0; |
136 | | // I don't want to deal with constant loop, hence, if other than the outer-most is a constant loop (0..<1), |
137 | | // we cannot merge. |
138 | 144 | for (i = 1; 24 i < left_block->loop_count; i++120 ) |
139 | 120 | if (left_right_link[i] == ONE) |
140 | 0 | return 0; |
141 | 144 | for (i = 1; 24 i < right_block->loop_count; i++120 ) |
142 | 120 | if (right_left_link[i] == ONE) |
143 | 0 | return 0; |
144 | 24 | assert((left_block->loop_count == right_block->loop_count) || |
145 | 24 | (left_block->loop_count == right_block->loop_count + 1) || |
146 | 24 | (left_block->loop_count + 1 == right_block->loop_count)); |
147 | | // The loop matches, but the ordering probably doesn't. We reorder loop based on statements. |
148 | | // Hence, two loops can only merge if using the statements as a pivot point and they can still |
149 | | // match things before / after the statement. |
150 | | // If both have statements, check if order preserving within the statement loop (we can be fancier |
151 | | // and recursively call this while using statement as pivoting point, but that is too much to my taste). |
152 | 24 | const int left_start_idx = left_right_link[0] == ONE ? 10 : 0; |
153 | 24 | const int right_start_idx = right_left_link[0] == ONE ? 10 : 0; |
154 | 168 | for (i = 0; i < left_block->loop_count; i++144 ) |
155 | 144 | left_loop_idx[i] = UNASSIGNED; |
156 | 168 | for (i = 0; i < right_block->loop_count; i++144 ) |
157 | 144 | right_loop_idx[i] = UNASSIGNED; |
158 | 24 | if (left_start_idx == 1) |
159 | 0 | left_loop_idx[0] = 0; // Assign their index. |
160 | 24 | if (right_start_idx == 0) |
161 | 24 | right_loop_idx[0] = 0; // Assign their index. |
162 | 24 | const int end_idx = left_right_link[0] == ONE && right_left_link[0] == ONE0 ? left_block->loop_count - 10 : ccv_min(left_block->loop_count, right_block->loop_count); |
163 | 24 | int pivot_idx = end_idx; |
164 | 24 | int k; |
165 | 168 | for (i = end_idx - 1; i >= 0; i--144 ) |
166 | 144 | { |
167 | 144 | if (left_block->loops[i + left_start_idx].statement_count > 0) |
168 | 24 | { |
169 | 24 | for (j = i + 1, k = i + 1; j < end_idx; j++0 ) |
170 | 0 | if (left_loop_idx[j + left_start_idx] == UNASSIGNED) |
171 | 0 | { |
172 | 0 | left_loop_idx[j + left_start_idx] = k + left_start_idx; |
173 | | // If the right one can be referenced pass previous pivot_idx, it is not right. |
174 | 0 | if (left_right_link[j + left_start_idx] >= pivot_idx + right_start_idx) |
175 | 0 | return 0; |
176 | 0 | right_loop_idx[left_right_link[j + left_start_idx]] = k + right_start_idx; |
177 | 0 | ++k; |
178 | 0 | if (k > pivot_idx) |
179 | 0 | return 0; |
180 | 0 | } |
181 | 24 | assert(k == pivot_idx); |
182 | 24 | pivot_idx = i + 1; |
183 | 24 | } |
184 | 144 | if (right_block->loops[i + right_start_idx].statement_count > 0) |
185 | 27 | { |
186 | 34 | for (j = i + 1, k = i + 1; j < end_idx; j++7 ) |
187 | 7 | if (right_loop_idx[j + left_start_idx] == UNASSIGNED) |
188 | 7 | { |
189 | 7 | right_loop_idx[j + right_start_idx] = k + right_start_idx; |
190 | | // If the left one can be referenced pass previous pivot_idx, it is not right. |
191 | 7 | if (right_left_link[j + right_start_idx] >= pivot_idx + left_start_idx) |
192 | 0 | return 0; |
193 | 7 | left_loop_idx[right_left_link[j + right_start_idx]] = k + left_start_idx; |
194 | 7 | ++k; |
195 | 7 | if (k > pivot_idx) |
196 | 0 | return 0; |
197 | 7 | } |
198 | 27 | assert(k == pivot_idx); |
199 | 27 | pivot_idx = i + 1; |
200 | 27 | } |
201 | 144 | } |
202 | 24 | if (end_idx == 0) |
203 | 0 | return 1; |
204 | | // Finally, to distribute the rest. |
205 | 168 | for (j = 0, k = 0; 24 j < end_idx; j++144 ) |
206 | 144 | { |
207 | 144 | if (left_loop_idx[j + left_start_idx] == UNASSIGNED) |
208 | 137 | { |
209 | 137 | left_loop_idx[j + left_start_idx] = k + left_start_idx; |
210 | | // If the right one can be referenced pass previous pivot_idx, it is not right. |
211 | 137 | if (left_right_link[j + left_start_idx] >= pivot_idx + right_start_idx) |
212 | 0 | return 0; |
213 | 137 | right_loop_idx[left_right_link[j + left_start_idx]] = k + right_start_idx; |
214 | 137 | ++k; |
215 | 137 | if (k > pivot_idx) |
216 | 0 | return 0; |
217 | 137 | } |
218 | 144 | } |
219 | 24 | assert(k == pivot_idx); |
220 | 24 | return 1; |
221 | 24 | } |
222 | | |
223 | | static void _ccv_nnc_loop_order_by(ccv_nnc_micro_loop_block_t* const block, int* const loop_idx, ccv_nnc_micro_loop_t* const loops) |
224 | 48 | { |
225 | 48 | int i; |
226 | 336 | for (i = 0; i < block->loop_count; i++288 ) |
227 | 288 | if (loop_idx[i] >= 0) |
228 | 288 | loops[loop_idx[i]] = block->loops[i]; |
229 | 0 | else |
230 | 0 | loops[i] = block->loops[i]; |
231 | 336 | for (i = 0; i < block->loop_count; i++288 ) |
232 | 288 | { |
233 | | // Essentially, we don't need to move statements, loop-carried variables, just the loop id and the start / end index. |
234 | 288 | block->loops[i].id = loops[i].id; |
235 | 288 | block->loops[i].start_index = loops[i].start_index; |
236 | 288 | block->loops[i].end_index = loops[i].end_index; |
237 | 288 | } |
238 | 48 | } |
239 | | |
240 | | static void _ccv_nnc_expression_rename_carrieds(ccv_nnc_micro_loop_expression_t* const expression, const int start_idx) |
241 | 11 | { |
242 | 11 | switch (expression->type) |
243 | 11 | { |
244 | 0 | case CCV_NNC_MICRO_LOOP_EXPR_TYPE_ID: |
245 | 0 | assert(expression->id.type == CCV_NNC_MICRO_LOOP_CARRIED_ID); |
246 | 0 | expression->id.id += start_idx; |
247 | 0 | break; |
248 | 0 | case CCV_NNC_MICRO_LOOP_EXPR_TYPE_TERNAY: |
249 | 0 | _ccv_nnc_expression_rename_carrieds(expression->ternary.pivot, start_idx); |
250 | 0 | _ccv_nnc_expression_rename_carrieds(expression->ternary.left, start_idx); |
251 | 0 | _ccv_nnc_expression_rename_carrieds(expression->ternary.right, start_idx); |
252 | 0 | break; |
253 | 0 | case CCV_NNC_MICRO_LOOP_EXPR_TYPE_BINARY: |
254 | 0 | _ccv_nnc_expression_rename_carrieds(expression->binary.left, start_idx); |
255 | 0 | _ccv_nnc_expression_rename_carrieds(expression->binary.right, start_idx); |
256 | 0 | break; |
257 | 0 | case CCV_NNC_MICRO_LOOP_EXPR_TYPE_UNARY: |
258 | 0 | _ccv_nnc_expression_rename_carrieds(expression->unary.x, start_idx); |
259 | 0 | break; |
260 | | // We don't need to care about other expressions because loop-carried variable cannot participate these operations. |
261 | 11 | case CCV_NNC_MICRO_LOOP_EXPR_TYPE_VAR: |
262 | 11 | break; |
263 | 11 | } |
264 | 11 | } |
265 | | |
266 | | static void _ccv_nnc_loop_rename_carrieds(ccv_nnc_micro_loop_block_t* const block, const int start_idx) |
267 | 11 | { |
268 | 11 | int i, j; |
269 | 11 | const int loop_count = block->loop_count; |
270 | 11 | ccv_nnc_micro_loop_t* const loops = block->loops; |
271 | 76 | for (i = 0; i < loop_count; i++65 ) |
272 | 65 | { |
273 | 65 | for (j = 0; j < loops[i].carried_count; j++0 ) |
274 | 0 | loops[i].carrieds[j].id.id += start_idx; |
275 | 76 | for (j = 0; j < loops[i].statement_count; j++11 ) |
276 | 11 | switch (loops[i].statements[j].type) |
277 | 11 | { |
278 | 6 | case CCV_NNC_MICRO_LOOP_STATEMENT_TYPE_ASSIGNMENT: |
279 | 6 | _ccv_nnc_expression_rename_carrieds(&loops[i].statements[j].compound_assignment.rvalue, start_idx); |
280 | 6 | break; |
281 | 5 | case CCV_NNC_MICRO_LOOP_STATEMENT_TYPE_COMPOUND_ASSIGNMENT: |
282 | 5 | if (loops[i].statements[j].compound_assignment.lvalue.type == CCV_NNC_MICRO_LOOP_EXPR_TYPE_ID) |
283 | 0 | { |
284 | 0 | assert(loops[i].statements[j].compound_assignment.lvalue.id.type == CCV_NNC_MICRO_LOOP_CARRIED_ID); |
285 | 0 | loops[i].statements[j].compound_assignment.lvalue.id.id += start_idx; |
286 | 0 | } |
287 | 5 | _ccv_nnc_expression_rename_carrieds(&loops[i].statements[j].compound_assignment.rvalue, start_idx); |
288 | 5 | break; |
289 | 11 | } |
290 | 65 | } |
291 | 11 | } |
292 | | |
293 | | static int _ccv_nnc_only_var_in_expression(const int id, const ccv_nnc_micro_loop_variable_t var, const ccv_nnc_micro_loop_expression_t* const expression, const int* const groups, khash_t(ccv_nnc_axis_id_group)* const axis_id_groups) |
294 | 335 | { |
295 | 335 | switch (expression->type) |
296 | 335 | { |
297 | 224 | case CCV_NNC_MICRO_LOOP_EXPR_TYPE_VAR: |
298 | 224 | if (expression->variable.id.type == CCV_NNC_MICRO_TENSOR_ID && expression->variable.id.id == id) |
299 | 27 | { |
300 | 27 | if (var.index_count != expression->variable.index_count) |
301 | 0 | return 2; |
302 | 27 | int i; |
303 | 180 | for (i = 0; i < var.index_count; i++153 ) |
304 | 153 | if (!_ccv_nnc_same_index_term(var.index[i], expression->variable.index[i], groups, axis_id_groups)) |
305 | 0 | return 2; |
306 | 27 | return 1; |
307 | 27 | } else |
308 | 197 | return 0; |
309 | 0 | case CCV_NNC_MICRO_LOOP_EXPR_TYPE_TERNAY: { |
310 | 0 | const int pivot = _ccv_nnc_only_var_in_expression(id, var, expression->ternary.pivot, groups, axis_id_groups); |
311 | 0 | const int left = _ccv_nnc_only_var_in_expression(id, var, expression->ternary.left, groups, axis_id_groups); |
312 | 0 | const int right = _ccv_nnc_only_var_in_expression(id, var, expression->ternary.right, groups, axis_id_groups); |
313 | 0 | if (pivot == 2 || left == 2 || right == 2) |
314 | 0 | return 2; |
315 | 0 | return (pivot || left || right); |
316 | 0 | } |
317 | 60 | case CCV_NNC_MICRO_LOOP_EXPR_TYPE_BINARY: { |
318 | 60 | const int left = _ccv_nnc_only_var_in_expression(id, var, expression->binary.left, groups, axis_id_groups); |
319 | 60 | const int right = _ccv_nnc_only_var_in_expression(id, var, expression->binary.right, groups, axis_id_groups); |
320 | 60 | if (left == 2 || right == 2) |
321 | 0 | return 2; |
322 | 60 | return (left || right51 ); |
323 | 60 | } |
324 | 0 | case CCV_NNC_MICRO_LOOP_EXPR_TYPE_UNARY: |
325 | 0 | return _ccv_nnc_only_var_in_expression(id, var, expression->unary.x, groups, axis_id_groups); |
326 | 9 | case CCV_NNC_MICRO_LOOP_EXPR_TYPE_ID: |
327 | 9 | assert(expression->id.type == CCV_NNC_MICRO_LOOP_CARRIED_ID); |
328 | 9 | return 0; |
329 | 335 | } |
330 | 42 | return 0; |
331 | 335 | } |
332 | | |
333 | | static int _ccv_nnc_only_var_in_rvalue(const int id, const ccv_nnc_micro_loop_variable_t var, const ccv_nnc_micro_loop_statement_t statement, const int* const groups, khash_t(ccv_nnc_axis_id_group)* const axis_id_groups) |
334 | 215 | { |
335 | 215 | switch (statement.type) |
336 | 215 | { |
337 | 164 | case CCV_NNC_MICRO_LOOP_STATEMENT_TYPE_ASSIGNMENT: |
338 | 164 | return _ccv_nnc_only_var_in_expression(id, var, &statement.assignment.rvalue, groups, axis_id_groups); |
339 | 51 | case CCV_NNC_MICRO_LOOP_STATEMENT_TYPE_COMPOUND_ASSIGNMENT: |
340 | 51 | return _ccv_nnc_only_var_in_expression(id, var, &statement.compound_assignment.rvalue, groups, axis_id_groups); |
341 | 215 | } |
342 | 0 | return 0; |
343 | 215 | } |
344 | | |
345 | | static ccv_nnc_micro_loop_expression_t _ccv_nnc_expression_deep_copy(const ccv_nnc_micro_loop_expression_t* const expression) |
346 | 3 | { |
347 | 3 | switch (expression->type) |
348 | 3 | { |
349 | 0 | case CCV_NNC_MICRO_LOOP_EXPR_TYPE_TERNAY: { |
350 | 0 | ccv_nnc_micro_loop_expression_t copy = *expression; |
351 | 0 | copy.ternary.pivot = (ccv_nnc_micro_loop_expression_t*)ccmalloc(sizeof(ccv_nnc_micro_loop_expression_t)); |
352 | 0 | *copy.ternary.pivot = _ccv_nnc_expression_deep_copy(expression->ternary.pivot); |
353 | 0 | copy.ternary.left = (ccv_nnc_micro_loop_expression_t*)ccmalloc(sizeof(ccv_nnc_micro_loop_expression_t)); |
354 | 0 | *copy.ternary.left = _ccv_nnc_expression_deep_copy(expression->ternary.left); |
355 | 0 | copy.ternary.right = (ccv_nnc_micro_loop_expression_t*)ccmalloc(sizeof(ccv_nnc_micro_loop_expression_t)); |
356 | 0 | *copy.ternary.right = _ccv_nnc_expression_deep_copy(expression->ternary.right); |
357 | 0 | return copy; |
358 | 0 | } |
359 | 0 | case CCV_NNC_MICRO_LOOP_EXPR_TYPE_BINARY: { |
360 | 0 | ccv_nnc_micro_loop_expression_t copy = *expression; |
361 | 0 | copy.binary.left = (ccv_nnc_micro_loop_expression_t*)ccmalloc(sizeof(ccv_nnc_micro_loop_expression_t)); |
362 | 0 | *copy.binary.left = _ccv_nnc_expression_deep_copy(expression->binary.left); |
363 | 0 | copy.binary.right = (ccv_nnc_micro_loop_expression_t*)ccmalloc(sizeof(ccv_nnc_micro_loop_expression_t)); |
364 | 0 | *copy.binary.right = _ccv_nnc_expression_deep_copy(expression->binary.right); |
365 | 0 | return copy; |
366 | 0 | } |
367 | 0 | case CCV_NNC_MICRO_LOOP_EXPR_TYPE_UNARY: { |
368 | 0 | ccv_nnc_micro_loop_expression_t copy = *expression; |
369 | 0 | copy.unary.x = (ccv_nnc_micro_loop_expression_t*)ccmalloc(sizeof(ccv_nnc_micro_loop_expression_t)); |
370 | 0 | *copy.unary.x = _ccv_nnc_expression_deep_copy(expression->unary.x); |
371 | 0 | return copy; |
372 | 0 | } |
373 | 3 | case CCV_NNC_MICRO_LOOP_EXPR_TYPE_VAR: { |
374 | 3 | ccv_nnc_micro_loop_expression_t copy = *expression; |
375 | 3 | int i; |
376 | 20 | for (i = 0; i < copy.variable.index_count; i++17 ) |
377 | 17 | copy.variable.index[i] = ccv_nnc_micro_loop_index_deep_copy(©.variable.index[i]); |
378 | 3 | return copy; |
379 | 0 | } |
380 | 0 | case CCV_NNC_MICRO_LOOP_EXPR_TYPE_ID: |
381 | 0 | return *expression; |
382 | 3 | } |
383 | 0 | return *expression; |
384 | 3 | } |
385 | | |
386 | | static void _ccv_nnc_replacing_id_in_expression(ccv_nnc_micro_loop_expression_t* const expression, const int id, ccv_nnc_micro_loop_expression_t rvalue, int* const count) |
387 | 116 | { |
388 | 116 | switch (expression->type) |
389 | 116 | { |
390 | 90 | case CCV_NNC_MICRO_LOOP_EXPR_TYPE_VAR: |
391 | 90 | if (expression->variable.id.type == CCV_NNC_MICRO_TENSOR_ID && expression->variable.id.id == id) |
392 | 24 | { |
393 | 24 | ccv_nnc_micro_loop_variable_free(&expression->variable); |
394 | 24 | if (*count == 0) // First time, just assign to expression. |
395 | 21 | *expression = rvalue; |
396 | 3 | else // Otherwise, need to make deep copy of it. |
397 | 3 | *expression = _ccv_nnc_expression_deep_copy(&rvalue); |
398 | 24 | ++(*count); |
399 | 24 | } |
400 | 90 | break; |
401 | 0 | case CCV_NNC_MICRO_LOOP_EXPR_TYPE_TERNAY: |
402 | 0 | _ccv_nnc_replacing_id_in_expression(expression->ternary.pivot, id, rvalue, count); |
403 | 0 | _ccv_nnc_replacing_id_in_expression(expression->ternary.left, id, rvalue, count); |
404 | 0 | _ccv_nnc_replacing_id_in_expression(expression->ternary.right, id, rvalue, count); |
405 | 0 | break; |
406 | 26 | case CCV_NNC_MICRO_LOOP_EXPR_TYPE_BINARY: |
407 | 26 | _ccv_nnc_replacing_id_in_expression(expression->binary.left, id, rvalue, count); |
408 | 26 | _ccv_nnc_replacing_id_in_expression(expression->binary.right, id, rvalue, count); |
409 | 26 | break; |
410 | 0 | case CCV_NNC_MICRO_LOOP_EXPR_TYPE_UNARY: |
411 | 0 | _ccv_nnc_replacing_id_in_expression(expression->unary.x, id, rvalue, count); |
412 | 0 | break; |
413 | 0 | case CCV_NNC_MICRO_LOOP_EXPR_TYPE_ID: |
414 | 0 | assert(expression->id.type == CCV_NNC_MICRO_LOOP_CARRIED_ID); |
415 | 0 | break; |
416 | 116 | } |
417 | 116 | } |
418 | | |
419 | | static void _ccv_nnc_replacing_id_in_rvalue(ccv_nnc_micro_loop_statement_t* const statement, const int id, ccv_nnc_micro_loop_expression_t rvalue, int* const count) |
420 | 64 | { |
421 | 64 | switch (statement->type) |
422 | 64 | { |
423 | 33 | case CCV_NNC_MICRO_LOOP_STATEMENT_TYPE_ASSIGNMENT: |
424 | 33 | _ccv_nnc_replacing_id_in_expression(&statement->assignment.rvalue, id, rvalue, count); |
425 | 33 | break; |
426 | 31 | case CCV_NNC_MICRO_LOOP_STATEMENT_TYPE_COMPOUND_ASSIGNMENT: |
427 | | // Not going to be in lvalue (which is the carried variable only). |
428 | 31 | _ccv_nnc_replacing_id_in_expression(&statement->compound_assignment.rvalue, id, rvalue, count); |
429 | 31 | break; |
430 | 64 | } |
431 | 64 | } |
432 | | |
433 | | typedef struct { |
434 | | int flag; |
435 | | int merge_to; |
436 | | ccv_array_t* writes; |
437 | | ccv_array_t* reads; |
438 | | } ccv_nnc_micro_loop_block_dependency_t; |
439 | | |
440 | | typedef struct { |
441 | | int flag; |
442 | | ccv_array_t* writes; |
443 | | ccv_array_t* reads; |
444 | | } ccv_nnc_micro_tensor_dependency_t; |
445 | | |
446 | | static void _ccv_nnc_micro_block_dependencies_from_rvalue(const ccv_nnc_micro_loop_expression_t* const rvalue, const int i, ccv_nnc_micro_loop_block_dependency_t* const block_dependencies, ccv_nnc_micro_tensor_dependency_t* const tensor_dependencies) |
447 | 75 | { |
448 | 75 | switch (rvalue->type) |
449 | 75 | { |
450 | 51 | case CCV_NNC_MICRO_LOOP_EXPR_TYPE_VAR: |
451 | 51 | if (rvalue->variable.id.type == CCV_NNC_MICRO_TENSOR_ID) |
452 | 51 | { |
453 | 51 | if (!block_dependencies[i].reads) |
454 | 33 | block_dependencies[i].reads = ccv_array_new(sizeof(int), 1, 0); |
455 | 51 | const int id = rvalue->variable.id.id; |
456 | 51 | ccv_array_add_unique_int(block_dependencies[i].reads, id); |
457 | 51 | if (!tensor_dependencies[id].reads) |
458 | 42 | tensor_dependencies[id].reads = ccv_array_new(sizeof(int), 1, 0); |
459 | 51 | ccv_array_add_unique_int(tensor_dependencies[id].reads, i); |
460 | 51 | } |
461 | 51 | break; |
462 | 0 | case CCV_NNC_MICRO_LOOP_EXPR_TYPE_TERNAY: |
463 | 0 | _ccv_nnc_micro_block_dependencies_from_rvalue(rvalue->ternary.pivot, i, block_dependencies, tensor_dependencies); |
464 | 0 | _ccv_nnc_micro_block_dependencies_from_rvalue(rvalue->ternary.left, i, block_dependencies, tensor_dependencies); |
465 | 0 | _ccv_nnc_micro_block_dependencies_from_rvalue(rvalue->ternary.right, i, block_dependencies, tensor_dependencies); |
466 | 0 | break; |
467 | 12 | case CCV_NNC_MICRO_LOOP_EXPR_TYPE_BINARY: |
468 | 12 | _ccv_nnc_micro_block_dependencies_from_rvalue(rvalue->binary.left, i, block_dependencies, tensor_dependencies); |
469 | 12 | _ccv_nnc_micro_block_dependencies_from_rvalue(rvalue->binary.right, i, block_dependencies, tensor_dependencies); |
470 | 12 | break; |
471 | 0 | case CCV_NNC_MICRO_LOOP_EXPR_TYPE_UNARY: |
472 | 0 | _ccv_nnc_micro_block_dependencies_from_rvalue(rvalue->unary.x, i, block_dependencies, tensor_dependencies); |
473 | 0 | break; |
474 | 6 | case CCV_NNC_MICRO_LOOP_EXPR_TYPE_ID: |
475 | 6 | assert(rvalue->id.type == CCV_NNC_MICRO_LOOP_CARRIED_ID); |
476 | 6 | break; |
477 | 75 | } |
478 | 75 | } |
479 | | |
480 | | static void _ccv_nnc_micro_block_dependencies(const ccv_nnc_micro_loop_block_t* const blocks, const int block_size, const int var_count, ccv_nnc_micro_loop_block_dependency_t** const block_dependencies_ref, ccv_nnc_micro_tensor_dependency_t** const tensor_dependencies_ref) |
481 | 6 | { |
482 | 6 | ccv_nnc_micro_loop_block_dependency_t* const block_dependencies = (ccv_nnc_micro_loop_block_dependency_t*)cccalloc(block_size, sizeof(ccv_nnc_micro_loop_block_dependency_t)); |
483 | 6 | ccv_nnc_micro_tensor_dependency_t* const tensor_dependencies = (ccv_nnc_micro_tensor_dependency_t*)cccalloc(var_count, sizeof(ccv_nnc_micro_tensor_dependency_t)); |
484 | 6 | int i, j, k; |
485 | 51 | for (i = 0; i < block_size; i++45 ) |
486 | 45 | { |
487 | 45 | block_dependencies[i].merge_to = i; |
488 | 45 | const ccv_nnc_micro_loop_t* const loops = blocks[i].loops; |
489 | 45 | const int loop_count = blocks[i].loop_count; |
490 | 286 | for (j = 0; j < loop_count; j++241 ) |
491 | 241 | { |
492 | 241 | const ccv_nnc_micro_loop_statement_t* const statements = loops[j].statements; |
493 | 241 | const int statement_count = loops[j].statement_count; |
494 | 292 | for (k = 0; k < statement_count; k++51 ) |
495 | 51 | switch (statements[k].type) |
496 | 51 | { |
497 | 39 | case CCV_NNC_MICRO_LOOP_STATEMENT_TYPE_ASSIGNMENT: { |
498 | 39 | assert(statements[k].assignment.lvalue.id.type == CCV_NNC_MICRO_TENSOR_ID); |
499 | 39 | const int id = statements[k].assignment.lvalue.id.id; |
500 | 39 | if (!block_dependencies[i].writes) |
501 | 39 | block_dependencies[i].writes = ccv_array_new(sizeof(int), 1, 0); |
502 | 39 | ccv_array_add_unique_int(block_dependencies[i].writes, id); |
503 | 39 | if (!tensor_dependencies[id].writes) |
504 | 39 | tensor_dependencies[id].writes = ccv_array_new(sizeof(int), 1, 0); |
505 | 39 | ccv_array_add_unique_int(tensor_dependencies[id].writes, i); |
506 | 39 | _ccv_nnc_micro_block_dependencies_from_rvalue(&statements[k].assignment.rvalue, i, block_dependencies, tensor_dependencies); |
507 | 39 | break; |
508 | 39 | } |
509 | 12 | case CCV_NNC_MICRO_LOOP_STATEMENT_TYPE_COMPOUND_ASSIGNMENT: { |
510 | 12 | if (statements[k].compound_assignment.lvalue.type == CCV_NNC_MICRO_LOOP_EXPR_TYPE_VAR) |
511 | 6 | { |
512 | 6 | assert(statements[k].compound_assignment.lvalue.id.type == CCV_NNC_MICRO_TENSOR_ID); |
513 | 6 | const int id = statements[k].compound_assignment.lvalue.id.id; |
514 | 6 | if (!block_dependencies[i].writes) |
515 | 6 | block_dependencies[i].writes = ccv_array_new(sizeof(int), 1, 0); |
516 | 6 | ccv_array_add_unique_int(block_dependencies[i].writes, id); |
517 | 6 | if (!tensor_dependencies[id].writes) |
518 | 0 | tensor_dependencies[id].writes = ccv_array_new(sizeof(int), 1, 0); |
519 | 6 | ccv_array_add_unique_int(tensor_dependencies[id].writes, i); |
520 | 6 | if (!block_dependencies[i].reads) |
521 | 6 | block_dependencies[i].reads = ccv_array_new(sizeof(int), 1, 0); |
522 | 6 | ccv_array_add_unique_int(block_dependencies[i].reads, id); |
523 | 6 | if (!tensor_dependencies[id].reads) |
524 | 6 | tensor_dependencies[id].reads = ccv_array_new(sizeof(int), 1, 0); |
525 | 6 | ccv_array_add_unique_int(tensor_dependencies[id].reads, i); |
526 | 6 | } |
527 | 12 | _ccv_nnc_micro_block_dependencies_from_rvalue(&statements[k].compound_assignment.rvalue, i, block_dependencies, tensor_dependencies); |
528 | 12 | break; |
529 | 12 | } |
530 | 51 | } |
531 | 241 | } |
532 | 45 | } |
533 | 6 | *block_dependencies_ref = block_dependencies; |
534 | 6 | *tensor_dependencies_ref = tensor_dependencies; |
535 | 6 | } |
536 | | |
537 | | static void _ccv_nnc_micro_dependencies_free(ccv_nnc_micro_loop_block_dependency_t* const block_dependencies, const int block_size, ccv_nnc_micro_tensor_dependency_t* const tensor_dependencies, const int var_count) |
538 | 6 | { |
539 | 6 | int i; |
540 | 51 | for (i = 0; i < block_size; i++45 ) |
541 | 45 | { |
542 | 45 | if (block_dependencies[i].writes) |
543 | 45 | ccv_array_free(block_dependencies[i].writes); |
544 | 45 | if (block_dependencies[i].reads) |
545 | 39 | ccv_array_free(block_dependencies[i].reads); |
546 | 45 | } |
547 | 6 | ccfree(block_dependencies); |
548 | 60 | for (i = 0; i < var_count; i++54 ) |
549 | 54 | { |
550 | 54 | if (tensor_dependencies[i].writes) |
551 | 39 | ccv_array_free(tensor_dependencies[i].writes); |
552 | 54 | if (tensor_dependencies[i].reads) |
553 | 48 | ccv_array_free(tensor_dependencies[i].reads); |
554 | 54 | } |
555 | 6 | ccfree(tensor_dependencies); |
556 | 6 | } |
557 | | |
558 | | static int _ccv_nnc_tensor_reads_in_y_from_writes_after_x(const ccv_nnc_micro_loop_block_dependency_t* const block_dependencies, const ccv_nnc_micro_tensor_dependency_t* const tensor_dependencies, const int x, const int y) |
559 | 30 | { |
560 | 30 | int i, j; |
561 | 30 | int flag = 0; |
562 | 68 | for (i = 0; !flag && i < block_dependencies[y].reads->rnum53 ; i++38 ) |
563 | 38 | { |
564 | 38 | const int read_idx = *(int*)ccv_array_get(block_dependencies[y].reads, i); |
565 | 38 | if (tensor_dependencies[read_idx].writes) |
566 | 68 | for (j = 0; 34 !flag && j < tensor_dependencies[read_idx].writes->rnum53 ; j++34 ) |
567 | 34 | { |
568 | 34 | int block_idx = *(int*)ccv_array_get(tensor_dependencies[read_idx].writes, j); |
569 | 47 | while (block_idx != block_dependencies[block_idx].merge_to) |
570 | 13 | block_idx = block_dependencies[block_idx].merge_to; |
571 | 34 | if (!block_dependencies[block_idx].flag) // Not in use, continue. |
572 | 0 | continue; |
573 | 34 | assert(block_idx <= y); |
574 | | // If the block_idx is between i and j (and not neither). We cannot merge. |
575 | 34 | if (block_idx > x && block_idx != y15 ) |
576 | 15 | flag = block_idx; |
577 | 34 | } |
578 | 38 | } |
579 | 30 | return flag; |
580 | 30 | } |
581 | | |
582 | | static int _ccv_nnc_tensor_writes_in_x_reads_before_y(const ccv_nnc_micro_loop_block_dependency_t* const block_dependencies, const ccv_nnc_micro_tensor_dependency_t* const tensor_dependencies, const int x, const int y) |
583 | 15 | { |
584 | 15 | int i, j; |
585 | 15 | int flag = 0; |
586 | 55 | for (i = 0; !flag && i < block_dependencies[x].writes->rnum48 ; i++40 ) |
587 | 40 | { |
588 | 40 | const int write_idx = *(int*)ccv_array_get(block_dependencies[x].writes, i); |
589 | 40 | if (tensor_dependencies[write_idx].reads) |
590 | 100 | for (j = 0; 40 !flag && j < tensor_dependencies[write_idx].reads->rnum93 ; j++60 ) |
591 | 60 | { |
592 | 60 | int block_idx = *(int*)ccv_array_get(tensor_dependencies[write_idx].reads, j); |
593 | 93 | while (block_idx != block_dependencies[block_idx].merge_to) |
594 | 33 | block_idx = block_dependencies[block_idx].merge_to; |
595 | 60 | if (!block_dependencies[block_idx].flag) // Not in use, continue. |
596 | 14 | continue; |
597 | 46 | assert(block_idx >= x); |
598 | | // If the block_idx is between i and j (and not neither). We cannot merge. |
599 | 46 | if (block_idx < y && block_idx != x35 ) |
600 | 7 | flag = block_idx; |
601 | 46 | } |
602 | 40 | } |
603 | 15 | return flag; |
604 | 15 | } |
605 | | |
606 | | static void _ccv_nnc_tensor_remove_dead_store(const ccv_nnc_micro_tensor_dependency_t* const tensor_dependency, const int tensor_idx, ccv_array_t* const blocks) |
607 | 6 | { |
608 | 6 | int i, j, k, l;; |
609 | 6 | if (tensor_dependency->writes) |
610 | 12 | for (i = 0; 6 i < tensor_dependency->writes->rnum; i++6 ) |
611 | 6 | { |
612 | 6 | const int write_idx = *(int*)ccv_array_get(tensor_dependency->writes, i); |
613 | 6 | ccv_nnc_micro_loop_block_t* const block = (ccv_nnc_micro_loop_block_t*)ccv_array_get(blocks, write_idx); |
614 | 6 | int flag = 0; |
615 | 6 | for (j = 0; j < block->loop_count; j++0 ) |
616 | 0 | { |
617 | 0 | ccv_nnc_micro_loop_statement_t* const statements = block->loops[j].statements; |
618 | 0 | for (k = 0, l = 0; k < block->loops[j].statement_count; k++) |
619 | 0 | { |
620 | | // It cannot be compound assignment, in this case, this tensor will be in read, and |
621 | | // it will be in active use (anything "read" in an active block will be marked as in use). |
622 | 0 | assert(!(statements[k].type == CCV_NNC_MICRO_LOOP_STATEMENT_TYPE_COMPOUND_ASSIGNMENT && |
623 | 0 | statements[k].compound_assignment.lvalue.id.type == CCV_NNC_MICRO_TENSOR_ID && |
624 | 0 | statements[k].compound_assignment.lvalue.id.id == tensor_idx)); |
625 | 0 | if (statements[k].type == CCV_NNC_MICRO_LOOP_STATEMENT_TYPE_ASSIGNMENT && |
626 | 0 | statements[k].assignment.lvalue.id.type == CCV_NNC_MICRO_TENSOR_ID && |
627 | 0 | statements[k].assignment.lvalue.id.id == tensor_idx) |
628 | 0 | { |
629 | | // This is a dead store, prepare to remove. |
630 | 0 | ccv_nnc_micro_loop_statement_free(&statements[k]); |
631 | 0 | } else { |
632 | 0 | statements[l] = statements[k]; |
633 | 0 | ++l; |
634 | 0 | } |
635 | 0 | } |
636 | 0 | if (l < block->loops[j].statement_count) |
637 | 0 | { |
638 | 0 | if (l > 0) |
639 | 0 | block->loops[j].statements = (ccv_nnc_micro_loop_statement_t*)ccrealloc(block->loops[j].statements, sizeof(ccv_nnc_micro_loop_statement_t) * l); |
640 | 0 | else { |
641 | 0 | ccfree(block->loops[j].statements); |
642 | 0 | block->loops[j].statements = 0; |
643 | 0 | } |
644 | 0 | block->loops[j].statement_count = 0; |
645 | 0 | } |
646 | 0 | if (block->loops[j].statement_count > 0) |
647 | 0 | flag = 1; |
648 | 0 | } |
649 | 6 | if (!flag) // No statement for this block, remove this whole block. |
650 | 6 | { |
651 | 6 | ccv_nnc_micro_loops_free(block->loops, block->loop_count); |
652 | 6 | ccfree(block->loops); |
653 | 6 | block->loops = 0; |
654 | 6 | block->loop_count = 0; |
655 | 6 | } |
656 | 6 | } |
657 | 6 | } |
658 | | |
659 | | static void _ccv_nnc_loop_merging(ccv_nnc_micro_loop_block_dependency_t* const block_dependencies, const ccv_nnc_micro_tensor_dependency_t* const tensor_dependencies, ccv_array_t* const blocks, const int max_loop_count, const int* const groups, khash_t(ccv_nnc_axis_id_group)* const axis_id_groups) |
660 | 6 | { |
661 | 6 | int i, j; |
662 | 6 | int left_loop_idx[max_loop_count]; |
663 | 6 | int right_loop_idx[max_loop_count]; |
664 | 6 | ccv_nnc_micro_loop_t loops[max_loop_count]; |
665 | | // Merge loops from blocks. |
666 | 45 | for (i = 0; i < blocks->rnum - 1; i++39 ) |
667 | 39 | { |
668 | 39 | ccv_nnc_micro_loop_block_t* const left_block = (ccv_nnc_micro_loop_block_t*)ccv_array_get(blocks, i); |
669 | 39 | if (left_block->loop_count == 0) |
670 | 22 | continue; |
671 | 74 | for (j = i + 1; 17 j < blocks->rnum; j++57 ) |
672 | 62 | { |
673 | | // We always merge from right block to left block. Thus, the right block will always be |
674 | | // in the original form. |
675 | 62 | ccv_nnc_micro_loop_block_t* const right_block = (ccv_nnc_micro_loop_block_t*)ccv_array_get(blocks, j); |
676 | 62 | if (right_block->loop_count == 0) |
677 | 8 | continue; |
678 | 54 | int merge_to_right = 0; |
679 | | // First check whether between left and right, there are any blocks that the right block |
680 | | // depends on. If that is the case, we cannot merge the right block into the left block. |
681 | 54 | if (j > i + 1 && block_dependencies[j].reads38 ) |
682 | 30 | { |
683 | 30 | const int block_idx = _ccv_nnc_tensor_reads_in_y_from_writes_after_x(block_dependencies, tensor_dependencies, i, j); |
684 | | // Cannot merge because we have dependencies in between. Merging will violate that |
685 | | // dependency relationship. |
686 | 30 | if (block_idx) |
687 | 15 | { |
688 | | // Now check to see if left can be merged into right? If so, we are lucky. |
689 | 15 | if (_ccv_nnc_tensor_writes_in_x_reads_before_y(block_dependencies, tensor_dependencies, i, j)) |
690 | 7 | continue; |
691 | 8 | merge_to_right = 1; |
692 | 8 | } |
693 | 30 | } |
694 | | // This method not only compares whether they have the same loop or not, but also gives indexes that |
695 | | // to match the loop start / end index, where they should move to. For example, if: |
696 | | // left_loop_idx[2] = 3 |
697 | | // right_loop_idx[0] = 3 |
698 | | // That means right now, loop at index 2 on the left is the same as loop at index 0 on the right. |
699 | | // And to match exactly, they both need to move to index 3. |
700 | 47 | if (_ccv_nnc_same_loop(left_block, right_block, groups, axis_id_groups, left_loop_idx, right_loop_idx)) |
701 | 24 | { |
702 | | // Make sure if we have extra loop, they are on the left. |
703 | 24 | if (right_block->loop_count > left_block->loop_count) |
704 | 0 | { |
705 | 0 | ccv_nnc_micro_loop_block_t t; |
706 | 0 | CCV_SWAP(*left_block, *right_block, t); |
707 | 0 | } |
708 | 24 | assert(left_block->loop_count == right_block->loop_count || left_block->loop_count == right_block->loop_count + 1); |
709 | 24 | _ccv_nnc_loop_order_by(left_block, left_loop_idx, loops); |
710 | 24 | _ccv_nnc_loop_order_by(right_block, right_loop_idx, loops); |
711 | 24 | const int left_start_idx = left_block->loop_count - right_block->loop_count; |
712 | 24 | if (left_block->carried_count > 0) |
713 | 11 | _ccv_nnc_loop_rename_carrieds(right_block, left_block->carried_count); |
714 | 24 | left_block->carried_count += right_block->carried_count; |
715 | 24 | int k; |
716 | 168 | for (k = 0; k < right_block->loop_count; k++144 ) // Merge loops. |
717 | 144 | { |
718 | 144 | const int left_idx = left_start_idx + k; |
719 | 144 | if (right_block->loops[k].carried_count > 0) |
720 | 3 | { |
721 | 3 | if (left_block->loops[left_idx].carried_count > 0) |
722 | 0 | { |
723 | 0 | left_block->loops[left_idx].carrieds = (ccv_nnc_micro_loop_carried_t*)ccrealloc(left_block->loops[left_idx].carrieds, sizeof(ccv_nnc_micro_loop_carried_t) * (left_block->loops[left_idx].carried_count + right_block->loops[k].carried_count)); |
724 | 0 | memcpy(left_block->loops[left_idx].carrieds + left_block->loops[left_idx].carried_count, right_block->loops[k].carrieds, sizeof(ccv_nnc_micro_loop_carried_t) * right_block->loops[k].carried_count); |
725 | 0 | ccfree(right_block->loops[k].carrieds); |
726 | 0 | } else |
727 | 3 | left_block->loops[left_idx].carrieds = right_block->loops[k].carrieds; |
728 | 3 | left_block->loops[left_idx].carried_count += right_block->loops[k].carried_count; |
729 | 3 | right_block->loops[k].carrieds = 0; |
730 | 3 | right_block->loops[k].carried_count = 0; |
731 | 3 | } |
732 | 144 | if (right_block->loops[k].statement_count > 0) |
733 | 27 | { |
734 | 27 | if (left_block->loops[left_idx].statement_count > 0) |
735 | 24 | { |
736 | 24 | left_block->loops[left_idx].statements = (ccv_nnc_micro_loop_statement_t*)ccrealloc(left_block->loops[left_idx].statements, sizeof(ccv_nnc_micro_loop_statement_t) * (left_block->loops[left_idx].statement_count + right_block->loops[k].statement_count)); |
737 | 24 | memcpy(left_block->loops[left_idx].statements + left_block->loops[left_idx].statement_count, right_block->loops[k].statements, sizeof(ccv_nnc_micro_loop_statement_t) * right_block->loops[k].statement_count); |
738 | 24 | ccfree(right_block->loops[k].statements); |
739 | 24 | } else |
740 | 3 | left_block->loops[left_idx].statements = right_block->loops[k].statements; |
741 | 27 | left_block->loops[left_idx].statement_count += right_block->loops[k].statement_count; |
742 | 27 | right_block->loops[k].statements = 0; |
743 | 27 | right_block->loops[k].statement_count = 0; |
744 | 27 | } |
745 | 144 | } |
746 | | // Once merged, free the loop. |
747 | 24 | ccfree(right_block->loops); |
748 | 24 | right_block->loops = 0; |
749 | 24 | right_block->loop_count = 0; |
750 | 24 | int x = i, y = j; |
751 | 24 | if (merge_to_right) // If this is merge to right. |
752 | 5 | { |
753 | 5 | ccv_nnc_micro_loop_block_t t; |
754 | 5 | CCV_SWAP(*left_block, *right_block, t); |
755 | 5 | x = j, y = i; |
756 | 5 | } |
757 | | // Merge all reads and writes tensors into block dependency. |
758 | 24 | if (block_dependencies[y].writes && block_dependencies[y].writes->rnum) |
759 | 24 | { |
760 | 24 | if (!block_dependencies[x].writes) |
761 | 0 | block_dependencies[x].writes = ccv_array_new(sizeof(int), 1, 0); |
762 | 69 | for (k = 0; k < block_dependencies[y].writes->rnum; k++45 ) |
763 | 45 | ccv_array_push(block_dependencies[x].writes, ccv_array_get(block_dependencies[y].writes, k)); |
764 | 24 | } |
765 | 24 | if (block_dependencies[y].reads && block_dependencies[y].reads->rnum) |
766 | 24 | { |
767 | 24 | if (!block_dependencies[x].reads) |
768 | 0 | block_dependencies[x].reads = ccv_array_new(sizeof(int), 1, 0); |
769 | 90 | for (k = 0; k < block_dependencies[y].reads->rnum; k++66 ) |
770 | 66 | ccv_array_push(block_dependencies[x].reads, ccv_array_get(block_dependencies[y].reads, k)); |
771 | 24 | } |
772 | | // Merged, mark the proper merging dependency. |
773 | 24 | block_dependencies[y].merge_to = x; |
774 | 24 | if (merge_to_right) // If this is merge to right, now left is empty, break. |
775 | 5 | break; |
776 | 24 | } |
777 | 47 | } |
778 | 17 | } |
779 | 6 | } |
780 | | |
781 | | static void _ccv_nnc_var_subst(ccv_nnc_micro_tensor_t* const vars, const int var_count, const ccv_nnc_micro_io_t* const inputs, const int input_size, const ccv_nnc_micro_io_t* const outputs, const int output_size, ccv_array_t* const blocks, const int* const groups, khash_t(ccv_nnc_axis_id_group)* const axis_id_groups) |
782 | 6 | { |
783 | 6 | int i, j; |
784 | | // These are simple programs, so we are going to loop over all blocks to see whether a non-output-input |
785 | | // var only write / read in one loop. If that is the case, we are going to remove that var. |
786 | | // We have to do this replacement from bottom to top though. |
787 | 60 | for (i = 0; i < var_count; i++54 ) |
788 | 54 | { |
789 | 54 | int flag = 0; |
790 | 186 | for (j = 0; !flag && j < input_size171 ; j++132 ) |
791 | 132 | flag = (inputs[j]->id == i); |
792 | 117 | for (j = 0; !flag && j < output_size93 ; j++63 ) |
793 | 63 | flag = (outputs[j]->id == i); |
794 | 54 | if (flag) // This is in outputs or inputs. |
795 | 24 | continue; |
796 | 30 | int count_var = 0; |
797 | 30 | ccv_nnc_micro_loop_variable_t lvalue; |
798 | 30 | ccv_nnc_micro_loop_expression_t rvalue; |
799 | 30 | int block_idx, loop_idx, statement_idx; |
800 | 119 | for (j = 0; j < blocks->rnum; j++89 ) |
801 | 89 | { |
802 | 89 | const ccv_nnc_micro_loop_block_t* const block = (ccv_nnc_micro_loop_block_t*)ccv_array_get(blocks, j); |
803 | 89 | int k, l; |
804 | 89 | const int loop_count = block->loop_count; |
805 | 89 | const ccv_nnc_micro_loop_t* const loops = block->loops; |
806 | 89 | int var_per_block = 0; |
807 | 450 | for (k = 0; k < loop_count; k++361 ) |
808 | 361 | { |
809 | 361 | int flag = 0; |
810 | 361 | const int statement_count = loops[k].statement_count; |
811 | 361 | ccv_nnc_micro_loop_statement_t* const statements = loops[k].statements; |
812 | 552 | for (l = 0; l < statement_count; l++191 ) |
813 | 191 | if (statements[l].type == CCV_NNC_MICRO_LOOP_STATEMENT_TYPE_ASSIGNMENT && |
814 | 191 | statements[l].assignment.lvalue.id.type == CCV_NNC_MICRO_TENSOR_ID140 && |
815 | 191 | statements[l].assignment.lvalue.id.id == i140 ) |
816 | 24 | { |
817 | 24 | lvalue = statements[l].assignment.lvalue; |
818 | 24 | if (_ccv_nnc_only_var_in_rvalue(i, lvalue, statements[l], groups, axis_id_groups)) |
819 | 0 | flag = 2; |
820 | 24 | else { |
821 | | // If the variable not showing up on the right-side, we can continue. |
822 | 24 | rvalue = statements[l].assignment.rvalue; |
823 | 24 | block_idx = j; |
824 | 24 | loop_idx = k; |
825 | 24 | statement_idx = l; |
826 | 24 | ++flag; |
827 | 24 | } |
828 | 167 | } else if (statements[l].type == CCV_NNC_MICRO_LOOP_STATEMENT_TYPE_COMPOUND_ASSIGNMENT && |
829 | 167 | statements[l].compound_assignment.lvalue.id.type == CCV_NNC_MICRO_TENSOR_ID51 && |
830 | 167 | statements[l].compound_assignment.lvalue.id.id == i42 ) { |
831 | | // This is compound assignment, automatically increase by 2. |
832 | 0 | flag += 2; |
833 | 0 | } |
834 | 361 | if (flag > 1) // We have more than 1 assignment for this id, it is not good. We cannot remove it. |
835 | 0 | { |
836 | 0 | var_per_block += flag; |
837 | 0 | continue; |
838 | 0 | } |
839 | 552 | for (l = 0; 361 l < statement_count; l++191 ) |
840 | 191 | flag = ccv_max(flag, _ccv_nnc_only_var_in_rvalue(i, lvalue, statements[l], groups, axis_id_groups)); |
841 | | // If flag == 2, meaning it found a var with a different index. This is a bad news. |
842 | 361 | var_per_block += flag; |
843 | 361 | } |
844 | 89 | count_var += var_per_block; |
845 | 89 | } |
846 | | // If this is used more than one place (write multiple times, have different index, or used in different blocks), |
847 | | // I cannot get rid of it. |
848 | 30 | if (count_var != 1) |
849 | 9 | continue; |
850 | | // Otherwise, now loop again and prepare to get rid of it. |
851 | 21 | ccv_nnc_micro_loop_block_t* const block = (ccv_nnc_micro_loop_block_t*)ccv_array_get(blocks, block_idx); |
852 | 21 | ccv_nnc_micro_loop_statement_t* statements = block->loops[loop_idx].statements; |
853 | 21 | ccv_nnc_micro_loop_statement_t statement = statements[statement_idx]; |
854 | | // First, remove the assignment. |
855 | 21 | if (statement_idx < block->loops[loop_idx].statement_count - 1) |
856 | 21 | memmove(statements + statement_idx, statements + statement_idx + 1, sizeof(ccv_nnc_micro_loop_statement_t) * (block->loops[loop_idx].statement_count - statement_idx - 1)); |
857 | 21 | --block->loops[loop_idx].statement_count; |
858 | 21 | const int statement_count = block->loops[loop_idx].statement_count; |
859 | 21 | statements = block->loops[loop_idx].statements = (ccv_nnc_micro_loop_statement_t*)ccrealloc(statements, sizeof(ccv_nnc_micro_loop_statement_t) * statement_count); |
860 | 21 | int k = 0; |
861 | 85 | for (j = 0; j < statement_count; j++64 ) |
862 | 64 | _ccv_nnc_replacing_id_in_rvalue(&statements[j], i, rvalue, &k); |
863 | 21 | if (k == 0) // If nothing to replace, free up everything. |
864 | 0 | ccv_nnc_micro_loop_statement_free(&statement); |
865 | 21 | else |
866 | 21 | ccv_nnc_micro_loop_statement_lvalue_free(&statement); |
867 | | // No need to allocate for this var. It is not used, only useful for shape computation. |
868 | 21 | vars[i].no_alloc = 1; |
869 | 21 | } |
870 | 6 | } |
871 | | |
872 | | static int _ccv_nnc_index_binary_size(const ccv_nnc_micro_loop_index_term_t index) |
873 | 448 | { |
874 | 448 | switch (index.type) |
875 | 448 | { |
876 | 0 | case CCV_NNC_MICRO_LOOP_INDEX_TYPE_NONE: |
877 | 0 | return 0; |
878 | 48 | case CCV_NNC_MICRO_LOOP_INDEX_TYPE_VAL: |
879 | 352 | case CCV_NNC_MICRO_LOOP_INDEX_TYPE_ID: |
880 | 352 | return 1; |
881 | 96 | case CCV_NNC_MICRO_LOOP_INDEX_TYPE_BINARY: |
882 | 96 | if (index.binary->op == CCV_NNC_MICRO_BINARY_OP_PLUS || index.binary->op == CCV_NNC_MICRO_BINARY_OP_MINUS48 ) |
883 | 96 | return _ccv_nnc_index_binary_size(index.binary->left) + _ccv_nnc_index_binary_size(index.binary->right); |
884 | 0 | else |
885 | 0 | return 1; |
886 | 448 | } |
887 | 0 | return 0; |
888 | 448 | } |
889 | | |
890 | | typedef struct { |
891 | | int sign:7; |
892 | | int ignore:1; |
893 | | ccv_nnc_micro_loop_index_term_t term; |
894 | | } ccv_nnc_micro_loop_binary_term_t; |
895 | | |
896 | | static void _ccv_nnc_index_term_flatten(ccv_nnc_micro_loop_binary_term_t* const binary_terms, const ccv_nnc_micro_loop_index_term_t index, const int sign, int* const i) |
897 | 448 | { |
898 | 448 | switch (index.type) |
899 | 448 | { |
900 | 0 | case CCV_NNC_MICRO_LOOP_INDEX_TYPE_NONE: // No need to occupy. |
901 | 0 | break; |
902 | 48 | case CCV_NNC_MICRO_LOOP_INDEX_TYPE_VAL: |
903 | 352 | case CCV_NNC_MICRO_LOOP_INDEX_TYPE_ID: |
904 | 352 | binary_terms[*i].term = index; |
905 | 352 | binary_terms[*i].sign = sign; |
906 | 352 | binary_terms[*i].ignore = 0; |
907 | 352 | ++(*i); |
908 | 352 | break; |
909 | 96 | case CCV_NNC_MICRO_LOOP_INDEX_TYPE_BINARY: |
910 | 96 | if (index.binary->op == CCV_NNC_MICRO_BINARY_OP_PLUS || index.binary->op == CCV_NNC_MICRO_BINARY_OP_MINUS48 ) |
911 | 96 | { |
912 | 96 | _ccv_nnc_index_term_flatten(binary_terms, index.binary->left, sign, i); |
913 | 96 | if (index.binary->op == CCV_NNC_MICRO_BINARY_OP_MINUS) // Switch sign. |
914 | 48 | _ccv_nnc_index_term_flatten(binary_terms, index.binary->right, sign == CCV_NNC_MICRO_BINARY_OP_PLUS ? CCV_NNC_MICRO_BINARY_OP_MINUS : CCV_NNC_MICRO_BINARY_OP_PLUS0 , i); |
915 | 48 | else |
916 | 48 | _ccv_nnc_index_term_flatten(binary_terms, index.binary->right, sign, i); |
917 | 96 | } else { |
918 | 0 | binary_terms[*i].term = index; |
919 | 0 | binary_terms[*i].sign = sign; |
920 | 0 | binary_terms[*i].ignore = 0; |
921 | 0 | ++(*i); |
922 | 0 | } |
923 | 96 | break; |
924 | 448 | } |
925 | 448 | } |
926 | | |
927 | | // 0 is we don't understand, -1 is false, 1 is true. |
928 | | static int _ccv_nnc_index_less_than_or_equal_to(const ccv_nnc_micro_loop_index_term_t left, const ccv_nnc_micro_loop_index_term_t right, const ccv_nnc_micro_tensor_t* const vars, const int var_count, const int* const groups, khash_t(ccv_nnc_axis_id_group)* const axis_id_groups) |
929 | 256 | { |
930 | | // Special case 1. |
931 | 256 | if (left.type == CCV_NNC_MICRO_LOOP_INDEX_TYPE_VAL && right.type == CCV_NNC_MICRO_LOOP_INDEX_TYPE_VAL128 ) |
932 | 128 | return left.immediate_value <= right.immediate_value ? 1 : -10 ; |
933 | | // Special case 2. |
934 | 128 | if (left.type == CCV_NNC_MICRO_LOOP_INDEX_TYPE_VAL && left.immediate_value == 00 && right.type == CCV_NNC_MICRO_LOOP_INDEX_TYPE_ID0 && right.id.type == CCV_NNC_MICRO_AXIS_SIZE_ID0 ) |
935 | 0 | return 1; |
936 | | // Special case 3. |
937 | 128 | if (left.type == CCV_NNC_MICRO_LOOP_INDEX_TYPE_ID && left.id.type == CCV_NNC_MICRO_AXIS_SIZE_ID104 && right.type == CCV_NNC_MICRO_LOOP_INDEX_TYPE_VAL92 && right.immediate_value == 00 ) |
938 | 0 | return -1; |
939 | | // Now, we only have one variable in both left and right, need to flat the binary tree (if possible) and reduce it to constant if possible. |
940 | | // We can only flatten if it is + / - at the moment. |
941 | 128 | const int left_binary_size = _ccv_nnc_index_binary_size(left); |
942 | 128 | assert(left_binary_size >= 1); |
943 | 128 | const int right_binary_size = _ccv_nnc_index_binary_size(right); |
944 | 128 | assert(right_binary_size >= 1); |
945 | 128 | ccv_nnc_micro_loop_binary_term_t* const left_binary_terms = (ccv_nnc_micro_loop_binary_term_t*)ccmalloc(sizeof(ccv_nnc_micro_loop_binary_term_t) * (left_binary_size + right_binary_size)); |
946 | 128 | ccv_nnc_micro_loop_binary_term_t* const right_binary_terms = left_binary_terms + left_binary_size; |
947 | 128 | int i, j; |
948 | 128 | i = 0; |
949 | 128 | _ccv_nnc_index_term_flatten(left_binary_terms, left, CCV_NNC_MICRO_BINARY_OP_PLUS, &i); |
950 | 128 | assert(i == left_binary_size); |
951 | 128 | i = 0; |
952 | 128 | _ccv_nnc_index_term_flatten(right_binary_terms, right, CCV_NNC_MICRO_BINARY_OP_PLUS, &i); |
953 | 128 | assert(i == right_binary_size); |
954 | | // Matching signs in left terms. |
955 | 200 | for (i = 0; 128 i < left_binary_size - 1; i++72 ) |
956 | 228 | for (j = i + 1; 72 j < left_binary_size; j++156 ) |
957 | 156 | if (!left_binary_terms[i].ignore && !left_binary_terms[j].ignore132 && |
958 | 156 | _ccv_nnc_same_index_term(left_binary_terms[i].term, left_binary_terms[j].term, groups, axis_id_groups)120 && |
959 | 156 | left_binary_terms[i].sign != left_binary_terms[j].sign24 ) |
960 | 24 | { |
961 | 24 | left_binary_terms[i].ignore = -1; |
962 | 24 | left_binary_terms[j].ignore = -1; |
963 | 24 | } |
964 | | // Matching signs in right terms. |
965 | 152 | for (i = 0; i < right_binary_size - 1; i++24 ) |
966 | 60 | for (j = i + 1; 24 j < right_binary_size; j++36 ) |
967 | 36 | if (!right_binary_terms[i].ignore && !right_binary_terms[j].ignore && |
968 | 36 | _ccv_nnc_same_index_term(right_binary_terms[i].term, right_binary_terms[j].term, groups, axis_id_groups) && |
969 | 36 | right_binary_terms[i].sign != right_binary_terms[j].sign0 ) |
970 | 0 | { |
971 | 0 | right_binary_terms[i].ignore = -1; |
972 | 0 | right_binary_terms[j].ignore = -1; |
973 | 0 | } |
974 | | // Matching left to right. |
975 | 328 | for (i = 0; i < left_binary_size; i++200 ) |
976 | 472 | for (j = 0; 200 j < right_binary_size; j++272 ) |
977 | | // If they are the same, we can ignore now. |
978 | 272 | if (!left_binary_terms[i].ignore && !right_binary_terms[j].ignore188 && |
979 | 272 | _ccv_nnc_same_index_term(left_binary_terms[i].term, right_binary_terms[j].term, groups, axis_id_groups)152 && |
980 | 272 | left_binary_terms[i].sign == right_binary_terms[j].sign140 ) |
981 | 140 | { |
982 | 140 | left_binary_terms[i].ignore = -1; |
983 | 140 | right_binary_terms[j].ignore = -1; |
984 | 140 | } |
985 | | // After reduced, we should only have immediate values left, otherwise we cannot progress. |
986 | 128 | int left_val = 0; |
987 | 316 | for (i = 0; i < left_binary_size; i++188 ) |
988 | 200 | if (!left_binary_terms[i].ignore) |
989 | 12 | { |
990 | 12 | if (left_binary_terms[i].term.type != CCV_NNC_MICRO_LOOP_INDEX_TYPE_VAL) |
991 | 12 | { |
992 | 12 | ccfree(left_binary_terms); |
993 | 12 | return 0; |
994 | 12 | } else |
995 | 0 | left_val += left_binary_terms[i].sign == CCV_NNC_MICRO_BINARY_OP_PLUS ? left_binary_terms[i].term.immediate_value : -left_binary_terms[i].term.immediate_value; |
996 | 12 | } |
997 | 116 | int right_val = 0; |
998 | 256 | for (i = 0; i < right_binary_size; i++140 ) |
999 | 140 | if (!right_binary_terms[i].ignore) |
1000 | 0 | { |
1001 | 0 | if (right_binary_terms[i].term.type != CCV_NNC_MICRO_LOOP_INDEX_TYPE_VAL) |
1002 | 0 | { |
1003 | 0 | ccfree(left_binary_terms); |
1004 | 0 | return 0; |
1005 | 0 | } else |
1006 | 0 | right_val += right_binary_terms[i].sign == CCV_NNC_MICRO_BINARY_OP_PLUS ? right_binary_terms[i].term.immediate_value : -right_binary_terms[i].term.immediate_value; |
1007 | 0 | } |
1008 | 116 | ccfree(left_binary_terms); |
1009 | 116 | return left_val <= right_val ? 1 : -10 ; |
1010 | 116 | } |
1011 | | |
1012 | | // If this index term refers to an axis size that actually has a expression, refer to that instead (like for reindex operation). |
1013 | | static ccv_nnc_micro_loop_index_term_t _ccv_nnc_micro_index_shape_merging(const ccv_nnc_micro_loop_index_term_t index, const ccv_nnc_micro_tensor_t* const vars, const int var_count, const int* const groups, khash_t(ccv_nnc_axis_id_group)* const axis_id_groups) |
1014 | 429 | { |
1015 | 429 | ccv_nnc_micro_loop_index_term_t result = index; |
1016 | 429 | for (;;) |
1017 | 648 | { |
1018 | 648 | if (!(result.type == CCV_NNC_MICRO_LOOP_INDEX_TYPE_ID && result.id.type == CCV_NNC_MICRO_AXIS_SIZE_ID451 )) |
1019 | 218 | return result; |
1020 | 430 | int root = groups[result.id.id]; |
1021 | 430 | while (groups[root] != root) |
1022 | 0 | root = groups[root]; |
1023 | 430 | if (vars[root].shape == 0) |
1024 | 211 | return result; |
1025 | 219 | assert(result.id.d >= 0 && result.id.d < vars[root].dimensions); |
1026 | 219 | result = vars[root].shape[result.id.d]; |
1027 | 219 | } |
1028 | 429 | } |
1029 | | |
1030 | | static int _ccv_nnc_micro_low_high_bound_from_index(const ccv_nnc_micro_loop_index_term_t index, ccv_nnc_micro_loop_index_term_t* const low_ref, ccv_nnc_micro_loop_index_term_t* const high_ref, const ccv_nnc_micro_loop_t* const loops, const int loop_count, const ccv_nnc_micro_tensor_t* const vars, const int var_count, const int* const groups, khash_t(ccv_nnc_axis_id_group)* const axis_id_groups) |
1031 | 36 | { |
1032 | 36 | switch (index.type) |
1033 | 36 | { |
1034 | 0 | case CCV_NNC_MICRO_LOOP_INDEX_TYPE_NONE: |
1035 | 0 | *low_ref = (ccv_nnc_micro_loop_index_term_t){ |
1036 | 0 | .type = CCV_NNC_MICRO_LOOP_INDEX_TYPE_VAL, |
1037 | 0 | .immediate_value = 0 |
1038 | 0 | }; |
1039 | 0 | *high_ref = (ccv_nnc_micro_loop_index_term_t){ |
1040 | 0 | .type = CCV_NNC_MICRO_LOOP_INDEX_TYPE_VAL, |
1041 | 0 | .immediate_value = 0 |
1042 | 0 | }; |
1043 | 0 | return 1; |
1044 | 24 | case CCV_NNC_MICRO_LOOP_INDEX_TYPE_ID: |
1045 | 24 | if (index.id.type == CCV_NNC_MICRO_LOOP_ID) |
1046 | 24 | { |
1047 | 24 | int loop_idx = -1; |
1048 | 24 | int i; |
1049 | 112 | for (i = 0; loop_idx < 0 && i < loop_count88 ; i++88 ) |
1050 | 88 | if (loops[i].id.id == index.id.id) |
1051 | 24 | loop_idx = i; |
1052 | 24 | assert(loop_idx >= 0); |
1053 | 24 | const ccv_nnc_micro_loop_index_term_t start_index = _ccv_nnc_micro_index_shape_merging(loops[loop_idx].start_index, vars, var_count, groups, axis_id_groups); |
1054 | 24 | const ccv_nnc_micro_loop_index_term_t end_index = _ccv_nnc_micro_index_shape_merging(loops[loop_idx].end_index, vars, var_count, groups, axis_id_groups); |
1055 | 24 | *low_ref = ccv_nnc_micro_loop_index_deep_copy(&start_index); |
1056 | 24 | *high_ref = ccv_nnc_micro_loop_index_deep_copy(&end_index); |
1057 | 24 | } else { |
1058 | 0 | *low_ref = index; |
1059 | 0 | *high_ref = index; |
1060 | 0 | } |
1061 | 24 | return 1; |
1062 | 0 | case CCV_NNC_MICRO_LOOP_INDEX_TYPE_VAL: |
1063 | 0 | *low_ref = index; |
1064 | 0 | *high_ref = index; |
1065 | 0 | return 1; |
1066 | 12 | case CCV_NNC_MICRO_LOOP_INDEX_TYPE_BINARY: { |
1067 | | // Get low, high from both left and right, and then construct new low / high. |
1068 | 12 | ccv_nnc_micro_loop_index_term_t left_low, left_high; |
1069 | 12 | if (!_ccv_nnc_micro_low_high_bound_from_index(index.binary->left, &left_low, &left_high, loops, loop_count, vars, var_count, groups, axis_id_groups)) |
1070 | 0 | return 0; |
1071 | 12 | ccv_nnc_micro_loop_index_term_t right_low, right_high; |
1072 | 12 | if (!_ccv_nnc_micro_low_high_bound_from_index(index.binary->right, &right_low, &right_high, loops, loop_count, vars, var_count, groups, axis_id_groups)) |
1073 | 0 | { |
1074 | 0 | ccv_nnc_micro_loop_index_free(&left_low); |
1075 | 0 | ccv_nnc_micro_loop_index_free(&left_high); |
1076 | 0 | return 0; |
1077 | 0 | } |
1078 | | // If left is not a range, or right is not a range, it is simple, just copy over. |
1079 | 12 | if (_ccv_nnc_same_index_term(left_low, left_high, groups, axis_id_groups) || _ccv_nnc_same_index_term(right_low, right_high, groups, axis_id_groups)) |
1080 | 0 | { |
1081 | 0 | *low_ref = (ccv_nnc_micro_loop_index_term_t){ |
1082 | 0 | .type = CCV_NNC_MICRO_LOOP_INDEX_TYPE_BINARY, |
1083 | 0 | .binary = (ccv_nnc_micro_loop_index_binary_t*)ccmalloc(sizeof(ccv_nnc_micro_loop_index_binary_t)) |
1084 | 0 | }; |
1085 | 0 | low_ref->binary->op = index.binary->op; |
1086 | 0 | low_ref->binary->left = left_low; |
1087 | 0 | low_ref->binary->right = right_low; |
1088 | 0 | *high_ref = (ccv_nnc_micro_loop_index_term_t){ |
1089 | 0 | .type = CCV_NNC_MICRO_LOOP_INDEX_TYPE_BINARY, |
1090 | 0 | .binary = (ccv_nnc_micro_loop_index_binary_t*)ccmalloc(sizeof(ccv_nnc_micro_loop_index_binary_t)) |
1091 | 0 | }; |
1092 | 0 | high_ref->binary->op = index.binary->op; |
1093 | 0 | high_ref->binary->left = left_high; |
1094 | 0 | high_ref->binary->right = right_high; |
1095 | 0 | return 1; |
1096 | 0 | } |
1097 | | // Cannot handle -, because lower bound will go to negative, similar for /. Only can handle + and *. |
1098 | 12 | if (!(index.binary->op == CCV_NNC_MICRO_BINARY_OP_PLUS || index.binary->op == CCV_NNC_MICRO_BINARY_OP_MUL0 ) || |
1099 | | // If lower bound is not a non-negative integer, we cannot compute interesting low / high bound, abort. |
1100 | 12 | (left_low.type != CCV_NNC_MICRO_LOOP_INDEX_TYPE_VAL || left_low.immediate_value < 0) || |
1101 | 12 | (right_low.type != CCV_NNC_MICRO_LOOP_INDEX_TYPE_VAL || right_low.immediate_value < 0)) |
1102 | 0 | { |
1103 | 0 | ccv_nnc_micro_loop_index_free(&left_low); |
1104 | 0 | ccv_nnc_micro_loop_index_free(&left_high); |
1105 | 0 | ccv_nnc_micro_loop_index_free(&right_low); |
1106 | 0 | ccv_nnc_micro_loop_index_free(&right_high); |
1107 | 0 | return 0; |
1108 | 0 | } |
1109 | 12 | *low_ref = (ccv_nnc_micro_loop_index_term_t){ |
1110 | 12 | .type = CCV_NNC_MICRO_LOOP_INDEX_TYPE_VAL, |
1111 | 12 | .immediate_value = index.binary->op == CCV_NNC_MICRO_BINARY_OP_PLUS ? left_low.immediate_value + right_low.immediate_value : left_low.immediate_value * right_low.immediate_value0 , |
1112 | 12 | }; |
1113 | | // higher bound is not inclusive, hence, we need to minus extra 1 for this. |
1114 | 12 | if (index.binary->op == CCV_NNC_MICRO_BINARY_OP_PLUS) |
1115 | 12 | { |
1116 | | // (left - 1) + (right - 1) + 1 |
1117 | 12 | ccv_nnc_micro_loop_index_term_t sum = { |
1118 | 12 | .type = CCV_NNC_MICRO_LOOP_INDEX_TYPE_BINARY, |
1119 | 12 | .binary = (ccv_nnc_micro_loop_index_binary_t*)ccmalloc(sizeof(ccv_nnc_micro_loop_index_binary_t)) |
1120 | 12 | }; |
1121 | 12 | sum.binary->op = CCV_NNC_MICRO_BINARY_OP_PLUS; |
1122 | 12 | sum.binary->left = left_high; |
1123 | 12 | sum.binary->right = right_high; |
1124 | 12 | *high_ref = (ccv_nnc_micro_loop_index_term_t){ |
1125 | 12 | .type = CCV_NNC_MICRO_LOOP_INDEX_TYPE_BINARY, |
1126 | 12 | .binary = (ccv_nnc_micro_loop_index_binary_t*)ccmalloc(sizeof(ccv_nnc_micro_loop_index_binary_t)) |
1127 | 12 | }; |
1128 | 12 | high_ref->binary->op = CCV_NNC_MICRO_BINARY_OP_MINUS; |
1129 | 12 | high_ref->binary->left = sum; |
1130 | 12 | high_ref->binary->right = (ccv_nnc_micro_loop_index_term_t){ |
1131 | 12 | .type = CCV_NNC_MICRO_LOOP_INDEX_TYPE_VAL, |
1132 | 12 | .immediate_value = 1 |
1133 | 12 | }; |
1134 | 12 | } else { |
1135 | | // (left - 1) * (right - 1) + 1 |
1136 | 0 | ccv_nnc_micro_loop_index_term_t prod = { |
1137 | 0 | .type = CCV_NNC_MICRO_LOOP_INDEX_TYPE_BINARY, |
1138 | 0 | .binary = (ccv_nnc_micro_loop_index_binary_t*)ccmalloc(sizeof(ccv_nnc_micro_loop_index_binary_t)) |
1139 | 0 | }; |
1140 | 0 | prod.binary->op = CCV_NNC_MICRO_BINARY_OP_MUL; |
1141 | 0 | prod.binary->left = left_high; |
1142 | 0 | prod.binary->right = right_high; |
1143 | 0 | ccv_nnc_micro_loop_index_term_t minus_left = { |
1144 | 0 | .type = CCV_NNC_MICRO_LOOP_INDEX_TYPE_BINARY, |
1145 | 0 | .binary = (ccv_nnc_micro_loop_index_binary_t*)ccmalloc(sizeof(ccv_nnc_micro_loop_index_binary_t)) |
1146 | 0 | }; |
1147 | 0 | minus_left.binary->op = CCV_NNC_MICRO_BINARY_OP_MINUS; |
1148 | 0 | minus_left.binary->left = prod; |
1149 | 0 | minus_left.binary->right = left_high; |
1150 | 0 | ccv_nnc_micro_loop_index_term_t minus_right = { |
1151 | 0 | .type = CCV_NNC_MICRO_LOOP_INDEX_TYPE_BINARY, |
1152 | 0 | .binary = (ccv_nnc_micro_loop_index_binary_t*)ccmalloc(sizeof(ccv_nnc_micro_loop_index_binary_t)) |
1153 | 0 | }; |
1154 | 0 | minus_right.binary->op = CCV_NNC_MICRO_BINARY_OP_MINUS; |
1155 | 0 | minus_right.binary->left = minus_left; |
1156 | 0 | minus_right.binary->right = right_high; |
1157 | 0 | *high_ref = (ccv_nnc_micro_loop_index_term_t){ |
1158 | 0 | .type = CCV_NNC_MICRO_LOOP_INDEX_TYPE_BINARY, |
1159 | 0 | .binary = (ccv_nnc_micro_loop_index_binary_t*)ccmalloc(sizeof(ccv_nnc_micro_loop_index_binary_t)) |
1160 | 0 | }; |
1161 | 0 | high_ref->binary->op = CCV_NNC_MICRO_BINARY_OP_PLUS; |
1162 | 0 | high_ref->binary->left = minus_right; |
1163 | 0 | high_ref->binary->right = (ccv_nnc_micro_loop_index_term_t){ |
1164 | 0 | .type = CCV_NNC_MICRO_LOOP_INDEX_TYPE_VAL, |
1165 | 0 | .immediate_value = 2 |
1166 | 0 | }; |
1167 | 0 | } |
1168 | 12 | return 1; |
1169 | 12 | } |
1170 | 36 | } |
1171 | 0 | return 0; |
1172 | 36 | } |
1173 | | |
1174 | | static void _ccv_nnc_micro_check_bound_for_variable(ccv_nnc_micro_loop_variable_t* const variable, const ccv_nnc_micro_loop_t* const loops, const int loop_count, const ccv_nnc_micro_tensor_t* const vars, const int var_count, const int* const groups, khash_t(ccv_nnc_axis_id_group)* const axis_id_groups) |
1175 | 39 | { |
1176 | 39 | if (variable->id.type != CCV_NNC_MICRO_TENSOR_ID) |
1177 | 0 | return; |
1178 | 39 | int i, j; |
1179 | 39 | assert(variable->id.id >= 0 && variable->id.id < var_count); |
1180 | 39 | ccv_nnc_micro_loop_index_term_t index_zero = { |
1181 | 39 | .type = CCV_NNC_MICRO_LOOP_INDEX_TYPE_VAL, |
1182 | 39 | .immediate_value = 0 |
1183 | 39 | }; |
1184 | 188 | for (i = 0; i < variable->index_count; i++149 ) |
1185 | 149 | { |
1186 | 149 | const ccv_nnc_micro_loop_index_term_t shape = _ccv_nnc_micro_index_shape_merging((ccv_nnc_micro_loop_index_term_t){ |
1187 | 149 | .type = CCV_NNC_MICRO_LOOP_INDEX_TYPE_ID, |
1188 | 149 | .id = { |
1189 | 149 | .type = CCV_NNC_MICRO_AXIS_SIZE_ID, |
1190 | 149 | .id = variable->id.id, |
1191 | 149 | .d = i |
1192 | 149 | } |
1193 | 149 | }, vars, var_count, groups, axis_id_groups); |
1194 | 149 | switch (variable->index[i].type) |
1195 | 149 | { |
1196 | 116 | case CCV_NNC_MICRO_LOOP_INDEX_TYPE_ID: |
1197 | | // For loop id, we can check the range to see if it is within the shape. |
1198 | 116 | if (variable->index[i].id.type == CCV_NNC_MICRO_LOOP_ID) |
1199 | 116 | { |
1200 | 116 | int loop_idx = -1; |
1201 | 482 | for (j = 0; loop_idx < 0 && j < loop_count366 ; j++366 ) |
1202 | 366 | if (loops[j].id.id == variable->index[i].id.id) |
1203 | 116 | loop_idx = j; |
1204 | 116 | assert(loop_idx >= 0); |
1205 | 116 | const ccv_nnc_micro_loop_index_term_t start_index = _ccv_nnc_micro_index_shape_merging(loops[loop_idx].start_index, vars, var_count, groups, axis_id_groups); |
1206 | 116 | const ccv_nnc_micro_loop_index_term_t end_index = _ccv_nnc_micro_index_shape_merging(loops[loop_idx].end_index, vars, var_count, groups, axis_id_groups); |
1207 | 116 | if (_ccv_nnc_index_less_than_or_equal_to(index_zero, start_index, vars, var_count, groups, axis_id_groups) == 1 && |
1208 | 116 | _ccv_nnc_index_less_than_or_equal_to(end_index, shape, vars, var_count, groups, axis_id_groups) == 1) |
1209 | 104 | variable->no_check_bound[i] = 1; |
1210 | 12 | else |
1211 | 12 | variable->no_check_bound[i] = 0; |
1212 | 116 | } else // If it is anything other than loop id, we have to check the bound. |
1213 | 0 | variable->no_check_bound[i] = 0; |
1214 | 116 | break; |
1215 | 116 | case CCV_NNC_MICRO_LOOP_INDEX_TYPE_BINARY: { |
1216 | | // Compute higher / lower bounds along the expression. |
1217 | 12 | ccv_nnc_micro_loop_index_term_t low, high; |
1218 | | // Cannot find high low, mark no_check_bound[i] = 0 |
1219 | 12 | if (!_ccv_nnc_micro_low_high_bound_from_index(variable->index[i], &low, &high, loops, loop_count, vars, var_count, groups, axis_id_groups)) |
1220 | 0 | { |
1221 | 0 | variable->no_check_bound[i] = 0; |
1222 | 0 | break; |
1223 | 0 | } |
1224 | 12 | if (_ccv_nnc_index_less_than_or_equal_to(index_zero, low, vars, var_count, groups, axis_id_groups) == 1 && |
1225 | 12 | _ccv_nnc_index_less_than_or_equal_to(high, shape, vars, var_count, groups, axis_id_groups) == 1) |
1226 | 12 | variable->no_check_bound[i] = 1; |
1227 | 0 | else |
1228 | 0 | variable->no_check_bound[i] = 0; |
1229 | 12 | ccv_nnc_micro_loop_index_free(&low); |
1230 | 12 | ccv_nnc_micro_loop_index_free(&high); |
1231 | 12 | break; |
1232 | 12 | } |
1233 | 21 | case CCV_NNC_MICRO_LOOP_INDEX_TYPE_VAL: |
1234 | | // If the index is an integer, and it is bigger than 0, we need to check bound (there is no assertion the end index is larger than anything other than 0). |
1235 | 21 | if (variable->index[i].immediate_value == 0) |
1236 | 21 | variable->no_check_bound[i] = 1; |
1237 | 0 | else |
1238 | 0 | variable->no_check_bound[i] = 0; |
1239 | 21 | break; |
1240 | 149 | } |
1241 | 149 | } |
1242 | 39 | } |
1243 | | |
1244 | | static void _ccv_nnc_micro_check_bound_for_expression(ccv_nnc_micro_loop_expression_t* const expression, const ccv_nnc_micro_loop_t* const loops, const int loop_count, const ccv_nnc_micro_tensor_t* const vars, const int var_count, const int* const groups, khash_t(ccv_nnc_axis_id_group)* const axis_id_groups) |
1245 | 39 | { |
1246 | 39 | switch (expression->type) |
1247 | 39 | { |
1248 | 21 | case CCV_NNC_MICRO_LOOP_EXPR_TYPE_VAR: |
1249 | 21 | _ccv_nnc_micro_check_bound_for_variable(&expression->variable, loops, loop_count, vars, var_count, groups, axis_id_groups); |
1250 | 21 | break; |
1251 | 0 | case CCV_NNC_MICRO_LOOP_EXPR_TYPE_TERNAY: |
1252 | 0 | _ccv_nnc_micro_check_bound_for_expression(expression->ternary.pivot, loops, loop_count, vars, var_count, groups, axis_id_groups); |
1253 | 0 | _ccv_nnc_micro_check_bound_for_expression(expression->ternary.left, loops, loop_count, vars, var_count, groups, axis_id_groups); |
1254 | 0 | _ccv_nnc_micro_check_bound_for_expression(expression->ternary.right, loops, loop_count, vars, var_count, groups, axis_id_groups); |
1255 | 0 | break; |
1256 | 9 | case CCV_NNC_MICRO_LOOP_EXPR_TYPE_BINARY: |
1257 | 9 | _ccv_nnc_micro_check_bound_for_expression(expression->binary.left, loops, loop_count, vars, var_count, groups, axis_id_groups); |
1258 | 9 | _ccv_nnc_micro_check_bound_for_expression(expression->binary.right, loops, loop_count, vars, var_count, groups, axis_id_groups); |
1259 | 9 | break; |
1260 | 0 | case CCV_NNC_MICRO_LOOP_EXPR_TYPE_UNARY: |
1261 | 0 | _ccv_nnc_micro_check_bound_for_expression(expression->unary.x, loops, loop_count, vars, var_count, groups, axis_id_groups); |
1262 | 0 | break; |
1263 | 39 | } |
1264 | 39 | } |
1265 | | |
1266 | | static void _ccv_nnc_micro_check_bound_for_block(ccv_nnc_micro_loop_block_t* const block, const ccv_nnc_micro_tensor_t* const vars, const int var_count, const int* const groups, khash_t(ccv_nnc_axis_id_group)* const axis_id_groups) |
1267 | 15 | { |
1268 | 15 | int i, j; |
1269 | 78 | for (i = 0; i < block->loop_count; i++63 ) |
1270 | 63 | { |
1271 | 63 | const int statement_count = block->loops[i].statement_count; |
1272 | 63 | ccv_nnc_micro_loop_statement_t* const statements = block->loops[i].statements; |
1273 | 84 | for (j = 0; j < statement_count; j++21 ) |
1274 | 21 | { |
1275 | 21 | switch (statements[j].type) |
1276 | 21 | { |
1277 | 12 | case CCV_NNC_MICRO_LOOP_STATEMENT_TYPE_ASSIGNMENT: |
1278 | 12 | _ccv_nnc_micro_check_bound_for_variable(&statements[j].assignment.lvalue, block->loops, block->loop_count, vars, var_count, groups, axis_id_groups); |
1279 | 12 | _ccv_nnc_micro_check_bound_for_expression(&statements[j].assignment.rvalue, block->loops, block->loop_count, vars, var_count, groups, axis_id_groups); |
1280 | 12 | break; |
1281 | 9 | case CCV_NNC_MICRO_LOOP_STATEMENT_TYPE_COMPOUND_ASSIGNMENT: |
1282 | 9 | if (statements[j].compound_assignment.lvalue.type == CCV_NNC_MICRO_LOOP_EXPR_TYPE_VAR) |
1283 | 6 | _ccv_nnc_micro_check_bound_for_variable(&statements[j].compound_assignment.lvalue.variable, block->loops, block->loop_count, vars, var_count, groups, axis_id_groups); |
1284 | 9 | _ccv_nnc_micro_check_bound_for_expression(&statements[j].compound_assignment.rvalue, block->loops, block->loop_count, vars, var_count, groups, axis_id_groups); |
1285 | 9 | break; |
1286 | 21 | } |
1287 | 21 | } |
1288 | 63 | } |
1289 | 15 | } |
1290 | | |
1291 | | void ccv_nnc_micro_program_simplify(ccv_nnc_micro_program_t* const program, const ccv_nnc_micro_io_t* const inputs, const int input_size, const ccv_nnc_micro_io_t* const outputs, const int output_size, const ccv_array_t* const equal_assertions) |
1292 | 6 | { |
1293 | | // Nothing to simplify for. |
1294 | 6 | if (program->function_count < 1) |
1295 | 0 | return; |
1296 | | // Only one block, nothing to simplify for. |
1297 | 6 | if (program->function_count == 1 && program->functions[0].block_count == 10 ) |
1298 | 0 | return; |
1299 | 6 | if (input_size == 0 || output_size == 0) |
1300 | 0 | return; |
1301 | | // Union-find to group all variables with the same shape. |
1302 | 6 | ccv_nnc_micro_tensor_t* const vars = program->vars; |
1303 | 6 | const int var_count = program->var_count; |
1304 | 6 | int* const groups = (int*)ccmalloc(sizeof(int) * var_count); |
1305 | 6 | int i, j; |
1306 | 60 | for (i = 0; i < var_count; i++54 ) |
1307 | 54 | groups[i] = i; |
1308 | | // If no shape, they should match these input. |
1309 | 60 | for (i = 0; i < var_count; i++54 ) |
1310 | 54 | if (vars[i].input >= 0 && !vars[i].shape42 ) |
1311 | 24 | { |
1312 | 24 | int root = vars[i].input; |
1313 | 27 | while (groups[root] != root) |
1314 | 3 | root = groups[root]; |
1315 | 24 | groups[i] = root; |
1316 | 24 | } |
1317 | 60 | for (i = 0; i < var_count; i++54 ) |
1318 | 54 | { |
1319 | | // If this is input (no other tensor as the input), we skip. |
1320 | 54 | if (vars[i].input < 0) |
1321 | 12 | continue; |
1322 | 42 | int root = i; |
1323 | 66 | while (groups[root] != root) |
1324 | 24 | root = groups[root]; |
1325 | | // If the sibling exists and we haven't visited yet, mark them has the same group as us. |
1326 | 42 | if (vars[i].sibling >= 0 && vars[i].sibling < i24 && groups[vars[i].sibling] < 024 ) |
1327 | 0 | groups[vars[i].sibling] = root; |
1328 | 42 | } |
1329 | 54 | for (i = var_count - 1; i > 0; i--48 ) |
1330 | 48 | { |
1331 | | // Now matching the shape. |
1332 | 48 | if (vars[i].input < 0 || !vars[i].shape42 ) |
1333 | 30 | continue; |
1334 | 18 | int root = i; |
1335 | 22 | while (groups[root] != root) |
1336 | 4 | root = groups[root]; |
1337 | 78 | for (j = i - 1; j >= 0; j--60 ) |
1338 | 60 | if (vars[j].shape && vars[j].dimensions == vars[i].dimensions18 && |
1339 | 60 | _ccv_nnc_same_shape(vars[j].shape, vars[i].shape, vars[i].dimensions)18 ) |
1340 | 4 | groups[j] = root; |
1341 | 18 | } |
1342 | | // Group equal assertions on axis together. |
1343 | 6 | khash_t(ccv_nnc_axis_id_group)* const axis_id_groups = kh_init(ccv_nnc_axis_id_group); |
1344 | 14 | for (i = 0; i < equal_assertions->rnum; i++8 ) |
1345 | 8 | { |
1346 | 8 | const ccv_nnc_micro_id_equal_assertion_t* const equal_assertion = (ccv_nnc_micro_id_equal_assertion_t*)ccv_array_get(equal_assertions, i); |
1347 | 8 | ccv_nnc_micro_id_t left = equal_assertion->left; |
1348 | 8 | while (groups[left.id] != left.id) |
1349 | 0 | left.id = groups[left.id]; |
1350 | 8 | int left_root = MICRO_ID_TO_INT(left); |
1351 | 8 | khiter_t k; |
1352 | 10 | for (;;) { |
1353 | 10 | k = kh_get(ccv_nnc_axis_id_group, axis_id_groups, left_root); |
1354 | 10 | if (k == kh_end(axis_id_groups)) |
1355 | 8 | break; |
1356 | 2 | left_root = kh_val(axis_id_groups, k); |
1357 | 2 | } |
1358 | 8 | ccv_nnc_micro_id_t right = equal_assertion->right; |
1359 | 8 | while (groups[right.id] != right.id) |
1360 | 0 | left.id = groups[right.id]; |
1361 | 8 | int right_root = MICRO_ID_TO_INT(equal_assertion->right); |
1362 | 10 | for (;;) { |
1363 | 10 | k = kh_get(ccv_nnc_axis_id_group, axis_id_groups, right_root); |
1364 | 10 | if (k == kh_end(axis_id_groups)) |
1365 | 8 | break; |
1366 | 2 | right_root = kh_val(axis_id_groups, k); |
1367 | 2 | } |
1368 | 8 | if (left_root != right_root) // k is the right root at the moment. |
1369 | 4 | { |
1370 | 4 | int ret; |
1371 | 4 | k = kh_put(ccv_nnc_axis_id_group, axis_id_groups, right_root, &ret); |
1372 | 4 | assert(ret != 0); |
1373 | 4 | kh_val(axis_id_groups, k) = left_root; |
1374 | 4 | } |
1375 | 8 | } |
1376 | | // First, flat out all functions into blocks. |
1377 | 6 | ccv_array_t* const blocks = ccv_array_new(sizeof(ccv_nnc_micro_loop_block_t), 0, 0); |
1378 | 6 | ccv_nnc_micro_function_t* const functions = program->functions; |
1379 | 6 | const int function_count = program->function_count; |
1380 | 6 | int max_loop_count = 0; |
1381 | 42 | for (i = 0; i < function_count; i++36 ) |
1382 | 36 | { |
1383 | 36 | const int block_count = functions[i].block_count; |
1384 | 36 | ccv_nnc_micro_loop_block_t* const function_blocks = block_count == 1 ? &functions[i].one_block27 : functions[i].blocks9 ; |
1385 | 81 | for (j = 0; j < block_count; j++45 ) |
1386 | 45 | { |
1387 | 45 | max_loop_count = ccv_max(function_blocks[j].loop_count, max_loop_count); |
1388 | 45 | ccv_array_push(blocks, &function_blocks[j]); |
1389 | 45 | } |
1390 | 36 | } |
1391 | | // Next, find dependencies between these function blocks and marking these that are dependencies for the final outputs. |
1392 | | // We need to build our connections between blocks <-> r/w vars. |
1393 | 6 | ccv_nnc_micro_loop_block_dependency_t* block_dependencies; |
1394 | 6 | ccv_nnc_micro_tensor_dependency_t* tensor_dependencies; |
1395 | 6 | const int block_size = blocks->rnum; |
1396 | 6 | _ccv_nnc_micro_block_dependencies((ccv_nnc_micro_loop_block_t*)ccv_array_get(blocks, 0), block_size, var_count, &block_dependencies, &tensor_dependencies); |
1397 | 6 | ccv_array_t* const in_use = ccv_array_new(sizeof(int), output_size, 0); |
1398 | | // Use the dependencies to mark blocks / vars that are in use. |
1399 | 15 | for (i = 0; i < output_size; i++9 ) |
1400 | 9 | { |
1401 | 9 | tensor_dependencies[outputs[i]->id].flag = 1; // Mark them as in use. |
1402 | 9 | ccv_array_push(in_use, &outputs[i]->id); |
1403 | 9 | } |
1404 | 21 | for (i = 0; i < input_size; i++15 ) |
1405 | 15 | tensor_dependencies[inputs[i]->id].flag = 1; // Mark inputs as in use so we don't go pass them. |
1406 | 39 | for (i = 0; i < in_use->rnum; i++33 ) |
1407 | 33 | { |
1408 | 33 | const int tensor_idx = *(int*)ccv_array_get(in_use, i); |
1409 | 33 | if (tensor_dependencies[tensor_idx].writes) |
1410 | 72 | for (j = 0; 33 j < tensor_dependencies[tensor_idx].writes->rnum; j++39 ) |
1411 | 39 | { |
1412 | 39 | const int block_idx = *(int*)ccv_array_get(tensor_dependencies[tensor_idx].writes, j); |
1413 | 39 | block_dependencies[block_idx].flag = 1; |
1414 | 39 | int k; |
1415 | 39 | if (block_dependencies[block_idx].reads) |
1416 | 81 | for (k = 0; 33 k < block_dependencies[block_idx].reads->rnum; k++48 ) |
1417 | 48 | { |
1418 | 48 | const int read_idx = *(int*)ccv_array_get(block_dependencies[block_idx].reads, k); |
1419 | 48 | if (!tensor_dependencies[read_idx].flag) |
1420 | 24 | { |
1421 | 24 | tensor_dependencies[read_idx].flag = 1; |
1422 | 24 | ccv_array_push(in_use, &read_idx); |
1423 | 24 | } |
1424 | 48 | } |
1425 | 39 | } |
1426 | 33 | } |
1427 | 6 | ccv_array_free(in_use); |
1428 | 51 | for (i = 0; i < block_size; i++45 ) |
1429 | 45 | if (!block_dependencies[i].flag) |
1430 | 6 | { |
1431 | 6 | ccv_nnc_micro_loop_block_t* const block = (ccv_nnc_micro_loop_block_t*)ccv_array_get(blocks, i); |
1432 | 6 | ccv_nnc_micro_loops_free(block->loops, block->loop_count); |
1433 | 6 | ccfree(block->loops); |
1434 | 6 | block->loops = 0; |
1435 | 6 | block->loop_count = 0; |
1436 | 6 | } |
1437 | 60 | for (i = 0; i < var_count; i++54 ) |
1438 | 54 | if (!tensor_dependencies[i].flag) // If this tensor is not visited, there is no need to alloc. |
1439 | 6 | { |
1440 | 6 | _ccv_nnc_tensor_remove_dead_store(&tensor_dependencies[i], i, blocks); |
1441 | 6 | vars[i].no_alloc = 1; |
1442 | 6 | } |
1443 | 6 | _ccv_nnc_loop_merging(block_dependencies, tensor_dependencies, blocks, max_loop_count, groups, axis_id_groups); |
1444 | 6 | _ccv_nnc_micro_dependencies_free(block_dependencies, block_size, tensor_dependencies, var_count); |
1445 | | // Culling out empty blocks. |
1446 | 51 | for (i = 0, j = 0; i < blocks->rnum; i++45 ) |
1447 | 45 | { |
1448 | 45 | const ccv_nnc_micro_loop_block_t* const block = (ccv_nnc_micro_loop_block_t*)ccv_array_get(blocks, i); |
1449 | 45 | if (block->loop_count > 0) |
1450 | 15 | { |
1451 | 15 | *(ccv_nnc_micro_loop_block_t*)ccv_array_get(blocks, j) = *block; |
1452 | 15 | ++j; |
1453 | 15 | } |
1454 | 45 | } |
1455 | | // Now we moved everything, set the proper block size. |
1456 | 6 | ccv_array_resize(blocks, j); |
1457 | | // Substitute variables. |
1458 | 6 | _ccv_nnc_var_subst(vars, var_count, inputs, input_size, outputs, output_size, blocks, groups, axis_id_groups); |
1459 | | // Mark whether we need to check bound for a particular variable or not. |
1460 | 21 | for (i = 0; i < blocks->rnum; i++15 ) |
1461 | 15 | { |
1462 | 15 | ccv_nnc_micro_loop_block_t* const block = (ccv_nnc_micro_loop_block_t*)ccv_array_get(blocks, i); |
1463 | 15 | _ccv_nnc_micro_check_bound_for_block(block, vars, var_count, groups, axis_id_groups); |
1464 | 15 | } |
1465 | 6 | free(groups); |
1466 | 6 | kh_destroy(ccv_nnc_axis_id_group, axis_id_groups); |
1467 | | // Reallocate function to be 1. |
1468 | 42 | for (i = 0; i < function_count; i++36 ) |
1469 | 36 | if (functions[i].block_count > 1) |
1470 | 9 | ccfree(functions[i].blocks); |
1471 | 6 | program->functions = (ccv_nnc_micro_function_t*)ccrealloc(program->functions, sizeof(ccv_nnc_micro_function_t)); |
1472 | 6 | program->functions[0].block_count = blocks->rnum; |
1473 | 6 | if (blocks->rnum > 1) |
1474 | 4 | { |
1475 | 4 | program->functions[0].blocks = (ccv_nnc_micro_loop_block_t*)ccmalloc(sizeof(ccv_nnc_micro_loop_block_t) * blocks->rnum); |
1476 | 4 | memcpy(program->functions[0].blocks, ccv_array_get(blocks, 0), sizeof(ccv_nnc_micro_loop_block_t) * blocks->rnum); |
1477 | 4 | } else |
1478 | 2 | program->functions[0].one_block = *(ccv_nnc_micro_loop_block_t*)ccv_array_get(blocks, 0); |
1479 | 6 | program->function_count = 1; |
1480 | 6 | ccv_array_free(blocks); |
1481 | 6 | } |