Coverage Report

Created: 2024-08-18 16:21

/home/liu/actions-runner/_work/ccv/ccv/lib/nnc/ccv_nnc_micro_simplify.c
Line
Count
Source (jump to first uncovered line)
1
#include "ccv_nnc.h"
2
#include "ccv_nnc_easy.h"
3
#include "ccv_nnc_internal.h"
4
#include "ccv_internal.h"
5
#include "_ccv_nnc_micro.h"
6
#include "3rdparty/khash/khash.h"
7
8
1.33k
#define MICRO_ID_TO_INT(x) (((x).id << 8) | ((x).d))
9
KHASH_MAP_INIT_INT(ccv_nnc_axis_id_group, int)
10
11
static int _ccv_nnc_same_index_term(const ccv_nnc_micro_loop_index_term_t a_index, const ccv_nnc_micro_loop_index_term_t b_index, const int* const groups, khash_t(ccv_nnc_axis_id_group)* const axis_id_groups)
12
1.57k
{
13
1.57k
  if (a_index.type != b_index.type)
14
108
    return 0;
15
1.46k
  const int type = a_index.type;
16
1.46k
  switch (type)
17
1.46k
  {
18
539
    case CCV_NNC_MICRO_LOOP_INDEX_TYPE_VAL:
19
539
      return a_index.immediate_value == b_index.immediate_value;
20
912
    case CCV_NNC_MICRO_LOOP_INDEX_TYPE_ID:
21
912
      if (a_index.id.type != b_index.id.type)
22
33
        return 0;
23
      // Check within the axis_id_groups to see if there is a match, if there is no match, we can proceed (to use the group table again to check).
24
879
      if (axis_id_groups && 
a_index.id.type == CCV_NNC_MICRO_AXIS_SIZE_ID827
)
25
659
      {
26
659
        ccv_nnc_micro_id_t a_id = a_index.id;
27
1.08k
        while (groups && groups[a_id.id] != a_id.id)
28
429
          a_id.id = groups[a_id.id];
29
659
        int a_root = MICRO_ID_TO_INT(a_id);
30
659
        khiter_t k;
31
677
        for (;;) {
32
677
          k = kh_get(ccv_nnc_axis_id_group, axis_id_groups, a_root);
33
677
          if (k == kh_end(axis_id_groups))
34
659
            break;
35
18
          a_root = kh_val(axis_id_groups, k);
36
18
        }
37
659
        ccv_nnc_micro_id_t b_id = b_index.id;
38
1.16k
        while (groups && groups[b_id.id] != b_id.id)
39
506
          b_id.id = groups[b_id.id];
40
659
        int b_root = MICRO_ID_TO_INT(b_id);
41
692
        for (;;) {
42
692
          k = kh_get(ccv_nnc_axis_id_group, axis_id_groups, b_root);
43
692
          if (k == kh_end(axis_id_groups))
44
659
            break;
45
33
          b_root = kh_val(axis_id_groups, k);
46
33
        }
47
659
        if (a_root == b_root)
48
271
          return 1;
49
659
      }
50
608
      if (groups && 
(556
a_index.id.type == CCV_NNC_MICRO_AXIS_SIZE_ID556
||
a_index.id.type == CCV_NNC_MICRO_TENSOR_ID168
))
51
388
      {
52
388
        if (a_index.id.d != b_index.id.d)
53
291
          return 0;
54
97
        switch (a_index.id.type)
55
97
        {
56
0
          case CCV_NNC_MICRO_TENSOR_ID:
57
97
          case CCV_NNC_MICRO_AXIS_SIZE_ID: {
58
            // Find their group identifier and then compare.
59
97
            int a_root = groups[a_index.id.id];
60
97
            while (groups[a_root] != a_root)
61
0
              a_root = groups[a_root];
62
97
            int b_root = groups[b_index.id.id];
63
97
            while (groups[b_root] != b_root)
64
0
              b_root = groups[b_root];
65
97
            return a_root == b_root;
66
0
          }
67
97
        }
68
0
        return a_index.id.id == b_index.id.id;
69
97
      } else
70
220
        return (a_index.id.d == b_index.id.d && 
a_index.id.id == b_index.id.id218
);
71
16
    case CCV_NNC_MICRO_LOOP_INDEX_TYPE_BINARY: {
72
16
      return a_index.binary->op == b_index.binary->op && _ccv_nnc_same_index_term(a_index.binary->left, b_index.binary->left, groups, axis_id_groups) && _ccv_nnc_same_index_term(a_index.binary->right, b_index.binary->right, groups, axis_id_groups);
73
608
    }
74
1.46k
  }
75
0
  return 0;
76
1.46k
}
77
78
static int _ccv_nnc_same_shape(const ccv_nnc_micro_loop_index_term_t* const a_shape, const ccv_nnc_micro_loop_index_term_t* const b_shape, const int dimensions)
79
18
{
80
18
  int i;
81
48
  for (i = 0; i < dimensions; 
i++30
)
82
44
    if (!_ccv_nnc_same_index_term(a_shape[i], b_shape[i], 0, 0))
83
14
      return 0;
84
4
  return 1;
85
18
}
86
87
static int _ccv_nnc_same_loop(const ccv_nnc_micro_loop_block_t* const left_block, const ccv_nnc_micro_loop_block_t* const right_block, const int* const groups, khash_t(ccv_nnc_axis_id_group)* const axis_id_groups, int* const left_loop_idx, int* const right_loop_idx)
88
47
{
89
47
  assert(left_block->loop_count > 0);
90
47
  assert(right_block->loop_count > 0);
91
47
  int i, j;
92
47
  int left_right_link[left_block->loop_count];
93
47
  int right_left_link[right_block->loop_count];
94
47
  enum {
95
47
    ONE = -2,
96
47
    UNASSIGNED = -1,
97
47
  };
98
279
  for (i = 0; i < left_block->loop_count; 
i++232
)
99
232
    if (left_block->loops[i].start_index.type == CCV_NNC_MICRO_LOOP_INDEX_TYPE_VAL && left_block->loops[i].start_index.immediate_value == 0 &&
100
232
      left_block->loops[i].end_index.type == CCV_NNC_MICRO_LOOP_INDEX_TYPE_VAL && 
left_block->loops[i].end_index.immediate_value == 10
)
101
0
      left_right_link[i] = ONE;
102
232
    else
103
232
      left_right_link[i] = UNASSIGNED;
104
277
  for (i = 0; i < right_block->loop_count; 
i++230
)
105
230
    if (right_block->loops[i].start_index.type == CCV_NNC_MICRO_LOOP_INDEX_TYPE_VAL && right_block->loops[i].start_index.immediate_value == 0 &&
106
230
      right_block->loops[i].end_index.type == CCV_NNC_MICRO_LOOP_INDEX_TYPE_VAL && 
right_block->loops[i].end_index.immediate_value == 10
)
107
0
      right_left_link[i] = ONE;
108
230
    else
109
230
      right_left_link[i] = UNASSIGNED;
110
279
  for (i = 0; i < left_block->loop_count; 
i++232
) // Find corresponding loop on the right.
111
232
  {
112
232
    if (left_right_link[i] != UNASSIGNED)
113
0
      continue;
114
232
    int flag = UNASSIGNED;
115
1.12k
    for (j = 0; j < right_block->loop_count && 
flag == UNASSIGNED1.01k
;
j++889
)
116
889
    {
117
889
      if (right_left_link[j] != UNASSIGNED)
118
382
        continue;
119
507
      if (_ccv_nnc_same_index_term(left_block->loops[i].start_index, right_block->loops[j].start_index, groups, axis_id_groups) && 
120
507
        _ccv_nnc_same_index_term(left_block->loops[i].end_index, right_block->loops[j].end_index, groups, axis_id_groups))
121
146
        flag = j;
122
507
    }
123
232
    if (flag != UNASSIGNED)
124
146
    {
125
146
      left_right_link[i] = flag;
126
146
      right_left_link[flag] = i;
127
146
    }
128
232
  }
129
  // If still have unmatched, they don't share the same loop.
130
191
  for (i = 0; i < left_block->loop_count; 
i++144
)
131
167
    if (left_right_link[i] == UNASSIGNED)
132
23
      return 0;
133
168
  
for (i = 0; 24
i < right_block->loop_count;
i++144
)
134
144
    if (right_left_link[i] == UNASSIGNED)
135
0
      return 0;
136
  // I don't want to deal with constant loop, hence, if other than the outer-most is a constant loop (0..<1),
137
  // we cannot merge.
138
144
  
for (i = 1; 24
i < left_block->loop_count;
i++120
)
139
120
    if (left_right_link[i] == ONE)
140
0
      return 0;
141
144
  
for (i = 1; 24
i < right_block->loop_count;
i++120
)
142
120
    if (right_left_link[i] == ONE)
143
0
      return 0;
144
24
  assert((left_block->loop_count == right_block->loop_count) ||
145
24
      (left_block->loop_count == right_block->loop_count + 1) ||
146
24
      (left_block->loop_count + 1 == right_block->loop_count));
147
  // The loop matches, but the ordering probably doesn't. We reorder loop based on statements.
148
  // Hence, two loops can only merge if using the statements as a pivot point and they can still
149
  // match things before / after the statement.
150
  // If both have statements, check if order preserving within the statement loop (we can be fancier
151
  // and recursively call this while using statement as pivoting point, but that is too much to my taste).
152
24
  const int left_start_idx = left_right_link[0] == ONE ? 
10
: 0;
153
24
  const int right_start_idx = right_left_link[0] == ONE ? 
10
: 0;
154
168
  for (i = 0; i < left_block->loop_count; 
i++144
)
155
144
    left_loop_idx[i] = UNASSIGNED;
156
168
  for (i = 0; i < right_block->loop_count; 
i++144
)
157
144
    right_loop_idx[i] = UNASSIGNED;
158
24
  if (left_start_idx == 1)
159
0
    left_loop_idx[0] = 0; // Assign their index.
160
24
  if (right_start_idx == 0)
161
24
    right_loop_idx[0] = 0; // Assign their index.
162
24
  const int end_idx = left_right_link[0] == ONE && 
right_left_link[0] == ONE0
?
left_block->loop_count - 10
: ccv_min(left_block->loop_count, right_block->loop_count);
163
24
  int pivot_idx = end_idx;
164
24
  int k;
165
168
  for (i = end_idx - 1; i >= 0; 
i--144
)
166
144
  {
167
144
    if (left_block->loops[i + left_start_idx].statement_count > 0)
168
24
    {
169
24
      for (j = i + 1, k = i + 1; j < end_idx; 
j++0
)
170
0
        if (left_loop_idx[j + left_start_idx] == UNASSIGNED)
171
0
        {
172
0
          left_loop_idx[j + left_start_idx] = k + left_start_idx;
173
          // If the right one can be referenced pass previous pivot_idx, it is not right.
174
0
          if (left_right_link[j + left_start_idx] >= pivot_idx + right_start_idx)
175
0
            return 0;
176
0
          right_loop_idx[left_right_link[j + left_start_idx]] = k + right_start_idx;
177
0
          ++k;
178
0
          if (k > pivot_idx)
179
0
            return 0;
180
0
        }
181
24
      assert(k == pivot_idx);
182
24
      pivot_idx = i + 1;
183
24
    }
184
144
    if (right_block->loops[i + right_start_idx].statement_count > 0)
185
27
    {
186
34
      for (j = i + 1, k = i + 1; j < end_idx; 
j++7
)
187
7
        if (right_loop_idx[j + left_start_idx] == UNASSIGNED)
188
7
        {
189
7
          right_loop_idx[j + right_start_idx] = k + right_start_idx;
190
          // If the left one can be referenced pass previous pivot_idx, it is not right.
191
7
          if (right_left_link[j + right_start_idx] >= pivot_idx + left_start_idx)
192
0
            return 0;
193
7
          left_loop_idx[right_left_link[j + right_start_idx]] = k + left_start_idx;
194
7
          ++k;
195
7
          if (k > pivot_idx)
196
0
            return 0;
197
7
        }
198
27
      assert(k == pivot_idx);
199
27
      pivot_idx = i + 1;
200
27
    }
201
144
  }
202
24
  if (end_idx == 0)
203
0
    return 1;
204
  // Finally, to distribute the rest.
205
168
  
for (j = 0, k = 0; 24
j < end_idx;
j++144
)
206
144
  {
207
144
    if (left_loop_idx[j + left_start_idx] == UNASSIGNED)
208
137
    {
209
137
      left_loop_idx[j + left_start_idx] = k + left_start_idx;
210
      // If the right one can be referenced pass previous pivot_idx, it is not right.
211
137
      if (left_right_link[j + left_start_idx] >= pivot_idx + right_start_idx)
212
0
        return 0;
213
137
      right_loop_idx[left_right_link[j + left_start_idx]] = k + right_start_idx;
214
137
      ++k;
215
137
      if (k > pivot_idx)
216
0
        return 0;
217
137
    }
218
144
  }
219
24
  assert(k == pivot_idx);
220
24
  return 1;
221
24
}
222
223
static void _ccv_nnc_loop_order_by(ccv_nnc_micro_loop_block_t* const block, int* const loop_idx, ccv_nnc_micro_loop_t* const loops)
224
48
{
225
48
  int i;
226
336
  for (i = 0; i < block->loop_count; 
i++288
)
227
288
    if (loop_idx[i] >= 0)
228
288
      loops[loop_idx[i]] = block->loops[i];
229
0
    else
230
0
      loops[i] = block->loops[i];
231
336
  for (i = 0; i < block->loop_count; 
i++288
)
232
288
  {
233
    // Essentially, we don't need to move statements, loop-carried variables, just the loop id and the start / end index.
234
288
    block->loops[i].id = loops[i].id;
235
288
    block->loops[i].start_index = loops[i].start_index;
236
288
    block->loops[i].end_index = loops[i].end_index;
237
288
  }
238
48
}
239
240
static void _ccv_nnc_expression_rename_carrieds(ccv_nnc_micro_loop_expression_t* const expression, const int start_idx)
241
11
{
242
11
  switch (expression->type)
243
11
  {
244
0
    case CCV_NNC_MICRO_LOOP_EXPR_TYPE_ID:
245
0
      assert(expression->id.type == CCV_NNC_MICRO_LOOP_CARRIED_ID);
246
0
      expression->id.id += start_idx;
247
0
      break;
248
0
    case CCV_NNC_MICRO_LOOP_EXPR_TYPE_TERNAY:
249
0
      _ccv_nnc_expression_rename_carrieds(expression->ternary.pivot, start_idx);
250
0
      _ccv_nnc_expression_rename_carrieds(expression->ternary.left, start_idx);
251
0
      _ccv_nnc_expression_rename_carrieds(expression->ternary.right, start_idx);
252
0
      break;
253
0
    case CCV_NNC_MICRO_LOOP_EXPR_TYPE_BINARY:
254
0
      _ccv_nnc_expression_rename_carrieds(expression->binary.left, start_idx);
255
0
      _ccv_nnc_expression_rename_carrieds(expression->binary.right, start_idx);
256
0
      break;
257
0
    case CCV_NNC_MICRO_LOOP_EXPR_TYPE_UNARY:
258
0
      _ccv_nnc_expression_rename_carrieds(expression->unary.x, start_idx);
259
0
      break;
260
    // We don't need to care about other expressions because loop-carried variable cannot participate these operations.
261
11
    case CCV_NNC_MICRO_LOOP_EXPR_TYPE_VAR:
262
11
      break;
263
11
  }
264
11
}
265
266
static void _ccv_nnc_loop_rename_carrieds(ccv_nnc_micro_loop_block_t* const block, const int start_idx)
267
11
{
268
11
  int i, j;
269
11
  const int loop_count = block->loop_count;
270
11
  ccv_nnc_micro_loop_t* const loops = block->loops;
271
76
  for (i = 0; i < loop_count; 
i++65
)
272
65
  {
273
65
    for (j = 0; j < loops[i].carried_count; 
j++0
)
274
0
      loops[i].carrieds[j].id.id += start_idx;
275
76
    for (j = 0; j < loops[i].statement_count; 
j++11
)
276
11
      switch (loops[i].statements[j].type)
277
11
      {
278
6
        case CCV_NNC_MICRO_LOOP_STATEMENT_TYPE_ASSIGNMENT:
279
6
          _ccv_nnc_expression_rename_carrieds(&loops[i].statements[j].compound_assignment.rvalue, start_idx);
280
6
          break;
281
5
        case CCV_NNC_MICRO_LOOP_STATEMENT_TYPE_COMPOUND_ASSIGNMENT:
282
5
          if (loops[i].statements[j].compound_assignment.lvalue.type == CCV_NNC_MICRO_LOOP_EXPR_TYPE_ID)
283
0
          {
284
0
            assert(loops[i].statements[j].compound_assignment.lvalue.id.type == CCV_NNC_MICRO_LOOP_CARRIED_ID);
285
0
            loops[i].statements[j].compound_assignment.lvalue.id.id += start_idx;
286
0
          }
287
5
          _ccv_nnc_expression_rename_carrieds(&loops[i].statements[j].compound_assignment.rvalue, start_idx);
288
5
          break;
289
11
      }
290
65
  }
291
11
}
292
293
static int _ccv_nnc_only_var_in_expression(const int id, const ccv_nnc_micro_loop_variable_t var, const ccv_nnc_micro_loop_expression_t* const expression, const int* const groups, khash_t(ccv_nnc_axis_id_group)* const axis_id_groups)
294
335
{
295
335
  switch (expression->type)
296
335
  {
297
224
    case CCV_NNC_MICRO_LOOP_EXPR_TYPE_VAR:
298
224
      if (expression->variable.id.type == CCV_NNC_MICRO_TENSOR_ID && expression->variable.id.id == id)
299
27
      {
300
27
        if (var.index_count != expression->variable.index_count)
301
0
          return 2;
302
27
        int i;
303
180
        for (i = 0; i < var.index_count; 
i++153
)
304
153
          if (!_ccv_nnc_same_index_term(var.index[i], expression->variable.index[i], groups, axis_id_groups))
305
0
            return 2;
306
27
        return 1;
307
27
      } else
308
197
        return 0;
309
0
    case CCV_NNC_MICRO_LOOP_EXPR_TYPE_TERNAY: {
310
0
      const int pivot = _ccv_nnc_only_var_in_expression(id, var, expression->ternary.pivot, groups, axis_id_groups);
311
0
      const int left = _ccv_nnc_only_var_in_expression(id, var, expression->ternary.left, groups, axis_id_groups);
312
0
      const int right = _ccv_nnc_only_var_in_expression(id, var, expression->ternary.right, groups, axis_id_groups);
313
0
      if (pivot == 2 || left == 2 || right == 2)
314
0
        return 2;
315
0
      return (pivot || left || right);
316
0
    }
317
60
    case CCV_NNC_MICRO_LOOP_EXPR_TYPE_BINARY: {
318
60
      const int left = _ccv_nnc_only_var_in_expression(id, var, expression->binary.left, groups, axis_id_groups);
319
60
      const int right = _ccv_nnc_only_var_in_expression(id, var, expression->binary.right, groups, axis_id_groups);
320
60
      if (left == 2 || right == 2)
321
0
        return 2;
322
60
      return (left || 
right51
);
323
60
    }
324
0
    case CCV_NNC_MICRO_LOOP_EXPR_TYPE_UNARY:
325
0
      return _ccv_nnc_only_var_in_expression(id, var, expression->unary.x, groups, axis_id_groups);
326
9
    case CCV_NNC_MICRO_LOOP_EXPR_TYPE_ID:
327
9
      assert(expression->id.type == CCV_NNC_MICRO_LOOP_CARRIED_ID);
328
9
      return 0;
329
335
  }
330
42
  return 0;
331
335
}
332
333
static int _ccv_nnc_only_var_in_rvalue(const int id, const ccv_nnc_micro_loop_variable_t var, const ccv_nnc_micro_loop_statement_t statement, const int* const groups, khash_t(ccv_nnc_axis_id_group)* const axis_id_groups)
334
215
{
335
215
  switch (statement.type)
336
215
  {
337
164
    case CCV_NNC_MICRO_LOOP_STATEMENT_TYPE_ASSIGNMENT:
338
164
      return _ccv_nnc_only_var_in_expression(id, var, &statement.assignment.rvalue, groups, axis_id_groups);
339
51
    case CCV_NNC_MICRO_LOOP_STATEMENT_TYPE_COMPOUND_ASSIGNMENT:
340
51
      return _ccv_nnc_only_var_in_expression(id, var, &statement.compound_assignment.rvalue, groups, axis_id_groups);
341
215
  }
342
0
  return 0;
343
215
}
344
345
static ccv_nnc_micro_loop_expression_t _ccv_nnc_expression_deep_copy(const ccv_nnc_micro_loop_expression_t* const expression)
346
3
{
347
3
  switch (expression->type)
348
3
  {
349
0
    case CCV_NNC_MICRO_LOOP_EXPR_TYPE_TERNAY: {
350
0
      ccv_nnc_micro_loop_expression_t copy = *expression;
351
0
      copy.ternary.pivot = (ccv_nnc_micro_loop_expression_t*)ccmalloc(sizeof(ccv_nnc_micro_loop_expression_t));
352
0
      *copy.ternary.pivot = _ccv_nnc_expression_deep_copy(expression->ternary.pivot);
353
0
      copy.ternary.left = (ccv_nnc_micro_loop_expression_t*)ccmalloc(sizeof(ccv_nnc_micro_loop_expression_t));
354
0
      *copy.ternary.left = _ccv_nnc_expression_deep_copy(expression->ternary.left);
355
0
      copy.ternary.right = (ccv_nnc_micro_loop_expression_t*)ccmalloc(sizeof(ccv_nnc_micro_loop_expression_t));
356
0
      *copy.ternary.right = _ccv_nnc_expression_deep_copy(expression->ternary.right);
357
0
      return copy;
358
0
    }
359
0
    case CCV_NNC_MICRO_LOOP_EXPR_TYPE_BINARY: {
360
0
      ccv_nnc_micro_loop_expression_t copy = *expression;
361
0
      copy.binary.left = (ccv_nnc_micro_loop_expression_t*)ccmalloc(sizeof(ccv_nnc_micro_loop_expression_t));
362
0
      *copy.binary.left = _ccv_nnc_expression_deep_copy(expression->binary.left);
363
0
      copy.binary.right = (ccv_nnc_micro_loop_expression_t*)ccmalloc(sizeof(ccv_nnc_micro_loop_expression_t));
364
0
      *copy.binary.right = _ccv_nnc_expression_deep_copy(expression->binary.right);
365
0
      return copy;
366
0
    }
367
0
    case CCV_NNC_MICRO_LOOP_EXPR_TYPE_UNARY: {
368
0
      ccv_nnc_micro_loop_expression_t copy = *expression;
369
0
      copy.unary.x = (ccv_nnc_micro_loop_expression_t*)ccmalloc(sizeof(ccv_nnc_micro_loop_expression_t));
370
0
      *copy.unary.x = _ccv_nnc_expression_deep_copy(expression->unary.x);
371
0
      return copy;
372
0
    }
373
3
    case CCV_NNC_MICRO_LOOP_EXPR_TYPE_VAR: {
374
3
      ccv_nnc_micro_loop_expression_t copy = *expression;
375
3
      int i;
376
20
      for (i = 0; i < copy.variable.index_count; 
i++17
)
377
17
        copy.variable.index[i] = ccv_nnc_micro_loop_index_deep_copy(&copy.variable.index[i]);
378
3
      return copy;
379
0
    }
380
0
    case CCV_NNC_MICRO_LOOP_EXPR_TYPE_ID:
381
0
      return *expression;
382
3
  }
383
0
  return *expression;
384
3
}
385
386
static void _ccv_nnc_replacing_id_in_expression(ccv_nnc_micro_loop_expression_t* const expression, const int id, ccv_nnc_micro_loop_expression_t rvalue, int* const count)
387
116
{
388
116
  switch (expression->type)
389
116
  {
390
90
    case CCV_NNC_MICRO_LOOP_EXPR_TYPE_VAR:
391
90
      if (expression->variable.id.type == CCV_NNC_MICRO_TENSOR_ID && expression->variable.id.id == id)
392
24
      {
393
24
        ccv_nnc_micro_loop_variable_free(&expression->variable);
394
24
        if (*count == 0) // First time, just assign to expression.
395
21
          *expression = rvalue;
396
3
        else // Otherwise, need to make deep copy of it.
397
3
          *expression = _ccv_nnc_expression_deep_copy(&rvalue);
398
24
        ++(*count);
399
24
      }
400
90
      break;
401
0
    case CCV_NNC_MICRO_LOOP_EXPR_TYPE_TERNAY:
402
0
      _ccv_nnc_replacing_id_in_expression(expression->ternary.pivot, id, rvalue, count);
403
0
      _ccv_nnc_replacing_id_in_expression(expression->ternary.left, id, rvalue, count);
404
0
      _ccv_nnc_replacing_id_in_expression(expression->ternary.right, id, rvalue, count);
405
0
      break;
406
26
    case CCV_NNC_MICRO_LOOP_EXPR_TYPE_BINARY:
407
26
      _ccv_nnc_replacing_id_in_expression(expression->binary.left, id, rvalue, count);
408
26
      _ccv_nnc_replacing_id_in_expression(expression->binary.right, id, rvalue, count);
409
26
      break;
410
0
    case CCV_NNC_MICRO_LOOP_EXPR_TYPE_UNARY:
411
0
      _ccv_nnc_replacing_id_in_expression(expression->unary.x, id, rvalue, count);
412
0
      break;
413
0
    case CCV_NNC_MICRO_LOOP_EXPR_TYPE_ID:
414
0
      assert(expression->id.type == CCV_NNC_MICRO_LOOP_CARRIED_ID);
415
0
      break;
416
116
  }
417
116
}
418
419
static void _ccv_nnc_replacing_id_in_rvalue(ccv_nnc_micro_loop_statement_t* const statement, const int id, ccv_nnc_micro_loop_expression_t rvalue, int* const count)
420
64
{
421
64
  switch (statement->type)
422
64
  {
423
33
    case CCV_NNC_MICRO_LOOP_STATEMENT_TYPE_ASSIGNMENT:
424
33
      _ccv_nnc_replacing_id_in_expression(&statement->assignment.rvalue, id, rvalue, count);
425
33
      break;
426
31
    case CCV_NNC_MICRO_LOOP_STATEMENT_TYPE_COMPOUND_ASSIGNMENT:
427
      // Not going to be in lvalue (which is the carried variable only).
428
31
      _ccv_nnc_replacing_id_in_expression(&statement->compound_assignment.rvalue, id, rvalue, count);
429
31
      break;
430
64
  }
431
64
}
432
433
typedef struct {
434
  int flag;
435
  int merge_to;
436
  ccv_array_t* writes;
437
  ccv_array_t* reads;
438
} ccv_nnc_micro_loop_block_dependency_t;
439
440
typedef struct {
441
  int flag;
442
  ccv_array_t* writes;
443
  ccv_array_t* reads;
444
} ccv_nnc_micro_tensor_dependency_t;
445
446
static void _ccv_nnc_micro_block_dependencies_from_rvalue(const ccv_nnc_micro_loop_expression_t* const rvalue, const int i, ccv_nnc_micro_loop_block_dependency_t* const block_dependencies, ccv_nnc_micro_tensor_dependency_t* const tensor_dependencies)
447
75
{
448
75
  switch (rvalue->type)
449
75
  {
450
51
    case CCV_NNC_MICRO_LOOP_EXPR_TYPE_VAR:
451
51
      if (rvalue->variable.id.type == CCV_NNC_MICRO_TENSOR_ID)
452
51
      {
453
51
        if (!block_dependencies[i].reads)
454
33
          block_dependencies[i].reads = ccv_array_new(sizeof(int), 1, 0);
455
51
        const int id = rvalue->variable.id.id;
456
51
        ccv_array_add_unique_int(block_dependencies[i].reads, id);
457
51
        if (!tensor_dependencies[id].reads)
458
42
          tensor_dependencies[id].reads = ccv_array_new(sizeof(int), 1, 0);
459
51
        ccv_array_add_unique_int(tensor_dependencies[id].reads, i);
460
51
      }
461
51
      break;
462
0
    case CCV_NNC_MICRO_LOOP_EXPR_TYPE_TERNAY:
463
0
      _ccv_nnc_micro_block_dependencies_from_rvalue(rvalue->ternary.pivot, i, block_dependencies, tensor_dependencies);
464
0
      _ccv_nnc_micro_block_dependencies_from_rvalue(rvalue->ternary.left, i, block_dependencies, tensor_dependencies);
465
0
      _ccv_nnc_micro_block_dependencies_from_rvalue(rvalue->ternary.right, i, block_dependencies, tensor_dependencies);
466
0
      break;
467
12
    case CCV_NNC_MICRO_LOOP_EXPR_TYPE_BINARY:
468
12
      _ccv_nnc_micro_block_dependencies_from_rvalue(rvalue->binary.left, i, block_dependencies, tensor_dependencies);
469
12
      _ccv_nnc_micro_block_dependencies_from_rvalue(rvalue->binary.right, i, block_dependencies, tensor_dependencies);
470
12
      break;
471
0
    case CCV_NNC_MICRO_LOOP_EXPR_TYPE_UNARY:
472
0
      _ccv_nnc_micro_block_dependencies_from_rvalue(rvalue->unary.x, i, block_dependencies, tensor_dependencies);
473
0
      break;
474
6
    case CCV_NNC_MICRO_LOOP_EXPR_TYPE_ID:
475
6
      assert(rvalue->id.type == CCV_NNC_MICRO_LOOP_CARRIED_ID);
476
6
      break;
477
75
  }
478
75
}
479
480
static void _ccv_nnc_micro_block_dependencies(const ccv_nnc_micro_loop_block_t* const blocks, const int block_size, const int var_count, ccv_nnc_micro_loop_block_dependency_t** const block_dependencies_ref, ccv_nnc_micro_tensor_dependency_t** const tensor_dependencies_ref)
481
6
{
482
6
  ccv_nnc_micro_loop_block_dependency_t* const block_dependencies = (ccv_nnc_micro_loop_block_dependency_t*)cccalloc(block_size, sizeof(ccv_nnc_micro_loop_block_dependency_t));
483
6
  ccv_nnc_micro_tensor_dependency_t* const tensor_dependencies = (ccv_nnc_micro_tensor_dependency_t*)cccalloc(var_count, sizeof(ccv_nnc_micro_tensor_dependency_t));
484
6
  int i, j, k;
485
51
  for (i = 0; i < block_size; 
i++45
)
486
45
  {
487
45
    block_dependencies[i].merge_to = i;
488
45
    const ccv_nnc_micro_loop_t* const loops = blocks[i].loops;
489
45
    const int loop_count = blocks[i].loop_count;
490
286
    for (j = 0; j < loop_count; 
j++241
)
491
241
    {
492
241
      const ccv_nnc_micro_loop_statement_t* const statements = loops[j].statements;
493
241
      const int statement_count = loops[j].statement_count;
494
292
      for (k = 0; k < statement_count; 
k++51
)
495
51
        switch (statements[k].type)
496
51
        {
497
39
          case CCV_NNC_MICRO_LOOP_STATEMENT_TYPE_ASSIGNMENT: {
498
39
            assert(statements[k].assignment.lvalue.id.type == CCV_NNC_MICRO_TENSOR_ID);
499
39
            const int id = statements[k].assignment.lvalue.id.id;
500
39
            if (!block_dependencies[i].writes)
501
39
              block_dependencies[i].writes = ccv_array_new(sizeof(int), 1, 0);
502
39
            ccv_array_add_unique_int(block_dependencies[i].writes, id);
503
39
            if (!tensor_dependencies[id].writes)
504
39
              tensor_dependencies[id].writes = ccv_array_new(sizeof(int), 1, 0);
505
39
            ccv_array_add_unique_int(tensor_dependencies[id].writes, i);
506
39
            _ccv_nnc_micro_block_dependencies_from_rvalue(&statements[k].assignment.rvalue, i, block_dependencies, tensor_dependencies);
507
39
            break;
508
39
          }
509
12
          case CCV_NNC_MICRO_LOOP_STATEMENT_TYPE_COMPOUND_ASSIGNMENT: {
510
12
            if (statements[k].compound_assignment.lvalue.type == CCV_NNC_MICRO_LOOP_EXPR_TYPE_VAR)
511
6
            {
512
6
              assert(statements[k].compound_assignment.lvalue.id.type == CCV_NNC_MICRO_TENSOR_ID);
513
6
              const int id = statements[k].compound_assignment.lvalue.id.id;
514
6
              if (!block_dependencies[i].writes)
515
6
                block_dependencies[i].writes = ccv_array_new(sizeof(int), 1, 0);
516
6
              ccv_array_add_unique_int(block_dependencies[i].writes, id);
517
6
              if (!tensor_dependencies[id].writes)
518
0
                tensor_dependencies[id].writes = ccv_array_new(sizeof(int), 1, 0);
519
6
              ccv_array_add_unique_int(tensor_dependencies[id].writes, i);
520
6
              if (!block_dependencies[i].reads)
521
6
                block_dependencies[i].reads = ccv_array_new(sizeof(int), 1, 0);
522
6
              ccv_array_add_unique_int(block_dependencies[i].reads, id);
523
6
              if (!tensor_dependencies[id].reads)
524
6
                tensor_dependencies[id].reads = ccv_array_new(sizeof(int), 1, 0);
525
6
              ccv_array_add_unique_int(tensor_dependencies[id].reads, i);
526
6
            }
527
12
            _ccv_nnc_micro_block_dependencies_from_rvalue(&statements[k].compound_assignment.rvalue, i, block_dependencies, tensor_dependencies);
528
12
            break;
529
12
          }
530
51
        }
531
241
    }
532
45
  }
533
6
  *block_dependencies_ref = block_dependencies;
534
6
  *tensor_dependencies_ref = tensor_dependencies;
535
6
}
536
537
static void _ccv_nnc_micro_dependencies_free(ccv_nnc_micro_loop_block_dependency_t* const block_dependencies, const int block_size, ccv_nnc_micro_tensor_dependency_t* const tensor_dependencies, const int var_count)
538
6
{
539
6
  int i;
540
51
  for (i = 0; i < block_size; 
i++45
)
541
45
  {
542
45
    if (block_dependencies[i].writes)
543
45
      ccv_array_free(block_dependencies[i].writes);
544
45
    if (block_dependencies[i].reads)
545
39
      ccv_array_free(block_dependencies[i].reads);
546
45
  }
547
6
  ccfree(block_dependencies);
548
60
  for (i = 0; i < var_count; 
i++54
)
549
54
  {
550
54
    if (tensor_dependencies[i].writes)
551
39
      ccv_array_free(tensor_dependencies[i].writes);
552
54
    if (tensor_dependencies[i].reads)
553
48
      ccv_array_free(tensor_dependencies[i].reads);
554
54
  }
555
6
  ccfree(tensor_dependencies);
556
6
}
557
558
static int _ccv_nnc_tensor_reads_in_y_from_writes_after_x(const ccv_nnc_micro_loop_block_dependency_t* const block_dependencies, const ccv_nnc_micro_tensor_dependency_t* const tensor_dependencies, const int x, const int y)
559
30
{
560
30
  int i, j;
561
30
  int flag = 0;
562
68
  for (i = 0; !flag && 
i < block_dependencies[y].reads->rnum53
;
i++38
)
563
38
  {
564
38
    const int read_idx = *(int*)ccv_array_get(block_dependencies[y].reads, i);
565
38
    if (tensor_dependencies[read_idx].writes)
566
68
      
for (j = 0; 34
!flag &&
j < tensor_dependencies[read_idx].writes->rnum53
;
j++34
)
567
34
      {
568
34
        int block_idx = *(int*)ccv_array_get(tensor_dependencies[read_idx].writes, j);
569
47
        while (block_idx != block_dependencies[block_idx].merge_to)
570
13
          block_idx = block_dependencies[block_idx].merge_to;
571
34
        if (!block_dependencies[block_idx].flag) // Not in use, continue.
572
0
          continue;
573
34
        assert(block_idx <= y);
574
        // If the block_idx is between i and j (and not neither). We cannot merge.
575
34
        if (block_idx > x && 
block_idx != y15
)
576
15
          flag = block_idx;
577
34
      }
578
38
  }
579
30
  return flag;
580
30
}
581
582
static int _ccv_nnc_tensor_writes_in_x_reads_before_y(const ccv_nnc_micro_loop_block_dependency_t* const block_dependencies, const ccv_nnc_micro_tensor_dependency_t* const tensor_dependencies, const int x, const int y)
583
15
{
584
15
  int i, j;
585
15
  int flag = 0;
586
55
  for (i = 0; !flag && 
i < block_dependencies[x].writes->rnum48
;
i++40
)
587
40
  {
588
40
    const int write_idx = *(int*)ccv_array_get(block_dependencies[x].writes, i);
589
40
    if (tensor_dependencies[write_idx].reads)
590
100
      
for (j = 0; 40
!flag &&
j < tensor_dependencies[write_idx].reads->rnum93
;
j++60
)
591
60
      {
592
60
        int block_idx = *(int*)ccv_array_get(tensor_dependencies[write_idx].reads, j);
593
93
        while (block_idx != block_dependencies[block_idx].merge_to)
594
33
          block_idx = block_dependencies[block_idx].merge_to;
595
60
        if (!block_dependencies[block_idx].flag) // Not in use, continue.
596
14
          continue;
597
46
        assert(block_idx >= x);
598
        // If the block_idx is between i and j (and not neither). We cannot merge.
599
46
        if (block_idx < y && 
block_idx != x35
)
600
7
          flag = block_idx;
601
46
      }
602
40
  }
603
15
  return flag;
604
15
}
605
606
static void _ccv_nnc_tensor_remove_dead_store(const ccv_nnc_micro_tensor_dependency_t* const tensor_dependency, const int tensor_idx, ccv_array_t* const blocks)
607
6
{
608
6
  int i, j, k, l;;
609
6
  if (tensor_dependency->writes)
610
12
    
for (i = 0; 6
i < tensor_dependency->writes->rnum;
i++6
)
611
6
    {
612
6
      const int write_idx = *(int*)ccv_array_get(tensor_dependency->writes, i);
613
6
      ccv_nnc_micro_loop_block_t* const block = (ccv_nnc_micro_loop_block_t*)ccv_array_get(blocks, write_idx);
614
6
      int flag = 0;
615
6
      for (j = 0; j < block->loop_count; 
j++0
)
616
0
      {
617
0
        ccv_nnc_micro_loop_statement_t* const statements = block->loops[j].statements;
618
0
        for (k = 0, l = 0; k < block->loops[j].statement_count; k++)
619
0
        {
620
          // It cannot be compound assignment, in this case, this tensor will be in read, and
621
          // it will be in active use (anything "read" in an active block will be marked as in use).
622
0
          assert(!(statements[k].type == CCV_NNC_MICRO_LOOP_STATEMENT_TYPE_COMPOUND_ASSIGNMENT &&
623
0
            statements[k].compound_assignment.lvalue.id.type == CCV_NNC_MICRO_TENSOR_ID &&
624
0
            statements[k].compound_assignment.lvalue.id.id == tensor_idx));
625
0
          if (statements[k].type == CCV_NNC_MICRO_LOOP_STATEMENT_TYPE_ASSIGNMENT &&
626
0
            statements[k].assignment.lvalue.id.type == CCV_NNC_MICRO_TENSOR_ID &&
627
0
            statements[k].assignment.lvalue.id.id == tensor_idx)
628
0
          {
629
            // This is a dead store, prepare to remove.
630
0
            ccv_nnc_micro_loop_statement_free(&statements[k]);
631
0
          } else {
632
0
            statements[l] = statements[k];
633
0
            ++l;
634
0
          }
635
0
        }
636
0
        if (l < block->loops[j].statement_count)
637
0
        {
638
0
          if (l > 0)
639
0
            block->loops[j].statements = (ccv_nnc_micro_loop_statement_t*)ccrealloc(block->loops[j].statements, sizeof(ccv_nnc_micro_loop_statement_t) * l);
640
0
          else {
641
0
            ccfree(block->loops[j].statements);
642
0
            block->loops[j].statements = 0;
643
0
          }
644
0
          block->loops[j].statement_count = 0;
645
0
        }
646
0
        if (block->loops[j].statement_count > 0)
647
0
          flag = 1;
648
0
      }
649
6
      if (!flag) // No statement for this block, remove this whole block.
650
6
      {
651
6
        ccv_nnc_micro_loops_free(block->loops, block->loop_count);
652
6
        ccfree(block->loops);
653
6
        block->loops = 0;
654
6
        block->loop_count = 0;
655
6
      }
656
6
    }
657
6
}
658
659
static void _ccv_nnc_loop_merging(ccv_nnc_micro_loop_block_dependency_t* const block_dependencies, const ccv_nnc_micro_tensor_dependency_t* const tensor_dependencies, ccv_array_t* const blocks, const int max_loop_count, const int* const groups, khash_t(ccv_nnc_axis_id_group)* const axis_id_groups)
660
6
{
661
6
  int i, j;
662
6
  int left_loop_idx[max_loop_count];
663
6
  int right_loop_idx[max_loop_count];
664
6
  ccv_nnc_micro_loop_t loops[max_loop_count];
665
  // Merge loops from blocks.
666
45
  for (i = 0; i < blocks->rnum - 1; 
i++39
)
667
39
  {
668
39
    ccv_nnc_micro_loop_block_t* const left_block = (ccv_nnc_micro_loop_block_t*)ccv_array_get(blocks, i);
669
39
    if (left_block->loop_count == 0)
670
22
      continue;
671
74
    
for (j = i + 1; 17
j < blocks->rnum;
j++57
)
672
62
    {
673
      // We always merge from right block to left block. Thus, the right block will always be
674
      // in the original form.
675
62
      ccv_nnc_micro_loop_block_t* const right_block = (ccv_nnc_micro_loop_block_t*)ccv_array_get(blocks, j);
676
62
      if (right_block->loop_count == 0)
677
8
        continue;
678
54
      int merge_to_right = 0;
679
      // First check whether between left and right, there are any blocks that the right block
680
      // depends on. If that is the case, we cannot merge the right block into the left block.
681
54
      if (j > i + 1 && 
block_dependencies[j].reads38
)
682
30
      {
683
30
        const int block_idx = _ccv_nnc_tensor_reads_in_y_from_writes_after_x(block_dependencies, tensor_dependencies, i, j);
684
        // Cannot merge because we have dependencies in between. Merging will violate that
685
        // dependency relationship.
686
30
        if (block_idx)
687
15
        {
688
          // Now check to see if left can be merged into right? If so, we are lucky.
689
15
          if (_ccv_nnc_tensor_writes_in_x_reads_before_y(block_dependencies, tensor_dependencies, i, j))
690
7
            continue;
691
8
          merge_to_right = 1;
692
8
        }
693
30
      }
694
      // This method not only compares whether they have the same loop or not, but also gives indexes that
695
      // to match the loop start / end index, where they should move to. For example, if:
696
      // left_loop_idx[2] = 3
697
      // right_loop_idx[0] = 3
698
      // That means right now, loop at index 2 on the left is the same as loop at index 0 on the right.
699
      // And to match exactly, they both need to move to index 3.
700
47
      if (_ccv_nnc_same_loop(left_block, right_block, groups, axis_id_groups, left_loop_idx, right_loop_idx))
701
24
      {
702
        // Make sure if we have extra loop, they are on the left.
703
24
        if (right_block->loop_count > left_block->loop_count)
704
0
        {
705
0
          ccv_nnc_micro_loop_block_t t;
706
0
          CCV_SWAP(*left_block, *right_block, t);
707
0
        }
708
24
        assert(left_block->loop_count == right_block->loop_count || left_block->loop_count == right_block->loop_count + 1);
709
24
        _ccv_nnc_loop_order_by(left_block, left_loop_idx, loops);
710
24
        _ccv_nnc_loop_order_by(right_block, right_loop_idx, loops);
711
24
        const int left_start_idx = left_block->loop_count - right_block->loop_count;
712
24
        if (left_block->carried_count > 0)
713
11
          _ccv_nnc_loop_rename_carrieds(right_block, left_block->carried_count);
714
24
        left_block->carried_count += right_block->carried_count;
715
24
        int k;
716
168
        for (k = 0; k < right_block->loop_count; 
k++144
) // Merge loops.
717
144
        {
718
144
          const int left_idx = left_start_idx + k;
719
144
          if (right_block->loops[k].carried_count > 0)
720
3
          {
721
3
            if (left_block->loops[left_idx].carried_count > 0)
722
0
            {
723
0
              left_block->loops[left_idx].carrieds = (ccv_nnc_micro_loop_carried_t*)ccrealloc(left_block->loops[left_idx].carrieds, sizeof(ccv_nnc_micro_loop_carried_t) * (left_block->loops[left_idx].carried_count + right_block->loops[k].carried_count));
724
0
              memcpy(left_block->loops[left_idx].carrieds + left_block->loops[left_idx].carried_count, right_block->loops[k].carrieds, sizeof(ccv_nnc_micro_loop_carried_t) * right_block->loops[k].carried_count);
725
0
              ccfree(right_block->loops[k].carrieds);
726
0
            } else
727
3
              left_block->loops[left_idx].carrieds = right_block->loops[k].carrieds;
728
3
            left_block->loops[left_idx].carried_count += right_block->loops[k].carried_count;
729
3
            right_block->loops[k].carrieds = 0;
730
3
            right_block->loops[k].carried_count = 0;
731
3
          }
732
144
          if (right_block->loops[k].statement_count > 0)
733
27
          {
734
27
            if (left_block->loops[left_idx].statement_count > 0)
735
24
            {
736
24
              left_block->loops[left_idx].statements = (ccv_nnc_micro_loop_statement_t*)ccrealloc(left_block->loops[left_idx].statements, sizeof(ccv_nnc_micro_loop_statement_t) * (left_block->loops[left_idx].statement_count + right_block->loops[k].statement_count));
737
24
              memcpy(left_block->loops[left_idx].statements + left_block->loops[left_idx].statement_count, right_block->loops[k].statements, sizeof(ccv_nnc_micro_loop_statement_t) * right_block->loops[k].statement_count);
738
24
              ccfree(right_block->loops[k].statements);
739
24
            } else
740
3
              left_block->loops[left_idx].statements = right_block->loops[k].statements;
741
27
            left_block->loops[left_idx].statement_count += right_block->loops[k].statement_count;
742
27
            right_block->loops[k].statements = 0;
743
27
            right_block->loops[k].statement_count = 0;
744
27
          }
745
144
        }
746
        // Once merged, free the loop.
747
24
        ccfree(right_block->loops);
748
24
        right_block->loops = 0;
749
24
        right_block->loop_count = 0;
750
24
        int x = i, y = j;
751
24
        if (merge_to_right) // If this is merge to right.
752
5
        {
753
5
          ccv_nnc_micro_loop_block_t t;
754
5
          CCV_SWAP(*left_block, *right_block, t);
755
5
          x = j, y = i;
756
5
        }
757
        // Merge all reads and writes tensors into block dependency.
758
24
        if (block_dependencies[y].writes && block_dependencies[y].writes->rnum)
759
24
        {
760
24
          if (!block_dependencies[x].writes)
761
0
            block_dependencies[x].writes = ccv_array_new(sizeof(int), 1, 0);
762
69
          for (k = 0; k < block_dependencies[y].writes->rnum; 
k++45
)
763
45
            ccv_array_push(block_dependencies[x].writes, ccv_array_get(block_dependencies[y].writes, k));
764
24
        }
765
24
        if (block_dependencies[y].reads && block_dependencies[y].reads->rnum)
766
24
        {
767
24
          if (!block_dependencies[x].reads)
768
0
            block_dependencies[x].reads = ccv_array_new(sizeof(int), 1, 0);
769
90
          for (k = 0; k < block_dependencies[y].reads->rnum; 
k++66
)
770
66
            ccv_array_push(block_dependencies[x].reads, ccv_array_get(block_dependencies[y].reads, k));
771
24
        }
772
        // Merged, mark the proper merging dependency.
773
24
        block_dependencies[y].merge_to = x;
774
24
        if (merge_to_right) // If this is merge to right, now left is empty, break.
775
5
          break;
776
24
      }
777
47
    }
778
17
  }
779
6
}
780
781
static void _ccv_nnc_var_subst(ccv_nnc_micro_tensor_t* const vars, const int var_count, const ccv_nnc_micro_io_t* const inputs, const int input_size, const ccv_nnc_micro_io_t* const outputs, const int output_size, ccv_array_t* const blocks, const int* const groups, khash_t(ccv_nnc_axis_id_group)* const axis_id_groups)
782
6
{
783
6
  int i, j;
784
  // These are simple programs, so we are going to loop over all blocks to see whether a non-output-input
785
  // var only write / read in one loop. If that is the case, we are going to remove that var.
786
  // We have to do this replacement from bottom to top though.
787
60
  for (i = 0; i < var_count; 
i++54
)
788
54
  {
789
54
    int flag = 0;
790
186
    for (j = 0; !flag && 
j < input_size171
;
j++132
)
791
132
      flag = (inputs[j]->id == i);
792
117
    for (j = 0; !flag && 
j < output_size93
;
j++63
)
793
63
      flag = (outputs[j]->id == i);
794
54
    if (flag) // This is in outputs or inputs.
795
24
      continue;
796
30
    int count_var = 0;
797
30
    ccv_nnc_micro_loop_variable_t lvalue;
798
30
    ccv_nnc_micro_loop_expression_t rvalue;
799
30
    int block_idx, loop_idx, statement_idx;
800
119
    for (j = 0; j < blocks->rnum; 
j++89
)
801
89
    {
802
89
      const ccv_nnc_micro_loop_block_t* const block = (ccv_nnc_micro_loop_block_t*)ccv_array_get(blocks, j);
803
89
      int k, l;
804
89
      const int loop_count = block->loop_count;
805
89
      const ccv_nnc_micro_loop_t* const loops = block->loops;
806
89
      int var_per_block = 0;
807
450
      for (k = 0; k < loop_count; 
k++361
)
808
361
      {
809
361
        int flag = 0;
810
361
        const int statement_count = loops[k].statement_count;
811
361
        ccv_nnc_micro_loop_statement_t* const statements = loops[k].statements;
812
552
        for (l = 0; l < statement_count; 
l++191
)
813
191
          if (statements[l].type == CCV_NNC_MICRO_LOOP_STATEMENT_TYPE_ASSIGNMENT &&
814
191
            
statements[l].assignment.lvalue.id.type == CCV_NNC_MICRO_TENSOR_ID140
&&
815
191
            
statements[l].assignment.lvalue.id.id == i140
)
816
24
          {
817
24
            lvalue = statements[l].assignment.lvalue;
818
24
            if (_ccv_nnc_only_var_in_rvalue(i, lvalue, statements[l], groups, axis_id_groups))
819
0
              flag = 2;
820
24
            else {
821
              // If the variable not showing up on the right-side, we can continue.
822
24
              rvalue = statements[l].assignment.rvalue;
823
24
              block_idx = j;
824
24
              loop_idx = k;
825
24
              statement_idx = l;
826
24
              ++flag;
827
24
            }
828
167
          } else if (statements[l].type == CCV_NNC_MICRO_LOOP_STATEMENT_TYPE_COMPOUND_ASSIGNMENT &&
829
167
            
statements[l].compound_assignment.lvalue.id.type == CCV_NNC_MICRO_TENSOR_ID51
&&
830
167
            
statements[l].compound_assignment.lvalue.id.id == i42
) {
831
            // This is compound assignment, automatically increase by 2.
832
0
            flag += 2;
833
0
          }
834
361
        if (flag > 1) // We have more than 1 assignment for this id, it is not good. We cannot remove it.
835
0
        {
836
0
          var_per_block += flag;
837
0
          continue;
838
0
        }
839
552
        
for (l = 0; 361
l < statement_count;
l++191
)
840
191
          flag = ccv_max(flag, _ccv_nnc_only_var_in_rvalue(i, lvalue, statements[l], groups, axis_id_groups));
841
        // If flag == 2, meaning it found a var with a different index. This is a bad news.
842
361
        var_per_block += flag;
843
361
      }
844
89
      count_var += var_per_block;
845
89
    }
846
    // If this is used more than one place (write multiple times, have different index, or used in different blocks),
847
    // I cannot get rid of it.
848
30
    if (count_var != 1)
849
9
      continue;
850
    // Otherwise, now loop again and prepare to get rid of it.
851
21
    ccv_nnc_micro_loop_block_t* const block = (ccv_nnc_micro_loop_block_t*)ccv_array_get(blocks, block_idx);
852
21
    ccv_nnc_micro_loop_statement_t* statements = block->loops[loop_idx].statements;
853
21
    ccv_nnc_micro_loop_statement_t statement = statements[statement_idx];
854
    // First, remove the assignment.
855
21
    if (statement_idx < block->loops[loop_idx].statement_count - 1)
856
21
      memmove(statements + statement_idx, statements + statement_idx + 1, sizeof(ccv_nnc_micro_loop_statement_t) * (block->loops[loop_idx].statement_count - statement_idx - 1));
857
21
    --block->loops[loop_idx].statement_count;
858
21
    const int statement_count = block->loops[loop_idx].statement_count;
859
21
    statements = block->loops[loop_idx].statements = (ccv_nnc_micro_loop_statement_t*)ccrealloc(statements, sizeof(ccv_nnc_micro_loop_statement_t) * statement_count);
860
21
    int k = 0;
861
85
    for (j = 0; j < statement_count; 
j++64
)
862
64
      _ccv_nnc_replacing_id_in_rvalue(&statements[j], i, rvalue, &k);
863
21
    if (k == 0) // If nothing to replace, free up everything.
864
0
      ccv_nnc_micro_loop_statement_free(&statement);
865
21
    else
866
21
      ccv_nnc_micro_loop_statement_lvalue_free(&statement);
867
    // No need to allocate for this var. It is not used, only useful for shape computation.
868
21
    vars[i].no_alloc = 1;
869
21
  }
870
6
}
871
872
static int _ccv_nnc_index_binary_size(const ccv_nnc_micro_loop_index_term_t index)
873
448
{
874
448
  switch (index.type)
875
448
  {
876
0
    case CCV_NNC_MICRO_LOOP_INDEX_TYPE_NONE:
877
0
      return 0;
878
48
    case CCV_NNC_MICRO_LOOP_INDEX_TYPE_VAL:
879
352
    case CCV_NNC_MICRO_LOOP_INDEX_TYPE_ID:
880
352
      return 1;
881
96
    case CCV_NNC_MICRO_LOOP_INDEX_TYPE_BINARY:
882
96
      if (index.binary->op == CCV_NNC_MICRO_BINARY_OP_PLUS || 
index.binary->op == CCV_NNC_MICRO_BINARY_OP_MINUS48
)
883
96
        return _ccv_nnc_index_binary_size(index.binary->left) + _ccv_nnc_index_binary_size(index.binary->right);
884
0
      else
885
0
        return 1;
886
448
  }
887
0
  return 0;
888
448
}
889
890
typedef struct {
891
  int sign:7;
892
  int ignore:1;
893
  ccv_nnc_micro_loop_index_term_t term;
894
} ccv_nnc_micro_loop_binary_term_t;
895
896
static void _ccv_nnc_index_term_flatten(ccv_nnc_micro_loop_binary_term_t* const binary_terms, const ccv_nnc_micro_loop_index_term_t index, const int sign, int* const i)
897
448
{
898
448
  switch (index.type)
899
448
  {
900
0
    case CCV_NNC_MICRO_LOOP_INDEX_TYPE_NONE: // No need to occupy.
901
0
      break;
902
48
    case CCV_NNC_MICRO_LOOP_INDEX_TYPE_VAL:
903
352
    case CCV_NNC_MICRO_LOOP_INDEX_TYPE_ID:
904
352
      binary_terms[*i].term = index;
905
352
      binary_terms[*i].sign = sign;
906
352
      binary_terms[*i].ignore = 0;
907
352
      ++(*i);
908
352
      break;
909
96
    case CCV_NNC_MICRO_LOOP_INDEX_TYPE_BINARY:
910
96
      if (index.binary->op == CCV_NNC_MICRO_BINARY_OP_PLUS || 
index.binary->op == CCV_NNC_MICRO_BINARY_OP_MINUS48
)
911
96
      {
912
96
        _ccv_nnc_index_term_flatten(binary_terms, index.binary->left, sign, i);
913
96
        if (index.binary->op == CCV_NNC_MICRO_BINARY_OP_MINUS) // Switch sign.
914
48
          _ccv_nnc_index_term_flatten(binary_terms, index.binary->right, sign == CCV_NNC_MICRO_BINARY_OP_PLUS ? CCV_NNC_MICRO_BINARY_OP_MINUS : 
CCV_NNC_MICRO_BINARY_OP_PLUS0
, i);
915
48
        else
916
48
          _ccv_nnc_index_term_flatten(binary_terms, index.binary->right, sign, i);
917
96
      } else {
918
0
        binary_terms[*i].term = index;
919
0
        binary_terms[*i].sign = sign;
920
0
        binary_terms[*i].ignore = 0;
921
0
        ++(*i);
922
0
      }
923
96
      break;
924
448
  }
925
448
}
926
927
// 0 is we don't understand, -1 is false, 1 is true.
928
static int _ccv_nnc_index_less_than_or_equal_to(const ccv_nnc_micro_loop_index_term_t left, const ccv_nnc_micro_loop_index_term_t right, const ccv_nnc_micro_tensor_t* const vars, const int var_count, const int* const groups, khash_t(ccv_nnc_axis_id_group)* const axis_id_groups)
929
256
{
930
  // Special case 1.
931
256
  if (left.type == CCV_NNC_MICRO_LOOP_INDEX_TYPE_VAL && 
right.type == CCV_NNC_MICRO_LOOP_INDEX_TYPE_VAL128
)
932
128
    return left.immediate_value <= right.immediate_value ? 1 : 
-10
;
933
  // Special case 2.
934
128
  if (left.type == CCV_NNC_MICRO_LOOP_INDEX_TYPE_VAL && 
left.immediate_value == 00
&&
right.type == CCV_NNC_MICRO_LOOP_INDEX_TYPE_ID0
&&
right.id.type == CCV_NNC_MICRO_AXIS_SIZE_ID0
)
935
0
    return 1;
936
  // Special case 3.
937
128
  if (left.type == CCV_NNC_MICRO_LOOP_INDEX_TYPE_ID && 
left.id.type == CCV_NNC_MICRO_AXIS_SIZE_ID104
&&
right.type == CCV_NNC_MICRO_LOOP_INDEX_TYPE_VAL92
&&
right.immediate_value == 00
)
938
0
    return -1;
939
  // Now, we only have one variable in both left and right, need to flat the binary tree (if possible) and reduce it to constant if possible.
940
  // We can only flatten if it is + / - at the moment.
941
128
  const int left_binary_size = _ccv_nnc_index_binary_size(left);
942
128
  assert(left_binary_size >= 1);
943
128
  const int right_binary_size = _ccv_nnc_index_binary_size(right);
944
128
  assert(right_binary_size >= 1);
945
128
  ccv_nnc_micro_loop_binary_term_t* const left_binary_terms = (ccv_nnc_micro_loop_binary_term_t*)ccmalloc(sizeof(ccv_nnc_micro_loop_binary_term_t) * (left_binary_size + right_binary_size));
946
128
  ccv_nnc_micro_loop_binary_term_t* const right_binary_terms = left_binary_terms + left_binary_size;
947
128
  int i, j;
948
128
  i = 0;
949
128
  _ccv_nnc_index_term_flatten(left_binary_terms, left, CCV_NNC_MICRO_BINARY_OP_PLUS, &i);
950
128
  assert(i == left_binary_size);
951
128
  i = 0;
952
128
  _ccv_nnc_index_term_flatten(right_binary_terms, right, CCV_NNC_MICRO_BINARY_OP_PLUS, &i);
953
128
  assert(i == right_binary_size);
954
  // Matching signs in left terms.
955
200
  
for (i = 0; 128
i < left_binary_size - 1;
i++72
)
956
228
    
for (j = i + 1; 72
j < left_binary_size;
j++156
)
957
156
      if (!left_binary_terms[i].ignore && 
!left_binary_terms[j].ignore132
&&
958
156
        
_ccv_nnc_same_index_term(left_binary_terms[i].term, left_binary_terms[j].term, groups, axis_id_groups)120
&&
959
156
        
left_binary_terms[i].sign != left_binary_terms[j].sign24
)
960
24
      {
961
24
        left_binary_terms[i].ignore = -1;
962
24
        left_binary_terms[j].ignore = -1;
963
24
      }
964
  // Matching signs in right terms.
965
152
  for (i = 0; i < right_binary_size - 1; 
i++24
)
966
60
    
for (j = i + 1; 24
j < right_binary_size;
j++36
)
967
36
      if (!right_binary_terms[i].ignore && !right_binary_terms[j].ignore &&
968
36
        _ccv_nnc_same_index_term(right_binary_terms[i].term, right_binary_terms[j].term, groups, axis_id_groups) &&
969
36
        
right_binary_terms[i].sign != right_binary_terms[j].sign0
)
970
0
      {
971
0
        right_binary_terms[i].ignore = -1;
972
0
        right_binary_terms[j].ignore = -1;
973
0
      }
974
  // Matching left to right.
975
328
  for (i = 0; i < left_binary_size; 
i++200
)
976
472
    
for (j = 0; 200
j < right_binary_size;
j++272
)
977
      // If they are the same, we can ignore now.
978
272
      if (!left_binary_terms[i].ignore && 
!right_binary_terms[j].ignore188
&&
979
272
        
_ccv_nnc_same_index_term(left_binary_terms[i].term, right_binary_terms[j].term, groups, axis_id_groups)152
&&
980
272
        
left_binary_terms[i].sign == right_binary_terms[j].sign140
)
981
140
      {
982
140
        left_binary_terms[i].ignore = -1;
983
140
        right_binary_terms[j].ignore = -1;
984
140
      }
985
  // After reduced, we should only have immediate values left, otherwise we cannot progress.
986
128
  int left_val = 0;
987
316
  for (i = 0; i < left_binary_size; 
i++188
)
988
200
    if (!left_binary_terms[i].ignore)
989
12
    {
990
12
      if (left_binary_terms[i].term.type != CCV_NNC_MICRO_LOOP_INDEX_TYPE_VAL)
991
12
      {
992
12
        ccfree(left_binary_terms);
993
12
        return 0;
994
12
      } else
995
0
        left_val += left_binary_terms[i].sign == CCV_NNC_MICRO_BINARY_OP_PLUS ? left_binary_terms[i].term.immediate_value : -left_binary_terms[i].term.immediate_value;
996
12
    }
997
116
  int right_val = 0;
998
256
  for (i = 0; i < right_binary_size; 
i++140
)
999
140
    if (!right_binary_terms[i].ignore)
1000
0
    {
1001
0
      if (right_binary_terms[i].term.type != CCV_NNC_MICRO_LOOP_INDEX_TYPE_VAL)
1002
0
      {
1003
0
        ccfree(left_binary_terms);
1004
0
        return 0;
1005
0
      } else
1006
0
        right_val += right_binary_terms[i].sign == CCV_NNC_MICRO_BINARY_OP_PLUS ? right_binary_terms[i].term.immediate_value : -right_binary_terms[i].term.immediate_value;
1007
0
    }
1008
116
  ccfree(left_binary_terms);
1009
116
  return left_val <= right_val ? 1 : 
-10
;
1010
116
}
1011
1012
// If this index term refers to an axis size that actually has a expression, refer to that instead (like for reindex operation).
1013
static ccv_nnc_micro_loop_index_term_t _ccv_nnc_micro_index_shape_merging(const ccv_nnc_micro_loop_index_term_t index, const ccv_nnc_micro_tensor_t* const vars, const int var_count, const int* const groups, khash_t(ccv_nnc_axis_id_group)* const axis_id_groups)
1014
429
{
1015
429
  ccv_nnc_micro_loop_index_term_t result = index;
1016
429
  for (;;)
1017
648
  {
1018
648
    if (!(result.type == CCV_NNC_MICRO_LOOP_INDEX_TYPE_ID && 
result.id.type == CCV_NNC_MICRO_AXIS_SIZE_ID451
))
1019
218
      return result;
1020
430
    int root = groups[result.id.id];
1021
430
    while (groups[root] != root)
1022
0
      root = groups[root];
1023
430
    if (vars[root].shape == 0)
1024
211
      return result;
1025
219
    assert(result.id.d >= 0 && result.id.d < vars[root].dimensions);
1026
219
    result = vars[root].shape[result.id.d];
1027
219
  }
1028
429
}
1029
1030
static int _ccv_nnc_micro_low_high_bound_from_index(const ccv_nnc_micro_loop_index_term_t index, ccv_nnc_micro_loop_index_term_t* const low_ref, ccv_nnc_micro_loop_index_term_t* const high_ref, const ccv_nnc_micro_loop_t* const loops, const int loop_count, const ccv_nnc_micro_tensor_t* const vars, const int var_count, const int* const groups, khash_t(ccv_nnc_axis_id_group)* const axis_id_groups)
1031
36
{
1032
36
  switch (index.type)
1033
36
  {
1034
0
    case CCV_NNC_MICRO_LOOP_INDEX_TYPE_NONE:
1035
0
      *low_ref = (ccv_nnc_micro_loop_index_term_t){
1036
0
        .type = CCV_NNC_MICRO_LOOP_INDEX_TYPE_VAL,
1037
0
        .immediate_value = 0
1038
0
      };
1039
0
      *high_ref = (ccv_nnc_micro_loop_index_term_t){
1040
0
        .type = CCV_NNC_MICRO_LOOP_INDEX_TYPE_VAL,
1041
0
        .immediate_value = 0
1042
0
      };
1043
0
      return 1;
1044
24
    case CCV_NNC_MICRO_LOOP_INDEX_TYPE_ID:
1045
24
      if (index.id.type == CCV_NNC_MICRO_LOOP_ID)
1046
24
      {
1047
24
        int loop_idx = -1;
1048
24
        int i;
1049
112
        for (i = 0; loop_idx < 0 && 
i < loop_count88
;
i++88
)
1050
88
          if (loops[i].id.id == index.id.id)
1051
24
            loop_idx = i;
1052
24
        assert(loop_idx >= 0);
1053
24
        const ccv_nnc_micro_loop_index_term_t start_index = _ccv_nnc_micro_index_shape_merging(loops[loop_idx].start_index, vars, var_count, groups, axis_id_groups);
1054
24
        const ccv_nnc_micro_loop_index_term_t end_index = _ccv_nnc_micro_index_shape_merging(loops[loop_idx].end_index, vars, var_count, groups, axis_id_groups);
1055
24
        *low_ref = ccv_nnc_micro_loop_index_deep_copy(&start_index);
1056
24
        *high_ref = ccv_nnc_micro_loop_index_deep_copy(&end_index);
1057
24
      } else {
1058
0
        *low_ref = index;
1059
0
        *high_ref = index;
1060
0
      }
1061
24
      return 1;
1062
0
    case CCV_NNC_MICRO_LOOP_INDEX_TYPE_VAL:
1063
0
      *low_ref = index;
1064
0
      *high_ref = index;
1065
0
      return 1;
1066
12
    case CCV_NNC_MICRO_LOOP_INDEX_TYPE_BINARY: {
1067
      // Get low, high from both left and right, and then construct new low / high.
1068
12
      ccv_nnc_micro_loop_index_term_t left_low, left_high;
1069
12
      if (!_ccv_nnc_micro_low_high_bound_from_index(index.binary->left, &left_low, &left_high, loops, loop_count, vars, var_count, groups, axis_id_groups))
1070
0
        return 0;
1071
12
      ccv_nnc_micro_loop_index_term_t right_low, right_high;
1072
12
      if (!_ccv_nnc_micro_low_high_bound_from_index(index.binary->right, &right_low, &right_high, loops, loop_count, vars, var_count, groups, axis_id_groups))
1073
0
      {
1074
0
        ccv_nnc_micro_loop_index_free(&left_low);
1075
0
        ccv_nnc_micro_loop_index_free(&left_high);
1076
0
        return 0;
1077
0
      }
1078
      // If left is not a range, or right is not a range, it is simple, just copy over.
1079
12
      if (_ccv_nnc_same_index_term(left_low, left_high, groups, axis_id_groups) || _ccv_nnc_same_index_term(right_low, right_high, groups, axis_id_groups))
1080
0
      {
1081
0
        *low_ref = (ccv_nnc_micro_loop_index_term_t){
1082
0
          .type = CCV_NNC_MICRO_LOOP_INDEX_TYPE_BINARY,
1083
0
          .binary = (ccv_nnc_micro_loop_index_binary_t*)ccmalloc(sizeof(ccv_nnc_micro_loop_index_binary_t))
1084
0
        };
1085
0
        low_ref->binary->op = index.binary->op;
1086
0
        low_ref->binary->left = left_low;
1087
0
        low_ref->binary->right = right_low;
1088
0
        *high_ref = (ccv_nnc_micro_loop_index_term_t){
1089
0
          .type = CCV_NNC_MICRO_LOOP_INDEX_TYPE_BINARY,
1090
0
          .binary = (ccv_nnc_micro_loop_index_binary_t*)ccmalloc(sizeof(ccv_nnc_micro_loop_index_binary_t))
1091
0
        };
1092
0
        high_ref->binary->op = index.binary->op;
1093
0
        high_ref->binary->left = left_high;
1094
0
        high_ref->binary->right = right_high;
1095
0
        return 1;
1096
0
      }
1097
      // Cannot handle -, because lower bound will go to negative, similar for /. Only can handle + and *.
1098
12
      if (!(index.binary->op == CCV_NNC_MICRO_BINARY_OP_PLUS || 
index.binary->op == CCV_NNC_MICRO_BINARY_OP_MUL0
) ||
1099
        // If lower bound is not a non-negative integer, we cannot compute interesting low / high bound, abort.
1100
12
        (left_low.type != CCV_NNC_MICRO_LOOP_INDEX_TYPE_VAL || left_low.immediate_value < 0) ||
1101
12
        (right_low.type != CCV_NNC_MICRO_LOOP_INDEX_TYPE_VAL || right_low.immediate_value < 0))
1102
0
      {
1103
0
        ccv_nnc_micro_loop_index_free(&left_low);
1104
0
        ccv_nnc_micro_loop_index_free(&left_high);
1105
0
        ccv_nnc_micro_loop_index_free(&right_low);
1106
0
        ccv_nnc_micro_loop_index_free(&right_high);
1107
0
        return 0;
1108
0
      }
1109
12
      *low_ref = (ccv_nnc_micro_loop_index_term_t){
1110
12
        .type = CCV_NNC_MICRO_LOOP_INDEX_TYPE_VAL,
1111
12
        .immediate_value = index.binary->op == CCV_NNC_MICRO_BINARY_OP_PLUS ? left_low.immediate_value + right_low.immediate_value : 
left_low.immediate_value * right_low.immediate_value0
,
1112
12
      };
1113
      // higher bound is not inclusive, hence, we need to minus extra 1 for this.
1114
12
      if (index.binary->op == CCV_NNC_MICRO_BINARY_OP_PLUS)
1115
12
      {
1116
        // (left - 1) + (right - 1) + 1
1117
12
        ccv_nnc_micro_loop_index_term_t sum = {
1118
12
          .type = CCV_NNC_MICRO_LOOP_INDEX_TYPE_BINARY,
1119
12
          .binary = (ccv_nnc_micro_loop_index_binary_t*)ccmalloc(sizeof(ccv_nnc_micro_loop_index_binary_t))
1120
12
        };
1121
12
        sum.binary->op = CCV_NNC_MICRO_BINARY_OP_PLUS;
1122
12
        sum.binary->left = left_high;
1123
12
        sum.binary->right = right_high;
1124
12
        *high_ref = (ccv_nnc_micro_loop_index_term_t){
1125
12
          .type = CCV_NNC_MICRO_LOOP_INDEX_TYPE_BINARY,
1126
12
          .binary = (ccv_nnc_micro_loop_index_binary_t*)ccmalloc(sizeof(ccv_nnc_micro_loop_index_binary_t))
1127
12
        };
1128
12
        high_ref->binary->op = CCV_NNC_MICRO_BINARY_OP_MINUS;
1129
12
        high_ref->binary->left = sum;
1130
12
        high_ref->binary->right = (ccv_nnc_micro_loop_index_term_t){
1131
12
          .type = CCV_NNC_MICRO_LOOP_INDEX_TYPE_VAL,
1132
12
          .immediate_value = 1
1133
12
        };
1134
12
      } else {
1135
        // (left - 1) * (right - 1) + 1
1136
0
        ccv_nnc_micro_loop_index_term_t prod = {
1137
0
          .type = CCV_NNC_MICRO_LOOP_INDEX_TYPE_BINARY,
1138
0
          .binary = (ccv_nnc_micro_loop_index_binary_t*)ccmalloc(sizeof(ccv_nnc_micro_loop_index_binary_t))
1139
0
        };
1140
0
        prod.binary->op = CCV_NNC_MICRO_BINARY_OP_MUL;
1141
0
        prod.binary->left = left_high;
1142
0
        prod.binary->right = right_high;
1143
0
        ccv_nnc_micro_loop_index_term_t minus_left = {
1144
0
          .type = CCV_NNC_MICRO_LOOP_INDEX_TYPE_BINARY,
1145
0
          .binary = (ccv_nnc_micro_loop_index_binary_t*)ccmalloc(sizeof(ccv_nnc_micro_loop_index_binary_t))
1146
0
        };
1147
0
        minus_left.binary->op = CCV_NNC_MICRO_BINARY_OP_MINUS;
1148
0
        minus_left.binary->left = prod;
1149
0
        minus_left.binary->right = left_high;
1150
0
        ccv_nnc_micro_loop_index_term_t minus_right = {
1151
0
          .type = CCV_NNC_MICRO_LOOP_INDEX_TYPE_BINARY,
1152
0
          .binary = (ccv_nnc_micro_loop_index_binary_t*)ccmalloc(sizeof(ccv_nnc_micro_loop_index_binary_t))
1153
0
        };
1154
0
        minus_right.binary->op = CCV_NNC_MICRO_BINARY_OP_MINUS;
1155
0
        minus_right.binary->left = minus_left;
1156
0
        minus_right.binary->right = right_high;
1157
0
        *high_ref = (ccv_nnc_micro_loop_index_term_t){
1158
0
          .type = CCV_NNC_MICRO_LOOP_INDEX_TYPE_BINARY,
1159
0
          .binary = (ccv_nnc_micro_loop_index_binary_t*)ccmalloc(sizeof(ccv_nnc_micro_loop_index_binary_t))
1160
0
        };
1161
0
        high_ref->binary->op = CCV_NNC_MICRO_BINARY_OP_PLUS;
1162
0
        high_ref->binary->left = minus_right;
1163
0
        high_ref->binary->right = (ccv_nnc_micro_loop_index_term_t){
1164
0
          .type = CCV_NNC_MICRO_LOOP_INDEX_TYPE_VAL,
1165
0
          .immediate_value = 2
1166
0
        };
1167
0
      }
1168
12
      return 1;
1169
12
    }
1170
36
  }
1171
0
  return 0;
1172
36
}
1173
1174
static void _ccv_nnc_micro_check_bound_for_variable(ccv_nnc_micro_loop_variable_t* const variable, const ccv_nnc_micro_loop_t* const loops, const int loop_count, const ccv_nnc_micro_tensor_t* const vars, const int var_count, const int* const groups, khash_t(ccv_nnc_axis_id_group)* const axis_id_groups)
1175
39
{
1176
39
  if (variable->id.type != CCV_NNC_MICRO_TENSOR_ID)
1177
0
    return;
1178
39
  int i, j;
1179
39
  assert(variable->id.id >= 0 && variable->id.id < var_count);
1180
39
  ccv_nnc_micro_loop_index_term_t index_zero = {
1181
39
    .type = CCV_NNC_MICRO_LOOP_INDEX_TYPE_VAL,
1182
39
    .immediate_value = 0
1183
39
  };
1184
188
  for (i = 0; i < variable->index_count; 
i++149
)
1185
149
  {
1186
149
    const ccv_nnc_micro_loop_index_term_t shape = _ccv_nnc_micro_index_shape_merging((ccv_nnc_micro_loop_index_term_t){
1187
149
      .type = CCV_NNC_MICRO_LOOP_INDEX_TYPE_ID,
1188
149
      .id = {
1189
149
        .type = CCV_NNC_MICRO_AXIS_SIZE_ID,
1190
149
        .id = variable->id.id,
1191
149
        .d = i
1192
149
      }
1193
149
    }, vars, var_count, groups, axis_id_groups);
1194
149
    switch (variable->index[i].type)
1195
149
    {
1196
116
      case CCV_NNC_MICRO_LOOP_INDEX_TYPE_ID:
1197
        // For loop id, we can check the range to see if it is within the shape.
1198
116
        if (variable->index[i].id.type == CCV_NNC_MICRO_LOOP_ID)
1199
116
        {
1200
116
          int loop_idx = -1;
1201
482
          for (j = 0; loop_idx < 0 && 
j < loop_count366
;
j++366
)
1202
366
            if (loops[j].id.id == variable->index[i].id.id)
1203
116
              loop_idx = j;
1204
116
          assert(loop_idx >= 0);
1205
116
          const ccv_nnc_micro_loop_index_term_t start_index = _ccv_nnc_micro_index_shape_merging(loops[loop_idx].start_index, vars, var_count, groups, axis_id_groups);
1206
116
          const ccv_nnc_micro_loop_index_term_t end_index = _ccv_nnc_micro_index_shape_merging(loops[loop_idx].end_index, vars, var_count, groups, axis_id_groups);
1207
116
          if (_ccv_nnc_index_less_than_or_equal_to(index_zero, start_index, vars, var_count, groups, axis_id_groups) == 1 &&
1208
116
            _ccv_nnc_index_less_than_or_equal_to(end_index, shape, vars, var_count, groups, axis_id_groups) == 1)
1209
104
            variable->no_check_bound[i] = 1;
1210
12
          else
1211
12
            variable->no_check_bound[i] = 0;
1212
116
        } else // If it is anything other than loop id, we have to check the bound.
1213
0
          variable->no_check_bound[i] = 0;
1214
116
        break;
1215
116
      case CCV_NNC_MICRO_LOOP_INDEX_TYPE_BINARY: {
1216
        // Compute higher / lower bounds along the expression.
1217
12
        ccv_nnc_micro_loop_index_term_t low, high;
1218
        // Cannot find high low, mark no_check_bound[i] = 0
1219
12
        if (!_ccv_nnc_micro_low_high_bound_from_index(variable->index[i], &low, &high, loops, loop_count, vars, var_count, groups, axis_id_groups))
1220
0
        {
1221
0
          variable->no_check_bound[i] = 0;
1222
0
          break;
1223
0
        }
1224
12
        if (_ccv_nnc_index_less_than_or_equal_to(index_zero, low, vars, var_count, groups, axis_id_groups) == 1 &&
1225
12
          _ccv_nnc_index_less_than_or_equal_to(high, shape, vars, var_count, groups, axis_id_groups) == 1)
1226
12
          variable->no_check_bound[i] = 1;
1227
0
        else
1228
0
          variable->no_check_bound[i] = 0;
1229
12
        ccv_nnc_micro_loop_index_free(&low);
1230
12
        ccv_nnc_micro_loop_index_free(&high);
1231
12
        break;
1232
12
      }
1233
21
      case CCV_NNC_MICRO_LOOP_INDEX_TYPE_VAL:
1234
        // If the index is an integer, and it is bigger than 0, we need to check bound (there is no assertion the end index is larger than anything other than 0).
1235
21
        if (variable->index[i].immediate_value == 0)
1236
21
          variable->no_check_bound[i] = 1;
1237
0
        else
1238
0
          variable->no_check_bound[i] = 0;
1239
21
        break;
1240
149
    }
1241
149
  }
1242
39
}
1243
1244
static void _ccv_nnc_micro_check_bound_for_expression(ccv_nnc_micro_loop_expression_t* const expression, const ccv_nnc_micro_loop_t* const loops, const int loop_count, const ccv_nnc_micro_tensor_t* const vars, const int var_count, const int* const groups, khash_t(ccv_nnc_axis_id_group)* const axis_id_groups)
1245
39
{
1246
39
  switch (expression->type)
1247
39
  {
1248
21
    case CCV_NNC_MICRO_LOOP_EXPR_TYPE_VAR:
1249
21
      _ccv_nnc_micro_check_bound_for_variable(&expression->variable, loops, loop_count, vars, var_count, groups, axis_id_groups);
1250
21
      break;
1251
0
    case CCV_NNC_MICRO_LOOP_EXPR_TYPE_TERNAY:
1252
0
      _ccv_nnc_micro_check_bound_for_expression(expression->ternary.pivot, loops, loop_count, vars, var_count, groups, axis_id_groups);
1253
0
      _ccv_nnc_micro_check_bound_for_expression(expression->ternary.left, loops, loop_count, vars, var_count, groups, axis_id_groups);
1254
0
      _ccv_nnc_micro_check_bound_for_expression(expression->ternary.right, loops, loop_count, vars, var_count, groups, axis_id_groups);
1255
0
      break;
1256
9
    case CCV_NNC_MICRO_LOOP_EXPR_TYPE_BINARY:
1257
9
      _ccv_nnc_micro_check_bound_for_expression(expression->binary.left, loops, loop_count, vars, var_count, groups, axis_id_groups);
1258
9
      _ccv_nnc_micro_check_bound_for_expression(expression->binary.right, loops, loop_count, vars, var_count, groups, axis_id_groups);
1259
9
      break;
1260
0
    case CCV_NNC_MICRO_LOOP_EXPR_TYPE_UNARY:
1261
0
      _ccv_nnc_micro_check_bound_for_expression(expression->unary.x, loops, loop_count, vars, var_count, groups, axis_id_groups);
1262
0
      break;
1263
39
  }
1264
39
}
1265
1266
static void _ccv_nnc_micro_check_bound_for_block(ccv_nnc_micro_loop_block_t* const block, const ccv_nnc_micro_tensor_t* const vars, const int var_count, const int* const groups, khash_t(ccv_nnc_axis_id_group)* const axis_id_groups)
1267
15
{
1268
15
  int i, j;
1269
78
  for (i = 0; i < block->loop_count; 
i++63
)
1270
63
  {
1271
63
    const int statement_count = block->loops[i].statement_count;
1272
63
    ccv_nnc_micro_loop_statement_t* const statements = block->loops[i].statements;
1273
84
    for (j = 0; j < statement_count; 
j++21
)
1274
21
    {
1275
21
      switch (statements[j].type)
1276
21
      {
1277
12
        case CCV_NNC_MICRO_LOOP_STATEMENT_TYPE_ASSIGNMENT:
1278
12
          _ccv_nnc_micro_check_bound_for_variable(&statements[j].assignment.lvalue, block->loops, block->loop_count, vars, var_count, groups, axis_id_groups);
1279
12
          _ccv_nnc_micro_check_bound_for_expression(&statements[j].assignment.rvalue, block->loops, block->loop_count, vars, var_count, groups, axis_id_groups);
1280
12
          break;
1281
9
        case CCV_NNC_MICRO_LOOP_STATEMENT_TYPE_COMPOUND_ASSIGNMENT:
1282
9
          if (statements[j].compound_assignment.lvalue.type == CCV_NNC_MICRO_LOOP_EXPR_TYPE_VAR)
1283
6
            _ccv_nnc_micro_check_bound_for_variable(&statements[j].compound_assignment.lvalue.variable, block->loops, block->loop_count, vars, var_count, groups, axis_id_groups);
1284
9
          _ccv_nnc_micro_check_bound_for_expression(&statements[j].compound_assignment.rvalue, block->loops, block->loop_count, vars, var_count, groups, axis_id_groups);
1285
9
          break;
1286
21
      }
1287
21
    }
1288
63
  }
1289
15
}
1290
1291
void ccv_nnc_micro_program_simplify(ccv_nnc_micro_program_t* const program, const ccv_nnc_micro_io_t* const inputs, const int input_size, const ccv_nnc_micro_io_t* const outputs, const int output_size, const ccv_array_t* const equal_assertions)
1292
6
{
1293
  // Nothing to simplify for.
1294
6
  if (program->function_count < 1)
1295
0
    return;
1296
  // Only one block, nothing to simplify for.
1297
6
  if (program->function_count == 1 && 
program->functions[0].block_count == 10
)
1298
0
    return;
1299
6
  if (input_size == 0 || output_size == 0)
1300
0
    return;
1301
  // Union-find to group all variables with the same shape.
1302
6
  ccv_nnc_micro_tensor_t* const vars = program->vars;
1303
6
  const int var_count = program->var_count;
1304
6
  int* const groups = (int*)ccmalloc(sizeof(int) * var_count);
1305
6
  int i, j;
1306
60
  for (i = 0; i < var_count; 
i++54
)
1307
54
    groups[i] = i;
1308
  // If no shape, they should match these input.
1309
60
  for (i = 0; i < var_count; 
i++54
)
1310
54
    if (vars[i].input >= 0 && 
!vars[i].shape42
)
1311
24
    {
1312
24
      int root = vars[i].input;
1313
27
      while (groups[root] != root)
1314
3
        root = groups[root];
1315
24
      groups[i] = root;
1316
24
    }
1317
60
  for (i = 0; i < var_count; 
i++54
)
1318
54
  {
1319
    // If this is input (no other tensor as the input), we skip.
1320
54
    if (vars[i].input < 0)
1321
12
      continue;
1322
42
    int root = i;
1323
66
    while (groups[root] != root)
1324
24
      root = groups[root];
1325
    // If the sibling exists and we haven't visited yet, mark them has the same group as us.
1326
42
    if (vars[i].sibling >= 0 && 
vars[i].sibling < i24
&&
groups[vars[i].sibling] < 024
)
1327
0
      groups[vars[i].sibling] = root;
1328
42
  }
1329
54
  for (i = var_count - 1; i > 0; 
i--48
)
1330
48
  {
1331
    // Now matching the shape.
1332
48
    if (vars[i].input < 0 || 
!vars[i].shape42
)
1333
30
      continue;
1334
18
    int root = i;
1335
22
    while (groups[root] != root)
1336
4
      root = groups[root];
1337
78
    for (j = i - 1; j >= 0; 
j--60
)
1338
60
      if (vars[j].shape && 
vars[j].dimensions == vars[i].dimensions18
&&
1339
60
        
_ccv_nnc_same_shape(vars[j].shape, vars[i].shape, vars[i].dimensions)18
)
1340
4
        groups[j] = root;
1341
18
  }
1342
  // Group equal assertions on axis together.
1343
6
  khash_t(ccv_nnc_axis_id_group)* const axis_id_groups = kh_init(ccv_nnc_axis_id_group);
1344
14
  for (i = 0; i < equal_assertions->rnum; 
i++8
)
1345
8
  {
1346
8
    const ccv_nnc_micro_id_equal_assertion_t* const equal_assertion = (ccv_nnc_micro_id_equal_assertion_t*)ccv_array_get(equal_assertions, i);
1347
8
    ccv_nnc_micro_id_t left = equal_assertion->left;
1348
8
    while (groups[left.id] != left.id)
1349
0
      left.id = groups[left.id];
1350
8
    int left_root = MICRO_ID_TO_INT(left);
1351
8
    khiter_t k;
1352
10
    for (;;) {
1353
10
      k = kh_get(ccv_nnc_axis_id_group, axis_id_groups, left_root);
1354
10
      if (k == kh_end(axis_id_groups))
1355
8
        break;
1356
2
      left_root = kh_val(axis_id_groups, k);
1357
2
    }
1358
8
    ccv_nnc_micro_id_t right = equal_assertion->right;
1359
8
    while (groups[right.id] != right.id)
1360
0
      left.id = groups[right.id];
1361
8
    int right_root = MICRO_ID_TO_INT(equal_assertion->right);
1362
10
    for (;;) {
1363
10
      k = kh_get(ccv_nnc_axis_id_group, axis_id_groups, right_root);
1364
10
      if (k == kh_end(axis_id_groups))
1365
8
        break;
1366
2
      right_root = kh_val(axis_id_groups, k);
1367
2
    }
1368
8
    if (left_root != right_root) // k is the right root at the moment.
1369
4
    {
1370
4
      int ret;
1371
4
      k = kh_put(ccv_nnc_axis_id_group, axis_id_groups, right_root, &ret);
1372
4
      assert(ret != 0);
1373
4
      kh_val(axis_id_groups, k) = left_root;
1374
4
    }
1375
8
  }
1376
  // First, flat out all functions into blocks.
1377
6
  ccv_array_t* const blocks = ccv_array_new(sizeof(ccv_nnc_micro_loop_block_t), 0, 0);
1378
6
  ccv_nnc_micro_function_t* const functions = program->functions;
1379
6
  const int function_count = program->function_count;
1380
6
  int max_loop_count = 0;
1381
42
  for (i = 0; i < function_count; 
i++36
)
1382
36
  {
1383
36
    const int block_count = functions[i].block_count;
1384
36
    ccv_nnc_micro_loop_block_t* const function_blocks = block_count == 1 ? 
&functions[i].one_block27
:
functions[i].blocks9
;
1385
81
    for (j = 0; j < block_count; 
j++45
)
1386
45
    {
1387
45
      max_loop_count = ccv_max(function_blocks[j].loop_count, max_loop_count);
1388
45
      ccv_array_push(blocks, &function_blocks[j]);
1389
45
    }
1390
36
  }
1391
  // Next, find dependencies between these function blocks and marking these that are dependencies for the final outputs.
1392
  // We need to build our connections between blocks <-> r/w vars.
1393
6
  ccv_nnc_micro_loop_block_dependency_t* block_dependencies;
1394
6
  ccv_nnc_micro_tensor_dependency_t* tensor_dependencies;
1395
6
  const int block_size = blocks->rnum;
1396
6
  _ccv_nnc_micro_block_dependencies((ccv_nnc_micro_loop_block_t*)ccv_array_get(blocks, 0), block_size, var_count, &block_dependencies, &tensor_dependencies);
1397
6
  ccv_array_t* const in_use = ccv_array_new(sizeof(int), output_size, 0);
1398
  // Use the dependencies to mark blocks / vars that are in use.
1399
15
  for (i = 0; i < output_size; 
i++9
)
1400
9
  {
1401
9
    tensor_dependencies[outputs[i]->id].flag = 1; // Mark them as in use.
1402
9
    ccv_array_push(in_use, &outputs[i]->id);
1403
9
  }
1404
21
  for (i = 0; i < input_size; 
i++15
)
1405
15
    tensor_dependencies[inputs[i]->id].flag = 1; // Mark inputs as in use so we don't go pass them.
1406
39
  for (i = 0; i < in_use->rnum; 
i++33
)
1407
33
  {
1408
33
    const int tensor_idx = *(int*)ccv_array_get(in_use, i);
1409
33
    if (tensor_dependencies[tensor_idx].writes)
1410
72
      
for (j = 0; 33
j < tensor_dependencies[tensor_idx].writes->rnum;
j++39
)
1411
39
      {
1412
39
        const int block_idx = *(int*)ccv_array_get(tensor_dependencies[tensor_idx].writes, j);
1413
39
        block_dependencies[block_idx].flag = 1;
1414
39
        int k;
1415
39
        if (block_dependencies[block_idx].reads)
1416
81
          
for (k = 0; 33
k < block_dependencies[block_idx].reads->rnum;
k++48
)
1417
48
          {
1418
48
            const int read_idx = *(int*)ccv_array_get(block_dependencies[block_idx].reads, k);
1419
48
            if (!tensor_dependencies[read_idx].flag)
1420
24
            {
1421
24
              tensor_dependencies[read_idx].flag = 1;
1422
24
              ccv_array_push(in_use, &read_idx);
1423
24
            }
1424
48
          }
1425
39
      }
1426
33
  }
1427
6
  ccv_array_free(in_use);
1428
51
  for (i = 0; i < block_size; 
i++45
)
1429
45
    if (!block_dependencies[i].flag)
1430
6
    {
1431
6
      ccv_nnc_micro_loop_block_t* const block = (ccv_nnc_micro_loop_block_t*)ccv_array_get(blocks, i);
1432
6
      ccv_nnc_micro_loops_free(block->loops, block->loop_count);
1433
6
      ccfree(block->loops);
1434
6
      block->loops = 0;
1435
6
      block->loop_count = 0;
1436
6
    }
1437
60
  for (i = 0; i < var_count; 
i++54
)
1438
54
    if (!tensor_dependencies[i].flag) // If this tensor is not visited, there is no need to alloc.
1439
6
    {
1440
6
      _ccv_nnc_tensor_remove_dead_store(&tensor_dependencies[i], i, blocks);
1441
6
      vars[i].no_alloc = 1;
1442
6
    }
1443
6
  _ccv_nnc_loop_merging(block_dependencies, tensor_dependencies, blocks, max_loop_count, groups, axis_id_groups);
1444
6
  _ccv_nnc_micro_dependencies_free(block_dependencies, block_size, tensor_dependencies, var_count);
1445
  // Culling out empty blocks.
1446
51
  for (i = 0, j = 0; i < blocks->rnum; 
i++45
)
1447
45
  {
1448
45
    const ccv_nnc_micro_loop_block_t* const block = (ccv_nnc_micro_loop_block_t*)ccv_array_get(blocks, i);
1449
45
    if (block->loop_count > 0)
1450
15
    {
1451
15
      *(ccv_nnc_micro_loop_block_t*)ccv_array_get(blocks, j) = *block;
1452
15
      ++j;
1453
15
    }
1454
45
  }
1455
  // Now we moved everything, set the proper block size.
1456
6
  ccv_array_resize(blocks, j);
1457
  // Substitute variables.
1458
6
  _ccv_nnc_var_subst(vars, var_count, inputs, input_size, outputs, output_size, blocks, groups, axis_id_groups);
1459
  // Mark whether we need to check bound for a particular variable or not.
1460
21
  for (i = 0; i < blocks->rnum; 
i++15
)
1461
15
  {
1462
15
    ccv_nnc_micro_loop_block_t* const block = (ccv_nnc_micro_loop_block_t*)ccv_array_get(blocks, i);
1463
15
    _ccv_nnc_micro_check_bound_for_block(block, vars, var_count, groups, axis_id_groups);
1464
15
  }
1465
6
  free(groups);
1466
6
  kh_destroy(ccv_nnc_axis_id_group, axis_id_groups);
1467
  // Reallocate function to be 1.
1468
42
  for (i = 0; i < function_count; 
i++36
)
1469
36
    if (functions[i].block_count > 1)
1470
9
      ccfree(functions[i].blocks);
1471
6
  program->functions = (ccv_nnc_micro_function_t*)ccrealloc(program->functions, sizeof(ccv_nnc_micro_function_t));
1472
6
  program->functions[0].block_count = blocks->rnum;
1473
6
  if (blocks->rnum > 1)
1474
4
  {
1475
4
    program->functions[0].blocks = (ccv_nnc_micro_loop_block_t*)ccmalloc(sizeof(ccv_nnc_micro_loop_block_t) * blocks->rnum);
1476
4
    memcpy(program->functions[0].blocks, ccv_array_get(blocks, 0), sizeof(ccv_nnc_micro_loop_block_t) * blocks->rnum);
1477
4
  } else
1478
2
    program->functions[0].one_block = *(ccv_nnc_micro_loop_block_t*)ccv_array_get(blocks, 0);
1479
6
  program->function_count = 1;
1480
6
  ccv_array_free(blocks);
1481
6
}