Coverage Report

Created: 2025-02-24 17:43

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/home/liu/actions-runner/_work/ccv/ccv/lib/nnc/ccv_nnc_symbolic_graph_simplify.c
Line
Count
Source
1
#include "ccv_nnc.h"
2
#include "ccv_nnc_easy.h"
3
#include "ccv_nnc_internal.h"
4
#include "ccv_internal.h"
5
#include "_ccv_nnc_symbolic_graph.h"
6
#include "3rdparty/siphash/siphash24.h"
7
8
// MARK - Level-3.5 API
9
10
static uint8_t key_siphash[16] = "graphcsekvlibnnc";
11
12
typedef struct {
13
  int tensor_symbol_info_size;
14
  int exec_symbol_info_size;
15
  ccv_nnc_symbolic_graph_t* graph;
16
  ccv_nnc_graph_visit_t* visit;
17
  ccv_nnc_tensor_symbol_info_t* tensor_symbol_info;
18
  ccv_nnc_graph_exec_symbol_info_t* exec_symbol_info;
19
  uint32_t* exec_dead; // Mark a exec is dead and need to be cleared, each bit represent a exec.
20
  uint32_t* tensor_dead; // Mark a tensor is dead and need to be cleared, each bit represent a tensor.
21
  int* output_execs; // Mapping from tensor to the exec that generates this tensor.
22
} ccv_nnc_symbolic_graph_simplify_t;
23
24
static void _ccv_nnc_symbolic_graph_simplify_update_output_execs(ccv_nnc_symbolic_graph_simplify_t* const simplify)
25
6.90k
{
26
6.90k
  int i;
27
39.6k
  for (i = 0; i < simplify->tensor_symbol_info_size; 
i++32.7k
)
28
32.7k
    simplify->output_execs[i] = -1;
29
11.8k
  ccv_nnc_graph_visit_for(simplify->visit, simplify->exec_symbol_info, node, idx) {
30
11.8k
    if (simplify->exec_dead[idx >> 5] & (1u << (idx & 0x1f)))
31
17
      continue;
32
24.6k
    
for (i = 0; 11.8k
i < node->output_size;
i++12.8k
)
33
12.8k
      if (node->outputs[i] >= 0)
34
12.8k
        simplify->output_execs[node->outputs[i]] = idx; // A tensor can only be written once.
35
11.8k
  } ccv_nnc_graph_visit_endfor
36
6.90k
}
37
38
static ccv_nnc_symbolic_graph_simplify_t* _ccv_nnc_symbolic_graph_simplify_new(ccv_nnc_symbolic_graph_t* const graph, const ccv_nnc_graph_exec_symbol_t* const sources, const int source_size, const ccv_nnc_graph_exec_symbol_t* const destinations, const int destination_size)
39
4.59k
{
40
4.59k
  ccv_nnc_symbolic_graph_simplify_t* const simplify = (ccv_nnc_symbolic_graph_simplify_t*)ccmalloc(sizeof(ccv_nnc_symbolic_graph_simplify_t));
41
4.59k
  simplify->graph = graph;
42
4.59k
  simplify->visit = ccv_nnc_graph_visit_new(graph, (ccv_nnc_graph_exec_symbol_info_t*)ccv_array_get(graph->exec_symbol_info, 0), graph->exec_symbol_info->rnum, sources, source_size, destinations, destination_size, 0);
43
4.59k
  simplify->tensor_symbol_info = (ccv_nnc_tensor_symbol_info_t*)ccmalloc(sizeof(ccv_nnc_tensor_symbol_info_t) * graph->tensor_symbol_info->rnum);
44
4.59k
  simplify->exec_symbol_info = (ccv_nnc_graph_exec_symbol_info_t*)ccmalloc(sizeof(ccv_nnc_graph_exec_symbol_info_t) * graph->exec_symbol_info->rnum);
45
4.59k
  ccv_nnc_symbolic_graph_symbol_infer(graph, simplify->visit, sources, source_size, destinations, destination_size, 0, 0, simplify->tensor_symbol_info, simplify->exec_symbol_info);
46
4.59k
  simplify->tensor_symbol_info_size = graph->tensor_symbol_info->rnum;
47
4.59k
  simplify->exec_symbol_info_size = graph->exec_symbol_info->rnum;
48
4.59k
  simplify->exec_dead = cccalloc(((simplify->exec_symbol_info_size + 31) >> 5) + ((simplify->tensor_symbol_info_size + 31) >> 5), sizeof(uint32_t));
49
4.59k
  simplify->tensor_dead = simplify->exec_dead + ((simplify->exec_symbol_info_size + 31) >> 5);
50
4.59k
  simplify->output_execs = (int*)ccmalloc(sizeof(int) * simplify->tensor_symbol_info_size);
51
4.59k
  return simplify;
52
4.59k
}
53
54
static void _ccv_nnc_symbolic_graph_simplify_apply(ccv_nnc_symbolic_graph_simplify_t* const simplify)
55
4.59k
{
56
4.59k
  int i, j;
57
12.5k
  for (i = 0; i < simplify->exec_symbol_info_size; 
i++7.91k
)
58
7.91k
    if (simplify->exec_dead[i >> 5] & (1u << (i & 0x1f)))
59
18
      ccv_nnc_graph_exec_symbol_free(simplify->graph, (ccv_nnc_graph_exec_symbol_t){
60
18
        .d = i,
61
18
        .graph = simplify->graph,
62
18
      });
63
7.90k
    else // If it is not marked as dead, go through to unmark tensor
64
16.5k
      
for (j = 0; 7.90k
j < simplify->exec_symbol_info[i].output_size;
j++8.61k
)
65
8.61k
      {
66
8.61k
        const int d = simplify->exec_symbol_info[i].outputs[j];
67
8.61k
        if (d >= 0)
68
8.58k
          simplify->tensor_dead[d >> 5] &= ~(1u << (d & 0x1f));
69
8.61k
      }
70
26.3k
  for (i = 0; i < simplify->tensor_symbol_info_size; 
i++21.7k
)
71
21.7k
    if (simplify->tensor_dead[i >> 5] & (1u << (i & 0x1f)))
72
16
      ccv_nnc_tensor_symbol_free(simplify->graph, (ccv_nnc_tensor_symbol_t){
73
16
        .d = i,
74
16
        .graph = simplify->graph,
75
16
      });
76
4.59k
}
77
78
static void _ccv_nnc_symbolic_graph_simplify_free(ccv_nnc_symbolic_graph_simplify_t* const simplify)
79
4.59k
{
80
4.59k
  ccv_nnc_graph_visit_free(simplify->visit);
81
4.59k
  ccfree(simplify->tensor_symbol_info);
82
4.59k
  ccfree(simplify->exec_symbol_info);
83
4.59k
  ccfree(simplify->exec_dead);
84
4.59k
  ccfree(simplify->output_execs);
85
4.59k
  ccfree(simplify);
86
4.59k
}
87
88
typedef struct {
89
  int d;
90
  int ifbit;
91
  uint64_t hash;
92
} ccv_nnc_cse_hash_t;
93
94
static int _ccv_nnc_cse_hash_find(ccv_nnc_cse_hash_t* const hash_map, const uint64_t hash, const int map_size)
95
6.28k
{
96
6.28k
  assert(hash > 0);
97
6.28k
  int i, j;
98
6.28k
  i = (hash - 1) % map_size;
99
6.28k
  for (j = 0; ; 
j++, i++5.78k
)
100
12.0k
  {
101
12.0k
    if (i >= map_size)
102
1
      i = 0;
103
12.0k
    if (j > hash_map[i].ifbit)
104
5.65k
      return -1;
105
6.41k
    if (hash_map[i].hash == hash)
106
628
      return hash_map[i].d;
107
6.41k
  }
108
6.28k
}
109
110
static void _ccv_nnc_cse_hash_add(ccv_nnc_cse_hash_t* const hash_map, uint64_t hash, int d, const int map_size)
111
3.28k
{
112
3.28k
  assert(hash > 0);
113
3.28k
  int i, j;
114
3.28k
  i = (hash - 1) % map_size;
115
3.28k
  for (j = 0; ; 
j++, i++261
)
116
3.54k
  {
117
3.54k
    if (i >= map_size)
118
1
      i = 0;
119
3.54k
    if (hash_map[i].hash == hash) // Already exists, do nothing.
120
6
      return;
121
3.54k
    if (hash_map[i].hash == 0)
122
3.28k
    {
123
      // Insert.
124
3.28k
      hash_map[i].d = d;
125
3.28k
      hash_map[i].ifbit = j;
126
3.28k
      hash_map[i].hash = hash;
127
3.28k
      return;
128
3.28k
    }
129
261
    if (j > hash_map[i].ifbit)
130
50
    {
131
50
      const ccv_nnc_cse_hash_t old_hash = hash_map[i];
132
      // Swap, and continue, until find an empty slot.
133
50
      hash_map[i].d = d;
134
50
      hash_map[i].ifbit = j;
135
50
      hash_map[i].hash = hash;
136
50
      d = old_hash.d;
137
50
      j = old_hash.ifbit;
138
50
      hash = old_hash.hash;
139
50
    }
140
261
  }
141
3.28k
}
142
143
static int _ccv_nnc_symbolic_graph_update_refs(ccv_nnc_symbolic_graph_simplify_t* const simplify, const ccv_nnc_tensor_symbol_t* const outputs, const int output_size, const int* const refs, const int output_exec_ref_dead)
144
4.61k
{
145
4.61k
  int i, j;
146
  // Go over refs, if a tensor is an alias, mark it reference to the new one.
147
26.5k
  for (i = 0; i < simplify->tensor_symbol_info_size; 
i++21.9k
)
148
21.9k
    if (refs[i] >= 0)
149
      // Mark this tensor as dead.
150
15
      simplify->tensor_dead[i >> 5] |= (1u << (i & 0x1f));
151
21.9k
    else if (simplify->tensor_symbol_info[i].alias_ref && 
refs[simplify->tensor_symbol_info[i].alias_ref - 1] >= 02.20k
) {
152
0
      const int alias_ref = simplify->tensor_symbol_info[i].alias_ref - 1;
153
0
      simplify->tensor_symbol_info[i].alias_ref = refs[alias_ref] + 1;
154
0
      ((ccv_nnc_tensor_symbol_info_t*)ccv_array_get(simplify->graph->tensor_symbol_info, i))->alias_ref = refs[alias_ref] + 1;
155
0
    }
156
9.24k
  for (i = 0; i < output_size; 
i++4.63k
)
157
    // If the output is an alias, that's fine, because if the alias is re-binded, we are good.
158
4.63k
    simplify->tensor_dead[outputs[i].d >> 5] &= ~(1u << (outputs[i].d & 0x1f)); // Undead for output tensor symbols.
159
  // Merge s_refs if the tensor is dead.
160
26.5k
  for (i = 0; i < simplify->tensor_symbol_info_size; 
i++21.9k
)
161
21.9k
    if (refs[i] >= 0 && 
(simplify->tensor_dead[i >> 5] & (1u << (i & 0x1f)))15
)
162
15
    {
163
15
      const int ref = refs[i];
164
15
      if (simplify->tensor_symbol_info[i].s_ref && 
simplify->tensor_symbol_info[i].s_ref->rnum2
)
165
2
      {
166
2
        if (!simplify->tensor_symbol_info[ref].s_ref) // If there is no s_ref, simple, just assign the pointer and set the old one to nil.
167
0
        {
168
0
          simplify->tensor_symbol_info[ref].s_ref = simplify->tensor_symbol_info[i].s_ref;
169
0
          simplify->tensor_symbol_info[i].s_ref = 0;
170
0
          ((ccv_nnc_tensor_symbol_info_t*)ccv_array_get(simplify->graph->tensor_symbol_info, i))->s_ref = 0;
171
2
        } else {
172
2
          ccv_array_t* const ref_s_ref = simplify->tensor_symbol_info[ref].s_ref;
173
2
          ccv_array_t* const i_s_ref = simplify->tensor_symbol_info[i].s_ref;
174
2
          const int ref_s_ref_rnum = ref_s_ref->rnum;
175
2
          int flag = 0;
176
          // Detect conflict, if there is, undead.
177
4
          for (j = 0; !flag && 
j < 3
ccv_min3
(ref_s_ref_rnum, i_s_ref->rnum);
j++2
)
178
2
          {
179
2
            const int ref_s_ref_k = *(int*)ccv_array_get(ref_s_ref, j);
180
2
            const int i_s_ref_k = *(int*)ccv_array_get(i_s_ref, j);
181
            // If for the same sub-graph, they have different tensors linked, we cannot merge these two.
182
2
            flag = (ref_s_ref_k > 0 && i_s_ref_k > 0 && 
ref_s_ref_k != i_s_ref_k1
);
183
2
          }
184
2
          if (flag)
185
1
          {
186
1
            simplify->tensor_dead[i >> 5] &= ~(1u << (i & 0x1f)); // Undead
187
1
            continue;
188
1
          }
189
1
          if (ref_s_ref_rnum < i_s_ref->rnum)
190
1
          {
191
1
            ccv_array_resize(ref_s_ref, i_s_ref->rnum);
192
1
            memcpy(ccv_array_get(ref_s_ref, ref_s_ref_rnum), ccv_array_get(i_s_ref, ref_s_ref_rnum), sizeof(int) * (i_s_ref->rnum - ref_s_ref_rnum));
193
1
          }
194
2
          for (j = 0; j < ccv_min(ref_s_ref_rnum, i_s_ref->rnum); 
j++1
)
195
1
          {
196
1
            const int ref_s_ref_k = *(int*)ccv_array_get(ref_s_ref, j);
197
1
            const int i_s_ref_k = *(int*)ccv_array_get(i_s_ref, j);
198
1
            assert(ref_s_ref_k == 0 || i_s_ref_k == 0);
199
1
            if (i_s_ref_k)
200
0
              *(int*)ccv_array_get(ref_s_ref, j) = i_s_ref_k;
201
1
          }
202
1
          ccv_array_free(simplify->tensor_symbol_info[i].s_ref);
203
1
          simplify->tensor_symbol_info[i].s_ref = 0;
204
1
          ((ccv_nnc_tensor_symbol_info_t*)ccv_array_get(simplify->graph->tensor_symbol_info, i))->s_ref = 0;
205
3
          for (j = 0; j < ref_s_ref->rnum; 
j++2
)
206
2
          {
207
2
            const int ref_k = *(int*)ccv_array_get(ref_s_ref, j) - 1;
208
2
            if (ref_k >= 0)
209
2
            {
210
2
              ccv_nnc_symbolic_graph_t* const sub_graph = *(ccv_nnc_symbolic_graph_t**)ccv_array_get(simplify->graph->sub_graphs, j);
211
2
              assert(sub_graph);
212
              // Update its p_ref.
213
2
              ((ccv_nnc_tensor_symbol_info_t*)ccv_array_get(sub_graph->tensor_symbol_info, ref_k))->p_ref = ref + 1;
214
2
            }
215
2
          }
216
1
        }
217
2
        assert
(simplify->tensor_symbol_info[i].s_ref == ((ccv_nnc_tensor_symbol_info_t*)ccv_array_get(simplify->graph->tensor_symbol_info, i))->s_ref)1
;
218
1
        assert(simplify->tensor_symbol_info[ref].s_ref == ((ccv_nnc_tensor_symbol_info_t*)ccv_array_get(simplify->graph->tensor_symbol_info, ref))->s_ref);
219
1
      }
220
15
    }
221
  // Going through refs that we are updating, going through its p_ref to make sure both are updated.
222
26.5k
  
for (i = 0; 4.61k
i < simplify->tensor_symbol_info_size;
i++21.9k
)
223
21.9k
    if (refs[i] >= 0 && 
(simplify->tensor_dead[i >> 5] & (1u << (i & 0x1f)))15
&&
simplify->tensor_symbol_info[i].p_ref14
)
224
1
    {
225
1
      const int ref = refs[i];
226
1
      const int p_ref = simplify->tensor_symbol_info[i].p_ref - 1;
227
1
      assert(p_ref >= 0);
228
1
      assert(simplify->graph->p);
229
1
      ccv_array_t* const s_ref = ((ccv_nnc_tensor_symbol_info_t*)ccv_array_get(simplify->graph->p->tensor_symbol_info, p_ref))->s_ref;
230
1
      const int s_idx = simplify->graph->p_idx - 1;
231
1
      assert(s_idx >= 0);
232
1
      assert(s_ref && s_ref->rnum > s_idx);
233
1
      *(int*)ccv_array_get(s_ref, s_idx) = ref + 1; // Update so it references to the new s_ref.
234
1
      assert(!simplify->tensor_symbol_info[ref].p_ref);
235
1
      simplify->tensor_symbol_info[ref].p_ref = p_ref + 1;
236
1
      ((ccv_nnc_tensor_symbol_info_t*)ccv_array_get(simplify->graph->tensor_symbol_info, ref))->p_ref = p_ref + 1;
237
1
    }
238
  // Now go over exec to mark them as dead because we don't need these to generate refs.
239
4.61k
  if (output_exec_ref_dead)
240
13.2k
    
for (i = 0; 2.30k
i < simplify->tensor_symbol_info_size;
i++10.9k
)
241
10.9k
      if (refs[i] >= 0 && 
(simplify->tensor_dead[i >> 5] & (1u << (i & 0x1f)))5
)
242
4
      {
243
4
        const int output_exec = simplify->output_execs[i];
244
4
        assert(output_exec >= 0);
245
4
        const ccv_nnc_graph_exec_symbol_info_t* const symbol_info = simplify->exec_symbol_info + output_exec;
246
4
        int flag = 0;
247
8
        for (j = 0; !flag && j < symbol_info->output_size; 
j++4
)
248
4
        {
249
4
          const int d = symbol_info->outputs[j];
250
4
          if (d >= 0)
251
4
            flag = (!(simplify->tensor_dead[d >> 5] & (1u << (d & 0x1f)))); // If some of the output is not dead, we cannot proceed.
252
4
        }
253
4
        if (!flag) // If all outputs are dead, mark the exec as dead.
254
4
          simplify->exec_dead[output_exec >> 5] |= (1u << (output_exec & 0x1f));
255
4
      }
256
4.61k
  int updated_refs = 0;
257
  // Go over replace inputs / outputs.
258
7.92k
  ccv_nnc_graph_visit_for(simplify->visit, simplify->exec_symbol_info, node, idx) {
259
22.5k
    for (i = 0; i < node->input_size; 
i++14.6k
)
260
14.6k
    {
261
14.6k
      const int d = node->inputs[i];
262
14.6k
      if (d >= 0 && 
refs[d] >= 014.5k
&&
(simplify->tensor_dead[d >> 5] & (1u << (d & 0x1f)))17
)
263
16
      {
264
16
          node->inputs[i] = refs[d]; // It can be replaced.
265
16
          updated_refs = 1;
266
16
      }
267
14.6k
    }
268
16.5k
    for (i = 0; i < node->output_size; 
i++8.61k
)
269
8.61k
    {
270
8.61k
      const int d = node->outputs[i];
271
8.61k
      if (d >= 0 && 
refs[d] >= 08.60k
&&
(simplify->tensor_dead[d >> 5] & (1u << (d & 0x1f)))14
)
272
13
      {
273
13
          node->outputs[i] = refs[d]; // It can be replaced.
274
13
          updated_refs = 1;
275
13
      }
276
8.61k
    }
277
7.92k
    assert(((ccv_nnc_graph_exec_symbol_info_t*)ccv_array_get(simplify->graph->exec_symbol_info, idx))->inputs == node->inputs);
278
7.92k
    assert(((ccv_nnc_graph_exec_symbol_info_t*)ccv_array_get(simplify->graph->exec_symbol_info, idx))->outputs == node->outputs);
279
7.92k
  } ccv_nnc_graph_visit_endfor
280
4.61k
  const ccv_nnc_graph_exec_symbol_info_t* const p_node_info = simplify->graph->p ? 
(ccv_nnc_graph_exec_symbol_info_t*)6
ccv_array_get6
(simplify->graph->p->exec_symbol_info, simplify->graph->exec_idx - 1) :
04.60k
;
281
4.61k
  if (p_node_info && 
(p_node_info->flags & CCV_NNC_GRAPH_EXEC_P_WHILE)6
)
282
    // Go over the while inputs as well.
283
12
    
for (i = 0; 6
i < p_node_info->p_while.input_size;
i++6
)
284
6
    {
285
6
      const int d = p_node_info->p_while.inputs[i];
286
6
      if (d >= 0 && 
refs[d] >= 00
&&
(simplify->tensor_dead[d >> 5] & (1u << (d & 0x1f)))0
)
287
0
      {
288
0
        p_node_info->p_while.inputs[i] = refs[d];
289
0
        updated_refs = 1;
290
0
      }
291
6
    }
292
4.61k
  return updated_refs;
293
4.61k
}
294
295
// This is a simple common sub-expression elimination implementation, particularly, we only replace the later computed output
296
// with the identical earlier computed output, and let the "elimination" part to the graph pruning.
297
static void _ccv_nnc_symbolic_graph_common_subexpression_elimination(ccv_nnc_symbolic_graph_simplify_t* const simplify, const ccv_nnc_tensor_symbol_t* const outputs, const int output_size)
298
2.30k
{
299
2.30k
  _ccv_nnc_symbolic_graph_simplify_update_output_execs(simplify);
300
  // tensor_hash starts with 0s, and it is either marked with the tensor index + 1, or the hash of the computations.
301
2.30k
  uint64_t* const tensor_hash = (uint64_t*)cccalloc(simplify->tensor_symbol_info_size, sizeof(uint64_t));
302
2.30k
  int i;
303
3.95k
  ccv_nnc_graph_visit_for(simplify->visit, simplify->exec_symbol_info, node, idx) {
304
    // If already marked as dead, skip.
305
3.95k
    if (simplify->exec_dead[idx >> 5] & (1u << (idx & 0x1f)))
306
0
      continue;
307
11.2k
    
for (i = 0; 3.95k
i < node->input_size;
i++7.31k
)
308
7.31k
    {
309
7.31k
      assert(node->inputs[i] < simplify->tensor_symbol_info_size);
310
      // If no hash for the input, use the index + 1 as the hash.
311
7.31k
      if (node->inputs[i] >= 0 && 
tensor_hash[node->inputs[i]] == 07.28k
)
312
6.59k
        tensor_hash[node->inputs[i]] = node->inputs[i] + 1;
313
7.31k
    }
314
    // We cannot support graph / custom command (we cannot model them properly).
315
3.95k
    if (node->cmd.cmd == CCV_NNC_GRAPH_FORWARD ||
316
3.95k
      
node->cmd.cmd == CCV_NNC_GRAPH_BACKWARD3.95k
||
317
3.95k
      
node->cmd.cmd == CCV_NNC_CUSTOM_FORWARD3.95k
||
318
3.95k
      
node->cmd.cmd == CCV_NNC_CUSTOM_BACKWARD3.95k
||
319
      // No need to process a opt disabled node.
320
3.95k
      
(node->flags & CCV_NNC_GRAPH_EXEC_DISABLE_OPT)3.95k
)
321
1.00k
      continue;
322
2.94k
    uint64_t hashout, hashin[3];
323
2.94k
    siphash((uint8_t*)&hashin[0], (const uint8_t*)&node->cmd.info, sizeof(node->cmd.info), key_siphash);
324
2.94k
    siphash((uint8_t*)&hashin[1], (const uint8_t*)&node->hint, sizeof(node->hint), key_siphash);
325
2.94k
    hashin[2] = node->cmd.cmd; // Now actually hash the cmd name.
326
2.94k
    siphash((uint8_t*)&hashout, (const uint8_t*)hashin, sizeof(hashin), key_siphash);
327
    // First, hash the cmd and the hints with the cmd.
328
    // Note on alias, we cannot really generate proper hash for alias (yet). Thus, just treat alias as normal tensors.
329
9.24k
    for (i = 0; i < node->input_size; 
i++6.29k
)
330
6.29k
    {
331
6.29k
      assert(node->inputs[i] < simplify->tensor_symbol_info_size);
332
6.29k
      if (node->inputs[i] >= 0)
333
6.27k
      {
334
        // Hash using the tensor hash.
335
6.27k
        hashin[0] = hashout;
336
6.27k
        hashin[1] = i; // Encode the positional information.
337
6.27k
        hashin[2] = tensor_hash[node->inputs[i]];
338
6.27k
      } else {
339
        // Hash using the input integer (could be special integer).
340
22
        hashin[0] = hashout;
341
22
        hashin[1] = i; // Encode the positional information.
342
22
        hashin[2] = node->inputs[i];
343
22
      }
344
6.29k
      siphash((uint8_t*)&hashout, (const uint8_t*)hashin, sizeof(hashin), key_siphash);
345
6.29k
    }
346
6.24k
    
for (i = 0; 2.94k
i < node->output_size;
i++3.29k
)
347
3.29k
      if (node->outputs[i] >= 0)
348
3.28k
      {
349
3.28k
        assert(node->outputs[i] < simplify->tensor_symbol_info_size);
350
        // Assigning once, especially now we don't consider aliases.
351
3.28k
        assert(tensor_hash[node->outputs[i]] == 0);
352
3.28k
        hashin[0] = hashout;
353
3.28k
        hashin[1] = i; // Positional information.
354
3.28k
        siphash((uint8_t*)&hashin[2], (const uint8_t*)&simplify->tensor_symbol_info[node->outputs[i]].info,
355
3.28k
            sizeof(simplify->tensor_symbol_info[node->outputs[i]].info), key_siphash);
356
        // Generate hash for the output.
357
3.28k
        siphash((uint8_t*)&tensor_hash[node->outputs[i]], (const uint8_t*)hashin, sizeof(hashin), key_siphash);
358
3.28k
      }
359
2.94k
  } ccv_nnc_graph_visit_endfor
360
  // Allocate 3 / 2 as much space, for the simple robin-hood open address hash map.
361
2.30k
  const int map_size = (simplify->tensor_symbol_info_size * 3 + 1) / 2;
362
2.30k
  int* const refs = (int*)ccmalloc(sizeof(int) * simplify->tensor_symbol_info_size);
363
13.2k
  for (i = 0; i < simplify->tensor_symbol_info_size; 
i++10.9k
)
364
10.9k
    refs[i] = -1;
365
2.30k
  ccv_nnc_cse_hash_t* const hash_map = (ccv_nnc_cse_hash_t*)cccalloc(map_size, sizeof(ccv_nnc_cse_hash_t));
366
  // Now, all tensors are hashed, identify tensors with the same hash code, replace the ones that accessed later.
367
3.95k
  ccv_nnc_graph_visit_for(simplify->visit, simplify->exec_symbol_info, node, idx) {
368
    // If already marked as dead, skip.
369
3.95k
    if (simplify->exec_dead[idx >> 5] & (1u << (idx & 0x1f)))
370
0
      continue;
371
    // No need to visit a opt disabled node.
372
3.95k
    if (node->flags & CCV_NNC_GRAPH_EXEC_DISABLE_OPT)
373
1.00k
      continue;
374
9.25k
    
for (i = 0; 2.95k
i < node->input_size;
i++6.30k
)
375
6.30k
      if (node->inputs[i] >= 0)
376
6.28k
      {
377
6.28k
        const int d = node->inputs[i];
378
6.28k
        assert(tensor_hash[d]);
379
6.28k
        const int new_d = _ccv_nnc_cse_hash_find(hash_map, tensor_hash[d], map_size);
380
6.28k
        if (new_d >= 0 && 
new_d != d628
)
381
7
        {
382
          // Check whether this can be replaced.
383
7
          assert(refs[d] == -1 || refs[d] == new_d);
384
7
          assert(!simplify->tensor_symbol_info[d].assign_ref);
385
7
          assert(!simplify->tensor_symbol_info[d].r_assign_ref);
386
7
          assert(!simplify->tensor_symbol_info[d].bypass_ref);
387
7
          assert(!simplify->tensor_symbol_info[new_d].assign_ref);
388
7
          assert(!simplify->tensor_symbol_info[new_d].r_assign_ref);
389
7
          assert(!simplify->tensor_symbol_info[new_d].bypass_ref);
390
          // Ignore if there is a pair_ref (again, pair_ref has side effect that is deeper (using tape))
391
7
          if (simplify->tensor_symbol_info[d].pair_ref)
392
0
            continue;
393
          // If both have p_ref, we cannot merge.
394
7
          if (simplify->tensor_symbol_info[d].p_ref && 
simplify->tensor_symbol_info[new_d].p_ref2
)
395
1
            continue;
396
          // Merge s_refs from ref[d] later.
397
6
          if (refs[d] != new_d)
398
5
            refs[d] = new_d;
399
6
          assert(simplify->output_execs[new_d] >= 0);
400
          // Establish new dependency.
401
6
          ccv_nnc_graph_exec_symbol_concat(simplify->graph, (ccv_nnc_graph_exec_symbol_t){
402
6
            .d = simplify->output_execs[new_d],
403
6
            .graph = simplify->graph,
404
6
          }, (ccv_nnc_graph_exec_symbol_t){
405
6
            .d = idx,
406
6
            .graph = simplify->graph,
407
6
          });
408
6
        }
409
6.28k
      }
410
    // We can reuse the input, but we cannot do that for output of these commands.
411
2.95k
    if (node->cmd.cmd == CCV_NNC_GRAPH_FORWARD ||
412
2.95k
      
node->cmd.cmd == CCV_NNC_GRAPH_BACKWARD2.94k
||
413
2.95k
      
node->cmd.cmd == CCV_NNC_CUSTOM_FORWARD2.94k
||
414
2.95k
      
node->cmd.cmd == CCV_NNC_CUSTOM_BACKWARD2.94k
)
415
1
      continue;
416
6.24k
    
for (i = 0; 2.94k
i < node->output_size;
i++3.29k
)
417
3.29k
      if (node->outputs[i] >= 0) // This tensor can be reused by others.
418
3.28k
        _ccv_nnc_cse_hash_add(hash_map, tensor_hash[node->outputs[i]], node->outputs[i], map_size);
419
2.94k
  } ccv_nnc_graph_visit_endfor
420
2.30k
  _ccv_nnc_symbolic_graph_update_refs(simplify, outputs, output_size, refs, 1 /* For these exec that generates refs, we don't need them any more. */);
421
2.30k
  ccfree(tensor_hash);
422
2.30k
  ccfree(hash_map);
423
2.30k
  ccfree(refs);
424
2.30k
}
425
426
static void _ccv_nnc_symbolic_graph_data_transfer_opt(ccv_nnc_symbolic_graph_simplify_t* const simplify, const ccv_nnc_tensor_symbol_t* const binds, const int bind_size, const ccv_nnc_tensor_symbol_t* const outputs, const int output_size)
427
2.30k
{
428
2.30k
  _ccv_nnc_symbolic_graph_simplify_update_output_execs(simplify);
429
2.30k
  uint32_t* const exec_dead = simplify->exec_dead;
430
2.30k
  const ccv_nnc_tensor_symbol_info_t* const tensor_symbol_info = simplify->tensor_symbol_info;
431
2.30k
  int i;
432
2.30k
  uint32_t* const has_alias = cccalloc(2 * ((simplify->tensor_symbol_info_size + 31) >> 5), sizeof(uint32_t));
433
2.30k
  uint32_t* const has_binds = has_alias + ((simplify->tensor_symbol_info_size + 31) >> 5);
434
7.62k
  for (i = 0; i < bind_size; 
i++5.32k
)
435
5.32k
    has_binds[binds[i].d >> 5] |= (1u << (binds[i].d & 0x1f));
436
4.61k
  for (i = 0; i < output_size; 
i++2.31k
)
437
2.31k
    has_binds[outputs[i].d >> 5] |= (1u << (outputs[i].d & 0x1f));
438
2.30k
  int* const refs = (int*)ccmalloc(sizeof(int) * simplify->tensor_symbol_info_size);
439
2.30k
  int updated_refs;
440
2.30k
  do {
441
2.30k
    memset(has_alias, 0, sizeof(uint32_t) * ((simplify->tensor_symbol_info_size + 31) >> 5));
442
    // Go through until no updates is possible. This won't result an infinite loop because every time,
443
    // a tensor is eliminated.
444
13.2k
    for (i = 0; i < simplify->tensor_symbol_info_size; 
i++10.9k
)
445
10.9k
    {
446
10.9k
      refs[i] = -1;
447
10.9k
      if (tensor_symbol_info[i].alias_ref)
448
1.10k
      {
449
1.10k
        const int alias_ref = tensor_symbol_info[i].alias_ref - 1;
450
1.10k
        has_alias[alias_ref >> 5] |= (1u << (alias_ref & 0x1f));
451
1.10k
      }
452
10.9k
    }
453
3.96k
    ccv_nnc_graph_visit_for(simplify->visit, simplify->exec_symbol_info, node, idx) {
454
      // If already marked as dead, skip.
455
3.96k
      if (exec_dead[idx >> 5] & (1u << (idx & 0x1f)))
456
4
        continue;
457
3.96k
      if (node->cmd.cmd != CCV_NNC_DATA_TRANSFER_FORWARD &&
458
3.96k
        
node->cmd.cmd != CCV_NNC_DATA_TRANSFER_BACKWARD3.96k
&&
459
        // Conversion, if the datatype is the same, it is the data transfer op.
460
3.96k
        
node->cmd.cmd != CCV_NNC_DATATYPE_CONVERSION_FORWARD3.95k
&&
461
3.96k
        
node->cmd.cmd != CCV_NNC_DATATYPE_CONVERSION_BACKWARD3.95k
&&
462
        // Format transform, if the format is the same, it is the data transfer op.
463
3.96k
        
node->cmd.cmd != CCV_NNC_FORMAT_TRANSFORM_FORWARD3.95k
&&
464
3.96k
        
node->cmd.cmd != CCV_NNC_FORMAT_TRANSFORM_BACKWARD2.93k
)
465
2.93k
        continue;
466
1.02k
      if (node->flags & CCV_NNC_GRAPH_EXEC_DISABLE_OPT) // If optimization pass disabled, skip.
467
1.00k
        continue;
468
48
      
for (i = 0; 21
i < node->output_size;
i++27
) // For data transfer, we only respect output size.
469
27
        if (node->inputs[i] >= 0 && node->outputs[i] >= 0)
470
27
        {
471
27
          assert(node->inputs[i] < simplify->tensor_symbol_info_size);
472
27
          assert(node->outputs[i] < simplify->tensor_symbol_info_size);
473
27
          int input_ref = node->inputs[i];
474
28
          while (refs[input_ref] >= 0)
475
1
            input_ref = refs[input_ref];
476
27
          int output_ref = node->outputs[i];
477
27
          while (refs[output_ref] >= 0)
478
0
            output_ref = refs[output_ref];
479
27
          if (input_ref == output_ref)
480
10
            continue;
481
17
          const ccv_nnc_tensor_symbol_info_t* const input = tensor_symbol_info + input_ref;
482
17
          const ccv_nnc_tensor_symbol_info_t* const output = tensor_symbol_info + output_ref;
483
          // If they are not the same data type, skip. (Likely data conversion op).
484
17
          if (input->info.datatype != output->info.datatype)
485
2
            continue;
486
          // If they are not the same format, skip. (Likely format transform op).
487
15
          if (input->info.format != output->info.format)
488
0
            continue;
489
          // If they are not on the same device (even for NUMA), skip.
490
15
          if (input->info.type != output->info.type)
491
0
            continue;
492
          // If both are alias, we cannot consolidate this.
493
15
          if (input->alias_ref && 
output->alias_ref1
)
494
0
            continue;
495
          // If input is alias, and output has alias reference to it, output cannot be the same as input.
496
15
          if (input->alias_ref && 
(has_alias[output_ref >> 5] & (1u << (output_ref & 0x1f)))1
)
497
1
            continue;
498
          // If output is alias, and input has alias reference to it, input cannot be the same as output.
499
14
          if (output->alias_ref && 
(has_alias[input_ref >> 5] & (1u << (input_ref & 0x1f)))13
)
500
0
            continue;
501
          // If either are carry overs (for while), we cannot do anything.
502
14
          if (input->assign_ref || output->assign_ref ||
503
14
            input->r_assign_ref || output->r_assign_ref)
504
0
            continue;
505
          // If either are bypasses (for case..of), we cannot do anything.
506
14
          if (input->bypass_ref || output->bypass_ref ||
507
14
            input->r_bypass_ref || output->r_bypass_ref)
508
0
            continue;
509
          // If either are inputs / outputs connecting the parent graph, we cannot do anything.
510
14
          if (input->p_ref || output->p_ref)
511
0
            continue;
512
          // If the type is the same, check which one is the alias.
513
          // We always prefer alias.
514
14
          if (output->alias_ref)
515
13
          {
516
            // Input cannot be binds.
517
13
            if (has_binds[input_ref >> 5] & (1u << (input_ref & 0x1f)))
518
4
              continue;
519
9
            refs[input_ref] = output_ref;
520
9
          } else { // if (input->alias_ref), else
521
            // Output cannot be binds.
522
1
            if (has_binds[output_ref >> 5] & (1u << (output_ref & 0x1f)))
523
0
              continue;
524
1
            refs[output_ref] = input_ref;
525
1
          }
526
14
        }
527
21
    } ccv_nnc_graph_visit_endfor
528
    // Make sure refs reference to the end.
529
13.2k
    
for (i = 0; 2.30k
i < simplify->tensor_symbol_info_size;
i++10.9k
)
530
10.9k
      if (refs[i] >= 0)
531
10
      {
532
10
        int ref = refs[i];
533
11
        while (refs[ref] >= 0)
534
1
          ref = refs[ref];
535
10
        refs[i] = ref;
536
10
      }
537
2.30k
    updated_refs = _ccv_nnc_symbolic_graph_update_refs(simplify, outputs, output_size, refs, 0 /* We still need these exec that generates the refs. */);
538
2.30k
  } while (updated_refs);
539
2.30k
  ccfree(refs);
540
2.30k
  ccfree(has_alias);
541
  // Now, all references updated, remove data transfers that sources and destinations are the same.
542
3.94k
  ccv_nnc_graph_visit_for(simplify->visit, simplify->exec_symbol_info, node, idx) {
543
    // If already marked as dead, skip.
544
3.94k
    if (exec_dead[idx >> 5] & (1u << (idx & 0x1f)))
545
4
      continue;
546
3.94k
    if (node->cmd.cmd != CCV_NNC_DATA_TRANSFER_FORWARD &&
547
3.94k
      
node->cmd.cmd != CCV_NNC_DATA_TRANSFER_BACKWARD3.94k
&&
548
      // Conversion, if the datatype is the same, it is the data transfer op.
549
3.94k
      
node->cmd.cmd != CCV_NNC_DATATYPE_CONVERSION_FORWARD3.94k
&&
550
3.94k
      
node->cmd.cmd != CCV_NNC_DATATYPE_CONVERSION_BACKWARD3.94k
&&
551
      // Format transform, if the format is the same, it is the data transfer op.
552
3.94k
      
node->cmd.cmd != CCV_NNC_FORMAT_TRANSFORM_FORWARD3.94k
&&
553
3.94k
      
node->cmd.cmd != CCV_NNC_FORMAT_TRANSFORM_BACKWARD2.92k
)
554
2.92k
      continue;
555
2.04k
    
for (i = 0; 1.02k
i < node->output_size;
i++1.02k
) // For data transfer, we only respect output size.
556
1.02k
      if (node->inputs[i] == node->outputs[i])
557
10
      {
558
10
        if (i + 1 < node->output_size)
559
2
        {
560
2
          node->inputs[i] = node->inputs[node->output_size - 1];
561
2
          node->outputs[i] = node->outputs[node->output_size - 1];
562
2
        }
563
10
        --node->output_size;
564
10
        --i;
565
10
      }
566
1.02k
    node->input_size = node->output_size;
567
1.02k
    ((ccv_nnc_graph_exec_symbol_info_t*)ccv_array_get(simplify->graph->exec_symbol_info, idx))->input_size = node->input_size;
568
1.02k
    ((ccv_nnc_graph_exec_symbol_info_t*)ccv_array_get(simplify->graph->exec_symbol_info, idx))->output_size = node->output_size;
569
    // Remove this data transfer node if it has no outputs.
570
1.02k
    if (node->output_size == 0)
571
8
      exec_dead[idx >> 5] |= (1u << (idx & 0x1f));
572
1.02k
  } ccv_nnc_graph_visit_endfor
573
2.30k
}
574
575
typedef struct {
576
  uint32_t fused_op; // The final fused op id.
577
  uint32_t ops_seq[2]; // The sequence of commands to identify.
578
  int ops_seq_size;
579
  int ops_info_select; // Which ops' info will be selected.
580
  struct {
581
    int type; // Whether it is input, or output. It doesn't make sense for example, input in ops_seq, but output in fused_op.
582
    int op_idx; // Index into the ops_seq.
583
    int from; // The index in ops_seq.
584
    int to; // The index in fused_op.
585
  } pos[4]; // maps of positions from ops seq to fused_op for inputs (outputs).
586
  int pos_size;
587
} ccv_nnc_ops_fusion_t;
588
589
enum {
590
  CCV_NNC_OPS_FUSION_INPUT_INDEX,
591
  CCV_NNC_OPS_FUSION_OUTPUT_INDEX,
592
};
593
594
const static int ccv_nnc_ops_fusion_io_size = 2;
595
const static ccv_nnc_ops_fusion_t ccv_nnc_ops_fusions[] = {
596
  {
597
    .fused_op = CCV_NNC_SOFTMAX_CROSSENTROPY_FORWARD,
598
    .ops_seq = {
599
      CCV_NNC_SOFTMAX_FORWARD, CCV_NNC_CATEGORICAL_CROSSENTROPY_FORWARD,
600
    },
601
    .ops_seq_size = 2,
602
    .ops_info_select = 1,
603
    .pos = {
604
      {
605
        .type = CCV_NNC_OPS_FUSION_INPUT_INDEX,
606
        .op_idx = 0,
607
        .from = 0,
608
        .to = 0,
609
      },
610
      {
611
        .type = CCV_NNC_OPS_FUSION_INPUT_INDEX,
612
        .op_idx = 1,
613
        .from = 1,
614
        .to = 1,
615
      },
616
      {
617
        .type = CCV_NNC_OPS_FUSION_OUTPUT_INDEX,
618
        .op_idx = 0,
619
        .from = 0,
620
        .to = 1,
621
      },
622
      {
623
        .type = CCV_NNC_OPS_FUSION_OUTPUT_INDEX,
624
        .op_idx = 1,
625
        .from = 0,
626
        .to = 0,
627
      },
628
    },
629
    .pos_size = 4,
630
  },
631
  {
632
    .fused_op = CCV_NNC_SIGMOID_BINARY_CROSSENTROPY_FORWARD,
633
    .ops_seq = {
634
      CCV_NNC_SIGMOID_FORWARD, CCV_NNC_BINARY_CROSSENTROPY_FORWARD,
635
    },
636
    .ops_seq_size = 2,
637
    .ops_info_select = 1,
638
    .pos = {
639
      {
640
        .type = CCV_NNC_OPS_FUSION_INPUT_INDEX,
641
        .op_idx = 0,
642
        .from = 0,
643
        .to = 0,
644
      },
645
      {
646
        .type = CCV_NNC_OPS_FUSION_INPUT_INDEX,
647
        .op_idx = 1,
648
        .from = 1,
649
        .to = 1,
650
      },
651
      {
652
        .type = CCV_NNC_OPS_FUSION_OUTPUT_INDEX,
653
        .op_idx = 0,
654
        .from = 0,
655
        .to = 1,
656
      },
657
      {
658
        .type = CCV_NNC_OPS_FUSION_OUTPUT_INDEX,
659
        .op_idx = 1,
660
        .from = 0,
661
        .to = 0,
662
      },
663
    },
664
    .pos_size = 4,
665
  }
666
};
667
668
static int _ccv_nnc_find_ops_for_fusion(const ccv_nnc_ops_fusion_t* const fusion, const int ops_idx, const ccv_nnc_graph_exec_symbol_info_t* const exec_symbol_info, const uint32_t* const exec_dead, const int exec_idx, int* const fusing_exec_symbols)
669
48
{
670
48
  if (exec_dead[exec_idx >> 5] & (1u << (exec_idx & 0x1f)))
671
0
    return 0;
672
48
  const ccv_nnc_graph_exec_symbol_info_t* const node = exec_symbol_info + exec_idx;
673
  // Doesn't match the ops_seq, return 0.
674
48
  if (fusion->ops_seq[ops_idx] != node->cmd.cmd)
675
16
    return 0;
676
32
  fusing_exec_symbols[ops_idx] = exec_idx;
677
  // If already reached the end, we are good.
678
32
  if (ops_idx == fusion->ops_seq_size - 1)
679
5
    return 1;
680
  // Otherwise, we need to go on, but don't have any to follow-up.
681
27
  if (!node->outgoings || 
!node->outgoings->rnum21
)
682
6
    return 0;
683
21
  int i;
684
37
  for (i = 0; i < node->outgoings->rnum; 
i++16
)
685
21
    if (_ccv_nnc_find_ops_for_fusion(fusion, ops_idx + 1, exec_symbol_info, exec_dead, *(int*)ccv_array_get(node->outgoings, i), fusing_exec_symbols))
686
5
      return 1;
687
16
  return 0;
688
21
}
689
690
static void _ccv_nnc_symbolic_graph_ops_fusion(ccv_nnc_symbolic_graph_simplify_t* const simplify, const ccv_nnc_tensor_symbol_t* const outputs, const int output_size)
691
4.59k
{
692
4.59k
  uint32_t* const exec_dead = simplify->exec_dead;
693
4.59k
  ccv_nnc_graph_exec_symbol_info_t* const exec_symbol_info = simplify->exec_symbol_info;
694
4.59k
  int i, j;
695
4.59k
  int fusing_exec_symbols[sizeof(ccv_nnc_ops_fusions->ops_seq)];
696
4.59k
  int fused_inputs[ccv_nnc_ops_fusion_io_size]; // 2 is just a number based on the ops_fusion.
697
4.59k
  int fused_outputs[ccv_nnc_ops_fusion_io_size];
698
7.85k
  ccv_nnc_graph_visit_for(simplify->visit, exec_symbol_info, node, idx) {
699
    // If already marked as dead, skip.
700
7.85k
    if (exec_dead[idx >> 5] & (1u << (idx & 0x1f)))
701
17
      continue;
702
7.84k
    if (node->flags & CCV_NNC_GRAPH_EXEC_DISABLE_OPT) // If optimization pass disabled, skip.
703
2.01k
      continue;
704
    // Run through rudimentary pattern matching for ops_seq. There are better ways to do this (immediately come to mind, Boyer-Moore). However, this is simpler to code.
705
    // If I am running into performance issues with this, I would be very happy.
706
17.4k
    
for (i = 0; 5.82k
i < sizeof(ccv_nnc_ops_fusions) / sizeof(ccv_nnc_ops_fusion_t);
i++11.6k
)
707
11.6k
    {
708
11.6k
      const ccv_nnc_ops_fusion_t* const ops_fusion = ccv_nnc_ops_fusions + i;
709
      // Check to see if a list of symbols are possible.
710
11.6k
      if (ops_fusion->ops_seq[0] == node->cmd.cmd &&
711
11.6k
        
_ccv_nnc_find_ops_for_fusion(ops_fusion, 0, exec_symbol_info, exec_dead, idx, fusing_exec_symbols)27
)
712
5
      {
713
        // Go through all the inputs and outputs, check if they exists and are mapped.
714
        // TODO: the mapping can be more sophisticated than what we have here.
715
        // Also, I need to check if some inputs / outputs cannot be mapped, then we cannot continue.
716
15
        for (j = 0; j < ccv_nnc_ops_fusion_io_size; 
j++10
)
717
10
          fused_inputs[j] = fused_outputs[j] = CCV_NNC_NO_TENSOR_SYMBOL;
718
5
        int input_size = 0, output_size = 0;
719
25
        for (j = 0; j < ops_fusion->pos_size; 
j++20
)
720
20
        {
721
20
          ccv_nnc_graph_exec_symbol_info_t* const fusing_op = exec_symbol_info + fusing_exec_symbols[ops_fusion->pos[j].op_idx];
722
20
          switch (ops_fusion->pos[j].type)
723
20
          {
724
10
            case CCV_NNC_OPS_FUSION_INPUT_INDEX:
725
10
              fused_inputs[ops_fusion->pos[j].to] = ops_fusion->pos[j].from < fusing_op->input_size ? fusing_op->inputs[ops_fusion->pos[j].from] : 
CCV_NNC_NO_TENSOR_SYMBOL0
;
726
10
              input_size = ccv_max(input_size, ops_fusion->pos[j].to + 1);
727
10
              break;
728
10
            case CCV_NNC_OPS_FUSION_OUTPUT_INDEX:
729
10
              fused_outputs[ops_fusion->pos[j].to] = ops_fusion->pos[j].from < fusing_op->output_size ? fusing_op->outputs[ops_fusion->pos[j].from] : 
CCV_NNC_NO_TENSOR_SYMBOL0
;
730
10
              output_size = ccv_max(output_size, ops_fusion->pos[j].to + 1);
731
10
              break;
732
20
          }
733
20
        }
734
5
        const ccv_nnc_cmd_param_t info = exec_symbol_info[fusing_exec_symbols[ops_fusion->ops_info_select]].cmd.info;
735
        // Modify the first node so it is the correct type and value. I need to look back to the actual graph to get the info.
736
5
        ccv_nnc_graph_exec_symbol_info_t* const actual_node = (ccv_nnc_graph_exec_symbol_info_t*)ccv_array_get(simplify->graph->exec_symbol_info, idx);
737
5
        actual_node->cmd.cmd = node->cmd.cmd = ops_fusion->fused_op;
738
5
        actual_node->cmd.info = node->cmd.info = info;
739
5
        if (node->input_size + node->output_size < input_size + output_size)
740
5
          actual_node->inputs = node->inputs = node->inputs ? ccrealloc(node->inputs, sizeof(int) * (input_size + output_size)) : 
ccmalloc0
(sizeof(int) * (input_size + output_size))0
;
741
5
        actual_node->outputs = node->outputs = node->inputs + input_size;
742
5
        actual_node->input_size = node->input_size = input_size;
743
5
        actual_node->output_size = node->output_size = output_size;
744
5
        memcpy(node->inputs, fused_inputs, sizeof(int) * input_size);
745
5
        memcpy(node->outputs, fused_outputs, sizeof(int) * output_size);
746
        // Afterwards, mark the rest as dead.
747
10
        for (j = 1; j < ops_fusion->ops_seq_size; 
j++5
)
748
5
          exec_dead[fusing_exec_symbols[j] >> 5] |= (1u << (fusing_exec_symbols[j] & 0x1f));
749
5
        break;
750
5
      }
751
11.6k
    }
752
5.82k
  } ccv_nnc_graph_visit_endfor
753
4.59k
}
754
755
static void _ccv_nnc_symbolic_graph_pruning_undead_exec(ccv_nnc_symbolic_graph_simplify_t* const simplify, const int exec_idx, uint32_t* const tensor_visited, ccv_array_t* const next)
756
7.71k
{
757
7.71k
  assert(exec_idx >= 0);
758
7.71k
  uint32_t* const exec_dead = simplify->exec_dead;
759
7.71k
  uint32_t* const tensor_dead = simplify->tensor_dead;
760
7.71k
  exec_dead[exec_idx >> 5] &= ~(1u << (exec_idx & 0x1f)); // Undead the exec.
761
7.71k
  ccv_nnc_graph_exec_symbol_info_t* const exec_symbol_info = simplify->exec_symbol_info + exec_idx;
762
7.71k
  int i;
763
7.71k
  if (exec_symbol_info->cmd.cmd == CCV_NNC_GRAPH_FORWARD ||
764
7.71k
    
exec_symbol_info->cmd.cmd == CCV_NNC_GRAPH_BACKWARD7.70k
||
765
7.71k
    
exec_symbol_info->cmd.cmd == CCV_NNC_CUSTOM_FORWARD7.70k
||
766
7.71k
    
exec_symbol_info->cmd.cmd == CCV_NNC_CUSTOM_BACKWARD7.70k
)
767
5
  {
768
    // All of its inputs / outputs need to be undead for these commands.
769
18
    for (i = 0; i < exec_symbol_info->input_size; 
i++13
)
770
13
    {
771
13
      const int d = exec_symbol_info->inputs[i];
772
13
      if (d >= 0 && !(tensor_visited[d >> 5] & (1u << (d & 0x1f))))
773
5
      {
774
5
        ccv_array_push(next, &d); // Push to the next round to be undead.
775
5
        tensor_visited[d >> 5] |= (1u << (d & 0x1f));
776
5
      }
777
13
    }
778
10
    for (i = 0; i < exec_symbol_info->output_size; 
i++5
)
779
5
    {
780
5
      const int d = exec_symbol_info->outputs[i];
781
5
      if (d >= 0 && !(tensor_visited[d >> 5] & (1u << (d & 0x1f))))
782
1
      {
783
1
        ccv_array_push(next, &d); // Push to the next round to be undead.
784
1
        tensor_visited[d >> 5] |= (1u << (d & 0x1f));
785
1
      }
786
5
    }
787
5
    return;
788
5
  }
789
  // Go through the input / output, to make sure that all of them can be available.
790
7.70k
  const int input_bitmask_size = (exec_symbol_info->input_size + 63) >> 6;
791
7.70k
  const int output_bitmask_size = (exec_symbol_info->output_size + 63) >> 6;
792
7.70k
  uint64_t input_bitmasks[ccv_max(1, input_bitmask_size)];
793
15.4k
  for (i = 0; i < input_bitmask_size; 
i++7.70k
)
794
7.70k
    input_bitmasks[i] = 0;
795
7.70k
  uint64_t output_bitmasks[ccv_max(1, output_bitmask_size)];
796
15.4k
  for (i = 0; i < output_bitmask_size; 
i++7.70k
)
797
7.70k
    output_bitmasks[i] = 0;
798
22.6k
  for (i = 0; i < exec_symbol_info->input_size; 
i++14.9k
)
799
14.9k
    if (exec_symbol_info->inputs[i] >= 0)
800
14.8k
      input_bitmasks[i >> 6] |= ((uint64_t)1 << (i & 63));
801
17.0k
  for (i = 0; i < exec_symbol_info->output_size; 
i++9.35k
)
802
9.35k
    if (exec_symbol_info->outputs[i] >= 0)
803
9.32k
      output_bitmasks[i >> 6] |= ((uint64_t)1 << (i & 63));
804
  // First, mark everything with bitmasks, and verify it works.
805
7.70k
  assert(ccv_nnc_cmd_bitmask(exec_symbol_info->cmd, exec_symbol_info->input_size, exec_symbol_info->output_size, input_bitmasks, input_bitmask_size, output_bitmasks, output_bitmask_size));
806
7.70k
  int flag;
807
7.70k
  do {
808
7.70k
    flag = 0;
809
    // Try to eliminate one at a time. Go over output first.
810
17.0k
    for (i = 0; i < exec_symbol_info->output_size; 
i++9.35k
)
811
9.35k
    {
812
9.35k
      const int d = exec_symbol_info->outputs[i];
813
      // If this tensor currently is marked as dead, try to see whether it works when we don't have this tensor at all.
814
9.35k
      if (d >= 0 && 
(tensor_dead[d >> 5] & (1u << (d & 0x1f)))9.32k
&&
815
9.35k
        
(output_bitmasks[i >> 6] & ((uint64_t)1 << (i & 63)))335
)
816
334
      {
817
334
        output_bitmasks[i >> 6] &= ~((uint64_t)1 << (i & 63));
818
334
        if (ccv_nnc_cmd_bitmask(exec_symbol_info->cmd, exec_symbol_info->input_size, exec_symbol_info->output_size, input_bitmasks, input_bitmask_size, output_bitmasks, output_bitmask_size))
819
1
          flag = 1;
820
333
        else // Reset the bitmask.
821
333
          output_bitmasks[i >> 6] |= ((uint64_t)1 << (i & 63));
822
334
      }
823
9.35k
    }
824
    // Only check if there are inputs we can remove if it is backward.
825
7.70k
    if (!ccv_nnc_cmd_is_forward(exec_symbol_info->cmd))
826
      // For inputs, no matter if it s dead or not, we try to limit our input to the smallest number.
827
93
      
for (i = 0; 18
i < exec_symbol_info->input_size;
i++75
)
828
75
      {
829
75
        const int d = exec_symbol_info->inputs[i];
830
        // If this tensor currently is marked as dead, try to see whether it works when we don't have this tensor at all.
831
75
        if (d >= 0 && 
(input_bitmasks[i >> 6] & ((uint64_t)1 << (i & 63)))32
)
832
32
        {
833
32
          input_bitmasks[i >> 6] &= ~((uint64_t)1 << (i & 63));
834
32
          if (ccv_nnc_cmd_bitmask(exec_symbol_info->cmd, exec_symbol_info->input_size, exec_symbol_info->output_size, input_bitmasks, input_bitmask_size, output_bitmasks, output_bitmask_size))
835
0
            flag = 1;
836
32
          else // Reset the bitmask.
837
32
            input_bitmasks[i >> 6] |= ((uint64_t)1 << (i & 63));
838
32
        }
839
75
      }
840
7.70k
  } while (flag);
841
  // Now we know which one to keep, which one to undead.
842
22.6k
  for (i = 0; i < exec_symbol_info->input_size; 
i++14.9k
)
843
14.9k
    if (input_bitmasks[i >> 6] & ((uint64_t)1 << (i & 63)))
844
14.8k
    {
845
14.8k
      const int d = exec_symbol_info->inputs[i];
846
14.8k
      if (d >= 0 && !(tensor_visited[d >> 5] & (1u << (d & 0x1f))))
847
7.10k
      {
848
7.10k
        ccv_array_push(next, &d); // Push to the next round to be undead.
849
7.10k
        tensor_visited[d >> 5] |= (1u << (d & 0x1f));
850
7.10k
      }
851
14.8k
    } else {
852
      // Clean up the inputs.
853
70
      exec_symbol_info->inputs[i] = CCV_NNC_NO_TENSOR_SYMBOL;
854
70
      assert(((ccv_nnc_graph_exec_symbol_info_t*)ccv_array_get(simplify->graph->exec_symbol_info, exec_idx))->inputs == exec_symbol_info->inputs);
855
70
    }
856
17.0k
  
for (i = 0; 7.70k
i < exec_symbol_info->output_size;
i++9.35k
)
857
9.35k
    if (output_bitmasks[i >> 6] & ((uint64_t)1 << (i & 63)))
858
9.32k
    {
859
9.32k
      const int d = exec_symbol_info->outputs[i];
860
9.32k
      if (d >= 0 && !(tensor_visited[d >> 5] & (1u << (d & 0x1f))))
861
3.72k
      {
862
3.72k
        ccv_array_push(next, &d); // Push to the next round to be undead.
863
3.72k
        tensor_visited[d >> 5] |= (1u << (d & 0x1f));
864
3.72k
      }
865
9.32k
    } else {
866
      // Clean up the outputs.
867
34
      exec_symbol_info->outputs[i] = CCV_NNC_NO_TENSOR_SYMBOL;
868
34
      assert(((ccv_nnc_graph_exec_symbol_info_t*)ccv_array_get(simplify->graph->exec_symbol_info, exec_idx))->outputs == exec_symbol_info->outputs);
869
34
    }
870
7.70k
}
871
872
static void _ccv_nnc_symbolic_graph_pruning(ccv_nnc_symbolic_graph_simplify_t* const simplify, const ccv_nnc_tensor_symbol_t* const outputs, const int output_size)
873
2.30k
{
874
2.30k
  uint32_t* const tensor_visited = (uint32_t*)cccalloc(sizeof(uint32_t), (simplify->tensor_symbol_info_size + 31) >> 5);
875
2.30k
  ccv_array_t* const preserve[2] = {
876
2.30k
    ccv_array_new(sizeof(int), output_size, 0),
877
2.30k
    ccv_array_new(sizeof(int), 0, 0),
878
2.30k
  };
879
2.30k
  int i, j;
880
2.30k
  ccv_array_t** const r_alias_refs = (ccv_array_t**)cccalloc(sizeof(ccv_array_t*), simplify->tensor_symbol_info_size);
881
13.1k
  for (i = 0; i < simplify->tensor_symbol_info_size; 
i++10.8k
)
882
10.8k
    if (simplify->tensor_symbol_info[i].alias_ref)
883
1.10k
    {
884
1.10k
      const int alias_ref = simplify->tensor_symbol_info[i].alias_ref - 1;
885
1.10k
      assert(alias_ref < simplify->tensor_symbol_info_size);
886
1.10k
      if (!r_alias_refs[alias_ref])
887
1.08k
        r_alias_refs[alias_ref] = ccv_array_new(sizeof(int), 1, 0);
888
1.10k
      ccv_array_push(r_alias_refs[alias_ref], &i);
889
1.10k
    }
890
2.30k
  uint32_t* const exec_dead = simplify->exec_dead;
891
2.30k
  uint32_t* const tensor_dead = simplify->tensor_dead;
892
2.30k
  int* const output_execs = simplify->output_execs;
893
2.30k
  _ccv_nnc_symbolic_graph_simplify_update_output_execs(simplify);
894
  // Mark everything visited as dead.
895
3.94k
  ccv_nnc_graph_visit_for(simplify->visit, simplify->exec_symbol_info, node, idx) {
896
3.94k
    if (node->flags & CCV_NNC_GRAPH_EXEC_DISABLE_OPT) // If optimization pass disabled, skip.
897
1.00k
      continue;
898
2.93k
    exec_dead[idx >> 5] |= (1u << (idx & 0x1f));
899
9.19k
    for (i = 0; i < node->input_size; 
i++6.25k
)
900
6.25k
    {
901
6.25k
      const int d = node->inputs[i];
902
6.25k
      if (d >= 0)
903
6.23k
        tensor_dead[d >> 5] |= (1u << (d & 0x1f));
904
6.25k
    }
905
6.20k
    for (i = 0; i < node->output_size; 
i++3.27k
)
906
3.27k
    {
907
3.27k
      const int d = node->outputs[i];
908
3.27k
      if (d >= 0)
909
3.26k
        tensor_dead[d >> 5] |= (1u << (d & 0x1f));
910
3.27k
    }
911
2.93k
  } ccv_nnc_graph_visit_endfor
912
  // If the tensor symbol is used by other exec that is not visited, unmark it.
913
6.24k
  for (i = 0; i < simplify->exec_symbol_info_size; 
i++3.94k
)
914
3.94k
  {
915
3.94k
    if (exec_dead[i >> 5] & (1u << (i & 0x1f)))
916
2.93k
      continue;
917
1.00k
    const ccv_nnc_graph_exec_symbol_info_t* const node = simplify->exec_symbol_info + i;
918
2.01k
    for (j = 0; j < node->input_size; 
j++1.00k
)
919
1.00k
    {
920
1.00k
      const int d = node->inputs[j];
921
      // Undead it.
922
1.00k
      if (d >= 0)
923
1.00k
        tensor_dead[d >> 5] &= ~(1u << (d & 0x1f));
924
1.00k
    }
925
2.01k
    for (j = 0; j < node->output_size; 
j++1.00k
)
926
1.00k
    {
927
1.00k
      const int d = node->outputs[j];
928
      // Undead it.
929
1.00k
      if (d >= 0)
930
1.00k
        tensor_dead[d >> 5] &= ~(1u << (d & 0x1f));
931
1.00k
    }
932
1.00k
  }
933
4.61k
  for (i = 0; i < output_size; 
i++2.31k
)
934
2.31k
    ccv_array_push(preserve[0], &outputs[i].d);
935
2.30k
  int p = 0, q = 1;
936
  // BFS to mark execs / tensors as not dead.
937
8.21k
  while (preserve[p]->rnum > 0)
938
5.90k
  {
939
5.90k
    ccv_array_clear(preserve[q]);
940
    // First, undead all relevant tensors.
941
19.0k
    for (i = 0; i < preserve[p]->rnum; 
i++13.1k
)
942
13.1k
    {
943
13.1k
      const int d = *(int*)ccv_array_get(preserve[p], i);
944
      // Undead the outputs.
945
13.1k
      tensor_dead[d >> 5] &= ~(1u << (d & 0x1f));
946
13.1k
      int alias_ref = d;
947
13.1k
      if (simplify->tensor_symbol_info[d].alias_ref)
948
1.10k
      {
949
1.10k
        alias_ref = simplify->tensor_symbol_info[d].alias_ref - 1;
950
1.10k
        tensor_dead[alias_ref >> 5] &= ~(1u << (alias_ref & 0x1f));
951
1.10k
        assert(r_alias_refs[alias_ref]);
952
1.10k
      }
953
13.1k
      if (r_alias_refs[alias_ref])
954
4.43k
        
for (j = 0; 2.18k
j < r_alias_refs[alias_ref]->rnum;
j++2.24k
)
955
2.24k
        {
956
2.24k
          const int b = *(int*)ccv_array_get(r_alias_refs[alias_ref], j);
957
2.24k
          if (output_execs[b] >= 0) // Only revive if it is written alias.
958
65
            tensor_dead[b >> 5] &= ~(1u << (b & 0x1f));
959
2.24k
        }
960
13.1k
    }
961
19.0k
    
for (i = 0; 5.90k
i < preserve[p]->rnum;
i++13.1k
)
962
13.1k
    {
963
13.1k
      const int d = *(int*)ccv_array_get(preserve[p], i);
964
13.1k
      const int output_exec = output_execs[d];
965
      // Undead the exec.
966
13.1k
      if (output_exec >= 0)
967
6.57k
        _ccv_nnc_symbolic_graph_pruning_undead_exec(simplify, output_exec, tensor_visited, preserve[q]);
968
13.1k
      int alias_ref = d;
969
13.1k
      if (simplify->tensor_symbol_info[d].alias_ref)
970
1.10k
      {
971
1.10k
        alias_ref = simplify->tensor_symbol_info[d].alias_ref - 1;
972
1.10k
        const int output_exec = output_execs[alias_ref];
973
1.10k
        if (output_exec >= 0)
974
1.07k
          _ccv_nnc_symbolic_graph_pruning_undead_exec(simplify, output_exec, tensor_visited, preserve[q]);
975
1.10k
      }
976
13.1k
      if (r_alias_refs[alias_ref])
977
4.43k
        
for (j = 0; 2.18k
j < r_alias_refs[alias_ref]->rnum;
j++2.24k
)
978
2.24k
        {
979
2.24k
          const int b = *(int*)ccv_array_get(r_alias_refs[alias_ref], j);
980
2.24k
          const int output_exec = output_execs[b];
981
2.24k
          if (output_exec >= 0)
982
65
            _ccv_nnc_symbolic_graph_pruning_undead_exec(simplify, output_exec, tensor_visited, preserve[q]);
983
2.24k
        }
984
13.1k
    }
985
5.90k
    CCV_SWAP(p, q, i);
986
5.90k
  }
987
2.30k
  ccfree(tensor_visited);
988
2.30k
  ccv_array_free(preserve[0]);
989
2.30k
  ccv_array_free(preserve[1]);
990
13.1k
  for (i = 0; i < simplify->tensor_symbol_info_size; 
i++10.8k
)
991
10.8k
    if (r_alias_refs[i])
992
1.08k
      ccv_array_free(r_alias_refs[i]);
993
2.30k
  ccfree(r_alias_refs);
994
2.30k
}
995
996
void ccv_nnc_symbolic_graph_simplify(ccv_nnc_symbolic_graph_t* const graph, const int* const passes, const int pass_size, const ccv_nnc_tensor_symbol_t* const binds, const int bind_size, const ccv_nnc_tensor_symbol_t* const outputs, const int output_size, const ccv_nnc_graph_exec_symbol_t* const sources, const int source_size, const ccv_nnc_graph_exec_symbol_t* const destinations, const int destination_size)
997
4.59k
{
998
4.59k
  ccv_nnc_symbolic_graph_simplify_t* const simplify = _ccv_nnc_symbolic_graph_simplify_new(graph, sources, source_size, destinations, destination_size);
999
4.59k
  int i;
1000
16.0k
  for (i = 0; i < pass_size; 
i++11.5k
)
1001
11.5k
    switch (passes[i])
1002
11.5k
    {
1003
2.30k
      case CCV_NNC_SIMPLIFY_COMMON_SUBEXPRESSION_ELIMINATION:
1004
2.30k
        _ccv_nnc_symbolic_graph_common_subexpression_elimination(simplify, outputs, output_size);
1005
2.30k
        break;
1006
2.30k
      case CCV_NNC_SIMPLIFY_DATA_TRANSFER_OPT:
1007
2.30k
        _ccv_nnc_symbolic_graph_data_transfer_opt(simplify, binds, bind_size, outputs, output_size);
1008
2.30k
        break;
1009
2.30k
      case CCV_NNC_SIMPLIFY_GRAPH_PRUNING:
1010
2.30k
        _ccv_nnc_symbolic_graph_pruning(simplify, outputs, output_size);
1011
2.30k
        break;
1012
4.59k
      case CCV_NNC_SIMPLIFY_OPS_FUSION:
1013
4.59k
        _ccv_nnc_symbolic_graph_ops_fusion(simplify, outputs, output_size);
1014
4.59k
        break;
1015
11.5k
    }
1016
4.59k
  _ccv_nnc_symbolic_graph_simplify_apply(simplify);
1017
4.59k
  _ccv_nnc_symbolic_graph_simplify_free(simplify);
1018
4.59k
}