Coverage Report

Created: 2021-04-05 03:19

/home/liu/buildslave/linux-x64-runtests/build/lib/nnc/ccv_nnc_symbolic_graph_simplify.c
Line
Count
Source (jump to first uncovered line)
1
#include "ccv_nnc.h"
2
#include "ccv_nnc_easy.h"
3
#include "ccv_nnc_internal.h"
4
#include "ccv_internal.h"
5
#include "_ccv_nnc_symbolic_graph.h"
6
#include "3rdparty/siphash/siphash24.h"
7
8
// MARK - Level-3.5 API
9
10
static uint8_t key_siphash[16] = "graphcsekvlibnnc";
11
12
typedef struct {
13
  int tensor_symbol_info_size;
14
  int exec_symbol_info_size;
15
  ccv_nnc_symbolic_graph_t* graph;
16
  ccv_nnc_graph_visit_t* visit;
17
  ccv_nnc_tensor_symbol_info_t* tensor_symbol_info;
18
  ccv_nnc_graph_exec_symbol_info_t* exec_symbol_info;
19
  uint32_t* exec_dead; // Mark a exec is dead and need to be cleared, each bit represent a exec.
20
  uint32_t* tensor_dead; // Mark a tensor is dead and need to be cleared, each bit represent a tensor.
21
  int* output_execs; // Mapping from tensor to the exec that generates this tensor.
22
} ccv_nnc_symbolic_graph_simplify_t;
23
24
static void _ccv_nnc_symbolic_graph_simplify_update_output_execs(ccv_nnc_symbolic_graph_simplify_t* const simplify)
25
6.79k
{
26
6.79k
  int i;
27
36.0k
  for (i = 0; i < simplify->tensor_symbol_info_size; 
i++29.2k
)
28
29.2k
    simplify->output_execs[i] = -1;
29
8.63k
  ccv_nnc_graph_visit_for(simplify->visit, simplify->exec_symbol_info, node, idx) {
30
8.63k
    if (simplify->exec_dead[idx >> 5] & (1u << (idx & 0x1f)))
31
14
      continue;
32
18.2k
    
for (i = 0; 8.62k
i < node->output_size;
i++9.67k
)
33
9.67k
      if (node->outputs[i] >= 0)
34
9.66k
        simplify->output_execs[node->outputs[i]] = idx; // A tensor can only be written once.
35
8.62k
  } ccv_nnc_graph_visit_endfor
36
6.79k
}
37
38
static ccv_nnc_symbolic_graph_simplify_t* _ccv_nnc_symbolic_graph_simplify_new(ccv_nnc_symbolic_graph_t* const graph, const ccv_nnc_graph_exec_symbol_t* const sources, const int source_size, const ccv_nnc_graph_exec_symbol_t* const destinations, const int destination_size)
39
4.52k
{
40
4.52k
  ccv_nnc_symbolic_graph_simplify_t* const simplify = (ccv_nnc_symbolic_graph_simplify_t*)ccmalloc(sizeof(ccv_nnc_symbolic_graph_simplify_t));
41
4.52k
  simplify->graph = graph;
42
4.52k
  simplify->visit = ccv_nnc_graph_visit_new(graph, (ccv_nnc_graph_exec_symbol_info_t*)ccv_array_get(graph->exec_symbol_info, 0), graph->exec_symbol_info->rnum, sources, source_size, destinations, destination_size, 0);
43
4.52k
  simplify->tensor_symbol_info = (ccv_nnc_tensor_symbol_info_t*)ccmalloc(sizeof(ccv_nnc_tensor_symbol_info_t) * graph->tensor_symbol_info->rnum);
44
4.52k
  simplify->exec_symbol_info = (ccv_nnc_graph_exec_symbol_info_t*)ccmalloc(sizeof(ccv_nnc_graph_exec_symbol_info_t) * graph->exec_symbol_info->rnum);
45
4.52k
  ccv_nnc_symbolic_graph_symbol_infer(graph, simplify->visit, sources, source_size, destinations, destination_size, 0, 0, simplify->tensor_symbol_info, simplify->exec_symbol_info);
46
4.52k
  simplify->tensor_symbol_info_size = graph->tensor_symbol_info->rnum;
47
4.52k
  simplify->exec_symbol_info_size = graph->exec_symbol_info->rnum;
48
4.52k
  simplify->exec_dead = cccalloc(((simplify->exec_symbol_info_size + 31) >> 5) + ((simplify->tensor_symbol_info_size + 31) >> 5), sizeof(uint32_t));
49
4.52k
  simplify->tensor_dead = simplify->exec_dead + ((simplify->exec_symbol_info_size + 31) >> 5);
50
4.52k
  simplify->output_execs = (int*)ccmalloc(sizeof(int) * simplify->tensor_symbol_info_size);
51
4.52k
  return simplify;
52
4.52k
}
53
54
static void _ccv_nnc_symbolic_graph_simplify_apply(ccv_nnc_symbolic_graph_simplify_t* const simplify)
55
4.52k
{
56
4.52k
  int i, j;
57
10.2k
  for (i = 0; i < simplify->exec_symbol_info_size; 
i++5.73k
)
58
5.73k
    if (simplify->exec_dead[i >> 5] & (1u << (i & 0x1f)))
59
15
      ccv_nnc_graph_exec_symbol_free(simplify->graph, (ccv_nnc_graph_exec_symbol_t){
60
15
        .d = i,
61
15
        .graph = simplify->graph,
62
15
      });
63
5.72k
    else // If it is not marked as dead, go through to unmark tensor 
64
12.1k
      
for (j = 0; 5.72k
j < simplify->exec_symbol_info[i].output_size;
j++6.41k
)
65
6.41k
      {
66
6.41k
        const int d = simplify->exec_symbol_info[i].outputs[j];
67
6.41k
        if (d >= 0)
68
6.41k
          simplify->tensor_dead[d >> 5] &= ~(1u << (d & 0x1f));
69
6.41k
      }
70
23.9k
  for (i = 0; i < simplify->tensor_symbol_info_size; 
i++19.4k
)
71
19.4k
    if (simplify->tensor_dead[i >> 5] & (1u << (i & 0x1f)))
72
13
      ccv_nnc_tensor_symbol_free(simplify->graph, (ccv_nnc_tensor_symbol_t){
73
13
        .d = i,
74
13
        .graph = simplify->graph,
75
13
      });
76
4.52k
  // autogen the sources / destinations.
77
4.52k
  ccv_nnc_graph_exec_symbol_autogen(simplify->graph, 0, 0, CCV_NNC_AUTOGEN_SOURCES_AND_DESTINATIONS);
78
4.52k
}
79
80
static void _ccv_nnc_symbolic_graph_simplify_free(ccv_nnc_symbolic_graph_simplify_t* const simplify)
81
4.52k
{
82
4.52k
  ccv_nnc_graph_visit_free(simplify->visit);
83
4.52k
  ccfree(simplify->tensor_symbol_info);
84
4.52k
  ccfree(simplify->exec_symbol_info);
85
4.52k
  ccfree(simplify->exec_dead);
86
4.52k
  ccfree(simplify->output_execs);
87
4.52k
  ccfree(simplify);
88
4.52k
}
89
90
typedef struct {
91
  int d;
92
  int ifbit;
93
  uint64_t hash;
94
} ccv_nnc_cse_hash_t;
95
96
static int _ccv_nnc_cse_hash_find(ccv_nnc_cse_hash_t* const hash_map, const uint64_t hash, const int map_size)
97
6.14k
{
98
6.14k
  assert(hash > 0);
99
6.14k
  int i, j;
100
6.14k
  i = (hash - 1) % map_size;
101
6.14k
  for (j = 0; ; 
j++, i++5.68k
)
102
11.8k
  {
103
11.8k
    if (i >= map_size)
104
2
      i = 0;
105
11.8k
    if (j > hash_map[i].ifbit)
106
5.57k
      return -1;
107
6.26k
    if (hash_map[i].hash == hash)
108
573
      return hash_map[i].d;
109
6.26k
  }
110
6.14k
}
111
112
static void _ccv_nnc_cse_hash_add(ccv_nnc_cse_hash_t* const hash_map, uint64_t hash, int d, const int map_size)
113
3.23k
{
114
3.23k
  assert(hash > 0);
115
3.23k
  int i, j;
116
3.23k
  i = (hash - 1) % map_size;
117
3.23k
  for (j = 0; ; 
j++, i++246
)
118
3.47k
  {
119
3.47k
    if (i >= map_size)
120
2
      i = 0;
121
3.47k
    if (hash_map[i].hash == hash) // Already exists, do nothing.
122
6
      return;
123
3.47k
    if (hash_map[i].hash == 0)
124
3.22k
    {
125
3.22k
      // Insert.
126
3.22k
      hash_map[i].d = d;
127
3.22k
      hash_map[i].ifbit = j;
128
3.22k
      hash_map[i].hash = hash;
129
3.22k
      return;
130
3.22k
    }
131
246
    if (j > hash_map[i].ifbit)
132
57
    {
133
57
      const ccv_nnc_cse_hash_t old_hash = hash_map[i];
134
57
      // Swap, and continue, until find an empty slot.
135
57
      hash_map[i].d = d;
136
57
      hash_map[i].ifbit = j;
137
57
      hash_map[i].hash = hash;
138
57
      d = old_hash.d;
139
57
      j = old_hash.ifbit;
140
57
      hash = old_hash.hash;
141
57
    }
142
246
  }
143
3.23k
}
144
145
static int _ccv_nnc_symbolic_graph_update_refs(ccv_nnc_symbolic_graph_simplify_t* const simplify, const ccv_nnc_tensor_symbol_t* const outputs, const int output_size, const int* const refs, const int output_exec_ref_dead)
146
4.53k
{
147
4.53k
  int i, j;
148
4.53k
  // Go over refs, if a tensor is an alias, mark it reference to the new one.
149
24.0k
  for (i = 0; i < simplify->tensor_symbol_info_size; 
i++19.5k
)
150
19.5k
    if (refs[i] >= 0)
151
12
      // Mark this tensor as dead.
152
12
      simplify->tensor_dead[i >> 5] |= (1u << (i & 0x1f));
153
19.4k
    else if (simplify->tensor_symbol_info[i].alias_ref && 
refs[simplify->tensor_symbol_info[i].alias_ref - 1] >= 02.23k
) {
154
0
      const int alias_ref = simplify->tensor_symbol_info[i].alias_ref - 1;
155
0
      simplify->tensor_symbol_info[i].alias_ref = refs[alias_ref] + 1;
156
0
      ((ccv_nnc_tensor_symbol_info_t*)ccv_array_get(simplify->graph->tensor_symbol_info, i))->alias_ref = refs[alias_ref] + 1;
157
0
    }
158
9.07k
  for (i = 0; i < output_size; 
i++4.54k
)
159
4.54k
    // If the output is an alias, that's fine, because if the alias is re-binded, we are good.
160
4.54k
    simplify->tensor_dead[outputs[i].d >> 5] &= ~(1u << (outputs[i].d & 0x1f)); // Undead for output tensor symbols.
161
4.53k
  // Merge s_refs if the tensor is dead.
162
24.0k
  for (i = 0; i < simplify->tensor_symbol_info_size; 
i++19.5k
)
163
19.5k
    if (refs[i] >= 0 && 
(simplify->tensor_dead[i >> 5] & (1u << (i & 0x1f)))12
)
164
12
    {
165
12
      const int ref = refs[i];
166
12
      if (simplify->tensor_symbol_info[i].s_ref && 
simplify->tensor_symbol_info[i].s_ref->rnum2
)
167
2
      {
168
2
        if (!simplify->tensor_symbol_info[ref].s_ref) // If there is no s_ref, simple, just assign the pointer and set the old one to nil.
169
0
        {
170
0
          simplify->tensor_symbol_info[ref].s_ref = simplify->tensor_symbol_info[i].s_ref;
171
0
          simplify->tensor_symbol_info[i].s_ref = 0;
172
0
          ((ccv_nnc_tensor_symbol_info_t*)ccv_array_get(simplify->graph->tensor_symbol_info, i))->s_ref = 0;
173
2
        } else {
174
2
          ccv_array_t* const ref_s_ref = simplify->tensor_symbol_info[ref].s_ref;
175
2
          ccv_array_t* const i_s_ref = simplify->tensor_symbol_info[i].s_ref;
176
2
          const int ref_s_ref_rnum = ref_s_ref->rnum;
177
2
          int flag = 0;
178
2
          // Detect conflict, if there is, undead.
179
4
          for (j = 0; !flag && 
j < 3
ccv_min3
(ref_s_ref_rnum, i_s_ref->rnum);
j++2
)
180
2
          {
181
2
            const int ref_s_ref_k = *(int*)ccv_array_get(ref_s_ref, j);
182
2
            const int i_s_ref_k = *(int*)ccv_array_get(i_s_ref, j);
183
2
            // If for the same sub-graph, they have different tensors linked, we cannot merge these two.
184
2
            flag = (ref_s_ref_k > 0 && i_s_ref_k > 0 && 
ref_s_ref_k != i_s_ref_k1
);
185
2
          }
186
2
          if (flag)
187
1
          {
188
1
            simplify->tensor_dead[i >> 5] &= ~(1u << (i & 0x1f)); // Undead
189
1
            continue;
190
1
          }
191
1
          if (ref_s_ref_rnum < i_s_ref->rnum)
192
1
          {
193
1
            ccv_array_resize(ref_s_ref, i_s_ref->rnum);
194
1
            memcpy(ccv_array_get(ref_s_ref, ref_s_ref_rnum), ccv_array_get(i_s_ref, ref_s_ref_rnum), sizeof(int) * (i_s_ref->rnum - ref_s_ref_rnum));
195
1
          }
196
2
          for (j = 0; j < ccv_min(ref_s_ref_rnum, i_s_ref->rnum); 
j++1
)
197
1
          {
198
1
            const int ref_s_ref_k = *(int*)ccv_array_get(ref_s_ref, j);
199
1
            const int i_s_ref_k = *(int*)ccv_array_get(i_s_ref, j);
200
1
            assert(ref_s_ref_k == 0 || i_s_ref_k == 0);
201
1
            if (i_s_ref_k)
202
0
              *(int*)ccv_array_get(ref_s_ref, j) = i_s_ref_k;
203
1
          }
204
1
          ccv_array_free(simplify->tensor_symbol_info[i].s_ref);
205
1
          simplify->tensor_symbol_info[i].s_ref = 0;
206
1
          ((ccv_nnc_tensor_symbol_info_t*)ccv_array_get(simplify->graph->tensor_symbol_info, i))->s_ref = 0;
207
3
          for (j = 0; j < ref_s_ref->rnum; 
j++2
)
208
2
          {
209
2
            const int ref_k = *(int*)ccv_array_get(ref_s_ref, j) - 1;
210
2
            if (ref_k >= 0)
211
2
            {
212
2
              ccv_nnc_symbolic_graph_t* const sub_graph = *(ccv_nnc_symbolic_graph_t**)ccv_array_get(simplify->graph->sub_graphs, j);
213
2
              assert(sub_graph);
214
2
              // Update its p_ref.
215
2
              ((ccv_nnc_tensor_symbol_info_t*)ccv_array_get(sub_graph->tensor_symbol_info, ref_k))->p_ref = ref + 1;
216
2
            }
217
2
          }
218
1
        }
219
2
        assert(simplify->tensor_symbol_info[i].s_ref == ((ccv_nnc_tensor_symbol_info_t*)ccv_array_get(simplify->graph->tensor_symbol_info, i))->s_ref);
220
1
        assert(simplify->tensor_symbol_info[ref].s_ref == ((ccv_nnc_tensor_symbol_info_t*)ccv_array_get(simplify->graph->tensor_symbol_info, ref))->s_ref);
221
1
      }
222
12
    }
223
4.53k
  // Going through refs that we are updating, going through its p_ref to make sure both are updated.
224
24.0k
  
for (i = 0; 4.53k
i < simplify->tensor_symbol_info_size;
i++19.5k
)
225
19.5k
    if (refs[i] >= 0 && 
(simplify->tensor_dead[i >> 5] & (1u << (i & 0x1f)))12
&&
simplify->tensor_symbol_info[i].p_ref11
)
226
1
    {
227
1
      const int ref = refs[i];
228
1
      const int p_ref = simplify->tensor_symbol_info[i].p_ref - 1;
229
1
      assert(p_ref >= 0);
230
1
      assert(simplify->graph->p);
231
1
      ccv_array_t* const s_ref = ((ccv_nnc_tensor_symbol_info_t*)ccv_array_get(simplify->graph->p->tensor_symbol_info, p_ref))->s_ref;
232
1
      const int s_idx = simplify->graph->p_idx - 1;
233
1
      assert(s_idx >= 0);
234
1
      assert(s_ref && s_ref->rnum > s_idx);
235
1
      *(int*)ccv_array_get(s_ref, s_idx) = ref + 1; // Update so it references to the new s_ref.
236
1
      assert(!simplify->tensor_symbol_info[ref].p_ref);
237
1
      simplify->tensor_symbol_info[ref].p_ref = p_ref + 1;
238
1
      ((ccv_nnc_tensor_symbol_info_t*)ccv_array_get(simplify->graph->tensor_symbol_info, ref))->p_ref = p_ref + 1;
239
1
    }
240
4.53k
  // Now go over exec to mark them as dead because we don't need these to generate refs.
241
4.53k
  if (output_exec_ref_dead)
242
12.0k
    
for (i = 0; 2.26k
i < simplify->tensor_symbol_info_size;
i++9.74k
)
243
9.74k
      if (refs[i] >= 0 && 
(simplify->tensor_dead[i >> 5] & (1u << (i & 0x1f)))5
)
244
4
      {
245
4
        const int output_exec = simplify->output_execs[i];
246
4
        assert(output_exec >= 0);
247
4
        const ccv_nnc_graph_exec_symbol_info_t* const symbol_info = simplify->exec_symbol_info + output_exec;
248
4
        int flag = 0;
249
8
        for (j = 0; !flag && j < symbol_info->output_size; 
j++4
)
250
4
        {
251
4
          const int d = symbol_info->outputs[j];
252
4
          if (d >= 0)
253
4
            flag = (!(simplify->tensor_dead[d >> 5] & (1u << (d & 0x1f)))); // If some of the output is not dead, we cannot proceed.
254
4
        }
255
4
        if (!flag) // If all outputs are dead, mark the exec as dead.
256
4
          simplify->exec_dead[output_exec >> 5] |= (1u << (output_exec & 0x1f));
257
4
      }
258
4.53k
  int updated_refs = 0;
259
4.53k
  // Go over replace inputs / outputs.
260
5.77k
  ccv_nnc_graph_visit_for(simplify->visit, simplify->exec_symbol_info, node, idx) {
261
18.0k
    for (i = 0; i < node->input_size; 
i++12.3k
)
262
12.3k
    {
263
12.3k
      const int d = node->inputs[i];
264
12.3k
      if (d >= 0 && 
refs[d] >= 012.3k
&&
(simplify->tensor_dead[d >> 5] & (1u << (d & 0x1f)))14
)
265
13
      {
266
13
          node->inputs[i] = refs[d]; // It can be replaced.
267
13
          updated_refs = 1;
268
13
      }
269
12.3k
    }
270
12.2k
    for (i = 0; i < node->output_size; 
i++6.47k
)
271
6.47k
    {
272
6.47k
      const int d = node->outputs[i];
273
6.47k
      if (d >= 0 && 
refs[d] >= 06.46k
&&
(simplify->tensor_dead[d >> 5] & (1u << (d & 0x1f)))11
)
274
10
      {
275
10
          node->outputs[i] = refs[d]; // It can be replaced.
276
10
          updated_refs = 1;
277
10
      }
278
6.47k
    }
279
5.77k
    assert(((ccv_nnc_graph_exec_symbol_info_t*)ccv_array_get(simplify->graph->exec_symbol_info, idx))->inputs == node->inputs);
280
5.77k
    assert(((ccv_nnc_graph_exec_symbol_info_t*)ccv_array_get(simplify->graph->exec_symbol_info, idx))->outputs == node->outputs);
281
5.77k
  } ccv_nnc_graph_visit_endfor
282
4.53k
  const ccv_nnc_graph_exec_symbol_info_t* const p_node_info = simplify->graph->p ? 
(ccv_nnc_graph_exec_symbol_info_t*)6
ccv_array_get6
(simplify->graph->p->exec_symbol_info, simplify->graph->exec_idx - 1) :
04.52k
;
283
4.53k
  if (p_node_info && 
(p_node_info->flags & CCV_NNC_GRAPH_EXEC_P_WHILE)6
)
284
6
    // Go over the while inputs as well.
285
12
    
for (i = 0; 6
i < p_node_info->p_while.input_size;
i++6
)
286
6
    {
287
6
      const int d = p_node_info->p_while.inputs[i];
288
6
      if (d >= 0 && 
refs[d] >= 00
&&
(simplify->tensor_dead[d >> 5] & (1u << (d & 0x1f)))0
)
289
0
      {
290
0
        p_node_info->p_while.inputs[i] = refs[d];
291
0
        updated_refs = 1;
292
0
      }
293
6
    }
294
4.53k
  return updated_refs;
295
4.53k
}
296
297
// This is a simple common sub-expression elimination implementation, particularly, we only replace the later computed output
298
// with the identical earlier computed output, and let the "elimination" part to the graph pruning.
299
static void _ccv_nnc_symbolic_graph_common_subexpression_elimination(ccv_nnc_symbolic_graph_simplify_t* const simplify, const ccv_nnc_tensor_symbol_t* const outputs, const int output_size)
300
2.26k
{
301
2.26k
  _ccv_nnc_symbolic_graph_simplify_update_output_execs(simplify);
302
2.26k
  // tensor_hash starts with 0s, and it is either marked with the tensor index + 1, or the hash of the computations.
303
2.26k
  uint64_t* const tensor_hash = (uint64_t*)cccalloc(simplify->tensor_symbol_info_size, sizeof(uint64_t));
304
2.26k
  int i;
305
2.88k
  ccv_nnc_graph_visit_for(simplify->visit, simplify->exec_symbol_info, node, idx) {
306
2.88k
    // We cannot support graph / custom command (we cannot model them properly).
307
2.88k
    if (node->cmd.cmd == CCV_NNC_GRAPH_FORWARD ||
308
2.88k
      
node->cmd.cmd == CCV_NNC_GRAPH_BACKWARD2.88k
||
309
2.88k
      
node->cmd.cmd == CCV_NNC_CUSTOM_FORWARD2.88k
||
310
2.88k
      
node->cmd.cmd == CCV_NNC_CUSTOM_BACKWARD2.88k
)
311
1
      continue;
312
2.88k
    // If already marked as dead, skip.
313
2.88k
    if (simplify->exec_dead[idx >> 5] & (1u << (idx & 0x1f)))
314
0
      continue;
315
2.88k
    uint64_t hashout, hashin[3];
316
2.88k
    siphash((uint8_t*)&hashin[0], (const uint8_t*)&node->cmd.info, sizeof(node->cmd.info), key_siphash);
317
2.88k
    siphash((uint8_t*)&hashin[1], (const uint8_t*)&node->hint, sizeof(node->hint), key_siphash);
318
2.88k
    hashin[2] = node->cmd.cmd; // Now actually hash the cmd name.
319
2.88k
    siphash((uint8_t*)&hashout, (const uint8_t*)hashin, sizeof(hashin), key_siphash);
320
2.88k
    // First, hash the cmd and the hints with the cmd.
321
2.88k
    // Note on alias, we cannot really generate proper hash for alias (yet). Thus, just treat alias as normal tensors.
322
9.03k
    for (i = 0; i < node->input_size; 
i++6.15k
)
323
6.15k
    {
324
6.15k
      assert(node->inputs[i] < simplify->tensor_symbol_info_size);
325
6.15k
      // If no hash for the input, use the index + 1 as the hash.
326
6.15k
      if (node->inputs[i] >= 0 && 
tensor_hash[node->inputs[i]] == 06.14k
)
327
5.51k
        tensor_hash[node->inputs[i]] = node->inputs[i] + 1;
328
6.15k
      if (node->inputs[i] >= 0)
329
6.14k
      {
330
6.14k
        // Hash using the tensor hash.
331
6.14k
        hashin[0] = hashout;
332
6.14k
        hashin[1] = i; // Encode the positional information.
333
6.14k
        hashin[2] = tensor_hash[node->inputs[i]];
334
6.14k
      } else {
335
10
        // Hash using the input integer (could be special integer).
336
10
        hashin[0] = hashout;
337
10
        hashin[1] = i; // Encode the positional information.
338
10
        hashin[2] = node->inputs[i];
339
10
      }
340
6.15k
      siphash((uint8_t*)&hashout, (const uint8_t*)hashin, sizeof(hashin), key_siphash);
341
6.15k
    }
342
6.11k
    
for (i = 0; 2.88k
i < node->output_size;
i++3.23k
)
343
3.23k
      if (node->outputs[i] >= 0)
344
3.23k
      {
345
3.23k
        assert(node->outputs[i] < simplify->tensor_symbol_info_size);
346
3.23k
        // Assigning once, especially now we don't consider aliases.
347
3.23k
        assert(tensor_hash[node->outputs[i]] == 0);
348
3.23k
        hashin[0] = hashout;
349
3.23k
        hashin[1] = i; // Positional information.
350
3.23k
        siphash((uint8_t*)&hashin[2], (const uint8_t*)&simplify->tensor_symbol_info[node->outputs[i]].info,
351
3.23k
            sizeof(simplify->tensor_symbol_info[node->outputs[i]].info), key_siphash);
352
3.23k
        // Generate hash for the output.
353
3.23k
        siphash((uint8_t*)&tensor_hash[node->outputs[i]], (const uint8_t*)hashin, sizeof(hashin), key_siphash);
354
3.23k
      }
355
2.88k
  } ccv_nnc_graph_visit_endfor
356
2.26k
  // Allocate 3 / 2 as much space, for the simple robin-hood open address hash map.
357
2.26k
  const int map_size = (simplify->tensor_symbol_info_size * 3 + 1) / 2;
358
2.26k
  int* const refs = (int*)ccmalloc(sizeof(int) * simplify->tensor_symbol_info_size);
359
12.0k
  for (i = 0; i < simplify->tensor_symbol_info_size; 
i++9.74k
)
360
9.74k
    refs[i] = -1;
361
2.26k
  ccv_nnc_cse_hash_t* const hash_map = (ccv_nnc_cse_hash_t*)cccalloc(map_size, sizeof(ccv_nnc_cse_hash_t));
362
2.26k
  // Now, all tensors are hashed, identify tensors with the same hash code, replace the ones that accessed later.
363
2.88k
  ccv_nnc_graph_visit_for(simplify->visit, simplify->exec_symbol_info, node, idx) {
364
2.88k
    // If already marked as dead, skip.
365
2.88k
    if (simplify->exec_dead[idx >> 5] & (1u << (idx & 0x1f)))
366
0
      continue;
367
9.04k
    
for (i = 0; 2.88k
i < node->input_size;
i++6.15k
)
368
6.15k
      if (node->inputs[i] >= 0)
369
6.14k
      {
370
6.14k
        const int d = node->inputs[i];
371
6.14k
        assert(tensor_hash[d]);
372
6.14k
        const int new_d = _ccv_nnc_cse_hash_find(hash_map, tensor_hash[d], map_size);
373
6.14k
        if (new_d >= 0 && 
new_d != d573
)
374
7
        {
375
7
          // Check whether this can be replaced.
376
7
          assert(refs[d] == -1 || refs[d] == new_d);
377
7
          assert(!simplify->tensor_symbol_info[d].assign_ref);
378
7
          assert(!simplify->tensor_symbol_info[d].r_assign_ref);
379
7
          assert(!simplify->tensor_symbol_info[d].bypass_ref);
380
7
          assert(!simplify->tensor_symbol_info[new_d].assign_ref);
381
7
          assert(!simplify->tensor_symbol_info[new_d].r_assign_ref);
382
7
          assert(!simplify->tensor_symbol_info[new_d].bypass_ref);
383
7
          // Ignore if there is a pair_ref (again, pair_ref has side effect that is deeper (using tape))
384
7
          if (simplify->tensor_symbol_info[d].pair_ref)
385
0
            continue;
386
7
          // If both have p_ref, we cannot merge.
387
7
          if (simplify->tensor_symbol_info[d].p_ref && 
simplify->tensor_symbol_info[new_d].p_ref2
)
388
1
            continue;
389
6
          // Merge s_refs from ref[d] later.
390
6
          if (refs[d] != new_d)
391
5
            refs[d] = new_d;
392
6
          assert(simplify->output_execs[new_d] >= 0);
393
6
          // Establish new dependency.
394
6
          ccv_nnc_graph_exec_symbol_concat(simplify->graph, (ccv_nnc_graph_exec_symbol_t){
395
6
            .d = simplify->output_execs[new_d],
396
6
            .graph = simplify->graph,
397
6
          }, (ccv_nnc_graph_exec_symbol_t){
398
6
            .d = idx,
399
6
            .graph = simplify->graph,
400
6
          });
401
6
        }
402
6.14k
      }
403
2.88k
    // We can reuse the input, but we cannot do that for output of these commands.
404
2.88k
    if (node->cmd.cmd == CCV_NNC_GRAPH_FORWARD ||
405
2.88k
      
node->cmd.cmd == CCV_NNC_GRAPH_BACKWARD2.88k
||
406
2.88k
      
node->cmd.cmd == CCV_NNC_CUSTOM_FORWARD2.88k
||
407
2.88k
      
node->cmd.cmd == CCV_NNC_CUSTOM_BACKWARD2.88k
)
408
1
      continue;
409
6.11k
    
for (i = 0; 2.88k
i < node->output_size;
i++3.23k
)
410
3.23k
      if (node->outputs[i] >= 0) // This tensor can be reused by others.
411
3.23k
        _ccv_nnc_cse_hash_add(hash_map, tensor_hash[node->outputs[i]], node->outputs[i], map_size);
412
2.88k
  } ccv_nnc_graph_visit_endfor
413
2.26k
  _ccv_nnc_symbolic_graph_update_refs(simplify, outputs, output_size, refs, 1 /* For these exec that generates refs, we don't need them any more. */);
414
2.26k
  ccfree(tensor_hash);
415
2.26k
  ccfree(hash_map);
416
2.26k
  ccfree(refs);
417
2.26k
}
418
419
static void _ccv_nnc_symbolic_graph_data_transfer_opt(ccv_nnc_symbolic_graph_simplify_t* const simplify, const ccv_nnc_tensor_symbol_t* const binds, const int bind_size, const ccv_nnc_tensor_symbol_t* const outputs, const int output_size)
420
2.26k
{
421
2.26k
  _ccv_nnc_symbolic_graph_simplify_update_output_execs(simplify);
422
2.26k
  uint32_t* const exec_dead = simplify->exec_dead;
423
2.26k
  const ccv_nnc_tensor_symbol_info_t* const tensor_symbol_info = simplify->tensor_symbol_info;
424
2.26k
  int i, j;
425
2.26k
  uint32_t* const has_alias = ccmalloc(sizeof(uint32_t) * ((simplify->tensor_symbol_info_size + 31) >> 5));
426
2.26k
  int* const refs = (int*)ccmalloc(sizeof(int) * simplify->tensor_symbol_info_size);
427
2.26k
  int updated_refs;
428
2.26k
  do {
429
2.26k
    memset(has_alias, 0, sizeof(uint32_t) * ((simplify->tensor_symbol_info_size + 31) >> 5));
430
2.26k
    // Go through until no updates is possible. This won't result an infinite loop because every time,
431
2.26k
    // a tensor is eliminated.
432
12.0k
    for (i = 0; i < simplify->tensor_symbol_info_size; 
i++9.75k
)
433
9.75k
    {
434
9.75k
      refs[i] = -1;
435
9.75k
      if (tensor_symbol_info[i].alias_ref)
436
1.11k
      {
437
1.11k
        const int alias_ref = tensor_symbol_info[i].alias_ref - 1;
438
1.11k
        has_alias[alias_ref >> 5] |= (1u << (alias_ref & 0x1f));
439
1.11k
      }
440
9.75k
    }
441
2.88k
    ccv_nnc_graph_visit_for(simplify->visit, simplify->exec_symbol_info, node, idx) {
442
2.88k
      // If already marked as dead, skip.
443
2.88k
      if (exec_dead[idx >> 5] & (1u << (idx & 0x1f)))
444
4
        continue;
445
2.88k
      if (node->cmd.cmd != CCV_NNC_DATA_TRANSFER_FORWARD &&
446
2.88k
        
node->cmd.cmd != CCV_NNC_DATA_TRANSFER_BACKWARD2.87k
&&
447
2.88k
        // Conversion, if the datatype is the same, it is the data transfer op.
448
2.88k
        
node->cmd.cmd != CCV_NNC_DATATYPE_CONVERSION_FORWARD2.87k
&&
449
2.88k
        
node->cmd.cmd != CCV_NNC_DATATYPE_CONVERSION_BACKWARD2.87k
&&
450
2.88k
        // Format transform, if the format is the same, it is the data transfer op.
451
2.88k
        
node->cmd.cmd != CCV_NNC_FORMAT_TRANSFORM_FORWARD2.87k
&&
452
2.88k
        
node->cmd.cmd != CCV_NNC_FORMAT_TRANSFORM_BACKWARD2.87k
)
453
2.87k
        continue;
454
36
      
for (i = 0; 15
i < node->output_size;
i++21
) // For data transfer, we only respect output size.
455
21
        if (node->inputs[i] >= 0 && node->outputs[i] >= 0)
456
21
        {
457
21
          assert(node->inputs[i] < simplify->tensor_symbol_info_size);
458
21
          assert(node->outputs[i] < simplify->tensor_symbol_info_size);
459
21
          const ccv_nnc_tensor_symbol_info_t* const input = tensor_symbol_info + node->inputs[i];
460
21
          const ccv_nnc_tensor_symbol_info_t* const output = tensor_symbol_info + node->outputs[i];
461
21
          // If they are not the same data type, skip. (Likely data conversion op).
462
21
          if (input->info.datatype != output->info.datatype)
463
0
            continue;
464
21
          // If they are not the same format, skip. (Likely format transform op).
465
21
          if (input->info.format != output->info.format)
466
0
            continue;
467
21
          // If they are not on the same device (even for NUMA), skip.
468
21
          if (input->info.type != output->info.type)
469
0
            continue;
470
21
          // If both are alias, we cannot consolidate this.
471
21
          if (input->alias_ref && 
output->alias_ref9
)
472
9
            continue;
473
12
          // If input is alias, and output has alias reference to it, output cannot be the same as input.
474
12
          if (input->alias_ref && 
(has_alias[node->outputs[i] >> 5] & (1u << (node->outputs[i] & 0x1f)))0
)
475
0
            continue;
476
12
          // If output is alias, and input has alias reference to it, input cannot be the same as output.
477
12
          if (output->alias_ref && 
(has_alias[node->inputs[i] >> 5] & (1u << (node->inputs[i] & 0x1f)))11
)
478
0
            continue;
479
12
          // If either are carry overs (for while), we cannot do anything.
480
12
          if (input->assign_ref || output->assign_ref ||
481
12
            input->r_assign_ref || output->r_assign_ref)
482
0
            continue;
483
12
          // If either are bypasses (for case..of), we cannot do anything.
484
12
          if (input->bypass_ref || output->bypass_ref ||
485
12
            input->r_bypass_ref || output->r_bypass_ref)
486
0
            continue;
487
12
          // If either are inputs / outputs connecting the parent graph, we cannot do anything.
488
12
          if (input->p_ref || output->p_ref)
489
0
            continue;
490
12
          int flag = 0;
491
26
          for (j = 0; !flag && 
j < bind_size22
;
j++14
)
492
14
            flag = (binds[j].d == node->inputs[i] || 
binds[j].d == node->outputs[i]10
);
493
20
          for (j = 0; !flag && 
j < output_size16
;
j++8
)
494
8
            flag = (outputs[j].d == node->inputs[i] || outputs[j].d == node->outputs[i]);
495
12
          // Either input or output cannot be in the outputs nor the binds.
496
12
          if (flag)
497
4
            continue;
498
8
          // If the type is the same, check which one is the alias.
499
8
          // We always prefer alias.
500
8
          if (output->alias_ref)
501
7
            refs[node->inputs[i]] = node->outputs[i];
502
1
          else // if (input->alias_ref), else
503
1
            refs[node->outputs[i]] = node->inputs[i];
504
8
        }
505
15
    } ccv_nnc_graph_visit_endfor
506
2.26k
    // Make sure refs reference to the end.
507
12.0k
    
for (i = 0; 2.26k
i < simplify->tensor_symbol_info_size;
i++9.75k
)
508
9.75k
      if (refs[i] >= 0)
509
7
      {
510
7
        int ref = refs[i];
511
7
        while (refs[ref] >= 0)
512
0
          ref = refs[i];
513
7
        refs[i] = ref;
514
7
      }
515
2.26k
    updated_refs = _ccv_nnc_symbolic_graph_update_refs(simplify, outputs, output_size, refs, 0 /* We still need these exec that generates the refs. */);
516
2.26k
  } while (updated_refs);
517
2.26k
  ccfree(refs);
518
2.26k
  ccfree(has_alias);
519
2.26k
  // Now, all references updated, remove data transfers that sources and destinations are the same.
520
2.87k
  ccv_nnc_graph_visit_for(simplify->visit, simplify->exec_symbol_info, node, idx) {
521
2.87k
    // If already marked as dead, skip.
522
2.87k
    if (exec_dead[idx >> 5] & (1u << (idx & 0x1f)))
523
4
      continue;
524
2.86k
    if (node->cmd.cmd != CCV_NNC_DATA_TRANSFER_FORWARD &&
525
2.86k
      
node->cmd.cmd != CCV_NNC_DATA_TRANSFER_BACKWARD2.86k
&&
526
2.86k
      // Conversion, if the datatype is the same, it is the data transfer op.
527
2.86k
      
node->cmd.cmd != CCV_NNC_DATATYPE_CONVERSION_FORWARD2.86k
&&
528
2.86k
      
node->cmd.cmd != CCV_NNC_DATATYPE_CONVERSION_BACKWARD2.86k
&&
529
2.86k
      // Format transform, if the format is the same, it is the data transfer op.
530
2.86k
      
node->cmd.cmd != CCV_NNC_FORMAT_TRANSFORM_FORWARD2.86k
&&
531
2.86k
      
node->cmd.cmd != CCV_NNC_FORMAT_TRANSFORM_BACKWARD2.86k
)
532
2.86k
      continue;
533
18
    
for (i = 0; 7
i < node->output_size;
i++11
) // For data transfer, we only respect output size.
534
11
      if (node->inputs[i] == node->outputs[i])
535
7
      {
536
7
        if (i + 1 < node->output_size)
537
2
        {
538
2
          node->inputs[i] = node->inputs[i + 1];
539
2
          node->outputs[i] = node->outputs[i + 1];
540
2
        }
541
7
        --node->output_size;
542
7
        --i;
543
7
      }
544
7
    node->input_size = node->output_size;
545
7
    ((ccv_nnc_graph_exec_symbol_info_t*)ccv_array_get(simplify->graph->exec_symbol_info, idx))->input_size = node->input_size;
546
7
    ((ccv_nnc_graph_exec_symbol_info_t*)ccv_array_get(simplify->graph->exec_symbol_info, idx))->output_size = node->output_size;
547
7
    // Remove this data transfer node if it has no outputs.
548
7
    if (node->output_size == 0)
549
5
      exec_dead[idx >> 5] |= (1u << (idx & 0x1f));
550
7
  } ccv_nnc_graph_visit_endfor
551
2.26k
}
552
553
typedef struct {
554
  uint32_t fused_op; // The final fused op id.
555
  uint32_t ops_seq[2]; // The sequence of commands to identify.
556
  int ops_seq_size;
557
  int ops_info_select; // Which ops' info will be selected.
558
  struct {
559
    int type; // Whether it is input, or output. It doesn't make sense for example, input in ops_seq, but output in fused_op.
560
    int op_idx; // Index into the ops_seq.
561
    int from; // The index in ops_seq.
562
    int to; // The index in fused_op.
563
  } pos[4]; // maps of positions from ops seq to fused_op for inputs (outputs).
564
  int pos_size;
565
} ccv_nnc_ops_fusion_t;
566
567
enum {
568
  CCV_NNC_OPS_FUSION_INPUT_INDEX,
569
  CCV_NNC_OPS_FUSION_OUTPUT_INDEX,
570
};
571
572
const static int ccv_nnc_ops_fusion_io_size = 2;
573
const static ccv_nnc_ops_fusion_t ccv_nnc_ops_fusions[] = {
574
  {
575
    .fused_op = CCV_NNC_SOFTMAX_CROSSENTROPY_FORWARD,
576
    .ops_seq = {
577
      CCV_NNC_SOFTMAX_FORWARD, CCV_NNC_CATEGORICAL_CROSSENTROPY_FORWARD,
578
    },
579
    .ops_seq_size = 2,
580
    .ops_info_select = 1,
581
    .pos = {
582
      {
583
        .type = CCV_NNC_OPS_FUSION_INPUT_INDEX,
584
        .op_idx = 0,
585
        .from = 0,
586
        .to = 0,
587
      },
588
      {
589
        .type = CCV_NNC_OPS_FUSION_INPUT_INDEX,
590
        .op_idx = 1,
591
        .from = 1,
592
        .to = 1,
593
      },
594
      {
595
        .type = CCV_NNC_OPS_FUSION_OUTPUT_INDEX,
596
        .op_idx = 0,
597
        .from = 0,
598
        .to = 1,
599
      },
600
      {
601
        .type = CCV_NNC_OPS_FUSION_OUTPUT_INDEX,
602
        .op_idx = 1,
603
        .from = 0,
604
        .to = 0,
605
      },
606
    },
607
    .pos_size = 4,
608
  },
609
  {
610
    .fused_op = CCV_NNC_SIGMOID_BINARY_CROSSENTROPY_FORWARD,
611
    .ops_seq = {
612
      CCV_NNC_SIGMOID_FORWARD, CCV_NNC_BINARY_CROSSENTROPY_FORWARD,
613
    },
614
    .ops_seq_size = 2,
615
    .ops_info_select = 1,
616
    .pos = {
617
      {
618
        .type = CCV_NNC_OPS_FUSION_INPUT_INDEX,
619
        .op_idx = 0,
620
        .from = 0,
621
        .to = 0,
622
      },
623
      {
624
        .type = CCV_NNC_OPS_FUSION_INPUT_INDEX,
625
        .op_idx = 1,
626
        .from = 1,
627
        .to = 1,
628
      },
629
      {
630
        .type = CCV_NNC_OPS_FUSION_OUTPUT_INDEX,
631
        .op_idx = 0,
632
        .from = 0,
633
        .to = 1,
634
      },
635
      {
636
        .type = CCV_NNC_OPS_FUSION_OUTPUT_INDEX,
637
        .op_idx = 1,
638
        .from = 0,
639
        .to = 0,
640
      },
641
    },
642
    .pos_size = 4,
643
  }
644
};
645
646
static int _ccv_nnc_find_ops_for_fusion(const ccv_nnc_ops_fusion_t* const fusion, const int ops_idx, const ccv_nnc_graph_exec_symbol_info_t* const exec_symbol_info, const uint32_t* const exec_dead, const int exec_idx, int* const fusing_exec_symbols)
647
38
{
648
38
  if (exec_dead[exec_idx >> 5] & (1u << (exec_idx & 0x1f)))
649
0
    return 0;
650
38
  const ccv_nnc_graph_exec_symbol_info_t* const node = exec_symbol_info + exec_idx;
651
38
  // Doesn't match the ops_seq, return 0.
652
38
  if (fusion->ops_seq[ops_idx] != node->cmd.cmd)
653
12
    return 0;
654
26
  fusing_exec_symbols[ops_idx] = exec_idx;
655
26
  // If already reached the end, we are good.
656
26
  if (ops_idx == fusion->ops_seq_size - 1)
657
5
    return 1;
658
21
  // Otherwise, we need to go on, but don't have any to follow-up.
659
21
  if (!node->outgoings || 
!node->outgoings->rnum17
)
660
4
    return 0;
661
17
  int i;
662
29
  for (i = 0; i < node->outgoings->rnum; 
i++12
)
663
17
    if (_ccv_nnc_find_ops_for_fusion(fusion, ops_idx + 1, exec_symbol_info, exec_dead, *(int*)ccv_array_get(node->outgoings, i), fusing_exec_symbols))
664
5
      return 1;
665
17
  
return 012
;
666
17
}
667
668
static void _ccv_nnc_symbolic_graph_ops_fusion(ccv_nnc_symbolic_graph_simplify_t* const simplify, const ccv_nnc_tensor_symbol_t* const outputs, const int output_size)
669
4.52k
{
670
4.52k
  uint32_t* const exec_dead = simplify->exec_dead;
671
4.52k
  ccv_nnc_graph_exec_symbol_info_t* const exec_symbol_info = simplify->exec_symbol_info;
672
4.52k
  int i, j;
673
4.52k
  int fusing_exec_symbols[sizeof(ccv_nnc_ops_fusions->ops_seq)];
674
4.52k
  int fused_inputs[ccv_nnc_ops_fusion_io_size]; // 2 is just a number based on the ops_fusion.
675
4.52k
  int fused_outputs[ccv_nnc_ops_fusion_io_size];
676
5.72k
  ccv_nnc_graph_visit_for(simplify->visit, exec_symbol_info, node, idx) {
677
5.72k
    // If already marked as dead, skip.
678
5.72k
    if (exec_dead[idx >> 5] & (1u << (idx & 0x1f)))
679
14
      continue;
680
5.71k
    // Run through rudimentary pattern matching for ops_seq. There are better ways to do this (immediately come to mind, Boyer-Moore). However, this is simpler to code.
681
5.71k
    // If I am running into performance issues with this, I would be very happy.
682
17.1k
    
for (i = 0; 5.71k
i < sizeof(ccv_nnc_ops_fusions) / sizeof(ccv_nnc_ops_fusion_t);
i++11.4k
)
683
11.4k
    {
684
11.4k
      const ccv_nnc_ops_fusion_t* const ops_fusion = ccv_nnc_ops_fusions + i;
685
11.4k
      // Check to see if a list of symbols are possible.
686
11.4k
      if (ops_fusion->ops_seq[0] == node->cmd.cmd &&
687
11.4k
        
_ccv_nnc_find_ops_for_fusion(ops_fusion, 0, exec_symbol_info, exec_dead, idx, fusing_exec_symbols)21
)
688
5
      {
689
5
        // Go through all the inputs and outputs, check if they exists and are mapped.
690
5
        // TODO: the mapping can be more sophisticated than what we have here.
691
5
        // Also, I need to check if some inputs / outputs cannot be mapped, then we cannot continue.
692
15
        for (j = 0; j < ccv_nnc_ops_fusion_io_size; 
j++10
)
693
10
          fused_inputs[j] = fused_outputs[j] = CCV_NNC_NO_TENSOR_SYMBOL;
694
5
        int input_size = 0, output_size = 0;
695
25
        for (j = 0; j < ops_fusion->pos_size; 
j++20
)
696
20
        {
697
20
          ccv_nnc_graph_exec_symbol_info_t* const fusing_op = exec_symbol_info + fusing_exec_symbols[ops_fusion->pos[j].op_idx];
698
20
          switch (ops_fusion->pos[j].type)
699
20
          {
700
10
            case CCV_NNC_OPS_FUSION_INPUT_INDEX:
701
10
              fused_inputs[ops_fusion->pos[j].to] = ops_fusion->pos[j].from < fusing_op->input_size ? fusing_op->inputs[ops_fusion->pos[j].from] : 
CCV_NNC_NO_TENSOR_SYMBOL0
;
702
10
              input_size = ccv_max(input_size, ops_fusion->pos[j].to + 1);
703
10
              break;
704
10
            case CCV_NNC_OPS_FUSION_OUTPUT_INDEX:
705
10
              fused_outputs[ops_fusion->pos[j].to] = ops_fusion->pos[j].from < fusing_op->output_size ? fusing_op->outputs[ops_fusion->pos[j].from] : 
CCV_NNC_NO_TENSOR_SYMBOL0
;
706
10
              output_size = ccv_max(output_size, ops_fusion->pos[j].to + 1);
707
10
              break;
708
20
          }
709
20
        }
710
5
        const ccv_nnc_cmd_param_t info = exec_symbol_info[fusing_exec_symbols[ops_fusion->ops_info_select]].cmd.info;
711
5
        // Modify the first node so it is the correct type and value. I need to look back to the actual graph to get the info.
712
5
        ccv_nnc_graph_exec_symbol_info_t* const actual_node = (ccv_nnc_graph_exec_symbol_info_t*)ccv_array_get(simplify->graph->exec_symbol_info, idx);
713
5
        actual_node->cmd.cmd = node->cmd.cmd = ops_fusion->fused_op;
714
5
        actual_node->cmd.info = node->cmd.info = info;
715
5
        if (node->input_size + node->output_size < input_size + output_size)
716
5
          actual_node->inputs = node->inputs = node->inputs ? ccrealloc(node->inputs, sizeof(int) * (input_size + output_size)) : 
ccmalloc0
(sizeof(int) * (input_size + output_size))0
;
717
5
        actual_node->outputs = node->outputs = node->inputs + input_size;
718
5
        actual_node->input_size = node->input_size = input_size;
719
5
        actual_node->output_size = node->output_size = output_size;
720
5
        memcpy(node->inputs, fused_inputs, sizeof(int) * input_size);
721
5
        memcpy(node->outputs, fused_outputs, sizeof(int) * output_size);
722
5
        // Afterwards, mark the rest as dead.
723
10
        for (j = 1; j < ops_fusion->ops_seq_size; 
j++5
)
724
5
          exec_dead[fusing_exec_symbols[j] >> 5] |= (1u << (fusing_exec_symbols[j] & 0x1f));
725
5
        break;
726
5
      }
727
11.4k
    }
728
5.71k
  } ccv_nnc_graph_visit_endfor
729
4.52k
}
730
731
static void _ccv_nnc_symbolic_graph_pruning_undead_exec(ccv_nnc_symbolic_graph_simplify_t* const simplify, const int exec_idx, uint32_t* const tensor_visited, ccv_array_t* const next)
732
5.64k
{
733
5.64k
  assert(exec_idx >= 0);
734
5.64k
  uint32_t* const exec_dead = simplify->exec_dead;
735
5.64k
  uint32_t* const tensor_dead = simplify->tensor_dead;
736
5.64k
  exec_dead[exec_idx >> 5] &= ~(1u << (exec_idx & 0x1f)); // Undead the exec.
737
5.64k
  ccv_nnc_graph_exec_symbol_info_t* const exec_symbol_info = simplify->exec_symbol_info + exec_idx;
738
5.64k
  int i;
739
5.64k
  if (exec_symbol_info->cmd.cmd == CCV_NNC_GRAPH_FORWARD ||
740
5.64k
    
exec_symbol_info->cmd.cmd == CCV_NNC_GRAPH_BACKWARD5.63k
||
741
5.64k
    
exec_symbol_info->cmd.cmd == CCV_NNC_CUSTOM_FORWARD5.63k
||
742
5.64k
    
exec_symbol_info->cmd.cmd == CCV_NNC_CUSTOM_BACKWARD5.63k
)
743
2
  {
744
2
    // All of its inputs / outputs need to be undead for these commands.
745
12
    for (i = 0; i < exec_symbol_info->input_size; 
i++10
)
746
10
    {
747
10
      const int d = exec_symbol_info->inputs[i];
748
10
      if (d >= 0 && !(tensor_visited[d >> 5] & (1u << (d & 0x1f))))
749
4
      {
750
4
        ccv_array_push(next, &d); // Push to the next round to be undead.
751
4
        tensor_visited[d >> 5] |= (1u << (d & 0x1f));
752
4
      }
753
10
    }
754
4
    for (i = 0; i < exec_symbol_info->output_size; 
i++2
)
755
2
    {
756
2
      const int d = exec_symbol_info->outputs[i];
757
2
      if (d >= 0 && !(tensor_visited[d >> 5] & (1u << (d & 0x1f))))
758
1
      {
759
1
        ccv_array_push(next, &d); // Push to the next round to be undead.
760
1
        tensor_visited[d >> 5] |= (1u << (d & 0x1f));
761
1
      }
762
2
    }
763
2
    return;
764
2
  }
765
5.63k
  // Go through the input / output, to make sure that all of them can be available.
766
5.63k
  const int input_bitmask_size = (exec_symbol_info->input_size + 63) >> 6;
767
5.63k
  const int output_bitmask_size = (exec_symbol_info->output_size + 63) >> 6;
768
5.63k
  uint64_t input_bitmasks[ccv_max(1, input_bitmask_size)];
769
11.2k
  for (i = 0; i < input_bitmask_size; 
i++5.63k
)
770
5.63k
    input_bitmasks[i] = 0;
771
5.63k
  uint64_t output_bitmasks[ccv_max(1, output_bitmask_size)];
772
11.2k
  for (i = 0; i < output_bitmask_size; 
i++5.63k
)
773
5.63k
    output_bitmasks[i] = 0;
774
18.4k
  for (i = 0; i < exec_symbol_info->input_size; 
i++12.8k
)
775
12.8k
    if (exec_symbol_info->inputs[i] >= 0)
776
12.7k
      input_bitmasks[i >> 6] |= ((uint64_t)1 << (i & 63));
777
12.9k
  for (i = 0; i < exec_symbol_info->output_size; 
i++7.30k
)
778
7.30k
    if (exec_symbol_info->outputs[i] >= 0)
779
7.29k
      output_bitmasks[i >> 6] |= ((uint64_t)1 << (i & 63));
780
5.63k
  // First, mark everything with bitmasks, and verify it works.
781
5.63k
  assert(ccv_nnc_cmd_bitmask(exec_symbol_info->cmd, exec_symbol_info->input_size, exec_symbol_info->output_size, input_bitmasks, input_bitmask_size, output_bitmasks, output_bitmask_size));
782
5.63k
  int flag;
783
5.64k
  do {
784
5.64k
    flag = 0;
785
5.64k
    // Try to eliminate one at a time. Go over output first.
786
12.9k
    for (i = 0; i < exec_symbol_info->output_size; 
i++7.31k
)
787
7.31k
    {
788
7.31k
      const int d = exec_symbol_info->outputs[i];
789
7.31k
      // If this tensor currently is marked as dead, try to see whether it works when we don't have this tensor at all.
790
7.31k
      if (d >= 0 && 
(tensor_dead[d >> 5] & (1u << (d & 0x1f)))7.29k
&&
791
7.31k
        
(output_bitmasks[i >> 6] & ((uint64_t)1 << (i & 63)))347
)
792
346
      {
793
346
        output_bitmasks[i >> 6] &= ~((uint64_t)1 << (i & 63));
794
346
        if (ccv_nnc_cmd_bitmask(exec_symbol_info->cmd, exec_symbol_info->input_size, exec_symbol_info->output_size, input_bitmasks, input_bitmask_size, output_bitmasks, output_bitmask_size))
795
1
          flag = 1;
796
345
        else // Reset the bitmask.
797
345
          output_bitmasks[i >> 6] |= ((uint64_t)1 << (i & 63));
798
346
      }
799
7.31k
    }
800
5.64k
    // For inputs, no matter if it s dead or not, we try to limit our input to the smallest number.
801
18.4k
    for (i = 0; i < exec_symbol_info->input_size; 
i++12.8k
)
802
12.8k
    {
803
12.8k
      const int d = exec_symbol_info->inputs[i];
804
12.8k
      // If this tensor currently is marked as dead, try to see whether it works when we don't have this tensor at all.
805
12.8k
      if (d >= 0 && 
(input_bitmasks[i >> 6] & ((uint64_t)1 << (i & 63)))12.7k
)
806
12.7k
      {
807
12.7k
        input_bitmasks[i >> 6] &= ~((uint64_t)1 << (i & 63));
808
12.7k
        if (ccv_nnc_cmd_bitmask(exec_symbol_info->cmd, exec_symbol_info->input_size, exec_symbol_info->output_size, input_bitmasks, input_bitmask_size, output_bitmasks, output_bitmask_size))
809
0
          flag = 1;
810
12.7k
        else // Reset the bitmask.
811
12.7k
          input_bitmasks[i >> 6] |= ((uint64_t)1 << (i & 63));
812
12.7k
      }
813
12.8k
    }
814
5.64k
  } while (flag);
815
5.63k
  // Now we know which one to keep, which one to undead.
816
18.4k
  for (i = 0; i < exec_symbol_info->input_size; 
i++12.8k
)
817
12.8k
    if (input_bitmasks[i >> 6] & ((uint64_t)1 << (i & 63)))
818
12.7k
    {
819
12.7k
      const int d = exec_symbol_info->inputs[i];
820
12.7k
      if (d >= 0 && !(tensor_visited[d >> 5] & (1u << (d & 0x1f))))
821
6.00k
      {
822
6.00k
        ccv_array_push(next, &d); // Push to the next round to be undead.
823
6.00k
        tensor_visited[d >> 5] |= (1u << (d & 0x1f));
824
6.00k
      }
825
12.7k
    } else {
826
41
      // Clean up the inputs.
827
41
      exec_symbol_info->inputs[i] = CCV_NNC_NO_TENSOR_SYMBOL;
828
41
      assert(((ccv_nnc_graph_exec_symbol_info_t*)ccv_array_get(simplify->graph->exec_symbol_info, exec_idx))->inputs == exec_symbol_info->inputs);
829
41
    }
830
12.9k
  
for (i = 0; 5.63k
i < exec_symbol_info->output_size;
i++7.30k
)
831
7.30k
    if (output_bitmasks[i >> 6] & ((uint64_t)1 << (i & 63)))
832
7.29k
    {
833
7.29k
      const int d = exec_symbol_info->outputs[i];
834
7.29k
      if (d >= 0 && !(tensor_visited[d >> 5] & (1u << (d & 0x1f))))
835
2.72k
      {
836
2.72k
        ccv_array_push(next, &d); // Push to the next round to be undead.
837
2.72k
        tensor_visited[d >> 5] |= (1u << (d & 0x1f));
838
2.72k
      }
839
7.29k
    } else {
840
16
      // Clean up the outputs.
841
16
      exec_symbol_info->outputs[i] = CCV_NNC_NO_TENSOR_SYMBOL;
842
16
      assert(((ccv_nnc_graph_exec_symbol_info_t*)ccv_array_get(simplify->graph->exec_symbol_info, exec_idx))->outputs == exec_symbol_info->outputs);
843
16
    }
844
5.63k
}
845
846
static void _ccv_nnc_symbolic_graph_pruning(ccv_nnc_symbolic_graph_simplify_t* const simplify, const ccv_nnc_tensor_symbol_t* const outputs, const int output_size)
847
2.26k
{
848
2.26k
  uint32_t* const tensor_visited = (uint32_t*)cccalloc(sizeof(uint32_t), (simplify->tensor_symbol_info_size + 31) >> 5);
849
2.26k
  ccv_array_t* const preserve[2] = {
850
2.26k
    ccv_array_new(sizeof(int), output_size, 0),
851
2.26k
    ccv_array_new(sizeof(int), 0, 0),
852
2.26k
  };
853
2.26k
  int i, j;
854
2.26k
  ccv_array_t** const r_alias_refs = (ccv_array_t**)cccalloc(sizeof(ccv_array_t*), simplify->tensor_symbol_info_size);
855
12.0k
  for (i = 0; i < simplify->tensor_symbol_info_size; 
i++9.74k
)
856
9.74k
    if (simplify->tensor_symbol_info[i].alias_ref)
857
1.12k
    {
858
1.12k
      const int alias_ref = simplify->tensor_symbol_info[i].alias_ref - 1;
859
1.12k
      assert(alias_ref < simplify->tensor_symbol_info_size);
860
1.12k
      if (!r_alias_refs[alias_ref])
861
1.10k
        r_alias_refs[alias_ref] = ccv_array_new(sizeof(int), 1, 0);
862
1.12k
      ccv_array_push(r_alias_refs[alias_ref], &i);
863
1.12k
    }
864
2.26k
  uint32_t* const exec_dead = simplify->exec_dead;
865
2.26k
  uint32_t* const tensor_dead = simplify->tensor_dead;
866
2.26k
  int* const output_execs = simplify->output_execs;
867
2.26k
  _ccv_nnc_symbolic_graph_simplify_update_output_execs(simplify);
868
2.26k
  // Mark everything visited as dead.
869
2.88k
  ccv_nnc_graph_visit_for(simplify->visit, simplify->exec_symbol_info, node, idx) {
870
2.88k
    exec_dead[idx >> 5] |= (1u << (idx & 0x1f));
871
9.03k
    for (i = 0; i < node->input_size; 
i++6.15k
)
872
6.15k
    {
873
6.15k
      const int d = node->inputs[i];
874
6.15k
      if (d >= 0)
875
6.14k
        tensor_dead[d >> 5] |= (1u << (d & 0x1f));
876
6.15k
    }
877
6.11k
    for (i = 0; i < node->output_size; 
i++3.22k
)
878
3.22k
    {
879
3.22k
      const int d = node->outputs[i];
880
3.22k
      if (d >= 0)
881
3.22k
        tensor_dead[d >> 5] |= (1u << (d & 0x1f));
882
3.22k
    }
883
2.88k
  } ccv_nnc_graph_visit_endfor
884
2.26k
  // If the tensor symbol is used by other exec that is not visited, unmark it.
885
5.14k
  for (i = 0; i < simplify->exec_symbol_info_size; 
i++2.88k
)
886
2.88k
  {
887
2.88k
    if (exec_dead[i >> 5] & (1u << (i & 0x1f)))
888
2.88k
      continue;
889
0
    const ccv_nnc_graph_exec_symbol_info_t* const node = simplify->exec_symbol_info + i;
890
0
    for (j = 0; j < node->input_size; j++)
891
0
    {
892
0
      const int d = node->inputs[j];
893
0
      // Undead it.
894
0
      if (d >= 0)
895
0
        tensor_dead[d >> 5] &= ~(1u << (d & 0x1f));
896
0
    }
897
0
    for (j = 0; j < node->output_size; j++)
898
0
    {
899
0
      const int d = node->outputs[j];
900
0
      // Undead it.
901
0
      if (d >= 0)
902
0
        tensor_dead[d >> 5] &= ~(1u << (d & 0x1f));
903
0
    }
904
0
  }
905
4.53k
  for (i = 0; i < output_size; 
i++2.27k
)
906
2.27k
    ccv_array_push(preserve[0], &outputs[i].d);
907
2.26k
  int p = 0, q = 1;
908
2.26k
  // BFS to mark execs / tensors as not dead.
909
7.08k
  while (preserve[p]->rnum > 0)
910
4.81k
  {
911
4.81k
    ccv_array_clear(preserve[q]);
912
4.81k
    // First, undead all relevant tensors.
913
15.8k
    for (i = 0; i < preserve[p]->rnum; 
i++11.0k
)
914
11.0k
    {
915
11.0k
      const int d = *(int*)ccv_array_get(preserve[p], i);
916
11.0k
      // Undead the outputs.
917
11.0k
      tensor_dead[d >> 5] &= ~(1u << (d & 0x1f));
918
11.0k
      int alias_ref = d;
919
11.0k
      if (simplify->tensor_symbol_info[d].alias_ref)
920
1.12k
      {
921
1.12k
        alias_ref = simplify->tensor_symbol_info[d].alias_ref - 1;
922
1.12k
        tensor_dead[alias_ref >> 5] &= ~(1u << (alias_ref & 0x1f));
923
1.12k
        assert(r_alias_refs[alias_ref]);
924
1.12k
      }
925
11.0k
      if (r_alias_refs[alias_ref])
926
4.50k
        
for (j = 0; 2.22k
j < r_alias_refs[alias_ref]->rnum;
j++2.27k
)
927
2.27k
        {
928
2.27k
          const int b = *(int*)ccv_array_get(r_alias_refs[alias_ref], j);
929
2.27k
          if (output_execs[b] >= 0) // Only revive if it is written alias.
930
56
            tensor_dead[b >> 5] &= ~(1u << (b & 0x1f));
931
2.27k
        }
932
11.0k
    }
933
15.8k
    
for (i = 0; 4.81k
i < preserve[p]->rnum;
i++11.0k
)
934
11.0k
    {
935
11.0k
      const int d = *(int*)ccv_array_get(preserve[p], i);
936
11.0k
      const int output_exec = output_execs[d];
937
11.0k
      // Undead the exec.
938
11.0k
      if (output_exec >= 0)
939
4.48k
        _ccv_nnc_symbolic_graph_pruning_undead_exec(simplify, output_exec, tensor_visited, preserve[q]);
940
11.0k
      int alias_ref = d;
941
11.0k
      if (simplify->tensor_symbol_info[d].alias_ref)
942
1.12k
      {
943
1.12k
        alias_ref = simplify->tensor_symbol_info[d].alias_ref - 1;
944
1.12k
        const int output_exec = output_execs[alias_ref];
945
1.12k
        if (output_exec >= 0)
946
1.09k
          _ccv_nnc_symbolic_graph_pruning_undead_exec(simplify, output_exec, tensor_visited, preserve[q]);
947
1.12k
      }
948
11.0k
      if (r_alias_refs[alias_ref])
949
4.50k
        
for (j = 0; 2.22k
j < r_alias_refs[alias_ref]->rnum;
j++2.27k
)
950
2.27k
        {
951
2.27k
          const int b = *(int*)ccv_array_get(r_alias_refs[alias_ref], j);
952
2.27k
          const int output_exec = output_execs[b];
953
2.27k
          if (output_exec >= 0)
954
56
            _ccv_nnc_symbolic_graph_pruning_undead_exec(simplify, output_exec, tensor_visited, preserve[q]);
955
2.27k
        }
956
11.0k
    }
957
4.81k
    CCV_SWAP(p, q, i);
958
4.81k
  }
959
2.26k
  ccfree(tensor_visited);
960
2.26k
  ccv_array_free(preserve[0]);
961
2.26k
  ccv_array_free(preserve[1]);
962
12.0k
  for (i = 0; i < simplify->tensor_symbol_info_size; 
i++9.74k
)
963
9.74k
    if (r_alias_refs[i])
964
1.10k
      ccv_array_free(r_alias_refs[i]);
965
2.26k
  ccfree(r_alias_refs);
966
2.26k
}
967
968
void ccv_nnc_symbolic_graph_simplify(ccv_nnc_symbolic_graph_t* const graph, const int* const passes, const int pass_size, const ccv_nnc_tensor_symbol_t* const binds, const int bind_size, const ccv_nnc_tensor_symbol_t* const outputs, const int output_size, const ccv_nnc_graph_exec_symbol_t* const sources, const int source_size, const ccv_nnc_graph_exec_symbol_t* const destinations, const int destination_size)
969
4.52k
{
970
4.52k
  ccv_nnc_symbolic_graph_simplify_t* const simplify = _ccv_nnc_symbolic_graph_simplify_new(graph, sources, source_size, destinations, destination_size);
971
4.52k
  int i;
972
15.8k
  for (i = 0; i < pass_size; 
i++11.3k
)
973
11.3k
    switch (passes[i])
974
11.3k
    {
975
2.26k
      case CCV_NNC_SIMPLIFY_COMMON_SUBEXPRESSION_ELIMINATION:
976
2.26k
        _ccv_nnc_symbolic_graph_common_subexpression_elimination(simplify, outputs, output_size);
977
2.26k
        break;
978
2.26k
      case CCV_NNC_SIMPLIFY_DATA_TRANSFER_OPT:
979
2.26k
        _ccv_nnc_symbolic_graph_data_transfer_opt(simplify, binds, bind_size, outputs, output_size);
980
2.26k
        break;
981
2.26k
      case CCV_NNC_SIMPLIFY_GRAPH_PRUNING:
982
2.26k
        _ccv_nnc_symbolic_graph_pruning(simplify, outputs, output_size);
983
2.26k
        break;
984
4.52k
      case CCV_NNC_SIMPLIFY_OPS_FUSION:
985
4.52k
        _ccv_nnc_symbolic_graph_ops_fusion(simplify, outputs, output_size);
986
4.52k
        break;
987
11.3k
    }
988
4.52k
  _ccv_nnc_symbolic_graph_simplify_apply(simplify);
989
4.52k
  _ccv_nnc_symbolic_graph_simplify_free(simplify);
990
4.52k
}