Coverage Report

Created: 2024-08-19 11:27

/home/liu/actions-runner/_work/ccv/ccv/lib/nnc/ccv_nnc_symbolic_graph_backward.c
Line
Count
Source (jump to first uncovered line)
1
#include "ccv_nnc.h"
2
#include "ccv_nnc_easy.h"
3
#include "ccv_nnc_internal.h"
4
#include "ccv_internal.h"
5
#include "_ccv_nnc_symbolic_graph.h"
6
7
// MARK - Level-3.5 API
8
9
typedef struct {
10
  int f_wrt; // Check if both f_symbols and wrt_symbols flow through this node.
11
  ccv_array_t* outgoings; // backward traverse nodes.
12
  uint64_t* input_bitmasks;
13
  int input_bitmask_size;
14
  uint64_t* output_bitmasks;
15
  int output_bitmask_size;
16
} ccv_nnc_graph_backward_info_t;
17
18
typedef struct {
19
  int input_size;
20
  int* inputs;
21
  int output;
22
  ccv_array_t* outgoings;
23
  float value;
24
  ccv_nnc_graph_exec_symbol_t symbol;
25
} ccv_nnc_sum_or_set_graph_exec_symbol_t;
26
27
typedef struct {
28
  int input_size;
29
  int output_size;
30
  int* inputs;
31
  int* outputs;
32
  ccv_array_t* outgoings;
33
  ccv_nnc_cmd_t cmd;
34
  ccv_nnc_graph_exec_symbol_t symbol;
35
} ccv_nnc_autograd_graph_exec_symbol_t;
36
37
typedef struct {
38
  int d; // The pointer to the forward level object.
39
  int alias_ref; // The alias ref to itself (autograd_tensor_symbols array).
40
  int flags; // Flags for this symbol.
41
  ccv_nnc_tensor_symbol_t symbol;
42
} ccv_nnc_autograd_tensor_symbol_t;
43
44
typedef struct {
45
  int d; // The tensor symbol ref.
46
  int x; // The exec symbol ref.
47
  ccv_array_t* exec_registry; // Additional exec symbol refs, similar to x, only useful for aliasing.
48
  ccv_array_t* alias_registry; // int point to all the alias (if this is not an alias). The alias is the object in autograd_tensor_symbols, you need another level of indirection to get the actual forward level alias.
49
} ccv_nnc_tensor_ref_t;
50
51
typedef struct {
52
  int c; // The start non-accumulated version.
53
  ccv_array_t* ref_version; // tensor ref point to the reverse tensor symbol.
54
} ccv_nnc_autograd_tensor_version_t;
55
56
typedef struct {
57
  int d;
58
  int alias_ref;
59
} ccv_nnc_sum_variable_t;
60
61
// This method tries to figure out if a set of aliases can cover the whole tensor dim.
62
// This is not a precise implementation though. The requirement is to answer this question
63
// with a given memory constraint, therefore, only allow up to 65536 different tensor locations.
64
// If you have more than that, it will assume that it doesn't have fully assigned aliases,
65
// and will return 0.
66
67
// Return 1 if inserted successfully.
68
static inline int _ccv_nnc_try_mix(int* const md, const int ins, const int c)
69
43
{
70
43
  if (!c)
71
25
  {
72
25
    md[0] = ins;
73
25
    return 1;
74
25
  }
75
18
  int ll = 0, uu = c - 1;
76
18
  int mm;
77
20
  do {
78
20
    mm = ll + ((uu - ll) >> 1);
79
20
    if (ins == md[mm])
80
16
      return 0;
81
4
    else if (ins < md[mm])
82
2
      uu = mm - 1;
83
2
    else if (ins > md[mm])
84
2
      ll = mm + 1;
85
20
  } while (
ll <= uu4
);
86
2
  if (ll < c)
87
2
    memmove(md + ll + 1, md + ll, sizeof(int) * (c - ll));
88
2
  md[ll] = ins;
89
2
  return 1;
90
18
}
91
92
static inline int _ccv_nnc_mix_idx(const int* const md, const int ins, const int c)
93
30
{
94
30
  if (c <= 1)
95
22
    return 0;
96
8
  int ll = 0, uu = c - 1;
97
8
  int mm;
98
14
  do {
99
14
    mm = ll + ((uu - ll) >> 1);
100
14
    if (ins == md[mm])
101
8
      return mm;
102
6
    else if (ins < md[mm])
103
0
      uu = mm - 1;
104
6
    else if (ins > md[mm])
105
6
      ll = mm + 1;
106
14
  } while (
ll <= uu6
);
107
0
  assert(0 && "Shouldn't reach here");
108
0
  return -1;
109
0
}
110
111
static inline void _ccv_nnc_try_set_pix_0(const int* const ofs, const int* const dim, const int* const tensor_dim, int* const* const scmd, const int* const cube_dim, const int* const cube_step, uint32_t* const cube, int offset)
112
6
{
113
6
  const int s = (ofs[0] == 0) ? 
03
:
_ccv_nnc_mix_idx(scmd[0], ofs[0], cube_dim[0]) + 13
;
114
6
  const int d = ((ofs[0] + dim[0] == tensor_dim[0]) ? 
cube_dim[0]3
:
_ccv_nnc_mix_idx(scmd[0], ofs[0] + 3
ccv_max3
(1, dim[0]), cube_dim[0])) + 1;
115
6
  assert(s >= 0 && d > s);
116
6
  int i;
117
12
  for (i = s; i < d; 
i++6
)
118
    // Fill this pix. I can make this faster by loop through full ones (divided by 8), but too lazy.
119
6
    cube[(offset + i) >> 5] |= (1u << ((offset + i) & 0x1f));
120
6
}
121
122
static inline void _ccv_nnc_try_set_pix_1(const int* const ofs, const int* const dim, const int* const tensor_dim, int* const* const scmd, const int* const cube_dim, const int* const cube_step, uint32_t* const cube, int offset)
123
16
{
124
16
  const int s0 = (ofs[0] == 0) ? 
014
:
_ccv_nnc_mix_idx(scmd[0], ofs[0], cube_dim[0]) + 12
;
125
16
  const int d0 = ((ofs[0] + dim[0] == tensor_dim[0]) ? 
cube_dim[0]14
:
_ccv_nnc_mix_idx(scmd[0], ofs[0] + 2
ccv_max2
(1, dim[0]), cube_dim[0])) + 1;
126
16
  assert(s0 >= 0 && d0 > s0);
127
16
  const int s1 = (ofs[1] == 0) ? 
010
:
_ccv_nnc_mix_idx(scmd[1], ofs[1], cube_dim[1]) + 16
;
128
16
  const int d1 = ((ofs[1] + dim[1] == tensor_dim[1]) ? 
cube_dim[1]10
:
_ccv_nnc_mix_idx(scmd[1], ofs[1] + 6
ccv_max6
(1, dim[1]), cube_dim[1])) + 1;
129
16
  assert(s1 >= 0 && d1 > s1);
130
16
  int i, j;
131
16
  const int step1 = cube_step[1];
132
16
  if (step1 == d0 - s0)
133
12
  {
134
    // Faster one, we can simply loop through.
135
26
    for (i = s1 * step1; i < d1 * step1; 
i++14
)
136
14
      cube[(offset + i) >> 5] |= (1u << ((offset + i) & 0x1f));
137
12
  } else {
138
4
    offset += s1 * step1;
139
    // There are gaps, slow one.
140
8
    for (i = s1; i < d1; 
i++, offset += step14
)
141
8
      
for (j = s0; 4
j < d0;
j++4
)
142
4
        cube[(offset + j) >> 5] |= (1u << ((offset + j) & 0x1f));
143
4
  }
144
16
}
145
146
static inline void _ccv_nnc_try_set_pix(const int* const ofs, const int* const dim, const int* const tensor_dim, int* const* const scmd, const int* const cube_dim, const int* const cube_step, uint32_t* const cube, int offset, const int dim_idx)
147
30
{
148
30
  switch (dim_idx)
149
30
  {
150
16
    case 1:
151
16
      _ccv_nnc_try_set_pix_1(ofs, dim, tensor_dim, scmd, cube_dim, cube_step, cube, offset);
152
16
      return;
153
6
    case 0:
154
6
      _ccv_nnc_try_set_pix_0(ofs, dim, tensor_dim, scmd, cube_dim, cube_step, cube, offset);
155
6
      return;
156
30
  }
157
8
  int i;
158
8
  const int s = (ofs[dim_idx] == 0) ? 
06
:
_ccv_nnc_mix_idx(scmd[dim_idx], ofs[dim_idx], cube_dim[dim_idx]) + 12
;
159
8
  const int d = ((ofs[dim_idx] + dim[dim_idx] == tensor_dim[dim_idx]) ? 
cube_dim[dim_idx]2
:
_ccv_nnc_mix_idx(scmd[dim_idx], ofs[dim_idx] + 6
ccv_max6
(1, dim[dim_idx]), cube_dim[dim_idx])) + 1;
160
8
  assert(s >= 0 && d > s);
161
16
  
for (i = s; 8
i < d;
i++8
)
162
8
    _ccv_nnc_try_set_pix(ofs, dim, tensor_dim, scmd, cube_dim, cube_step, cube, offset + i * cube_step[dim_idx], dim_idx - 1);
163
8
}
164
165
static int _ccv_nnc_tensor_ref_fully_assigned_with_aliases(const ccv_nnc_tensor_ref_t* const tensor_ref, const ccv_array_t* const autograd_tensor_symbols, const ccv_nnc_tensor_symbol_info_t* const tensor_symbol_info)
166
2.11k
{
167
  // Only work with tensor_ref of aliases.
168
2.11k
  assert(tensor_ref->alias_registry);
169
2.11k
  const ccv_nnc_autograd_tensor_symbol_t* autograd = (ccv_nnc_autograd_tensor_symbol_t*)ccv_array_get(autograd_tensor_symbols, tensor_ref->d);
170
2.11k
  assert(tensor_symbol_info[autograd->d].alias_ref == 0);
171
2.11k
  const int* tensor_dim = tensor_symbol_info[autograd->d].info.dim;
172
2.11k
  const int tensor_count = ccv_nnc_dimension_count(tensor_dim);
173
2.11k
  int i, j;
174
2.15k
  for (i = 0; i < tensor_ref->alias_registry->rnum; 
i++39
)
175
2.12k
  {
176
2.12k
    const int d = *(int*)ccv_array_get(tensor_ref->alias_registry, i);
177
2.12k
    assert(d < autograd_tensor_symbols->rnum);
178
2.12k
    const ccv_nnc_autograd_tensor_symbol_t* autograd = (ccv_nnc_autograd_tensor_symbol_t*)ccv_array_get(autograd_tensor_symbols, d);
179
2.12k
    assert(tensor_symbol_info[autograd->d].alias_ref);
180
2.12k
    const int* stride = tensor_symbol_info[autograd->d].stride;
181
    // If this is just reshaped (i.e., dimension is the same, and inc covers the whole). We have fully assigned.
182
2.12k
    if (ccv_nnc_is_tensor_stride_packed(stride, tensor_symbol_info[autograd->d].info.dim) && 
ccv_nnc_dimension_count(tensor_symbol_info[autograd->d].info.dim) == tensor_count2.10k
)
183
2.09k
      return 1;
184
    // Otherwise if inc doesn't match original dim, it is not covered.
185
39
    if (!ccv_nnc_is_tensor_stride_packed(stride, tensor_dim))
186
0
      return 0;
187
39
  }
188
  /* We need a solid cube (potentially hyper dimensional) to compute if there are overlaps.
189
   * To make this cube as small as possible, we need to map the actual tensor dimension
190
   * (therefore, we don't actually allocate the whole tensor to compute overlaps) to a smaller
191
   * cube given the ofs and dim size of its aliases.
192
   *
193
   * The following code generated the dimension mapping (using scratch space) with binary search + insertion
194
   * and then we fill the cube with a given tensor alias's dimensional information (ofs, dim).
195
   * Afterwards, we simply need to check if the cube is totally filled up to know if this tensor
196
   * is fully assigned with its aliases (if that is the case, we can skip zeroing for this tensor).
197
   *
198
   * There are several restrictions though to make this faster: 1). I cannot handle any cube that all side
199
   * lengths combined larger than 1023 (scm only have 1024 scratch space). 2). I cannot handle any cube
200
   * that the total volume is larger than 2048 * 8 (I only allocate 2K on stack for this).
201
   * */
202
26
  int scm[1024]; // Having 1024 int scratch space for mapping dimensions. (Or sparse coordinate mapping).
203
26
  int cube_dim[CCV_NNC_MAX_DIM_ALLOC] = {}; // Mapping dimension size.
204
26
  int cube_size = 1;
205
26
  int* scmptr = scm;
206
50
  for (i = 0; i < CCV_NNC_MAX_DIM_ALLOC && tensor_dim[i]; 
i++24
)
207
40
  {
208
40
    int head = 0, tail = 0; // Note that we touched both the head and tail (otherwise this dimension is not fully covered).
209
40
    int len = 0;
210
105
    for (j = 0; j < tensor_ref->alias_registry->rnum; 
j++65
)
211
65
    {
212
65
      const int d = *(int*)ccv_array_get(tensor_ref->alias_registry, j);
213
65
      assert(d < autograd_tensor_symbols->rnum);
214
65
      const ccv_nnc_autograd_tensor_symbol_t* autograd = (ccv_nnc_autograd_tensor_symbol_t*)ccv_array_get(autograd_tensor_symbols, d);
215
65
      assert(tensor_symbol_info[autograd->d].alias_ref);
216
65
      const int* ofs = tensor_symbol_info[autograd->d].ofs;
217
65
      const int* dim = tensor_symbol_info[autograd->d].info.dim;
218
65
      head = head || 
(ofs[i] == 0)44
;
219
65
      tail = tail || 
(ofs[i] + 47
ccv_max47
(1, dim[i]) == tensor_dim[i]);
220
65
      if (ofs[i] != 0)
221
14
        len += _ccv_nnc_try_mix(scmptr, ofs[i], len);
222
65
      if (scmptr - scm + len >= 1024) // Cannot handle that much, abort.
223
0
        return 0;
224
65
      if (ofs[i] + ccv_max(1, dim[i]) < tensor_dim[i])
225
29
        len += _ccv_nnc_try_mix(scmptr, ofs[i] + ccv_max(1, dim[i]), len);
226
65
      if (scmptr - scm + len >= 1024) // Cannot handle that much, abort.
227
0
        return 0;
228
65
    }
229
40
    if (!head || 
!tail39
)
230
16
      return 0;
231
24
    cube_size *= (len + 1);
232
24
    cube_dim[i] = len;
233
24
    scmptr += len; // Moving to next level.
234
24
  }
235
  // The cube map is too large, cannot do the computation, assume it is not fully assigned.
236
10
  if (cube_size > 2048 * 8)
237
0
    return 0;
238
  // binary map to see if it fills up.
239
10
  uint32_t cube[(cube_size + 31) >> 5];
240
10
  memset(cube, 0, sizeof(uint32_t) * ((cube_size + 31) >> 5));
241
10
  int* scmd[CCV_NNC_MAX_DIM_ALLOC] = {}; // Sparse coordinate map at dimension x.
242
10
  int cube_step[CCV_NNC_MAX_DIM_ALLOC] = {};
243
32
  for (i = 0; i < CCV_NNC_MAX_DIM_ALLOC && tensor_dim[i]; 
i++22
)
244
22
  {
245
22
    cube_step[i] = (i > 0) ? 
cube_step[i - 1] * (cube_dim[i - 1] + 1)12
:
110
;
246
22
    scmd[i] = (i > 0) ? 
scmd[i - 1] + cube_dim[i - 1]12
:
scm10
;
247
22
  }
248
10
  const int max_dim = i;
249
32
  for (i = 0; i < tensor_ref->alias_registry->rnum; 
i++22
)
250
22
  {
251
22
    const int d = *(int*)ccv_array_get(tensor_ref->alias_registry, i);
252
22
    assert(d < autograd_tensor_symbols->rnum);
253
22
    const ccv_nnc_autograd_tensor_symbol_t* autograd = (ccv_nnc_autograd_tensor_symbol_t*)ccv_array_get(autograd_tensor_symbols, d);
254
22
    assert(tensor_symbol_info[autograd->d].alias_ref);
255
22
    const int* ofs = tensor_symbol_info[autograd->d].ofs;
256
22
    const int* dim = tensor_symbol_info[autograd->d].info.dim;
257
22
    _ccv_nnc_try_set_pix(ofs, dim, tensor_dim, scmd, cube_dim, cube_step, cube, 0, max_dim - 1);
258
22
  }
259
  // Compare to see now if the binary map filled up. If it filled up, we know it is fully assigned.
260
10
  for (i = 0; i < (cube_size >> 5); 
i++0
)
261
0
    if (cube[i] < 0xffffffff)
262
0
      return 0;
263
10
  if ((cube_size & 0x1f) > 0)
264
10
  {
265
    // Fetch the rest.
266
10
    uint32_t r = 0;
267
32
    for (i = 0; i < (cube_size & 0x1f); 
i++22
)
268
22
      r |= (1u << i);
269
10
    assert(cube[((cube_size + 31) >> 5) - 1] <= r);
270
10
    if (cube[((cube_size + 31) >> 5) - 1] < r)
271
0
      return 0;
272
10
  }
273
10
  return 1;
274
10
}
275
276
static int _ccv_nnc_tensor_ref_version_find_init(const ccv_nnc_autograd_tensor_version_t* const tensor_ver)
277
5
{
278
5
  int i;
279
10
  for (i = 0; i < tensor_ver->ref_version->rnum; 
i++5
)
280
7
    if (((ccv_nnc_tensor_ref_t*)ccv_array_get(tensor_ver->ref_version, i))->x < 0)
281
2
      return i;
282
3
  return -1;
283
5
}
284
285
static void _ccv_nnc_graph_sum_autograd_tensor_versions(const int idx, const int d, const int exec_symbol_info_size, const ccv_nnc_tensor_symbol_info_t* const tensor_symbol_info, ccv_nnc_autograd_tensor_version_t* const tensor_ver, ccv_nnc_autograd_graph_exec_symbol_t* const autograd_execs, ccv_array_t* const autograd_tensor_symbols, ccv_array_t* const sum_or_set_execs)
286
4.27k
{
287
4.27k
  int i, j;
288
4.27k
  assert(tensor_ver->c < tensor_ver->ref_version->rnum);
289
4.27k
  const int input_size = tensor_ver->ref_version->rnum - tensor_ver->c;
290
4.27k
  int* inputs = (int*)ccmalloc(sizeof(int) * input_size);
291
12.8k
  for (i = tensor_ver->c; i < tensor_ver->ref_version->rnum; 
i++8.57k
)
292
8.57k
    inputs[i] = ((ccv_nnc_tensor_ref_t*)ccv_array_get(tensor_ver->ref_version, i))->d;
293
4.27k
  const ccv_nnc_autograd_tensor_symbol_t tensor_sym = {
294
4.27k
    .d = d
295
4.27k
  };
296
4.27k
  ccv_array_push(autograd_tensor_symbols, &tensor_sym);
297
4.27k
  ccv_nnc_sum_or_set_graph_exec_symbol_t sum_exec = {
298
4.27k
    .input_size = input_size,
299
4.27k
    .inputs = inputs,
300
4.27k
    .output = autograd_tensor_symbols->rnum - 1
301
4.27k
  };
302
4.27k
  if (idx >= 0)
303
4.24k
  {
304
4.24k
    sum_exec.outgoings = ccv_array_new(sizeof(int), 1, 0);
305
4.24k
    ccv_array_push(sum_exec.outgoings, &idx);
306
4.24k
  }
307
4.27k
  ccv_array_push(sum_or_set_execs, &sum_exec);
308
4.27k
  const int outgoing = exec_symbol_info_size + sum_or_set_execs->rnum - 1;
309
12.8k
  for (i = tensor_ver->c; i < tensor_ver->ref_version->rnum; 
i++8.57k
)
310
8.57k
  {
311
8.57k
    const ccv_nnc_tensor_ref_t* tensor_ref = (ccv_nnc_tensor_ref_t*)ccv_array_get(tensor_ver->ref_version, i);
312
8.57k
    const int x = tensor_ref->x;
313
8.57k
    if (x < 0) /* This is initialization tensor, it has to be occurred before the execution anyway. */
314
1
    {
315
      // No alias.
316
1
      assert(!tensor_ref->alias_registry);
317
      // No associated additional execs.
318
1
      assert(!tensor_ref->exec_registry);
319
1
      continue;
320
1
    }
321
8.57k
    if (x < exec_symbol_info_size)
322
8.57k
    {
323
8.57k
      ccv_nnc_autograd_graph_exec_symbol_t* back_exec = autograd_execs + x;
324
8.57k
      if (!back_exec->outgoings)
325
39
        back_exec->outgoings = ccv_array_new(sizeof(int), 1, 0);
326
8.57k
      ccv_array_replace_unique_int(back_exec->outgoings, idx, outgoing);
327
8.57k
    } else {
328
      // This tensor_ref is generated by the sum operation.
329
0
      ccv_nnc_sum_or_set_graph_exec_symbol_t* sum_or_set = (ccv_nnc_sum_or_set_graph_exec_symbol_t*)ccv_array_get(sum_or_set_execs, x - exec_symbol_info_size);
330
0
      ccv_array_replace_unique_int(sum_or_set->outgoings, idx, outgoing);
331
0
    }
332
    // If this tensor have associated alias, we need to init it to zeros when it is allocated (we only need to set a flag here)
333
    // it is handled at compilation phase.
334
8.57k
    if (tensor_ref->alias_registry &&
335
      // Loop over to see if this tensor is fully occupied to avoid extra zero step.
336
8.57k
      
!_ccv_nnc_tensor_ref_fully_assigned_with_aliases(tensor_ref, autograd_tensor_symbols, tensor_symbol_info)22
)
337
8
    {
338
8
      ccv_nnc_autograd_tensor_symbol_t* tensor_sym = (ccv_nnc_autograd_tensor_symbol_t*)ccv_array_get(autograd_tensor_symbols, tensor_ref->d);
339
      // By having alias_registry, what this symbol represents must not by an alias.
340
8
      assert(tensor_sym->alias_ref == 0);
341
8
      tensor_sym->flags = CCV_NNC_TENSOR_SYMBOL_INIT_ZEROS;
342
8
    }
343
8.57k
    if (tensor_ref->exec_registry)
344
4
      
for (j = 0; 2
j < tensor_ref->exec_registry->rnum;
j++2
)
345
2
      {
346
2
        const int x = *(int*)ccv_array_get(tensor_ref->exec_registry, j);
347
2
        assert(x >= 0);
348
        // The exec_registry can only be generated by alias registry, therefore, it cannot reference to a sum operation.
349
2
        assert(x < exec_symbol_info_size);
350
2
        ccv_nnc_autograd_graph_exec_symbol_t* back_exec = autograd_execs + x;
351
2
        if (!back_exec->outgoings)
352
1
          back_exec->outgoings = ccv_array_new(sizeof(int), 1, 0);
353
2
        ccv_array_replace_unique_int(back_exec->outgoings, idx, outgoing);
354
2
      }
355
8.57k
  }
356
4.27k
  const ccv_nnc_tensor_ref_t tensor_ref = {
357
4.27k
    .d = autograd_tensor_symbols->rnum - 1,
358
4.27k
    .x = outgoing
359
4.27k
  };
360
4.27k
  ccv_array_push(tensor_ver->ref_version, &tensor_ref);
361
  /* Move the c pointer up to the latest summed result. */
362
4.27k
  tensor_ver->c = tensor_ver->ref_version->rnum - 1;
363
4.27k
}
364
365
static int _ccv_nnc_tensor_ref_version_involve_alias(const ccv_nnc_tensor_ref_t* const tensor_ref, const ccv_array_t* const autograd_tensor_symbols, const ccv_nnc_tensor_symbol_info_t* const tensor_symbol_info, const ccv_nnc_tensor_symbol_info_t* const alias)
366
69
{
367
69
  assert(alias->alias_ref > 0);
368
  // No alias_registry, must conflict (owns the whole band).
369
69
  if (!tensor_ref->alias_registry)
370
25
    return 1;
371
44
  int i;
372
63
  for (i = 0; i < tensor_ref->alias_registry->rnum; 
i++19
)
373
54
  {
374
54
    const int d = *(int*)ccv_array_get(tensor_ref->alias_registry, i);
375
54
    assert(d < autograd_tensor_symbols->rnum);
376
54
    ccv_nnc_autograd_tensor_symbol_t* autograd = (ccv_nnc_autograd_tensor_symbol_t*)ccv_array_get(autograd_tensor_symbols, d);
377
54
    if (ccv_nnc_over_tensor_symbol_aliases(tensor_symbol_info + autograd->d, alias))
378
35
      return 1;
379
54
  }
380
  // All aliases referenced by this ref_version doesn't overlap with the provided one, thus, there is no conflict at all.
381
9
  return 0;
382
44
}
383
384
static int _ccv_nnc_tensor_ref_version_find_alias(const ccv_nnc_tensor_ref_t* const tensor_ref, const ccv_array_t* const autograd_tensor_symbols, const ccv_nnc_tensor_symbol_info_t* const tensor_symbol_info, const ccv_nnc_tensor_symbol_info_t* const alias)
385
30
{
386
30
  assert(alias->alias_ref > 0);
387
  // No alias_registry, thus, cannot find the exact matched alias.
388
30
  if (!tensor_ref->alias_registry)
389
11
    return -1;
390
19
  int i;
391
34
  for (i = 0; i < tensor_ref->alias_registry->rnum; 
i++15
)
392
26
  {
393
26
    const int d = *(int*)ccv_array_get(tensor_ref->alias_registry, i);
394
26
    assert(d < autograd_tensor_symbols->rnum);
395
26
    ccv_nnc_autograd_tensor_symbol_t* autograd = (ccv_nnc_autograd_tensor_symbol_t*)ccv_array_get(autograd_tensor_symbols, d);
396
    // This must reference to an alias.
397
26
    assert(tensor_symbol_info[autograd->d].alias_ref);
398
26
    const int* stride = tensor_symbol_info[autograd->d].stride;
399
26
    const int* ofs = tensor_symbol_info[autograd->d].ofs;
400
26
    const int* dim = tensor_symbol_info[autograd->d].info.dim;
401
    // If everything matches, this is the required alias.
402
26
    if (memcmp(stride, alias->stride, sizeof(int) * CCV_NNC_MAX_DIM_ALLOC) == 0 &&
403
26
      memcmp(ofs, alias->ofs, sizeof(int) * CCV_NNC_MAX_DIM_ALLOC) == 0 &&
404
26
      
memcmp(dim, alias->info.dim, sizeof(int) * 11
CCV_NNC_MAX_DIM_ALLOC11
) == 0)
405
11
      return d;
406
26
  }
407
8
  return -1;
408
19
}
409
410
static int _ccv_nnc_tensor_ref_version_has_this_alias_exclusively(const ccv_nnc_tensor_ref_t* const tensor_ref, const ccv_array_t* const autograd_tensor_symbols, const ccv_nnc_tensor_symbol_info_t* const tensor_symbol_info, const ccv_nnc_tensor_symbol_info_t* const alias)
411
4
{
412
4
  assert(alias->alias_ref > 0);
413
  // No alias_registry, thus, cannot find the exact matched alias.
414
4
  if (!tensor_ref->alias_registry)
415
0
    return 0;
416
4
  int i;
417
8
  for (i = 0; i < tensor_ref->alias_registry->rnum; 
i++4
)
418
5
  {
419
5
    const int d = *(int*)ccv_array_get(tensor_ref->alias_registry, i);
420
5
    assert(d < autograd_tensor_symbols->rnum);
421
5
    ccv_nnc_autograd_tensor_symbol_t* autograd = (ccv_nnc_autograd_tensor_symbol_t*)ccv_array_get(autograd_tensor_symbols, d);
422
    // This must reference to an alias.
423
5
    assert(tensor_symbol_info[autograd->d].alias_ref);
424
5
    const int* stride = tensor_symbol_info[autograd->d].stride;
425
5
    const int* ofs = tensor_symbol_info[autograd->d].ofs;
426
5
    const int* dim = tensor_symbol_info[autograd->d].info.dim;
427
5
    if (memcmp(stride, alias->stride, sizeof(int) * CCV_NNC_MAX_DIM_ALLOC) != 0 ||
428
5
      memcmp(ofs, alias->ofs, sizeof(int) * CCV_NNC_MAX_DIM_ALLOC) != 0 ||
429
5
      
memcmp(dim, alias->info.dim, sizeof(int) * 4
CCV_NNC_MAX_DIM_ALLOC4
) != 0)
430
1
      return 0;
431
5
  }
432
  // If everything matches for every alias in registry, we can use any of the alias directly.
433
3
  return 1;
434
4
}
435
436
static int _ccv_nnc_graph_sum_autograd_tensor_versions_alias(const int idx, const int d, const ccv_nnc_tensor_symbol_info_t* const tensor_symbol_info, const int exec_symbol_info_size, const ccv_nnc_tensor_symbol_info_t* const alias, ccv_nnc_autograd_tensor_version_t* const tensor_ver, ccv_nnc_autograd_graph_exec_symbol_t* const autograd_execs, ccv_array_t* const autograd_tensor_symbols, ccv_array_t* const sum_or_set_execs)
437
21
{
438
21
  assert(tensor_ver->c < tensor_ver->ref_version->rnum);
439
21
  int i, j = 0;
440
21
  struct {
441
21
    int k;
442
21
    int i;
443
21
  } kd[tensor_ver->ref_version->rnum - tensor_ver->c];
444
51
  for (i = tensor_ver->c; i < tensor_ver->ref_version->rnum; 
i++30
)
445
30
  {
446
30
    ccv_nnc_tensor_ref_t* tensor_ref = (ccv_nnc_tensor_ref_t*)ccv_array_get(tensor_ver->ref_version, i);
447
30
    const int k = _ccv_nnc_tensor_ref_version_find_alias(tensor_ref, autograd_tensor_symbols, tensor_symbol_info, alias);
448
30
    if (k >= 0)
449
11
      kd[j++] = (typeof(kd[0])){
450
11
        .k = k, .i = i
451
11
      };
452
19
    else if (_ccv_nnc_tensor_ref_version_involve_alias(tensor_ref, autograd_tensor_symbols, tensor_symbol_info, alias))
453
19
      kd[j++] = (typeof(kd[0])) {
454
19
        .k = -1, .i = i // It has dependency to the original tensor (non-alias) now, label this with highest bit.
455
19
      };
456
30
  }
457
  // Can only find one. This is the easy case, we can simply return that symbol (or its alias).
458
21
  if (j == 1)
459
15
  {
460
15
    if (kd[0].k >= 0)
461
4
      return kd[0].k; // Only can find one alias, that is the one.
462
    // Otherwise, need to create a new alias.
463
11
    ccv_nnc_tensor_ref_t* tensor_ref = (ccv_nnc_tensor_ref_t*)ccv_array_get(tensor_ver->ref_version, kd[0].i);
464
11
    ccv_nnc_autograd_tensor_symbol_t* ref = (ccv_nnc_autograd_tensor_symbol_t*)ccv_array_get(autograd_tensor_symbols, tensor_ref->d);
465
    // Since we create new alias, we need to set the referenced one to be allocated with 0s.
466
11
    if (ref->alias_ref) // If this is an alias, it has to be zero initialized.
467
0
    {
468
0
      ref = (ccv_nnc_autograd_tensor_symbol_t*)ccv_array_get(autograd_tensor_symbols, ref->alias_ref - 1);
469
0
      assert(ref->alias_ref == 0); // This is original.
470
0
      ref->flags = CCV_NNC_TENSOR_SYMBOL_INIT_ZEROS;
471
11
    } else if (tensor_ref->alias_registry && // Otherwise, to see if this symbol is fully occupied.
472
        // Loop over to see if this tensor is fully occupied to avoid extra zero step.
473
11
        
!_ccv_nnc_tensor_ref_fully_assigned_with_aliases(tensor_ref, autograd_tensor_symbols, tensor_symbol_info)3
) {
474
1
      ref->flags = CCV_NNC_TENSOR_SYMBOL_INIT_ZEROS;
475
1
    }
476
11
    ccv_nnc_autograd_tensor_symbol_t tensor_sym = {
477
11
      .d = d,
478
11
      .alias_ref = tensor_ref->d + 1
479
11
    };
480
11
    ccv_array_push(autograd_tensor_symbols, &tensor_sym);
481
11
    const int ad = autograd_tensor_symbols->rnum - 1;
482
11
    if (tensor_ref->alias_registry) // Only push this when it has an alias registry (otherwise it already conflict with everyone).
483
3
      ccv_array_push(tensor_ref->alias_registry, &ad);
484
11
    if (tensor_ref->x >= exec_symbol_info_size && 
idx >= 02
)
485
2
    {
486
2
      ccv_nnc_sum_or_set_graph_exec_symbol_t* const sum_or_set_exec = (ccv_nnc_sum_or_set_graph_exec_symbol_t*)ccv_array_get(sum_or_set_execs, tensor_ref->x - exec_symbol_info_size);
487
      // This may be summed, thus, we need to create a connection between this and the sum.
488
2
      if (!sum_or_set_exec->outgoings)
489
0
        sum_or_set_exec->outgoings = ccv_array_new(sizeof(int), 1, 0);
490
2
      ccv_array_push(sum_or_set_exec->outgoings, &idx);
491
2
    }
492
    // The newly inserted tensor symbol.
493
11
    return ad;
494
11
  }
495
  // Otherwise, we need to create the sum operation out of these.
496
6
  const int input_size = j;
497
6
  int has_this_alias_exclusively = 1;
498
6
  int* inputs = input_size > 0 ? (int*)ccmalloc(sizeof(int) * input_size) : 
00
;
499
21
  for (i = 0; i < input_size; 
i++15
)
500
15
  {
501
15
    ccv_nnc_tensor_ref_t* tensor_ref = (ccv_nnc_tensor_ref_t*)ccv_array_get(tensor_ver->ref_version, kd[i].i);
502
    // Can take a fast path if every ref involved has the same alias, our sum operation can be faster (using alias directly).
503
15
    if (has_this_alias_exclusively && 
kd[i].k >= 08
&&
_ccv_nnc_tensor_ref_version_has_this_alias_exclusively(tensor_ref, autograd_tensor_symbols, tensor_symbol_info, alias)4
)
504
3
      inputs[i] = *(int*)ccv_array_get(tensor_ref->alias_registry, 0); // Assigning the alias.
505
12
    else {
506
12
      if (has_this_alias_exclusively)
507
5
      {
508
5
        has_this_alias_exclusively = 0;
509
5
        for (j = 0; j < i; 
j++0
)
510
0
          inputs[j] = ((ccv_nnc_tensor_ref_t*)ccv_array_get(tensor_ver->ref_version, kd[j].i))->d;
511
5
      }
512
12
      inputs[i] = tensor_ref->d;
513
12
    }
514
15
  }
515
6
  ccv_nnc_autograd_tensor_symbol_t tensor_sym = {
516
6
    .d = alias->alias_ref - 1
517
6
  };
518
6
  ccv_array_push(autograd_tensor_symbols, &tensor_sym);
519
6
  const int tensor_ref_d = autograd_tensor_symbols->rnum - 1;
520
6
  tensor_sym.d = d;
521
6
  tensor_sym.alias_ref = tensor_ref_d + 1;
522
6
  ccv_array_push(autograd_tensor_symbols, &tensor_sym);
523
6
  const int ad = autograd_tensor_symbols->rnum - 1;
524
6
  ccv_nnc_sum_or_set_graph_exec_symbol_t sum_exec = {
525
6
    .input_size = input_size,
526
6
    .inputs = inputs,
527
6
    .output = has_this_alias_exclusively ? 
ad1
:
tensor_ref_d5
/* If has this alias exclusively, the output should be alias as well. Otherwise the output is the real tensor. */
528
6
  };
529
6
  if (idx >= 0)
530
6
  {
531
6
    sum_exec.outgoings = ccv_array_new(sizeof(int), 1, 0);
532
6
    ccv_array_push(sum_exec.outgoings, &idx);
533
6
  }
534
6
  ccv_array_push(sum_or_set_execs, &sum_exec);
535
6
  const int outgoing = exec_symbol_info_size + sum_or_set_execs->rnum - 1;
536
6
  int no_alias_registry = 0;
537
21
  for (i = 0; i < input_size; 
i++15
)
538
15
  {
539
15
    ccv_nnc_tensor_ref_t* tensor_ref = (ccv_nnc_tensor_ref_t*)ccv_array_get(tensor_ver->ref_version, kd[i].i);
540
15
    if (!has_this_alias_exclusively)
541
12
    {
542
      // If the sum operation is not operating on one alias. I need to zero this tensor out when it is first
543
      // allocated (see discussions around the flags I use).
544
12
      ccv_nnc_autograd_tensor_symbol_t* tensor_sym = (ccv_nnc_autograd_tensor_symbol_t*)ccv_array_get(autograd_tensor_symbols, tensor_ref->d);
545
12
      if (tensor_sym->alias_ref)
546
0
      {
547
        // Find the original tensor_sym and set its flags (I prefer to set flags on its original).
548
0
        ccv_nnc_autograd_tensor_symbol_t* ref = (ccv_nnc_autograd_tensor_symbol_t*)ccv_array_get(autograd_tensor_symbols, tensor_sym->alias_ref - 1);
549
0
        assert(ref->alias_ref == 0); // This is original.
550
0
        ref->flags = CCV_NNC_TENSOR_SYMBOL_INIT_ZEROS;
551
12
      } else if (tensor_ref->alias_registry && // Otherwise, to see if this symbol is fully occupied.
552
          // Loop over to see if this tensor is fully occupied to avoid extra zero step.
553
12
          
!_ccv_nnc_tensor_ref_fully_assigned_with_aliases(tensor_ref, autograd_tensor_symbols, tensor_symbol_info)9
) {
554
6
        tensor_sym->flags = CCV_NNC_TENSOR_SYMBOL_INIT_ZEROS;
555
6
      }
556
12
    }
557
    // Check to see if any of these tensors doesn't have alias.
558
15
    no_alias_registry |= (!tensor_ref->alias_registry);
559
15
    const int x = tensor_ref->x;
560
15
    assert(x >= 0); /* Otherwise, this is initialization tensor, which is impossible to be summed up by. */
561
15
    if (x < exec_symbol_info_size)
562
15
    {
563
15
      ccv_nnc_autograd_graph_exec_symbol_t* back_exec = autograd_execs + x;
564
15
      if (!back_exec->outgoings)
565
0
        back_exec->outgoings = ccv_array_new(sizeof(int), 1, 0);
566
15
      ccv_array_push(back_exec->outgoings, &outgoing);
567
15
    } else {
568
0
      ccv_nnc_sum_or_set_graph_exec_symbol_t* sum_or_set = (ccv_nnc_sum_or_set_graph_exec_symbol_t*)ccv_array_get(sum_or_set_execs, x - exec_symbol_info_size);
569
0
      ccv_array_push(sum_or_set->outgoings, &outgoing);
570
0
    }
571
15
    if (tensor_ref->exec_registry)
572
6
      
for (j = 0; 3
j < tensor_ref->exec_registry->rnum;
j++3
)
573
3
      {
574
3
        const int x = *(int*)ccv_array_get(tensor_ref->exec_registry, j);
575
3
        assert(x >= 0); /* Otherwise, this is initialization tensor, which is impossible to be summed up by. */
576
3
        assert(x < exec_symbol_info_size); // exec_registry is only used by alias_registry, it simply cannot reference to a sum operation.
577
3
        ccv_nnc_autograd_graph_exec_symbol_t* back_exec = autograd_execs + x;
578
3
        if (!back_exec->outgoings)
579
0
          back_exec->outgoings = ccv_array_new(sizeof(int), 1, 0);
580
3
        ccv_array_push(back_exec->outgoings, &outgoing);
581
3
      }
582
15
  }
583
6
  const ccv_nnc_tensor_ref_t tensor_ref = {
584
6
    .d = tensor_ref_d,
585
6
    .x = outgoing,
586
6
    .exec_registry = 0, // I don't need to take execution dependencies because this tensor is generated by sum, therefore, we already take that dependency.
587
6
    .alias_registry = !no_alias_registry || 
has_this_alias_exclusively2
?
ccv_array_new(sizeof(int), 1, 0)4
:
02
588
6
  };
589
  // If there is no alias registry, then we take the whole tensor ref as one.
590
6
  if (!no_alias_registry || 
has_this_alias_exclusively2
)
591
4
  {
592
    // If this tensor ref contains multiple different types of alias, have to add them together (otherwise
593
    // the computation for if there is an empty slot in this tensor ref is not correct without all the
594
    // occupancy availability information).
595
4
    if (!has_this_alias_exclusively)
596
10
      
for (i = 0; 3
i < input_size;
i++7
)
597
7
      {
598
7
        ccv_nnc_tensor_ref_t* ref = (ccv_nnc_tensor_ref_t*)ccv_array_get(tensor_ver->ref_version, kd[i].i);
599
7
        assert(ref->alias_registry);
600
        // It may get duplicates. But whatever, won't matter the computation.
601
19
        
for (j = 0; 7
j < ref->alias_registry->rnum;
j++12
)
602
12
          ccv_array_push(tensor_ref.alias_registry, ccv_array_get(ref->alias_registry, j));
603
7
      }
604
4
    ccv_array_push(tensor_ref.alias_registry, &ad);
605
4
  }
606
6
  assert(input_size <= tensor_ver->ref_version->rnum - tensor_ver->c);
607
6
  ccv_nnc_tensor_ref_t x;
608
21
  for (i = 0; i < input_size; 
i++15
)
609
    // If the current one (i + tensor_ver->c) is smaller than the one referenced to, exchange.
610
15
    if (kd[i].i > i + tensor_ver->c)
611
0
      CCV_SWAP(*(ccv_nnc_tensor_ref_t*)ccv_array_get(tensor_ver->ref_version, i + tensor_ver->c), *(ccv_nnc_tensor_ref_t*)ccv_array_get(tensor_ver->ref_version, kd[i].i), x);
612
6
  ccv_array_push(tensor_ver->ref_version, &tensor_ref);
613
  // We've consumed input_size tensor refs, now move c up to the pointer of non-consumed tensors.
614
6
  tensor_ver->c += input_size;
615
6
  return ad;
616
6
}
617
618
typedef struct ccv_nnc_symbolic_graph_backward_prep_s {
619
  int exec_symbol_info_size; // Number of graph exec symbols before adding any new symbols related to automatic differentiation.
620
  int tensor_symbol_info_size; // Number of tensor symbols before adding anything new.
621
  int sub_prep_size;
622
  ccv_nnc_graph_exec_symbol_info_t* exec_symbol_info;
623
  ccv_nnc_tensor_symbol_info_t* tensor_symbol_info;
624
  ccv_nnc_graph_backward_info_t* backward_info; // Corresponding to forward graph exec symbol info, it is exactly in reverse.
625
  ccv_nnc_graph_visit_t* forward_visit; // The visitor structure (top sorted index) when doing traversal.
626
  ccv_nnc_graph_visit_t* backward_visit; // The visitor structure (top sorted index) when doing reverse traversal.
627
  ccv_nnc_autograd_graph_exec_symbol_t* autograd_execs; // The graph exec symbols we need for automatic differentiation. This is a 1:1 mapping for forward graph exec symbols, however, unlike backward_info, its outgoings may be more complex (may contain outgoing flows to sum nodes).
628
  ccv_nnc_autograd_tensor_version_t* autograd_tensor_versions; // Corresponding to forward tensor symbols, each may contain multiple versions (due to multi-write).
629
  ccv_array_t* autograd_tensor_symbols; // The tensor symbols we need for automatic differentiation (it may not be 1:1 mapping).
630
  ccv_array_t* sum_or_set_execs; // The sum nodes, because in reverse mode, a tensor could have multiple versions, we need to sum them up before use.
631
  struct ccv_nnc_symbolic_graph_backward_prep_s* sub_preps; // The preps of its sub-graphs.
632
  // Pointers not managed by this struct
633
  ccv_nnc_symbolic_graph_t* graph;
634
} ccv_nnc_symbolic_graph_backward_prep_t;
635
636
static ccv_nnc_symbolic_graph_backward_prep_t _ccv_nnc_symbolic_graph_backward_prep(const ccv_nnc_symbolic_graph_t* const graph, const ccv_nnc_graph_exec_symbol_t* const sources, const int source_size, const ccv_nnc_graph_exec_symbol_t* const destinations, const int destination_size)
637
6.78k
{
638
6.78k
  const int exec_symbol_info_size = graph->exec_symbol_info->rnum;
639
6.78k
  assert(exec_symbol_info_size > 0);
640
6.78k
  const int tensor_symbol_info_size = graph->tensor_symbol_info->rnum;
641
6.78k
  assert(tensor_symbol_info_size > 0);
642
6.78k
  ccv_nnc_graph_exec_symbol_info_t* exec_symbol_info = (ccv_nnc_graph_exec_symbol_info_t*)ccmalloc(sizeof(ccv_nnc_graph_exec_symbol_info_t) * exec_symbol_info_size);
643
6.78k
  ccv_nnc_tensor_symbol_info_t* tensor_symbol_info = (ccv_nnc_tensor_symbol_info_t*)ccmalloc(sizeof(ccv_nnc_tensor_symbol_info_t) * tensor_symbol_info_size);
644
13.5k
  ccv_nnc_graph_visit_t* forward_visit = 
ccv_nnc_graph_visit_new6.78k
(graph, (ccv_nnc_graph_exec_symbol_info_t*)ccv_array_get(graph->exec_symbol_info, 0), exec_symbol_info_size, sources, source_size, destinations, destination_size, 0);
645
0
  ccv_nnc_symbolic_graph_symbol_infer(graph, forward_visit, sources, source_size, destinations, destination_size, 0, 0, tensor_symbol_info, exec_symbol_info);
646
13.5k
  int i;
647
  // Now, for each one of these, find a reverse graph.
648
13.5k
  ccv_nnc_graph_backward_info_t* backward_info = (ccv_nnc_graph_backward_info_t*)
cccalloc6.78k
(exec_symbol_info_size, sizeof(ccv_nnc_graph_backward_info_t));
649
19.1k
  ccv_nnc_graph_visit_for(forward_visit, exec_symbol_info, node, idx) {
650
19.1k
    assert(ccv_nnc_cmd_is_forward(node->cmd) || node->cmd.cmd == CCV_NNC_NOOP);
651
19.1k
    if (node->outgoings)
652
24.7k
      
for (i = 0; 12.3k
i < node->outgoings->rnum;
i++12.4k
)
653
12.4k
      {
654
12.4k
        int d = *(int*)ccv_array_get(node->outgoings, i);
655
12.4k
        if (!backward_info[d].outgoings)
656
12.3k
          backward_info[d].outgoings = ccv_array_new(sizeof(int32_t), 1, 0);
657
12.4k
        ccv_array_push(backward_info[d].outgoings, &idx);
658
12.4k
      }
659
19.1k
  } ccv_nnc_graph_visit_endfor
660
  // Also mark only the output bits that we use.
661
25.9k
  
for (i = 0; 6.78k
i < exec_symbol_info_size;
i++19.1k
)
662
19.1k
  {
663
19.1k
    backward_info[i].input_bitmask_size = ((exec_symbol_info[i].output_size * 2 + exec_symbol_info[i].input_size + 63) >> 6);
664
19.1k
    backward_info[i].output_bitmask_size = ((exec_symbol_info[i].input_size + 63) >> 6);
665
    // Allocate input / output bitmasks
666
19.1k
    if (backward_info[i].input_bitmask_size + backward_info[i].output_bitmask_size > 0)
667
19.1k
    {
668
19.1k
      backward_info[i].input_bitmasks = (uint64_t*)cccalloc(backward_info[i].input_bitmask_size + backward_info[i].output_bitmask_size, sizeof(uint64_t));
669
19.1k
      if (backward_info[i].output_bitmask_size)
670
19.1k
        backward_info[i].output_bitmasks = backward_info[i].input_bitmasks + backward_info[i].input_bitmask_size;
671
19.1k
    }
672
19.1k
  }
673
6.78k
  ccv_nnc_graph_visit_t* backward_visit = ccv_nnc_graph_visit_new(graph, backward_info, exec_symbol_info_size, destinations, destination_size, sources, source_size, 0);
674
6.78k
  const int sub_prep_size = graph->sub_graphs ? 
graph->sub_graphs->rnum2
:
06.77k
;
675
6.78k
  ccv_nnc_symbolic_graph_backward_prep_t* sub_preps = sub_prep_size > 0 ? 
(ccv_nnc_symbolic_graph_backward_prep_t*)2
cccalloc2
(sub_prep_size, sizeof(ccv_nnc_symbolic_graph_backward_prep_t)) :
06.77k
;
676
6.78k
  for (i = 0; i < sub_prep_size; 
i++4
)
677
4
  {
678
4
    const ccv_nnc_symbolic_graph_t* const sub_graph = *(ccv_nnc_symbolic_graph_t**)ccv_array_get(graph->sub_graphs, i);
679
4
    sub_preps[i] = _ccv_nnc_symbolic_graph_backward_prep(sub_graph, ccv_nnc_symbolic_graph_sources(sub_graph), ccv_nnc_symbolic_graph_source_size(sub_graph), ccv_nnc_symbolic_graph_destinations(sub_graph), ccv_nnc_symbolic_graph_destination_size(sub_graph));
680
4
  }
681
6.78k
  return (ccv_nnc_symbolic_graph_backward_prep_t){
682
6.78k
    .exec_symbol_info_size = exec_symbol_info_size,
683
6.78k
    .tensor_symbol_info_size = tensor_symbol_info_size,
684
6.78k
    .sub_prep_size = sub_prep_size,
685
6.78k
    .exec_symbol_info = exec_symbol_info,
686
6.78k
    .tensor_symbol_info = tensor_symbol_info,
687
6.78k
    .backward_info = backward_info,
688
6.78k
    .forward_visit = forward_visit,
689
6.78k
    .backward_visit = backward_visit,
690
6.78k
    .sub_preps = sub_preps,
691
6.78k
    .graph = (ccv_nnc_symbolic_graph_t*)graph,
692
6.78k
  };
693
6.78k
}
694
695
static void _ccv_nnc_symbolic_graph_backward_exec_io(const ccv_nnc_graph_exec_symbol_info_t* const node, int** const back_input_map, int** const back_output_map, int* const back_input_size, int* const back_output_size)
696
19.1k
{
697
19.1k
  int i;
698
19.1k
  if (node->flags & CCV_NNC_GRAPH_EXEC_CASE_OF)
699
7
  {
700
7
    *back_input_map = node->outputs;
701
7
    *back_input_size = node->output_size;
702
14
    for (i = 0; i < node->case_of.argument.offset; 
i++7
)
703
7
      (*back_output_map)[i] = node->inputs[i];
704
7
    const int argument_offset = node->case_of.argument.offset;
705
7
    const int argument_size = node->case_of.argument.size;
706
    // Skip the argument range.
707
7
    for (i = argument_offset + argument_size; i < node->input_size; 
i++0
)
708
0
      (*back_output_map)[i - argument_size] = node->inputs[i];
709
7
    *back_output_size = node->input_size - node->case_of.argument.size;
710
19.1k
  } else { // if (node->flags & CCV_NNC_GRAPH_EXEC_P_WHILE) {
711
19.1k
    *back_input_map = node->outputs;
712
19.1k
    *back_input_size = node->output_size;
713
19.1k
    *back_output_map = node->inputs;
714
19.1k
    *back_output_size = node->input_size;
715
19.1k
  }
716
19.1k
}
717
718
static void _ccv_nnc_symbolic_graph_backward_prep_sub_f_wrt_symbols(const ccv_nnc_graph_exec_symbol_info_t* const forw_exec, const ccv_nnc_symbolic_graph_t* const sub_graph, const int graph_ref, const ccv_nnc_tensor_symbol_info_t* const tensor_symbol_info, const uint64_t* const input_bitmasks, const uint64_t* const output_bitmasks, ccv_array_t* const sub_f_symbols, ccv_array_t* const sub_wrt_symbols)
719
8
{
720
8
  int i, j;
721
8
  ccv_array_clear(sub_wrt_symbols);
722
8
  int forw_outputs[ccv_max(1, forw_exec->output_size)];
723
8
  int forw_inputs[ccv_max(1, forw_exec->input_size)];
724
8
  int* back_input_map = forw_outputs;
725
8
  int* back_output_map = forw_inputs;
726
8
  int back_input_size, back_output_size;
727
8
  _ccv_nnc_symbolic_graph_backward_exec_io(forw_exec, &back_input_map, &back_output_map, &back_input_size, &back_output_size);
728
18
  for (i = 0; i < back_output_size; 
i++10
)
729
10
    if (output_bitmasks[i >> 6] & ((uint64_t)1 << (i & 63)))
730
8
    {
731
8
      const int d = back_output_map[i];
732
8
      const ccv_array_t* const s_refs = tensor_symbol_info[d].s_ref;
733
8
      const int s_ref = s_refs && s_refs->rnum > graph_ref ? 
*(int*)7
ccv_array_get7
(s_refs, graph_ref) - 1 :
-11
;
734
8
      if (s_ref >= 0)
735
4
      {
736
4
        ccv_nnc_tensor_symbol_t sub_wrt_symbol = {
737
4
          .d = s_ref,
738
4
          .graph = sub_graph,
739
4
        };
740
4
        ccv_array_push(sub_wrt_symbols, &sub_wrt_symbol);
741
4
      } else
742
4
        ccv_array_push(sub_wrt_symbols, &NO_TENSOR_SYMBOL);
743
8
    }
744
8
  ccv_array_clear(sub_f_symbols);
745
16
  for (i = 0; i < back_input_size; 
i++8
)
746
8
    if (input_bitmasks[i >> 6] & ((uint64_t)1 << (i & 63)))
747
8
    {
748
8
      const int d = back_input_map[i];
749
8
      ccv_nnc_tensor_symbol_t sub_f_symbol = {
750
8
        .d = *(int*)ccv_array_get(tensor_symbol_info[d].s_ref, graph_ref) - 1,
751
8
        .graph = sub_graph,
752
8
      };
753
8
      ccv_array_push(sub_f_symbols, &sub_f_symbol);
754
8
    }
755
  // Go through all its assignments (parameterized loop), making them either wrt or f.
756
  // The reason is these must flow through the graph, otherwise we cannot form a full
757
  // enclosed loop. Also because they are the additional f / wrt symbols, there is
758
  // no case that we cannot find their corresponding gradients in the backward sub graphs
759
  // (these gradients have to be parameterized to form an enclosed loop as well).
760
30
  for (i = 0; i < sub_graph->tensor_symbol_info->rnum; 
i++22
)
761
22
  {
762
22
    const ccv_nnc_tensor_symbol_info_t* const tensor_symbol_info = (ccv_nnc_tensor_symbol_info_t*)ccv_array_get(sub_graph->tensor_symbol_info, i);
763
22
    if (tensor_symbol_info->assign_ref)
764
2
    {
765
2
      const int assign_ref = tensor_symbol_info->assign_ref - 1;
766
      // i is the wrt, assign_ref is the f.
767
2
      int flag = 0;
768
4
      for (j = 0; !flag && j < sub_wrt_symbols->rnum; 
j++2
)
769
2
        flag = (((ccv_nnc_tensor_symbol_t*)ccv_array_get(sub_wrt_symbols, j))->d == i);
770
2
      if (!flag)
771
2
      {
772
2
        ccv_nnc_tensor_symbol_t sub_wrt_symbol = {
773
2
          .d = i,
774
2
          .graph = sub_graph,
775
2
        };
776
2
        ccv_array_push(sub_wrt_symbols, &sub_wrt_symbol);
777
2
      }
778
2
      flag = 0;
779
4
      for (j = 0; !flag && 
j < sub_f_symbols->rnum2
;
j++2
)
780
2
        flag = (((ccv_nnc_tensor_symbol_t*)ccv_array_get(sub_f_symbols, j))->d == assign_ref);
781
2
      if (!flag)
782
0
      {
783
0
        ccv_nnc_tensor_symbol_t sub_f_symbol = {
784
0
          .d = assign_ref,
785
0
          .graph = sub_graph,
786
0
        };
787
0
        ccv_array_push(sub_f_symbols, &sub_f_symbol);
788
0
      }
789
2
    }
790
22
  }
791
8
}
792
793
// Check whether for a given f_symbol, we can compute wrt_symbols at all, if we can, tag the minimal io and ops (some ops can be replaced with noop) required to do so.
794
static int _ccv_nnc_symbolic_graph_backward_prep_prune_ops(const ccv_nnc_symbolic_graph_backward_prep_t* const backward_prep, const ccv_nnc_tensor_symbol_t* const f_symbols, const int f_symbol_size, const ccv_nnc_tensor_symbol_t* const wrt_symbols, const int wrt_symbol_size, const ccv_nnc_graph_exec_symbol_t* const sources, const int source_size, const ccv_nnc_graph_exec_symbol_t* const destinations, const int destination_size)
795
6.78k
{
796
6.78k
  int i, j, p;
797
6.78k
  const int tensor_symbol_info_size = backward_prep->tensor_symbol_info_size;
798
6.78k
  const ccv_nnc_graph_exec_symbol_info_t* const exec_symbol_info = backward_prep->exec_symbol_info;
799
6.78k
  const ccv_nnc_tensor_symbol_info_t* const tensor_symbol_info =backward_prep->tensor_symbol_info;
800
6.78k
  const ccv_nnc_graph_visit_t* const forward_visit = backward_prep->forward_visit;
801
  // Now, for each one of these, find a reverse graph.
802
6.78k
  ccv_nnc_graph_backward_info_t* const backward_info = backward_prep->backward_info;
803
6.78k
  const ccv_nnc_graph_visit_t* const backward_visit = backward_prep->backward_visit;
804
  // Find the f_symbols, and tag its flows.
805
19.1k
  ccv_nnc_graph_visit_for(backward_visit, backward_info, node, idx) {
806
19.1k
    int f = node->f_wrt & 0x1;
807
26.2k
    for (i = 0; i < exec_symbol_info[idx].output_size && 
!f19.4k
;
i++7.04k
)
808
7.04k
    {
809
7.04k
      int d = exec_symbol_info[idx].outputs[i];
810
7.04k
      if (d < 0)
811
206
        continue;
812
6.84k
      
while (6.83k
tensor_symbol_info[d].alias_ref)
813
3
        d = tensor_symbol_info[d].alias_ref - 1;
814
13.6k
      for (j = 0; j < f_symbol_size && 
!f6.87k
;
j++6.85k
)
815
6.85k
        if (d == f_symbols[j].d)
816
6.79k
          f = 1;
817
6.83k
    }
818
19.1k
    if (f)
819
19.1k
    {
820
19.1k
      node->f_wrt |= f;
821
19.1k
      if (node->outgoings)
822
24.6k
        
for (i = 0; 12.3k
i < node->outgoings->rnum;
i++12.3k
)
823
12.3k
        {
824
12.3k
          int d = *(int*)ccv_array_get(node->outgoings, i);
825
12.3k
          backward_info[d].f_wrt |= f;
826
12.3k
        }
827
19.1k
    }
828
19.1k
  } ccv_nnc_graph_visit_endfor
829
  // Find the wrt_symbols, and tag its flows.
830
19.1k
  ccv_nnc_graph_visit_for(forward_visit, exec_symbol_info, node, idx) {
831
19.1k
    int wrt = backward_info[idx].f_wrt & 0x2;
832
30.1k
    for (i = 0; i < node->input_size && 
!wrt27.5k
;
i++10.9k
)
833
10.9k
    {
834
10.9k
      int d = node->inputs[i];
835
10.9k
      if (d < 0)
836
1
        continue;
837
10.9k
      
while (10.9k
tensor_symbol_info[d].alias_ref)
838
7
        d = tensor_symbol_info[d].alias_ref - 1;
839
24.6k
      for (j = 0; j < wrt_symbol_size && 
!wrt13.8k
;
j++13.7k
)
840
13.7k
      {
841
13.7k
        int wrt_d = wrt_symbols[j].d;
842
13.7k
        if (wrt_d < 0)
843
29
          continue;
844
        // Find the root of this tensor alias.
845
13.7k
        if (tensor_symbol_info[wrt_d].alias_ref)
846
2
          wrt_d = tensor_symbol_info[wrt_d].alias_ref - 1;
847
13.7k
        if (d == wrt_d)
848
6.85k
          wrt = 0x2;
849
13.7k
      }
850
10.9k
    }
851
19.1k
    if (wrt)
852
19.1k
    {
853
19.1k
      backward_info[idx].f_wrt |= wrt;
854
19.1k
      if (node->outgoings)
855
24.7k
        
for (i = 0; 12.3k
i < node->outgoings->rnum;
i++12.3k
)
856
12.3k
        {
857
12.3k
          int d = *(int*)ccv_array_get(node->outgoings, i);
858
12.3k
          backward_info[d].f_wrt |= wrt;
859
12.3k
        }
860
19.1k
    }
861
19.1k
  } ccv_nnc_graph_visit_endfor
862
6.78k
  enum {
863
6.78k
    WRT_SYMBOL_USE = 1,
864
6.78k
    F_SYMBOL_USE = 2
865
6.78k
  };
866
6.78k
  uint8_t* used_grad = (uint8_t*)cccalloc(tensor_symbol_info_size, sizeof(uint8_t));
867
  // First, all f_symbols and wrt_symbols are used.
868
13.5k
  for (i = 0; i < f_symbol_size; 
i++6.79k
)
869
6.79k
    if (f_symbols[i].d >= 0)
870
6.79k
      used_grad[tensor_symbol_info[f_symbols[i].d].alias_ref ? 
tensor_symbol_info[f_symbols[i].d].alias_ref - 10
: f_symbols[i].d] |= F_SYMBOL_USE;
871
16.3k
  for (i = 0; i < wrt_symbol_size; 
i++9.53k
)
872
9.53k
    if (wrt_symbols[i].d >= 0)
873
9.52k
      used_grad[tensor_symbol_info[wrt_symbols[i].d].alias_ref ? 
tensor_symbol_info[wrt_symbols[i].d].alias_ref - 11
:
wrt_symbols[i].d9.52k
] |= WRT_SYMBOL_USE;
874
  // Do optimistic assumption, and then compute used_grad
875
19.1k
  ccv_nnc_graph_visit_for(forward_visit, exec_symbol_info, _, idx) {
876
19.1k
    ccv_nnc_graph_backward_info_t* node = backward_info + idx;
877
    /* Only interested in the ones on the f / wrt flow */
878
19.1k
    if ((node->f_wrt & 0x3) == 0x3)
879
19.1k
    {
880
19.1k
      const ccv_nnc_graph_exec_symbol_info_t* forw_exec = exec_symbol_info + idx;
881
19.1k
      ccv_nnc_cmd_t cmd = forw_exec->cmd;
882
19.1k
      if (cmd.cmd != CCV_NNC_NOOP)
883
19.1k
        cmd.cmd += 1; /* Backward command is the one after forward command. */
884
19.1k
      assert(ccv_nnc_cmd_is_backward(cmd) || cmd.cmd == CCV_NNC_NOOP);
885
92.9k
      
for (i = 0; 19.1k
i < forw_exec->output_size * 2 + forw_exec->input_size;
i++73.8k
)
886
73.8k
        if (!(i >= forw_exec->output_size && 
i < forw_exec->output_size + forw_exec->input_size54.3k
&&
887
73.8k
          
forw_exec->inputs[i - forw_exec->output_size] < 034.7k
) && // If the input is empty, no need.
888
73.8k
          
!(73.8k
i >= forw_exec->output_size + forw_exec->input_size73.8k
&&
i < forw_exec->output_size * 2 + forw_exec->input_size19.5k
&&
889
73.8k
          
forw_exec->outputs[i - forw_exec->output_size - forw_exec->input_size] < 019.5k
) && // If the output is empty, no need.
890
73.8k
          
!(73.6k
i < forw_exec->output_size73.6k
&&
forw_exec->outputs[i] < 019.5k
)) // If the output is empty for gradient, no need.
891
73.4k
          node->input_bitmasks[i >> 6] |= ((uint64_t)1 << (i & 63));
892
53.8k
      for (i = 0; i < forw_exec->input_size; 
i++34.7k
)
893
34.7k
        if (!(forw_exec->inputs[i] < 0)) // If the inputs is empty, no need.
894
34.7k
          node->output_bitmasks[i >> 6] |= ((uint64_t)1 << (i & 63));
895
19.1k
      int maybe_noop = 1;
896
23.1k
      for (i = 0; i < forw_exec->input_size; 
i++4.05k
)
897
        /* See if it is used as wrt, if not, no need to run this node at all. */
898
23.1k
        if (forw_exec->inputs[i] >= 0 && 
used_grad[23.1k
tensor_symbol_info[forw_exec->inputs[i]].alias_ref23.1k
?
tensor_symbol_info[forw_exec->inputs[i]].alias_ref - 12.09k
:
forw_exec->inputs[i]21.0k
] & WRT_SYMBOL_USE)
899
19.1k
        {
900
19.1k
          maybe_noop = 0;
901
19.1k
          break;
902
19.1k
        }
903
19.1k
      if (maybe_noop)
904
0
      {
905
0
        for (i = 0; i < node->input_bitmask_size; i++)
906
0
          node->input_bitmasks[i] = 0;
907
0
        for (i = 0; i < node->output_bitmask_size; i++)
908
0
          node->output_bitmasks[i] = 0;
909
0
        node->output_bitmask_size = 0;
910
19.1k
      } else if (cmd.cmd == CCV_NNC_GRAPH_FORWARD || cmd.cmd == CCV_NNC_GRAPH_BACKWARD) {
911
        // Clear out all potential outputs if we think it is not a wrt symbols.
912
6
        for (i = 0; i < forw_exec->input_size; 
i++4
)
913
4
          if ((node->output_bitmasks[i >> 6] & ((uint64_t)1 << (i & 63))) &&
914
4
            !(used_grad[tensor_symbol_info[forw_exec->inputs[i]].alias_ref ? 
tensor_symbol_info[forw_exec->inputs[i]].alias_ref - 10
: forw_exec->inputs[i]] & WRT_SYMBOL_USE))
915
1
            node->output_bitmasks[i >> 6] &= ~((uint64_t)1 << (i & 63));
916
        // But for now, assuming we need all input gradients.
917
        // Clear out all inputs / outputs from forward op.
918
8
        for (i = forw_exec->output_size; i < forw_exec->output_size * 2 + forw_exec->input_size; 
i++6
)
919
6
          node->input_bitmasks[i >> 6] &= ~((uint64_t)1 << (i & 63));
920
19.1k
      } else if (ccv_nnc_cmd_bitmask(cmd, forw_exec->output_size * 2 + forw_exec->input_size, forw_exec->input_size, node->input_bitmasks, node->input_bitmask_size, node->output_bitmasks, node->output_bitmask_size)) {
921
16.7k
        int flag; /* Only continue if it changed */
922
32.1k
        do {
923
32.1k
          flag = 0;
924
          /* Check if the output first */
925
93.3k
          for (i = 0; i < forw_exec->input_size; 
i++61.2k
)
926
            /* Only try to eliminate the one that is not used. */
927
61.2k
            if ((node->output_bitmasks[i >> 6] & ((uint64_t)1 << (i & 63))) &&
928
61.2k
              
!(used_grad[52.6k
tensor_symbol_info[forw_exec->inputs[i]].alias_ref52.6k
?
tensor_symbol_info[forw_exec->inputs[i]].alias_ref - 13.25k
:
forw_exec->inputs[i]49.3k
] & WRT_SYMBOL_USE))
929
8.62k
            {
930
8.62k
              node->output_bitmasks[i >> 6] &= ~((uint64_t)1 << (i & 63));
931
              /* If it worked, mark it as flagged. */
932
8.62k
              if (ccv_nnc_cmd_bitmask(cmd, forw_exec->output_size * 2 + forw_exec->input_size, forw_exec->input_size, node->input_bitmasks, node->input_bitmask_size, node->output_bitmasks, node->output_bitmask_size))
933
8.58k
                flag = 1;
934
46
              else /* Refit this with the bit back again. */
935
46
                node->output_bitmasks[i >> 6] |= ((uint64_t)1 << (i & 63));
936
8.62k
            }
937
159k
          for (i = 0; i < forw_exec->output_size * 2 + forw_exec->input_size; 
i++127k
)
938
127k
            if ((node->input_bitmasks[i >> 6] & ((uint64_t)1 << (i & 63))) &&
939
127k
              
(96.4k
i >= forw_exec->output_size96.4k
||
940
96.4k
               
!(used_grad[32.3k
tensor_symbol_info[forw_exec->outputs[i]].alias_ref32.3k
?
tensor_symbol_info[forw_exec->outputs[i]].alias_ref - 143
:
forw_exec->outputs[i]32.3k
] & F_SYMBOL_USE)))
941
82.0k
            { /* Try to eliminate one of the input. */
942
82.0k
              node->input_bitmasks[i >> 6] &= ~((uint64_t)1 << (i & 63));
943
              /* If it worked, mark it as flagged. */
944
82.0k
              if (ccv_nnc_cmd_bitmask(cmd, forw_exec->output_size * 2 + forw_exec->input_size, forw_exec->input_size, node->input_bitmasks, node->input_bitmask_size, node->output_bitmasks, node->output_bitmask_size))
945
28.8k
                flag = 1;
946
53.2k
              else /* Refit this with the bit back again. */
947
53.2k
                node->input_bitmasks[i >> 6] |= ((uint64_t)1 << (i & 63));
948
82.0k
            }
949
32.1k
        } while (flag);
950
16.7k
      }
951
38.6k
      for (i = 0; i < forw_exec->output_size; 
i++19.5k
)
952
19.5k
        if (node->input_bitmasks[i >> 6] & ((uint64_t)1 << (i & 63)))
953
          /* Mark it is used as wrt. */
954
19.1k
          used_grad[tensor_symbol_info[forw_exec->outputs[i]].alias_ref ? 
tensor_symbol_info[forw_exec->outputs[i]].alias_ref - 121
:
forw_exec->outputs[i]19.0k
] |= WRT_SYMBOL_USE;
955
53.8k
      for (i = 0; i < forw_exec->input_size; 
i++34.7k
)
956
          /* Mark it is used as f. */
957
34.7k
        if (node->output_bitmasks[i >> 6] & ((uint64_t)1 << (i & 63)))
958
26.1k
          used_grad[tensor_symbol_info[forw_exec->inputs[i]].alias_ref ? 
tensor_symbol_info[forw_exec->inputs[i]].alias_ref - 12.12k
:
forw_exec->inputs[i]24.0k
] |= F_SYMBOL_USE;
959
19.1k
    }
960
19.1k
  } ccv_nnc_graph_visit_endfor
961
6.78k
  ccv_array_t* sub_f_symbols = 0;
962
6.78k
  ccv_array_t* sub_wrt_symbols = 0;
963
19.1k
  ccv_nnc_graph_visit_for(forward_visit, exec_symbol_info, _, idx) {
964
19.1k
    ccv_nnc_graph_backward_info_t* node = backward_info + idx;
965
19.1k
    const ccv_nnc_graph_exec_symbol_info_t* forw_exec = exec_symbol_info + idx;
966
    /* Only interested in the ones on the f / wrt flow */
967
19.1k
    if ((node->f_wrt & 0x3) == 0x3 && 
forw_exec->graph_ref_size > 019.1k
)
968
2
    {
969
2
      uint64_t stack_input_bitmasks1[node->input_bitmask_size];
970
2
      uint64_t stack_input_bitmasks2[node->input_bitmask_size];
971
2
      uint64_t* const input_bitmasks = forw_exec->graph_ref_size > 1 ? 
stack_input_bitmasks11
:
node->input_bitmasks1
;
972
      // We collect input masks into this location.
973
2
      if (forw_exec->graph_ref_size > 1)
974
1
        memset(stack_input_bitmasks2, 0, sizeof(uint64_t) * node->input_bitmask_size);
975
6
      for (p = 0; p < forw_exec->graph_ref_size; 
p++4
)
976
4
      {
977
        // Reset the stack input bitmasks.
978
4
        if (forw_exec->graph_ref_size > 1)
979
3
          memcpy(stack_input_bitmasks1, node->input_bitmasks, sizeof(uint64_t) * node->input_bitmask_size);
980
        // Now calling it recursively until we are sure no f_symbols can be removed.
981
4
        const int graph_ref = CCV_NNC_GRAPH_REF(forw_exec)[p] - 1;
982
4
        ccv_nnc_symbolic_graph_backward_prep_t* const sub_prep = backward_prep->sub_preps + graph_ref;
983
4
        if (!sub_wrt_symbols)
984
2
          sub_wrt_symbols = ccv_array_new(sizeof(ccv_nnc_tensor_symbol_t), 0, 0);
985
2
        else
986
2
          ccv_array_clear(sub_wrt_symbols);
987
12
        for (i = 0; i < forw_exec->input_size; 
i++8
)
988
8
          if (node->output_bitmasks[i >> 6] & ((uint64_t)1 << (i & 63)))
989
7
          {
990
7
            const ccv_array_t* const s_refs = tensor_symbol_info[forw_exec->inputs[i]].s_ref;
991
7
            const int s_ref = s_refs && s_refs->rnum > graph_ref ? 
*(int*)5
ccv_array_get5
(s_refs, graph_ref) - 1 :
-12
;
992
7
            if (s_ref >= 0)
993
3
            {
994
3
              ccv_nnc_tensor_symbol_t sub_wrt_symbol = {
995
3
                .d = s_ref,
996
3
                .graph = sub_prep->graph,
997
3
              };
998
3
              ccv_array_push(sub_wrt_symbols, &sub_wrt_symbol);
999
3
            }
1000
7
          }
1001
4
        int flag; // Only continue if it changed */
1002
4
        do {
1003
4
          flag = 0;
1004
8
          for (i = 0; i < forw_exec->output_size; 
i++4
)
1005
            // Try to reduce number of inputs for the backward graph. If it is not tagged as F_SYMBOL_USE, we can reduce it.
1006
            // It is reducible because this sub graph may have multiple computation paths, therefore, some of these may not
1007
            // involve our wrt symbols at all.
1008
4
            if (!(used_grad[tensor_symbol_info[forw_exec->outputs[i]].alias_ref ? 
tensor_symbol_info[forw_exec->outputs[i]].alias_ref - 10
: forw_exec->outputs[i]] & F_SYMBOL_USE) &&
1009
4
              
input_bitmasks[i >> 6] & ((uint64_t)1 << (i & 63))0
)
1010
0
            { /* Try to eliminate one of the input. */
1011
0
              input_bitmasks[i >> 6] &= ~((uint64_t)1 << (i & 63));
1012
0
              if (!sub_f_symbols)
1013
0
                sub_f_symbols = ccv_array_new(sizeof(ccv_nnc_tensor_symbol_t), 0, 0);
1014
0
              else
1015
0
                ccv_array_clear(sub_f_symbols);
1016
0
              for (j = 0; j < forw_exec->output_size; j++)
1017
0
                if (node->input_bitmasks[j >> 6] & ((uint64_t)1 << (j & 63)))
1018
0
                {
1019
0
                  const int s_ref = *(int*)ccv_array_get(tensor_symbol_info[forw_exec->outputs[j]].s_ref, graph_ref) - 1;
1020
0
                  assert(s_ref >= 0);
1021
0
                  ccv_nnc_tensor_symbol_t sub_f_symbol = {
1022
0
                    .d = s_ref,
1023
0
                    .graph = sub_prep->graph,
1024
0
                  };
1025
0
                  ccv_array_push(sub_f_symbols, &sub_f_symbol);
1026
0
                }
1027
0
              if (_ccv_nnc_symbolic_graph_backward_prep_prune_ops(sub_prep, (ccv_nnc_tensor_symbol_t*)ccv_array_get(sub_f_symbols, 0), sub_f_symbols->rnum, (ccv_nnc_tensor_symbol_t*)ccv_array_get(sub_wrt_symbols, 0), sub_wrt_symbols->rnum, ccv_nnc_symbolic_graph_sources(sub_prep->graph), ccv_nnc_symbolic_graph_source_size(sub_prep->graph), ccv_nnc_symbolic_graph_destinations(sub_prep->graph), ccv_nnc_symbolic_graph_destination_size(sub_prep->graph)))
1028
0
                flag = 1;
1029
0
              else /* Refit this with the bit back again. */
1030
0
                input_bitmasks[i >> 6] |= ((uint64_t)1 << (i & 63));
1031
0
            }
1032
4
        } while (flag);
1033
        // I am done, need to redo above for sub_prep, and it has to be successful now.
1034
4
        if (!sub_f_symbols)
1035
2
          sub_f_symbols = ccv_array_new(sizeof(ccv_nnc_tensor_symbol_t), 0, 0);
1036
2
        else
1037
2
          ccv_array_clear(sub_f_symbols);
1038
8
        for (i = 0; i < forw_exec->output_size; 
i++4
)
1039
4
          if (input_bitmasks[i >> 6] & ((uint64_t)1 << (i & 63)))
1040
4
          {
1041
4
            const int s_ref = *(int*)ccv_array_get(tensor_symbol_info[forw_exec->outputs[i]].s_ref, graph_ref) - 1;
1042
4
            assert(s_ref >= 0);
1043
4
            ccv_nnc_tensor_symbol_t sub_f_symbol = {
1044
4
              .d = s_ref,
1045
4
              .graph = sub_prep->graph,
1046
4
            };
1047
4
            ccv_array_push(sub_f_symbols, &sub_f_symbol);
1048
4
          }
1049
4
        _ccv_nnc_symbolic_graph_backward_prep_prune_ops(sub_prep, (ccv_nnc_tensor_symbol_t*)ccv_array_get(sub_f_symbols, 0), sub_f_symbols->rnum, (ccv_nnc_tensor_symbol_t*)ccv_array_get(sub_wrt_symbols, 0), sub_wrt_symbols->rnum, ccv_nnc_symbolic_graph_sources(sub_prep->graph), ccv_nnc_symbolic_graph_source_size(sub_prep->graph), ccv_nnc_symbolic_graph_destinations(sub_prep->graph), ccv_nnc_symbolic_graph_destination_size(sub_prep->graph));
1050
4
        if (forw_exec->graph_ref_size > 1)
1051
6
          
for (i = 0; 3
i < node->input_bitmask_size;
i++3
)
1052
3
            stack_input_bitmasks2[i] |= input_bitmasks[i];
1053
4
      }
1054
2
      if (forw_exec->graph_ref_size > 1)
1055
1
        memcpy(node->input_bitmasks, stack_input_bitmasks2, sizeof(uint64_t) * node->input_bitmask_size);
1056
2
    }
1057
19.1k
  } ccv_nnc_graph_visit_endfor
1058
6.78k
  if (sub_f_symbols)
1059
2
    ccv_array_free(sub_f_symbols);
1060
6.78k
  if (sub_wrt_symbols)
1061
2
    ccv_array_free(sub_wrt_symbols);
1062
6.78k
  int flag = 1;
1063
13.5k
  for (i = 0; i < f_symbol_size && 
flag6.79k
;
i++6.79k
)
1064
6.79k
    flag = (used_grad[tensor_symbol_info[f_symbols[i].d].alias_ref ? 
tensor_symbol_info[f_symbols[i].d].alias_ref - 10
: f_symbols[i].d] & WRT_SYMBOL_USE);
1065
6.78k
  ccfree(used_grad);
1066
6.78k
  return flag;
1067
6.78k
}
1068
1069
static void _ccv_nnc_symbolic_graph_backward_prep_gen(ccv_nnc_symbolic_graph_backward_prep_t* const backward_prep, const ccv_nnc_tensor_symbol_t* const f_symbols, const int f_symbol_size, const ccv_nnc_tensor_symbol_t* const wrt_symbols, const int wrt_symbol_size, const int is_while, const ccv_nnc_graph_exec_symbol_t* const sources, const int source_size, const ccv_nnc_graph_exec_symbol_t* const destinations, const int destination_size)
1070
6.78k
{
1071
6.78k
  const int exec_symbol_info_size = backward_prep->exec_symbol_info_size;
1072
6.78k
  const int tensor_symbol_info_size = backward_prep->tensor_symbol_info_size;
1073
6.78k
  const ccv_nnc_graph_exec_symbol_info_t* const exec_symbol_info = backward_prep->exec_symbol_info;
1074
6.78k
  const ccv_nnc_tensor_symbol_info_t* const tensor_symbol_info =backward_prep->tensor_symbol_info;
1075
6.78k
  const ccv_nnc_graph_visit_t* const forward_visit = backward_prep->forward_visit;
1076
  // Now, for each one of these, find a reverse graph.
1077
6.78k
  ccv_nnc_graph_backward_info_t* const backward_info = backward_prep->backward_info;
1078
6.78k
  const ccv_nnc_graph_visit_t* const backward_visit = backward_prep->backward_visit;
1079
6.78k
  int i, j;
1080
  // Now, only the flow from f_symbols back to wrt_symbols are interested to us.
1081
  // Visit the graph in reverse order, build the AD nodes.
1082
6.78k
  ccv_nnc_autograd_graph_exec_symbol_t* const autograd_execs = (ccv_nnc_autograd_graph_exec_symbol_t*)cccalloc(exec_symbol_info_size, sizeof(ccv_nnc_autograd_graph_exec_symbol_t));
1083
6.78k
  int max_forw_input_size = 0, max_forw_output_size = 0;
1084
25.9k
  for (i = 0; i < exec_symbol_info_size; 
i++19.1k
)
1085
19.1k
    if ((backward_info[i].f_wrt & 0x3) == 0x3)
1086
19.1k
    {
1087
19.1k
      max_forw_input_size = ccv_max(max_forw_input_size, exec_symbol_info[i].input_size);
1088
19.1k
      max_forw_output_size = ccv_max(max_forw_output_size, exec_symbol_info[i].output_size);
1089
19.1k
      if (backward_info[i].outgoings)
1090
12.3k
      {
1091
        // Copy over the outgoing bits.
1092
12.3k
        autograd_execs[i].outgoings = ccv_array_new(sizeof(int), backward_info[i].outgoings->rnum, 0);
1093
24.6k
        for (j = 0; j < backward_info[i].outgoings->rnum; 
j++12.3k
)
1094
12.3k
        {
1095
12.3k
          const int d = *(int*)ccv_array_get(backward_info[i].outgoings, j);
1096
          // Only push the outgoing node if it is in the f_wrt path.
1097
12.3k
          if ((backward_info[d].f_wrt & 0x3) == 0x3)
1098
12.3k
            ccv_array_push(autograd_execs[i].outgoings, &d);
1099
12.3k
        }
1100
12.3k
      }
1101
19.1k
    }
1102
6.78k
  int max_forw_inputs[ccv_max(1, max_forw_input_size)];
1103
6.78k
  int max_forw_outputs[ccv_max(1, max_forw_output_size)];
1104
6.78k
  ccv_nnc_autograd_tensor_version_t* const autograd_tensor_versions = (ccv_nnc_autograd_tensor_version_t*)cccalloc(tensor_symbol_info_size, sizeof(ccv_nnc_autograd_tensor_version_t));
1105
6.78k
  ccv_array_t* autograd_tensor_symbols = ccv_array_new(sizeof(ccv_nnc_autograd_tensor_symbol_t), tensor_symbol_info_size, 0);
1106
6.78k
  ccv_array_t* sum_or_set_execs = ccv_array_new(sizeof(ccv_nnc_sum_or_set_graph_exec_symbol_t), 0, 0);
1107
19.1k
  ccv_nnc_graph_visit_for(backward_visit, backward_info, back_info_node, idx) {
1108
    /* This is required by both f flow and wrt flow, therefore, an interest to us */
1109
19.1k
    if ((back_info_node->f_wrt & 0x3) == 0x3)
1110
19.1k
    {
1111
19.1k
      const ccv_nnc_graph_exec_symbol_info_t* forw_exec = exec_symbol_info + idx;
1112
19.1k
      ccv_nnc_autograd_graph_exec_symbol_t* back_exec = autograd_execs + idx;
1113
19.1k
      back_exec->cmd = forw_exec->cmd;
1114
19.1k
      if (back_exec->cmd.cmd != CCV_NNC_NOOP)
1115
19.1k
        back_exec->cmd.cmd += 1; /* Backward command is the one after forward command. */
1116
19.1k
      assert(ccv_nnc_cmd_is_backward(back_exec->cmd) || back_exec->cmd.cmd == CCV_NNC_NOOP);
1117
19.1k
      if (!back_info_node->output_bitmask_size) /* This has no output, can be a noop. */
1118
0
        back_exec->cmd.cmd = CCV_NNC_NOOP;
1119
19.1k
      else {
1120
19.1k
        int* back_input_map = max_forw_outputs;
1121
19.1k
        int* back_output_map = max_forw_inputs;
1122
19.1k
        _ccv_nnc_symbolic_graph_backward_exec_io(forw_exec, &back_input_map, &back_output_map, &back_exec->input_size, &back_exec->output_size);
1123
19.1k
        back_exec->inputs = ccmalloc(sizeof(int) * (back_exec->input_size + back_exec->output_size));
1124
19.1k
        back_exec->outputs = back_exec->inputs + back_exec->input_size;
1125
        /* Need to compute input before we compute output */
1126
38.6k
        for (i = 0; i < back_exec->input_size; 
i++19.5k
)
1127
19.5k
        {
1128
          /* If we can skip this input, do that. */
1129
19.5k
          if (!(back_info_node->input_bitmasks[i >> 6] & ((uint64_t)1 << i)))
1130
424
            continue;
1131
19.1k
          const int d = back_input_map[i];
1132
19.1k
          const int alias_ref = tensor_symbol_info[d].alias_ref;
1133
19.1k
          ccv_nnc_autograd_tensor_version_t* tensor_ver = alias_ref ? 
autograd_tensor_versions + (alias_ref - 1)21
:
autograd_tensor_versions + d19.0k
;
1134
          /* Initialization tensor, should corresponding to f symbols */
1135
19.1k
          if (!tensor_ver->ref_version)
1136
6.79k
          {
1137
6.79k
            ccv_nnc_autograd_tensor_symbol_t tensor_sym = {};
1138
6.79k
            if (!alias_ref)
1139
6.79k
            {
1140
6.79k
              tensor_sym.d = d;
1141
6.79k
              ccv_array_push(autograd_tensor_symbols, &tensor_sym);
1142
6.79k
              const ccv_nnc_tensor_ref_t tensor_ref = {
1143
6.79k
                .d = autograd_tensor_symbols->rnum - 1,
1144
6.79k
                .x = idx,
1145
6.79k
                .alias_registry = 0
1146
6.79k
              };
1147
6.79k
              tensor_ver->ref_version = ccv_array_new(sizeof(ccv_nnc_tensor_ref_t), 1, 0);
1148
6.79k
              ccv_array_push(tensor_ver->ref_version, &tensor_ref);
1149
6.79k
            } else {
1150
2
              tensor_sym.d = alias_ref - 1;
1151
2
              ccv_array_push(autograd_tensor_symbols, &tensor_sym);
1152
2
              const ccv_nnc_tensor_ref_t tensor_ref = {
1153
2
                .d = autograd_tensor_symbols->rnum - 1,
1154
2
                .x = idx,
1155
2
                .alias_registry = ccv_array_new(sizeof(int), 1, 0)
1156
2
              };
1157
2
              tensor_ver->ref_version = ccv_array_new(sizeof(ccv_nnc_tensor_ref_t), 1, 0);
1158
2
              ccv_array_push(tensor_ver->ref_version, &tensor_ref);
1159
2
              tensor_sym.d = d; /* set back */
1160
2
              tensor_sym.alias_ref = tensor_ref.d + 1;
1161
2
              ccv_array_push(autograd_tensor_symbols, &tensor_sym);
1162
2
              const int ad = autograd_tensor_symbols->rnum - 1;
1163
2
              ccv_array_push(tensor_ref.alias_registry, &ad);
1164
2
            }
1165
6.79k
          }
1166
          /* The simplest case (most common), it is not an alias. */
1167
19.1k
          if (!alias_ref)
1168
19.0k
          {
1169
            /* Even simpler, this only have one reference tensor, thus, pass this as input. */
1170
19.0k
            if (tensor_ver->c == tensor_ver->ref_version->rnum - 1)
1171
14.8k
            {
1172
14.8k
              ccv_nnc_tensor_ref_t* tensor_ref = (ccv_nnc_tensor_ref_t*)ccv_array_get(tensor_ver->ref_version, tensor_ver->c);
1173
              /* There are alias associated with this tensor ref, zero it out when this tensor is allocated. */
1174
              /* This is is required. Consider the case that we have an alias of this tensor used somehwere */
1175
              /* on forward pass, when we compute backward, we have that alias computed first, however, its */
1176
              /* underlying tensor is not zero initialized, and we will end up with garbage values here. */
1177
14.8k
              if (tensor_ref->alias_registry &&
1178
                /* Loop over to see if this tensor is fully occupied to avoid extra zero step. */
1179
14.8k
                
!_ccv_nnc_tensor_ref_fully_assigned_with_aliases(tensor_ref, autograd_tensor_symbols, tensor_symbol_info)2.08k
)
1180
1
              {
1181
1
                ccv_nnc_autograd_tensor_symbol_t* tensor_sym = (ccv_nnc_autograd_tensor_symbol_t*)ccv_array_get(autograd_tensor_symbols, tensor_ref->d);
1182
1
                assert(tensor_sym->alias_ref == 0);
1183
1
                tensor_sym->flags = CCV_NNC_TENSOR_SYMBOL_INIT_ZEROS;
1184
1
              }
1185
14.8k
              back_exec->inputs[i] = tensor_ref->d;
1186
14.8k
            } else {
1187
              /* Otherwise, we need to sum them up, and then pass the summed result to the computation. */
1188
4.24k
              _ccv_nnc_graph_sum_autograd_tensor_versions(idx, d, exec_symbol_info_size, tensor_symbol_info, tensor_ver, autograd_execs, autograd_tensor_symbols, sum_or_set_execs);
1189
4.24k
              ccv_nnc_tensor_ref_t* tensor_ref = (ccv_nnc_tensor_ref_t*)ccv_array_get(tensor_ver->ref_version, tensor_ver->c);
1190
4.24k
              back_exec->inputs[i] = tensor_ref->d;
1191
4.24k
            }
1192
19.0k
          } else
1193
            /* If this is an alias, go through all available tensor ref versions */
1194
21
            back_exec->inputs[i] = _ccv_nnc_graph_sum_autograd_tensor_versions_alias(idx, d, tensor_symbol_info, exec_symbol_info_size, tensor_symbol_info + d, tensor_ver, autograd_execs, autograd_tensor_symbols, sum_or_set_execs);
1195
19.1k
        }
1196
53.8k
        
for (i = 0; 19.1k
i < back_exec->output_size;
i++34.7k
)
1197
34.7k
        {
1198
          /* If we can skip this output, do that. */
1199
34.7k
          if (!(back_info_node->output_bitmasks[i >> 6] & ((uint64_t)1 << i)))
1200
8.59k
            continue;
1201
26.1k
          const int d = back_output_map[i];
1202
26.1k
          const int alias_ref = tensor_symbol_info[d].alias_ref;
1203
26.1k
          ccv_nnc_autograd_tensor_symbol_t tensor_sym = {
1204
26.1k
            .d = d
1205
26.1k
          };
1206
          /* The simplest case (most common), it is not an alias. */
1207
26.1k
          if (!alias_ref)
1208
24.0k
          {
1209
24.0k
            ccv_array_push(autograd_tensor_symbols, &tensor_sym);
1210
24.0k
            const ccv_nnc_tensor_ref_t tensor_ref = {
1211
24.0k
              .d = autograd_tensor_symbols->rnum - 1,
1212
24.0k
              .x = idx,
1213
24.0k
              .exec_registry = 0,
1214
24.0k
              .alias_registry = 0
1215
24.0k
            };
1216
24.0k
            ccv_nnc_autograd_tensor_version_t* tensor_ver = autograd_tensor_versions + d;
1217
24.0k
            if (!tensor_ver->ref_version)
1218
19.7k
              tensor_ver->ref_version = ccv_array_new(sizeof(ccv_nnc_tensor_ref_t), 1, 0);
1219
24.0k
            ccv_array_push(tensor_ver->ref_version, &tensor_ref);
1220
24.0k
            back_exec->outputs[i] = tensor_ref.d;
1221
24.0k
          } else {
1222
            /* Otherwise, in case that this is an alias, we try to find the existing one (in tensor_ver
1223
             * see if can meet the need (thus, for the tensor info / ofs, it fits). */
1224
2.12k
            ccv_nnc_autograd_tensor_version_t* tensor_ver = autograd_tensor_versions + (alias_ref - 1);
1225
2.12k
            if (!tensor_ver->ref_version)
1226
2.09k
              tensor_ver->ref_version = ccv_array_new(sizeof(ccv_nnc_tensor_ref_t), 1, 0);
1227
            /* If already exists a ref version, check if any of these not-sealed tensors have free space. */
1228
2.12k
            int found = 0;
1229
2.17k
            for (j = tensor_ver->c; !found && 
j < tensor_ver->ref_version->rnum2.16k
;
j++50
)
1230
50
            {
1231
50
              ccv_nnc_tensor_ref_t* tensor_ref = (ccv_nnc_tensor_ref_t*)ccv_array_get(tensor_ver->ref_version, j);
1232
50
              if (!_ccv_nnc_tensor_ref_version_involve_alias(tensor_ref, autograd_tensor_symbols, tensor_symbol_info, tensor_symbol_info + d))
1233
9
              {
1234
9
                tensor_sym.alias_ref = tensor_ref->d + 1;
1235
9
                ccv_array_push(autograd_tensor_symbols, &tensor_sym);
1236
9
                const int ad = autograd_tensor_symbols->rnum - 1;
1237
9
                ccv_array_push(tensor_ref->alias_registry, &ad);
1238
9
                if (!tensor_ref->exec_registry)
1239
7
                  tensor_ref->exec_registry = ccv_array_new(sizeof(int), 1, 0);
1240
9
                ccv_array_push(tensor_ref->exec_registry, &idx);
1241
9
                back_exec->outputs[i] = ad;
1242
9
                found = 1;
1243
9
              }
1244
50
            }
1245
2.12k
            if (!found) /* Cannot find an tensor ref to insert, create one first */
1246
2.11k
            {
1247
2.11k
              tensor_sym.d = alias_ref - 1; /* Reference back to the non-alias. */
1248
2.11k
              ccv_array_push(autograd_tensor_symbols, &tensor_sym);
1249
2.11k
              const ccv_nnc_tensor_ref_t tensor_ref = {
1250
2.11k
                .d = autograd_tensor_symbols->rnum - 1,
1251
2.11k
                .x = idx,
1252
2.11k
                .exec_registry = 0,
1253
2.11k
                .alias_registry = ccv_array_new(sizeof(int), 1, 0)
1254
2.11k
              };
1255
2.11k
              ccv_array_push(tensor_ver->ref_version, &tensor_ref);
1256
2.11k
              tensor_sym.d = d; /* set back */
1257
2.11k
              tensor_sym.alias_ref = tensor_ref.d + 1;
1258
2.11k
              ccv_array_push(autograd_tensor_symbols, &tensor_sym);
1259
2.11k
              const int ad = autograd_tensor_symbols->rnum - 1;
1260
2.11k
              ccv_array_push(tensor_ref.alias_registry, &ad);
1261
2.11k
              back_exec->outputs[i] = ad;
1262
2.11k
            }
1263
2.12k
          }
1264
26.1k
        }
1265
19.1k
      }
1266
19.1k
    }
1267
19.1k
  } ccv_nnc_graph_visit_endfor
1268
  // Find all relevant wrt symbols, generate sum for them if needed.
1269
16.3k
  
for (i = 0; 6.78k
i < wrt_symbol_size;
i++9.53k
)
1270
9.53k
  {
1271
9.53k
    const int d = wrt_symbols[i].d;
1272
9.53k
    if (d < 0)
1273
9
      continue;
1274
9.52k
    const int ref_d = (!tensor_symbol_info[d].alias_ref) ? 
d9.52k
:
tensor_symbol_info[d].alias_ref - 11
;
1275
9.52k
    ccv_nnc_autograd_tensor_version_t* tensor_ver = autograd_tensor_versions + ref_d;
1276
9.52k
    if (!tensor_ver->ref_version)
1277
1
    {
1278
      // This wrt symbol is not available at all, for this case, we set its flag to init zero.
1279
1
      const ccv_nnc_autograd_tensor_symbol_t tensor_sym = {
1280
1
        .d = ref_d
1281
1
      };
1282
1
      ccv_array_push(autograd_tensor_symbols, &tensor_sym);
1283
1
      ccv_nnc_sum_or_set_graph_exec_symbol_t set_exec = {
1284
1
        .value = 0,
1285
1
        .output = autograd_tensor_symbols->rnum - 1,
1286
1
      };
1287
1
      ccv_array_push(sum_or_set_execs, &set_exec);
1288
      // Insert the one to be set to zero.
1289
1
      const ccv_nnc_tensor_ref_t tensor_ref = {
1290
1
        .d = autograd_tensor_symbols->rnum - 1,
1291
1
        .x = exec_symbol_info_size + sum_or_set_execs->rnum - 1,
1292
1
      };
1293
1
      tensor_ver->ref_version = ccv_array_new(sizeof(ccv_nnc_tensor_ref_t), 1, 0);
1294
1
      ccv_array_push(tensor_ver->ref_version, &tensor_ref);
1295
1
      continue;
1296
1
    }
1297
    // If it is a while loop, we need to insert an accumulator to the graph (this is expressed as a initialization tensor summed with existing results).
1298
    // First, insert the initialization tensor if this wrt results is not used directly in next while loop (thus, it participates the computation, therefore, no need to accumulate).
1299
9.52k
    if (is_while && 
!tensor_symbol_info[ref_d].assign_ref2
&&
1300
9.52k
      
_ccv_nnc_tensor_ref_version_find_init(tensor_ver) < 01
) // If the initialization tensor is not inserted yet.
1301
1
    {
1302
1
      const ccv_nnc_autograd_tensor_symbol_t tensor_sym = {
1303
1
        .d = ref_d
1304
1
      };
1305
1
      ccv_array_push(autograd_tensor_symbols, &tensor_sym);
1306
      // Insert the one to be summed.
1307
1
      const ccv_nnc_tensor_ref_t tensor_ref = {
1308
1
        .d = autograd_tensor_symbols->rnum - 1,
1309
1
        .x = -1, // This denotes it is an initialization vector.
1310
1
      };
1311
1
      ccv_array_push(tensor_ver->ref_version, &tensor_ref);
1312
1
    }
1313
    // If there are more than one tensor in the list, it is possible to sum them up.
1314
9.52k
    if (tensor_ver->c < tensor_ver->ref_version->rnum - 1)
1315
30
      _ccv_nnc_graph_sum_autograd_tensor_versions(-1, ref_d, exec_symbol_info_size, tensor_symbol_info, tensor_ver, autograd_execs, autograd_tensor_symbols, sum_or_set_execs);
1316
    // The tensor version should have ref_version, and only one now (after sum up).
1317
9.52k
    assert(tensor_ver->c == tensor_ver->ref_version->rnum - 1);
1318
9.52k
  }
1319
  // Adding additional fields to backward_prep now.
1320
6.78k
  backward_prep->autograd_execs = autograd_execs;
1321
6.78k
  backward_prep->autograd_tensor_versions = autograd_tensor_versions;
1322
6.78k
  backward_prep->autograd_tensor_symbols = autograd_tensor_symbols;
1323
6.78k
  backward_prep->sum_or_set_execs = sum_or_set_execs;
1324
6.78k
  ccv_array_t* sub_f_symbols = 0;
1325
6.78k
  ccv_array_t* sub_wrt_symbols = 0;
1326
19.1k
  ccv_nnc_graph_visit_for(forward_visit, exec_symbol_info, _, idx) {
1327
19.1k
    ccv_nnc_graph_backward_info_t* node = backward_info + idx;
1328
19.1k
    const ccv_nnc_graph_exec_symbol_info_t* forw_exec = exec_symbol_info + idx;
1329
    /* Only interested in the ones on the f / wrt flow */
1330
19.1k
    if ((node->f_wrt & 0x3) == 0x3)
1331
19.1k
    {
1332
19.1k
      const int is_while = (forw_exec->flags & CCV_NNC_GRAPH_EXEC_P_WHILE);
1333
19.1k
      for (i = 0; i < forw_exec->graph_ref_size; 
i++4
)
1334
4
      {
1335
        // Now calling it recursively until we are sure no f_symbols can be removed.
1336
4
        const int graph_ref = CCV_NNC_GRAPH_REF(forw_exec)[i] - 1;
1337
4
        ccv_nnc_symbolic_graph_backward_prep_t* const sub_prep = backward_prep->sub_preps + graph_ref;
1338
4
        if (!sub_wrt_symbols)
1339
2
          sub_wrt_symbols = ccv_array_new(sizeof(ccv_nnc_tensor_symbol_t), 0, 0);
1340
4
        if (!sub_f_symbols)
1341
2
          sub_f_symbols = ccv_array_new(sizeof(ccv_nnc_tensor_symbol_t), 0, 0);
1342
4
        _ccv_nnc_symbolic_graph_backward_prep_sub_f_wrt_symbols(forw_exec, sub_prep->graph, graph_ref, tensor_symbol_info, node->input_bitmasks, node->output_bitmasks, sub_f_symbols, sub_wrt_symbols);
1343
4
        _ccv_nnc_symbolic_graph_backward_prep_gen(sub_prep, (ccv_nnc_tensor_symbol_t*)ccv_array_get(sub_f_symbols, 0), sub_f_symbols->rnum, (ccv_nnc_tensor_symbol_t*)ccv_array_get(sub_wrt_symbols, 0), sub_wrt_symbols->rnum, is_while, ccv_nnc_symbolic_graph_sources(sub_prep->graph), ccv_nnc_symbolic_graph_source_size(sub_prep->graph), ccv_nnc_symbolic_graph_destinations(sub_prep->graph), ccv_nnc_symbolic_graph_destination_size(sub_prep->graph));
1344
4
      }
1345
19.1k
    }
1346
19.1k
  } ccv_nnc_graph_visit_endfor
1347
6.78k
  if (sub_f_symbols)
1348
2
    ccv_array_free(sub_f_symbols);
1349
6.78k
  if (sub_wrt_symbols)
1350
2
    ccv_array_free(sub_wrt_symbols);
1351
6.78k
}
1352
1353
static void _ccv_nnc_symbolic_graph_backward_prep_free(const ccv_nnc_symbolic_graph_backward_prep_t backward_prep)
1354
6.78k
{
1355
6.78k
  int i, j;
1356
6.78k
  const int exec_symbol_info_size = backward_prep.exec_symbol_info_size;
1357
6.78k
  const int tensor_symbol_info_size = backward_prep.tensor_symbol_info_size;
1358
6.78k
  ccv_nnc_autograd_graph_exec_symbol_t* const autograd_execs = backward_prep.autograd_execs;
1359
6.78k
  if (autograd_execs)
1360
6.78k
  {
1361
25.9k
    for (i = 0; i < exec_symbol_info_size; 
i++19.1k
)
1362
19.1k
    {
1363
19.1k
      if (autograd_execs[i].inputs)
1364
19.1k
        ccfree(autograd_execs[i].inputs);
1365
19.1k
      if (autograd_execs[i].outgoings)
1366
12.3k
        ccv_array_free(autograd_execs[i].outgoings);
1367
19.1k
    }
1368
6.78k
    ccfree(autograd_execs);
1369
6.78k
  }
1370
6.78k
  ccv_nnc_autograd_tensor_version_t* const autograd_tensor_versions = backward_prep.autograd_tensor_versions;
1371
6.78k
  if (autograd_tensor_versions)
1372
6.78k
  {
1373
46.4k
    for (i = 0; i < tensor_symbol_info_size; 
i++39.6k
)
1374
39.6k
    {
1375
39.6k
      if (autograd_tensor_versions[i].ref_version)
1376
28.6k
      {
1377
65.9k
        for (j = 0; j < autograd_tensor_versions[i].ref_version->rnum; 
j++37.2k
)
1378
37.2k
        {
1379
37.2k
          ccv_nnc_tensor_ref_t* ref_version = (ccv_nnc_tensor_ref_t*)ccv_array_get(autograd_tensor_versions[i].ref_version, j);
1380
37.2k
          if (ref_version->exec_registry)
1381
7
            ccv_array_free(ref_version->exec_registry);
1382
37.2k
          if (ref_version->alias_registry)
1383
2.12k
            ccv_array_free(ref_version->alias_registry);
1384
37.2k
        }
1385
28.6k
        ccv_array_free(autograd_tensor_versions[i].ref_version);
1386
28.6k
      }
1387
39.6k
    }
1388
6.78k
    ccfree(autograd_tensor_versions);
1389
6.78k
  }
1390
6.78k
  if (backward_prep.autograd_tensor_symbols)
1391
6.78k
    ccv_array_free(backward_prep.autograd_tensor_symbols);
1392
6.78k
  ccv_array_t* const sum_or_set_execs = backward_prep.sum_or_set_execs;
1393
6.78k
  if (sum_or_set_execs)
1394
6.78k
  {
1395
11.0k
    for (i = 0; i < sum_or_set_execs->rnum; 
i++4.28k
)
1396
4.28k
    {
1397
4.28k
      ccv_nnc_sum_or_set_graph_exec_symbol_t* sum_or_set = (ccv_nnc_sum_or_set_graph_exec_symbol_t*)ccv_array_get(sum_or_set_execs, i);
1398
4.28k
      if (sum_or_set->inputs)
1399
4.28k
        ccfree(sum_or_set->inputs);
1400
4.28k
      if (sum_or_set->outgoings)
1401
4.25k
        ccv_array_free(sum_or_set->outgoings);
1402
4.28k
    }
1403
6.78k
    ccv_array_free(sum_or_set_execs);
1404
6.78k
  }
1405
  // Now afterwards, these are mandatory.
1406
6.78k
  ccv_nnc_graph_backward_info_t* const backward_info = backward_prep.backward_info;
1407
25.9k
  for (i = 0; i < exec_symbol_info_size; 
i++19.1k
)
1408
19.1k
  {
1409
19.1k
    if (backward_info[i].outgoings)
1410
12.3k
      ccv_array_free(backward_info[i].outgoings);
1411
19.1k
    if (backward_info[i].input_bitmasks)
1412
19.1k
      ccfree(backward_info[i].input_bitmasks);
1413
19.1k
  }
1414
6.78k
  ccfree(backward_info);
1415
6.78k
  ccv_nnc_graph_visit_free(backward_prep.backward_visit);
1416
6.78k
  ccv_nnc_graph_visit_free(backward_prep.forward_visit);
1417
6.78k
  ccfree(backward_prep.exec_symbol_info);
1418
6.78k
  ccfree(backward_prep.tensor_symbol_info);
1419
6.78k
  for (i = 0; i < backward_prep.sub_prep_size; 
i++4
)
1420
4
    _ccv_nnc_symbolic_graph_backward_prep_free(backward_prep.sub_preps[i]);
1421
6.78k
  if (backward_prep.sub_preps)
1422
2
    ccfree(backward_prep.sub_preps);
1423
6.78k
}
1424
1425
static void _ccv_nnc_add_backward_breakpoint_for_symbol(const ccv_nnc_symbolic_graph_backward_prep_t* const backward_prep, const ccv_nnc_graph_exec_symbol_t breakpoint, ccv_nnc_symbolic_graph_t* const graph, ccv_array_t* const sub_breakpoints)
1426
1
{
1427
1
  const ccv_nnc_graph_exec_symbol_t noop = ccv_nnc_graph_exec_symbol_new(graph, ccv_nnc_cmd(CCV_NNC_NOOP, 0, CMD_GENERIC(), 0), 0, 0, 0, 0, 0);
1428
1
  ccv_array_push(sub_breakpoints, &noop);
1429
  // Now need to hook this up to the graph.
1430
1
  const ccv_nnc_graph_exec_symbol_info_t* const exec_symbol_info = backward_prep->exec_symbol_info;
1431
1
  const ccv_nnc_graph_visit_t* const forward_visit = backward_prep->forward_visit;
1432
  // Now, for each one of these, find a reverse graph.
1433
1
  ccv_nnc_graph_backward_info_t* const backward_info = backward_prep->backward_info;
1434
1
  int i;
1435
  // Clean up the high bit.
1436
4
  for (i = 0; i < backward_prep->exec_symbol_info_size; 
i++3
)
1437
3
    backward_info[i].f_wrt &= ~0x4;
1438
1
  assert((backward_info[breakpoint.d].f_wrt & 0x3) != 0x3);
1439
1
  backward_info[breakpoint.d].f_wrt |= 0x4;
1440
1
  const ccv_nnc_graph_visit_t* const backward_visit = backward_prep->backward_visit;
1441
1
  const ccv_nnc_autograd_graph_exec_symbol_t* const autograd_execs = backward_prep->autograd_execs;
1442
  // Going forward to find whether this breakpoint is a source node to some f_wrt nodes.
1443
3
  ccv_nnc_graph_visit_for(forward_visit, exec_symbol_info, forw_exec, idx) {
1444
3
    ccv_nnc_graph_backward_info_t* const node = backward_info + idx;
1445
    // If it is tagged on breakpoint flow, but not as both f or wrt, flow through it.
1446
3
    if ((node->f_wrt & 0x4) && 
(node->f_wrt & 0x3) != 0x31
)
1447
1
      for (i = 0; forw_exec->outgoings && 
i < forw_exec->outgoings->rnum0
;
i++0
)
1448
0
      {
1449
0
        const int outgoing_idx = *(int*)ccv_array_get(forw_exec->outgoings, i);
1450
0
        ccv_nnc_graph_backward_info_t* const outgoing_node = backward_info + outgoing_idx;
1451
        // If this is a f_wrt node. Concatenate.
1452
0
        if (!(outgoing_node->f_wrt & 0x4) && (outgoing_node->f_wrt & 0x3) == 0x3)
1453
0
            ccv_nnc_graph_exec_symbol_concat(graph, autograd_execs[outgoing_idx].symbol, noop);
1454
0
        outgoing_node->f_wrt |= 0x4;
1455
0
      }
1456
3
  } ccv_nnc_graph_visit_endfor
1457
  // Going backward to find whether this breakpoint is a destination node for some f_wrt_nodes.
1458
3
  ccv_nnc_graph_visit_for(backward_visit, backward_info, node, idx) {
1459
3
    if ((node->f_wrt & 0x4) && 
(node->f_wrt & 0x3) != 0x32
)
1460
2
      
for (i = 0; 1
node->outgoings && i < node->outgoings->rnum;
i++1
)
1461
1
      {
1462
1
        const int outgoing_idx = *(int*)ccv_array_get(node->outgoings, i);
1463
1
        ccv_nnc_graph_backward_info_t* const outgoing_node = backward_info + outgoing_idx;
1464
        // If this is a f_wrt node. Concatenate.
1465
1
        if (!(outgoing_node->f_wrt & 0x4) && (outgoing_node->f_wrt & 0x3) == 0x3)
1466
1
            ccv_nnc_graph_exec_symbol_concat(graph, noop, autograd_execs[outgoing_idx].symbol);
1467
1
        outgoing_node->f_wrt |= 0x4;
1468
1
      }
1469
3
  } ccv_nnc_graph_visit_endfor
1470
1
}
1471
1472
static ccv_nnc_autograd_tensor_symbol_t* _ccv_nnc_autograd_tensor_symbol_from_tensor_version(ccv_array_t* const autograd_tensor_symbols, const ccv_nnc_autograd_tensor_version_t* const tensor_ver)
1473
7
{
1474
7
  assert(tensor_ver->ref_version);
1475
7
  const ccv_nnc_tensor_ref_t* const tensor_ref = (ccv_nnc_tensor_ref_t*)ccv_array_get(tensor_ver->ref_version, tensor_ver->c);
1476
7
  return (ccv_nnc_autograd_tensor_symbol_t*)ccv_array_get(autograd_tensor_symbols, tensor_ref->d);
1477
7
}
1478
1479
static void _ccv_nnc_symbolic_graph_set_backward_carry_overs(const ccv_nnc_symbolic_graph_backward_prep_t* const backward_prep, const ccv_nnc_tensor_symbol_t* const wrt_symbols, const int wrt_symbol_size, ccv_nnc_symbolic_graph_t* const graph)
1480
1
{
1481
1
  int i;
1482
5
  for (i = 0; i < backward_prep->graph->tensor_symbol_info->rnum; 
i++4
)
1483
4
  {
1484
4
    const ccv_nnc_tensor_symbol_info_t* const tensor_symbol_info = backward_prep->tensor_symbol_info + i;
1485
4
    if (tensor_symbol_info->assign_ref)
1486
1
    {
1487
1
      const int assign_ref = tensor_symbol_info->assign_ref - 1;
1488
1
      ccv_nnc_autograd_tensor_symbol_t* const destination_autograd_symbol = _ccv_nnc_autograd_tensor_symbol_from_tensor_version(backward_prep->autograd_tensor_symbols, backward_prep->autograd_tensor_versions + assign_ref);
1489
1
      ccv_nnc_autograd_tensor_symbol_t* const source_autograd_symbol = _ccv_nnc_autograd_tensor_symbol_from_tensor_version(backward_prep->autograd_tensor_symbols, backward_prep->autograd_tensor_versions + i);
1490
1
      ccv_nnc_symbolic_graph_set_carry_overs(graph, (ccv_nnc_tensor_symbol_map_t []){
1491
1
        { .source = source_autograd_symbol->symbol, .destination = destination_autograd_symbol->symbol }
1492
1
      }, 1);
1493
1
    }
1494
4
  }
1495
3
  for (i = 0; i < wrt_symbol_size; 
i++2
)
1496
2
  {
1497
2
    const int d = wrt_symbols[i].d;
1498
2
    if (d < 0)
1499
0
      continue;
1500
2
    const int ref_d = (!backward_prep->tensor_symbol_info[d].alias_ref) ? d : 
backward_prep->tensor_symbol_info[d].alias_ref - 10
;
1501
2
    const ccv_nnc_autograd_tensor_version_t* const tensor_ver = backward_prep->autograd_tensor_versions + ref_d;
1502
2
    const int init_ref_ver = _ccv_nnc_tensor_ref_version_find_init(tensor_ver);
1503
2
    if (init_ref_ver >= 0)
1504
1
    {
1505
1
      const int init_d = ((ccv_nnc_tensor_ref_t*)ccv_array_get(tensor_ver->ref_version, init_ref_ver))->d;
1506
1
      ccv_nnc_autograd_tensor_symbol_t* const destination_autograd_symbol = (ccv_nnc_autograd_tensor_symbol_t*)ccv_array_get(backward_prep->autograd_tensor_symbols, init_d);
1507
1
      ccv_nnc_autograd_tensor_symbol_t* const source_autograd_symbol = _ccv_nnc_autograd_tensor_symbol_from_tensor_version(backward_prep->autograd_tensor_symbols, backward_prep->autograd_tensor_versions + ref_d);
1508
1
      ccv_nnc_symbolic_graph_set_carry_overs(graph, (ccv_nnc_tensor_symbol_map_t []){
1509
1
        { .source = source_autograd_symbol->symbol, .destination = destination_autograd_symbol->symbol }
1510
1
      }, 1);
1511
1
    }
1512
2
  }
1513
1
}
1514
1515
static void _ccv_nnc_symbolic_graph_add_init_zeros(const ccv_nnc_symbolic_graph_backward_prep_t* const sub_prep, const ccv_nnc_tensor_symbol_t* const wrt_symbols, const int wrt_symbol_size, ccv_nnc_symbolic_graph_t* const graph, ccv_nnc_symbolic_graph_t* const sub_graph, ccv_array_t* const symbols)
1516
1
{
1517
1
  int i;
1518
3
  for (i = 0; i < wrt_symbol_size; 
i++2
)
1519
2
  {
1520
2
    const int d = wrt_symbols[i].d;
1521
2
    if (d < 0)
1522
0
      continue;
1523
2
    const int ref_d = (!sub_prep->tensor_symbol_info[d].alias_ref) ? d : 
sub_prep->tensor_symbol_info[d].alias_ref - 10
;
1524
2
    const ccv_nnc_autograd_tensor_version_t* const tensor_ver = sub_prep->autograd_tensor_versions + ref_d;
1525
2
    const int init_ref_ver = _ccv_nnc_tensor_ref_version_find_init(tensor_ver);
1526
2
    if (init_ref_ver >= 0)
1527
1
    {
1528
      // Need de-dup logic.
1529
1
      const int init_d = ((ccv_nnc_tensor_ref_t*)ccv_array_get(tensor_ver->ref_version, init_ref_ver))->d;
1530
1
      ccv_nnc_autograd_tensor_symbol_t* const init_autograd_symbol = (ccv_nnc_autograd_tensor_symbol_t*)ccv_array_get(sub_prep->autograd_tensor_symbols, init_d);
1531
1
      const ccv_nnc_tensor_symbol_info_t* const sub_init_symbol_info = (ccv_nnc_tensor_symbol_info_t*)ccv_array_get(sub_graph->tensor_symbol_info, init_autograd_symbol->symbol.d);
1532
      // If it doesn't have a parent ref yet, create one.
1533
1
      if (!sub_init_symbol_info->p_ref)
1534
1
      {
1535
1
        ccv_nnc_tensor_symbol_t new_symbol = ccv_nnc_tensor_symbol_new(graph, sub_prep->tensor_symbol_info[ref_d].info, 0);
1536
1
        ccv_nnc_tensor_symbol_set_flags(graph, new_symbol, CCV_NNC_TENSOR_SYMBOL_INIT_ZEROS);
1537
1
        ccv_array_push(symbols, &new_symbol);
1538
1
        ccv_nnc_tensor_symbol_hookup(graph, sub_graph, new_symbol, init_autograd_symbol->symbol);
1539
1
      }
1540
1
    }
1541
2
  }
1542
1
}
1543
1544
static void _ccv_nnc_symbolic_graph_add_tape_vars(const ccv_nnc_symbolic_graph_backward_prep_t* const sub_prep, ccv_nnc_symbolic_graph_t* const root, ccv_nnc_symbolic_graph_t* const graph, ccv_nnc_symbolic_graph_t* const sub_graph, ccv_array_t* const symbols)
1545
4
{
1546
4
  int i;
1547
24
  for (i = 0; i < sub_graph->tensor_symbol_info->rnum; 
i++20
)
1548
20
  {
1549
20
    const ccv_nnc_tensor_symbol_info_t* const symbol_info = (ccv_nnc_tensor_symbol_info_t*)ccv_array_get(sub_graph->tensor_symbol_info, i);
1550
20
    if ((symbol_info->flags & CCV_NNC_TENSOR_SYMBOL_TAPE_VAR) && 
symbol_info->pair_ref7
)
1551
7
    {
1552
7
      const int pair_ref = symbol_info->pair_ref - 1;
1553
7
      const ccv_nnc_tensor_symbol_t root_symbol = ccv_nnc_tensor_symbol_resolve(root, (ccv_nnc_tensor_symbol_t){
1554
7
        .d = pair_ref,
1555
7
        .graph = sub_prep->graph,
1556
7
      });
1557
7
      if (root_symbol.d >= 0)
1558
3
      {
1559
3
        ccv_nnc_tensor_symbol_hookup(root, sub_graph, root_symbol, (ccv_nnc_tensor_symbol_t){
1560
3
          .d = i,
1561
3
          .graph = sub_graph,
1562
3
        });
1563
3
        if (symbols)
1564
2
        {
1565
2
          const ccv_nnc_tensor_symbol_t p_symbol = ccv_nnc_tensor_symbol_resolve(graph, (ccv_nnc_tensor_symbol_t){
1566
2
            .d = i,
1567
2
            .graph = sub_graph,
1568
2
          });
1569
2
          ccv_array_push(symbols, &p_symbol);
1570
2
        }
1571
3
      }
1572
7
    }
1573
20
  }
1574
4
}
1575
1576
static void _ccv_nnc_symbolic_graph_backward_gen(const ccv_nnc_symbolic_graph_backward_prep_t* const backward_prep, const ccv_nnc_tensor_symbol_t* const f_symbols, const int f_symbol_size, const ccv_nnc_tensor_symbol_t* const wrt_symbols, const int wrt_symbol_size, ccv_nnc_symbolic_graph_t* const graph, ccv_nnc_symbolic_graph_t* const root)
1577
6.78k
{
1578
6.78k
  assert(graph == backward_prep->graph || graph->pair == backward_prep->graph);
1579
6.78k
  const int exec_symbol_info_size = backward_prep->exec_symbol_info_size;
1580
6.78k
  const int tensor_symbol_info_size = backward_prep->tensor_symbol_info_size;
1581
6.78k
  const ccv_nnc_graph_exec_symbol_info_t* const exec_symbol_info = backward_prep->exec_symbol_info;
1582
6.78k
  const ccv_nnc_tensor_symbol_info_t* const tensor_symbol_info = backward_prep->tensor_symbol_info;
1583
6.78k
  int i, j, k, p;
1584
6.78k
  ccv_array_t* const autograd_tensor_symbols = backward_prep->autograd_tensor_symbols;
1585
  // Generate required symbols based on the information gathered above.
1586
46.1k
  for (i = 0; i < autograd_tensor_symbols->rnum; 
i++39.4k
)
1587
39.4k
  {
1588
39.4k
    ccv_nnc_autograd_tensor_symbol_t* symbol = (ccv_nnc_autograd_tensor_symbol_t*)ccv_array_get(autograd_tensor_symbols, i);
1589
39.4k
    assert(symbol->d >= 0);
1590
39.4k
    assert(symbol->d < tensor_symbol_info_size);
1591
39.4k
    const ccv_nnc_tensor_symbol_info_t* const forw_symbol = tensor_symbol_info + symbol->d;
1592
39.4k
    if (!symbol->alias_ref)
1593
37.2k
    {
1594
37.2k
      assert(!forw_symbol->alias_ref);
1595
37.2k
      symbol->symbol = ccv_nnc_tensor_symbol_new(graph, forw_symbol->info, 0);
1596
37.2k
      ccv_nnc_tensor_symbol_set_flags(graph, symbol->symbol, symbol->flags);
1597
37.2k
    } else {
1598
2.14k
      assert(forw_symbol->alias_ref);
1599
2.14k
      assert(symbol->flags == 0); // We don't set flags on alias.
1600
      // Due to our generation order, this must be after the original symbol is created.
1601
2.14k
      ccv_nnc_autograd_tensor_symbol_t* ref = (ccv_nnc_autograd_tensor_symbol_t*)ccv_array_get(autograd_tensor_symbols, symbol->alias_ref - 1);
1602
2.14k
      symbol->symbol = ccv_nnc_tensor_symbol_alias_new(graph, ref->symbol, forw_symbol->ofs, forw_symbol->stride, forw_symbol->info, 0);
1603
2.14k
    }
1604
39.4k
  }
1605
6.78k
  ccv_nnc_graph_backward_info_t* const backward_info = backward_prep->backward_info;
1606
6.78k
  ccv_nnc_autograd_graph_exec_symbol_t* const autograd_execs = backward_prep->autograd_execs;
1607
6.78k
  ccv_array_t* symbols = ccv_array_new(sizeof(ccv_nnc_tensor_symbol_t), 0, 0);
1608
6.78k
  ccv_array_t* symbol_map = ccv_array_new(sizeof(ccv_nnc_tensor_symbol_map_t), 0, 0);
1609
6.78k
  ccv_array_t* sub_f_symbols = 0;
1610
6.78k
  ccv_array_t* sub_wrt_symbols = 0;
1611
6.78k
  ccv_array_t* sub_execs = 0;
1612
25.9k
  for (i = 0; i < exec_symbol_info_size; 
i++19.1k
)
1613
19.1k
  {
1614
    // This is not going to be an interesting node. Skip.
1615
19.1k
    if ((backward_info[i].f_wrt & 0x3) != 0x3)
1616
86
      continue;
1617
19.1k
    ccv_nnc_graph_backward_info_t* const back_info = backward_info + i;
1618
19.1k
    ccv_nnc_autograd_graph_exec_symbol_t* const back_exec = autograd_execs + i;
1619
19.1k
    if (back_exec->cmd.cmd == CCV_NNC_NOOP)
1620
1
    {
1621
1
      back_exec->symbol = ccv_nnc_graph_exec_symbol_new(graph, back_exec->cmd, 0, 0, 0, 0, 0);
1622
1
      continue;
1623
1
    }
1624
19.1k
    const ccv_nnc_graph_exec_symbol_info_t* const forw_exec = exec_symbol_info + i;
1625
19.1k
    if (forw_exec->flags & CCV_NNC_GRAPH_EXEC_P_WHILE)
1626
1
    {
1627
1
      ccv_array_clear(symbols);
1628
1
      const int graph_ref = CCV_NNC_GRAPH_REF(forw_exec)[0] - 1;
1629
1
      ccv_nnc_symbolic_graph_backward_prep_t* sub_prep = backward_prep->sub_preps + graph_ref;
1630
1
      ccv_nnc_symbolic_graph_t* sub_graph = ccv_nnc_symbolic_graph_new();
1631
1
      sub_graph->pair = sub_prep->graph;
1632
1
      if (!sub_wrt_symbols)
1633
1
        sub_wrt_symbols = ccv_array_new(sizeof(ccv_nnc_tensor_symbol_t), 0, 0);
1634
      // I am done, need to redo above for sub_prep, and it has to be successful now.
1635
1
      if (!sub_f_symbols)
1636
1
        sub_f_symbols = ccv_array_new(sizeof(ccv_nnc_tensor_symbol_t), 0, 0);
1637
1
      _ccv_nnc_symbolic_graph_backward_prep_sub_f_wrt_symbols(forw_exec, sub_prep->graph, graph_ref, tensor_symbol_info, back_info->input_bitmasks, back_info->output_bitmasks, sub_f_symbols, sub_wrt_symbols);
1638
1
      _ccv_nnc_symbolic_graph_backward_gen(sub_prep, (ccv_nnc_tensor_symbol_t*)ccv_array_get(sub_f_symbols, 0), sub_f_symbols->rnum, (ccv_nnc_tensor_symbol_t*)ccv_array_get(sub_wrt_symbols, 0), sub_wrt_symbols->rnum, sub_graph, root);
1639
1
      back_exec->symbol = ccv_nnc_symbolic_graph_while(graph, back_exec->cmd.cmd, sub_graph, forw_exec->name);
1640
1
      if (!sub_execs)
1641
1
        sub_execs = ccv_array_new(sizeof(ccv_nnc_graph_exec_symbol_t), 0, 0);
1642
1
      ccv_array_clear(sub_execs);
1643
      // Find the breakpoints in forward graph, creating the reverse one.
1644
2
      for (j = 0; j < sub_prep->graph->breakpoint_size; 
j++1
)
1645
1
      {
1646
1
        const int d = sub_prep->graph->breakpoints[j].d;
1647
1
        if (sub_prep->autograd_execs[d].symbol.graph)
1648
0
          ccv_array_push(sub_execs, &sub_prep->autograd_execs[d].symbol);
1649
1
        else
1650
1
          _ccv_nnc_add_backward_breakpoint_for_symbol(sub_prep, sub_prep->graph->breakpoints[j], sub_graph, sub_execs);
1651
1
      }
1652
1
      ccv_nnc_symbolic_graph_set_while_expr(sub_graph, NOOP_GRAPH_WHILE_EXPR, 0, 0, 0, (ccv_nnc_graph_exec_symbol_t*)ccv_array_get(sub_execs, 0), sub_execs->rnum);
1653
1
      ccv_nnc_graph_exec_symbol_autogen(sub_graph, 0, 0, CCV_NNC_AUTOGEN_SOURCES_AND_DESTINATIONS);
1654
1
      _ccv_nnc_symbolic_graph_set_backward_carry_overs(sub_prep, (ccv_nnc_tensor_symbol_t*)ccv_array_get(sub_wrt_symbols, 0), sub_wrt_symbols->rnum, sub_graph);
1655
2
      for (j = 0; j < back_exec->input_size; 
j++1
)
1656
1
        if (back_info->input_bitmasks[j >> 6] & ((uint64_t)1 << j))
1657
1
          ccv_array_push(symbols, &(((ccv_nnc_autograd_tensor_symbol_t*)ccv_array_get(autograd_tensor_symbols, back_exec->inputs[j]))->symbol));
1658
      // Find whether in the wrt symbols, anything we need to init to zero, if there are, these need to be inputs here too.
1659
1
      _ccv_nnc_symbolic_graph_add_init_zeros(sub_prep, (ccv_nnc_tensor_symbol_t*)ccv_array_get(sub_wrt_symbols, 0), sub_wrt_symbols->rnum, graph, sub_graph, symbols);
1660
1
      _ccv_nnc_symbolic_graph_add_tape_vars(sub_prep, root, graph, sub_graph, symbols);
1661
      // input_size at this point, may be different from the back_exec->input_size, the reason is because we may added zeroing tensors as input tensors.
1662
1
      const int input_size = symbols->rnum;
1663
3
      for (j = 0; j < back_exec->output_size; 
j++2
)
1664
2
        if (back_info->output_bitmasks[j >> 6] & ((uint64_t)1 << j))
1665
1
          ccv_array_push(symbols, &(((ccv_nnc_autograd_tensor_symbol_t*)ccv_array_get(autograd_tensor_symbols, back_exec->outputs[j]))->symbol));
1666
1
      const int output_size = symbols->rnum - input_size;
1667
1
      const int p_idx = sub_prep->graph->p_idx - 1;
1668
1
      assert(back_exec->input_size == forw_exec->output_size);
1669
1
      k = 0;
1670
2
      for (j = 0; j < back_exec->input_size; 
j++1
)
1671
1
        if (back_info->input_bitmasks[j >> 6] & ((uint64_t)1 << j))
1672
1
        {
1673
1
          const ccv_nnc_tensor_symbol_info_t* const info = tensor_symbol_info + forw_exec->outputs[j];
1674
1
          const int s_idx = *(int*)ccv_array_get(info->s_ref, p_idx) - 1;
1675
1
          assert(s_idx >= 0);
1676
1
          const ccv_nnc_autograd_tensor_symbol_t* const autograd_symbol = _ccv_nnc_autograd_tensor_symbol_from_tensor_version(sub_prep->autograd_tensor_symbols, sub_prep->autograd_tensor_versions + s_idx);
1677
1
          ccv_nnc_tensor_symbol_hookup(graph, sub_graph, *(ccv_nnc_tensor_symbol_t*)ccv_array_get(symbols, k), autograd_symbol->symbol);
1678
1
          ++k;
1679
1
        }
1680
1
      k = input_size; // Reset k, the symbol pass already set up by add_init_zeros.
1681
1
      assert(back_exec->output_size == forw_exec->input_size);
1682
3
      
for (j = 0; 1
j < back_exec->output_size;
j++2
)
1683
2
        if (back_info->output_bitmasks[j >> 6] & ((uint64_t)1 << j))
1684
1
        {
1685
1
          const ccv_nnc_tensor_symbol_info_t* const info = tensor_symbol_info + forw_exec->inputs[j];
1686
1
          const int s_idx = *(int*)ccv_array_get(info->s_ref, p_idx) - 1;
1687
1
          assert(s_idx >= 0);
1688
1
          const ccv_nnc_autograd_tensor_symbol_t* const autograd_symbol = _ccv_nnc_autograd_tensor_symbol_from_tensor_version(sub_prep->autograd_tensor_symbols, sub_prep->autograd_tensor_versions + s_idx);
1689
1
          ccv_nnc_tensor_symbol_hookup(graph, sub_graph, *(ccv_nnc_tensor_symbol_t*)ccv_array_get(symbols, k), autograd_symbol->symbol);
1690
1
          ++k;
1691
1
        }
1692
1
      ccv_nnc_graph_exec_symbol_set_io(graph, back_exec->symbol, ccv_array_get(symbols, 0), input_size, ccv_array_get(symbols, input_size), output_size);
1693
19.1k
    } else if (forw_exec->flags & CCV_NNC_GRAPH_EXEC_CASE_OF) {
1694
1
      ccv_array_clear(symbol_map);
1695
2
      for (j = 0; j < back_exec->output_size; 
j++1
)
1696
1
        if (back_info->output_bitmasks[j >> 6] & ((uint64_t)1 << j))
1697
1
        {
1698
1
          ccv_nnc_tensor_symbol_map_t symbol = {
1699
1
            .source = ((ccv_nnc_autograd_tensor_symbol_t*)ccv_array_get(autograd_tensor_symbols, back_exec->inputs[j]))->symbol,
1700
1
            .destination = ((ccv_nnc_autograd_tensor_symbol_t*)ccv_array_get(autograd_tensor_symbols, back_exec->outputs[j]))->symbol,
1701
1
          };
1702
1
          ccv_array_push(symbol_map, &symbol);
1703
1
        }
1704
1
      const int symbol_map_size = symbol_map->rnum;
1705
1
      back_exec->symbol = ccv_nnc_symbolic_graph_case_of_new(graph, back_exec->cmd.cmd, 0, 0, ccv_array_get(symbol_map, 0), symbol_map_size, forw_exec->name);
1706
1
      ccv_nnc_symbolic_graph_set_case_of_expr(graph, back_exec->symbol, NOOP_GRAPH_CASE_OF_EXPR, 0);
1707
4
      for (p = 0; p < forw_exec->graph_ref_size; 
p++3
)
1708
3
      {
1709
3
        const int graph_ref = CCV_NNC_GRAPH_REF(forw_exec)[p] - 1;
1710
3
        ccv_nnc_symbolic_graph_backward_prep_t* sub_prep = backward_prep->sub_preps + graph_ref;
1711
3
        ccv_nnc_symbolic_graph_t* sub_graph = ccv_nnc_symbolic_graph_new();
1712
3
        sub_graph->pair = sub_prep->graph;
1713
3
        if (!sub_wrt_symbols)
1714
1
          sub_wrt_symbols = ccv_array_new(sizeof(ccv_nnc_tensor_symbol_t), 0, 0);
1715
        // I am done, need to redo above for sub_prep, and it has to be successful now.
1716
3
        if (!sub_f_symbols)
1717
1
          sub_f_symbols = ccv_array_new(sizeof(ccv_nnc_tensor_symbol_t), 0, 0);
1718
3
        _ccv_nnc_symbolic_graph_backward_prep_sub_f_wrt_symbols(forw_exec, sub_prep->graph, graph_ref, tensor_symbol_info, back_info->input_bitmasks, back_info->output_bitmasks, sub_f_symbols, sub_wrt_symbols);
1719
3
        _ccv_nnc_symbolic_graph_backward_gen(sub_prep, (ccv_nnc_tensor_symbol_t*)ccv_array_get(sub_f_symbols, 0), sub_f_symbols->rnum, (ccv_nnc_tensor_symbol_t*)ccv_array_get(sub_wrt_symbols, 0), sub_wrt_symbols->rnum, sub_graph, root);
1720
3
        ccv_array_clear(symbol_map);
1721
3
        k = 0;
1722
6
        for (j = 0; j < back_exec->output_size; 
j++3
)
1723
3
          if (back_info->output_bitmasks[j >> 6] & ((uint64_t)1 << j))
1724
3
          {
1725
3
            const int d = ((ccv_nnc_tensor_symbol_t*)ccv_array_get(sub_wrt_symbols, k))->d;
1726
3
            if (d >= 0)
1727
1
            {
1728
1
              const ccv_nnc_autograd_tensor_symbol_t* const autograd_symbol = _ccv_nnc_autograd_tensor_symbol_from_tensor_version(sub_prep->autograd_tensor_symbols, sub_prep->autograd_tensor_versions + d);
1729
1
              ccv_nnc_tensor_symbol_map_t symbol = {
1730
1
                .source = autograd_symbol->symbol,
1731
1
                .destination = ((ccv_nnc_autograd_tensor_symbol_t*)ccv_array_get(autograd_tensor_symbols, back_exec->outputs[j]))->symbol,
1732
1
              };
1733
1
              ccv_array_push(symbol_map, &symbol);
1734
2
            } else {
1735
              // Create a new tensor in sub-graph and set it to be 0.
1736
2
              const ccv_nnc_autograd_tensor_symbol_t* const autograd_symbol = (ccv_nnc_autograd_tensor_symbol_t*)ccv_array_get(autograd_tensor_symbols, back_exec->outputs[j]);
1737
              // autograd_symbol->d points to the corresponding forward tensor.
1738
2
              ccv_nnc_tensor_symbol_t zero_symbol = ccv_nnc_tensor_symbol_new(sub_graph, tensor_symbol_info[autograd_symbol->d].info, 0);
1739
2
              ccv_nnc_graph_exec_symbol_new(sub_graph, CMD_SET_FORWARD(0), 0, 0, &zero_symbol, 1, 0);
1740
2
              ccv_nnc_tensor_symbol_map_t symbol = {
1741
2
                .source = zero_symbol,
1742
2
                .destination = autograd_symbol->symbol,
1743
2
              };
1744
2
              ccv_array_push(symbol_map, &symbol);
1745
2
            }
1746
3
            ++k;
1747
3
          }
1748
3
        ccv_nnc_graph_exec_symbol_autogen(sub_graph, 0, 0, CCV_NNC_AUTOGEN_SOURCES_AND_DESTINATIONS);
1749
3
        const int symbol_map_size = symbol_map->rnum;
1750
3
        ccv_nnc_symbolic_graph_set_case_of(graph, back_exec->symbol, sub_graph, p, ccv_array_get(symbol_map, 0), symbol_map_size);
1751
        // Hookup input only after this becomes a sub graph of the graph.
1752
3
        k = 0;
1753
6
        for (j = 0; j < back_exec->input_size; 
j++3
)
1754
3
          if (back_info->input_bitmasks[j >> 6] & ((uint64_t)1 << j))
1755
3
          {
1756
3
            const int d = ((ccv_nnc_tensor_symbol_t*)ccv_array_get(sub_f_symbols, k))->d;
1757
3
            assert(d >= 0);
1758
            // No corresponding sub tensors allocated. Skip.
1759
3
            if (!sub_prep->autograd_tensor_versions[d].ref_version ||
1760
3
              
!sub_prep->autograd_tensor_versions[d].ref_version->rnum1
)
1761
2
              continue;
1762
1
            const ccv_nnc_autograd_tensor_symbol_t* const autograd_symbol = _ccv_nnc_autograd_tensor_symbol_from_tensor_version(sub_prep->autograd_tensor_symbols, sub_prep->autograd_tensor_versions + d);
1763
1
            ccv_nnc_tensor_symbol_hookup(graph, sub_graph, ((ccv_nnc_autograd_tensor_symbol_t*)ccv_array_get(autograd_tensor_symbols, back_exec->inputs[j]))->symbol, autograd_symbol->symbol);
1764
1
            ++k;
1765
1
          }
1766
        // Need to make sure tape vars are hooked up.
1767
3
        _ccv_nnc_symbolic_graph_add_tape_vars(sub_prep, root, graph, sub_graph, 0);
1768
3
      }
1769
19.1k
    } else {
1770
19.1k
      ccv_array_clear(symbols);
1771
      // Gradient inputs.
1772
38.6k
      for (j = 0; j < back_exec->input_size; 
j++19.5k
)
1773
19.5k
        if (back_info->input_bitmasks[j >> 6] & ((uint64_t)1 << j))
1774
19.1k
          ccv_array_push(symbols, &(((ccv_nnc_autograd_tensor_symbol_t*)ccv_array_get(autograd_tensor_symbols, back_exec->inputs[j]))->symbol));
1775
424
        else
1776
424
          ccv_array_push(symbols, &NO_TENSOR_SYMBOL);
1777
      // Inputs from forward function.
1778
53.8k
      for (j = 0; j < forw_exec->input_size; 
j++34.7k
)
1779
34.7k
        if (!(back_info->input_bitmasks[(j + back_exec->input_size) >> 6] & ((uint64_t)1 << (j + back_exec->input_size))))
1780
14.4k
          ccv_array_push(symbols, &NO_TENSOR_SYMBOL);
1781
20.2k
        else {
1782
20.2k
          const ccv_nnc_tensor_symbol_t symbol = {
1783
20.2k
            .d = forw_exec->inputs[j],
1784
20.2k
            .graph = backward_prep->graph
1785
20.2k
          };
1786
20.2k
          if (graph == backward_prep->graph)
1787
20.2k
            ccv_array_push(symbols, &symbol);
1788
5
          else { // Otherwise, create a new symbol, and set its pair to the old symbol.
1789
5
            const ccv_nnc_tensor_symbol_t new_symbol = ccv_nnc_tensor_symbol_new(graph, tensor_symbol_info[forw_exec->inputs[j]].info, tensor_symbol_info[forw_exec->inputs[j]].name);
1790
5
            ccv_nnc_tensor_symbol_pair_with(graph, new_symbol, symbol);
1791
5
            const int flags = ccv_nnc_tensor_symbol_flags(backward_prep->graph, symbol) | CCV_NNC_TENSOR_SYMBOL_TAPE_VAR;
1792
5
            ccv_nnc_tensor_symbol_set_flags(graph, new_symbol, flags);
1793
5
            ccv_nnc_tensor_symbol_set_flags(backward_prep->graph, symbol, flags);
1794
5
            ccv_array_push(symbols, &new_symbol);
1795
5
          }
1796
20.2k
        }
1797
      // Outputs from forward function.
1798
38.6k
      for (j = 0; j < forw_exec->output_size; 
j++19.5k
)
1799
19.5k
        if (!(back_info->input_bitmasks[(j + back_exec->input_size + forw_exec->input_size) >> 6] & ((uint64_t)1 << (j + back_exec->input_size + forw_exec->input_size))))
1800
14.3k
          ccv_array_push(symbols, &NO_TENSOR_SYMBOL);
1801
5.14k
        else {
1802
5.14k
          const ccv_nnc_tensor_symbol_t symbol = {
1803
5.14k
            .d = forw_exec->outputs[j],
1804
5.14k
            .graph = backward_prep->graph
1805
5.14k
          };
1806
5.14k
          if (graph == backward_prep->graph)
1807
5.14k
            ccv_array_push(symbols, &symbol);
1808
2
          else { // Otherwise, create a new symbol, and set its pair to the old symbol.
1809
2
            const ccv_nnc_tensor_symbol_t new_symbol = ccv_nnc_tensor_symbol_new(graph, tensor_symbol_info[forw_exec->outputs[j]].info, tensor_symbol_info[forw_exec->outputs[j]].name);
1810
2
            ccv_nnc_tensor_symbol_pair_with(graph, new_symbol, symbol);
1811
2
            const int flags = ccv_nnc_tensor_symbol_flags(backward_prep->graph, symbol) | CCV_NNC_TENSOR_SYMBOL_TAPE_VAR;
1812
2
            ccv_nnc_tensor_symbol_set_flags(graph, new_symbol, flags);
1813
2
            ccv_nnc_tensor_symbol_set_flags(backward_prep->graph, symbol, flags);
1814
2
            ccv_array_push(symbols, &new_symbol);
1815
2
          }
1816
5.14k
        }
1817
53.8k
      for (j = 0; j < back_exec->output_size; 
j++34.7k
)
1818
34.7k
        if (back_info->output_bitmasks[j >> 6] & ((uint64_t)1 << j))
1819
26.1k
          ccv_array_push(symbols, &(((ccv_nnc_autograd_tensor_symbol_t*)ccv_array_get(autograd_tensor_symbols, back_exec->outputs[j]))->symbol));
1820
8.59k
        else
1821
8.59k
          ccv_array_push(symbols, &NO_TENSOR_SYMBOL);
1822
19.1k
      back_exec->symbol = ccv_nnc_graph_exec_symbol_new(graph, back_exec->cmd, ccv_array_get(symbols, 0), back_exec->input_size + forw_exec->input_size + forw_exec->output_size, ccv_array_get(symbols, back_exec->input_size + forw_exec->input_size + forw_exec->output_size), back_exec->output_size, 0);
1823
19.1k
      ccv_nnc_graph_exec_symbol_set_hint(graph, back_exec->symbol, exec_symbol_info[i].hint);
1824
19.1k
      ccv_nnc_graph_exec_symbol_pair_with(graph, back_exec->symbol, (ccv_nnc_graph_exec_symbol_t){
1825
19.1k
        .d = i,
1826
19.1k
        .graph = backward_prep->graph,
1827
19.1k
      });
1828
19.1k
    }
1829
19.1k
  }
1830
6.78k
  if (sub_f_symbols)
1831
2
    ccv_array_free(sub_f_symbols);
1832
6.78k
  if (sub_wrt_symbols)
1833
2
    ccv_array_free(sub_wrt_symbols);
1834
6.78k
  if (sub_execs)
1835
1
    ccv_array_free(sub_execs);
1836
6.78k
  ccv_array_t* const sum_or_set_execs = backward_prep->sum_or_set_execs;
1837
11.0k
  for (i = 0; i < sum_or_set_execs->rnum; 
i++4.28k
)
1838
4.28k
  {
1839
4.28k
    ccv_nnc_sum_or_set_graph_exec_symbol_t* sum_or_set_exec = (ccv_nnc_sum_or_set_graph_exec_symbol_t*)ccv_array_get(sum_or_set_execs, i);
1840
    // It is sum, set don't have inputs.
1841
4.28k
    if (sum_or_set_exec->input_size)
1842
4.28k
    {
1843
4.28k
      ccv_array_clear(symbols);
1844
      // This is to sum.
1845
12.8k
      for (j = 0; j < sum_or_set_exec->input_size; 
j++8.59k
)
1846
8.59k
        ccv_array_push(symbols, &(((ccv_nnc_autograd_tensor_symbol_t*)ccv_array_get(autograd_tensor_symbols, sum_or_set_exec->inputs[j]))->symbol));
1847
4.28k
      ccv_nnc_cmd_t cmd = ccv_nnc_cmd(CCV_NNC_EWSUM_FORWARD, 0, CMD_GENERIC(), 0);
1848
4.28k
      sum_or_set_exec->symbol = ccv_nnc_graph_exec_symbol_new(graph, cmd, ccv_array_get(symbols, 0), sum_or_set_exec->input_size, &(((ccv_nnc_autograd_tensor_symbol_t*)ccv_array_get(autograd_tensor_symbols, sum_or_set_exec->output))->symbol), 1, 0);
1849
4.28k
    } else
1850
1
      sum_or_set_exec->symbol = ccv_nnc_graph_exec_symbol_new(graph, CMD_SET_FORWARD(sum_or_set_exec->value), 0, 0, &(((ccv_nnc_autograd_tensor_symbol_t*)ccv_array_get(autograd_tensor_symbols, sum_or_set_exec->output))->symbol), 1, 0);
1851
4.28k
  }
1852
6.78k
  ccv_array_free(symbol_map);
1853
6.78k
  ccv_array_free(symbols);
1854
25.9k
  for (i = 0; i < exec_symbol_info_size; 
i++19.1k
)
1855
19.1k
  {
1856
    // This is not going to be an interesting node. Skip.
1857
19.1k
    if ((backward_info[i].f_wrt & 0x3) != 0x3)
1858
86
      continue;
1859
19.1k
    ccv_nnc_autograd_graph_exec_symbol_t* const back_exec = autograd_execs + i;
1860
    // If on the same graph, we cannot decide whether it is before or after the forw_exec, enforcing it is after forw_exec.
1861
19.1k
    if (graph == backward_prep->graph)
1862
19.1k
      ccv_nnc_graph_exec_symbol_concat(graph, (ccv_nnc_graph_exec_symbol_t){
1863
19.1k
        .d = i,
1864
19.1k
        .graph = graph
1865
19.1k
      }, back_exec->symbol);
1866
19.1k
    if (back_exec->outgoings)
1867
24.7k
      
for (j = 0; 12.3k
j < back_exec->outgoings->rnum;
j++12.4k
)
1868
12.4k
      {
1869
12.4k
        int d = *(int*)ccv_array_get(back_exec->outgoings, j);
1870
12.4k
        if (d < exec_symbol_info_size)
1871
8.08k
          ccv_nnc_graph_exec_symbol_concat(graph, back_exec->symbol, autograd_execs[d].symbol);
1872
4.36k
        else
1873
4.36k
          ccv_nnc_graph_exec_symbol_concat(graph, back_exec->symbol, ((ccv_nnc_sum_or_set_graph_exec_symbol_t*)ccv_array_get(sum_or_set_execs, d - exec_symbol_info_size))->symbol);
1874
12.4k
      }
1875
19.1k
  }
1876
11.0k
  for (i = 0; i < sum_or_set_execs->rnum; 
i++4.28k
)
1877
4.28k
  {
1878
4.28k
    ccv_nnc_sum_or_set_graph_exec_symbol_t* exec = (ccv_nnc_sum_or_set_graph_exec_symbol_t*)ccv_array_get(sum_or_set_execs, i);
1879
4.28k
    if (exec->outgoings)
1880
8.50k
      
for (j = 0; 4.25k
j < exec->outgoings->rnum;
j++4.25k
)
1881
4.25k
      {
1882
4.25k
        int d = *(int*)ccv_array_get(exec->outgoings, j);
1883
4.25k
        if (d < exec_symbol_info_size)
1884
4.25k
          ccv_nnc_graph_exec_symbol_concat(graph, exec->symbol, autograd_execs[d].symbol);
1885
0
        else
1886
0
          ccv_nnc_graph_exec_symbol_concat(graph, exec->symbol, ((ccv_nnc_sum_or_set_graph_exec_symbol_t*)ccv_array_get(sum_or_set_execs, d - exec_symbol_info_size))->symbol);
1887
4.25k
      }
1888
4.28k
  }
1889
  // Now, everything is done, set the metadata on graph so that we can lookup later for backward symbols
1890
6.78k
  if (graph->backward.tensor_symbol_idx)
1891
4.40k
    graph->backward.tensor_symbol_idx = (int*)ccrealloc(graph->backward.tensor_symbol_idx, sizeof(int) * (graph->tensor_symbol_info->rnum + tensor_symbol_info_size));
1892
2.37k
  else
1893
2.37k
    graph->backward.tensor_symbol_idx = (int*)ccmalloc(sizeof(int) * (graph->tensor_symbol_info->rnum + tensor_symbol_info_size));
1894
6.78k
  graph->backward.tensor_symbol_size = tensor_symbol_info_size;
1895
6.78k
  graph->backward.exec_symbol_idx = graph->backward.tensor_symbol_idx + tensor_symbol_info_size;
1896
6.78k
  graph->backward.exec_symbol_size = graph->tensor_symbol_info->rnum;
1897
46.4k
  for (i = 0; i < tensor_symbol_info_size; 
i++39.6k
)
1898
39.6k
    graph->backward.tensor_symbol_idx[i] = -1;
1899
85.8k
  for (i = 0; i < graph->backward.exec_symbol_size; 
i++79.1k
)
1900
79.1k
    graph->backward.exec_symbol_idx[i] = -1;
1901
6.78k
  ccv_nnc_autograd_tensor_version_t* const autograd_tensor_versions = backward_prep->autograd_tensor_versions;
1902
  // Assigning for wrt symbols.
1903
16.3k
  for (i = 0; i < wrt_symbol_size; 
i++9.53k
)
1904
9.53k
  {
1905
9.53k
    const int d = wrt_symbols[i].d;
1906
9.53k
    if (d < 0)
1907
9
      continue;
1908
9.52k
    assert(d < tensor_symbol_info_size);
1909
9.52k
    const ccv_nnc_tensor_symbol_info_t* const forw_symbol = tensor_symbol_info + d;
1910
9.52k
    ccv_nnc_autograd_tensor_version_t* const tensor_ver = autograd_tensor_versions + ((!forw_symbol->alias_ref) ? 
d9.52k
:
forw_symbol->alias_ref - 11
);
1911
9.52k
    assert(tensor_ver->ref_version);
1912
9.52k
    ccv_nnc_tensor_ref_t* const tensor_ref = (ccv_nnc_tensor_ref_t*)ccv_array_get(tensor_ver->ref_version, tensor_ver->c);
1913
9.52k
    ccv_nnc_autograd_tensor_symbol_t* const autograd_symbol = (ccv_nnc_autograd_tensor_symbol_t*)ccv_array_get(autograd_tensor_symbols, tensor_ref->d);
1914
    // If this wrt symbol is an alias, create extra alias for this.
1915
9.52k
    if (!forw_symbol->alias_ref)
1916
9.52k
      graph->backward.tensor_symbol_idx[d] = autograd_symbol->symbol.d;
1917
1
    else // We create new alias, and this cannot be referenced from exec_symbol_idx because its size limited to previous tensor symbol size.
1918
1
      graph->backward.tensor_symbol_idx[d] = ccv_nnc_tensor_symbol_alias_new(graph, autograd_symbol->symbol, forw_symbol->ofs, forw_symbol->stride, forw_symbol->info, 0).d;
1919
9.52k
    const int dd = autograd_symbol->symbol.d;
1920
9.52k
    const int x = tensor_ref->x;
1921
9.52k
    if (tensor_ref->exec_registry && 
tensor_ref->exec_registry->rnum2
) // Create no-op node.
1922
2
    {
1923
2
      ccv_nnc_graph_exec_symbol_t noop = ccv_nnc_graph_exec_symbol_new(graph, ccv_nnc_cmd(CCV_NNC_NOOP, 0, CMD_GENERIC(), 0), 0, 0, 0, 0, 0);
1924
2
      if (x < exec_symbol_info_size)
1925
2
        ccv_nnc_graph_exec_symbol_concat(graph, autograd_execs[x].symbol, noop);
1926
0
      else
1927
0
        ccv_nnc_graph_exec_symbol_concat(graph, ((ccv_nnc_sum_or_set_graph_exec_symbol_t*)ccv_array_get(sum_or_set_execs, x - exec_symbol_info_size))->symbol, noop);
1928
6
      for (j = 0; j < tensor_ref->exec_registry->rnum; 
j++4
)
1929
4
      {
1930
4
        const int x = *(int*)ccv_array_get(tensor_ref->exec_registry, j);
1931
4
        assert(x >= 0); /* Otherwise, this is initialization tensor, which is impossible to be summed up by. */
1932
4
        assert(x < exec_symbol_info_size); // exec_registry is only used by alias_registry, it simply cannot reference to a sum operation.
1933
4
        ccv_nnc_graph_exec_symbol_concat(graph, autograd_execs[x].symbol, noop);
1934
4
      }
1935
2
      graph->backward.exec_symbol_idx[dd] = noop.d;
1936
9.52k
    } else {
1937
9.52k
      if (x < exec_symbol_info_size)
1938
9.49k
        graph->backward.exec_symbol_idx[dd] = autograd_execs[x].symbol.d;
1939
33
      else
1940
33
        graph->backward.exec_symbol_idx[dd] = ((ccv_nnc_sum_or_set_graph_exec_symbol_t*)ccv_array_get(sum_or_set_execs, x - exec_symbol_info_size))->symbol.d;
1941
9.52k
    }
1942
9.52k
  }
1943
  // Assigning for f symbols.
1944
13.5k
  
for (i = 0; 6.78k
i < f_symbol_size;
i++6.79k
)
1945
6.79k
  {
1946
6.79k
    const int d = f_symbols[i].d;
1947
6.79k
    assert(d >= 0);
1948
6.79k
    assert(d < tensor_symbol_info_size);
1949
6.79k
    const ccv_nnc_autograd_tensor_version_t* const tensor_ver = autograd_tensor_versions + d;
1950
6.79k
    if (tensor_ver->ref_version)
1951
6.79k
    {
1952
      // We don't use _ccv_nnc_autograd_tensor_symbol_from_tensor_version because that select the last version, but for us, we need the first version.
1953
6.79k
      const ccv_nnc_tensor_ref_t* const tensor_ref = (ccv_nnc_tensor_ref_t*)ccv_array_get(tensor_ver->ref_version, 0);
1954
6.79k
      const ccv_nnc_autograd_tensor_symbol_t* const autograd_symbol = (ccv_nnc_autograd_tensor_symbol_t*)ccv_array_get(autograd_tensor_symbols, tensor_ref->d);
1955
6.79k
      graph->backward.tensor_symbol_idx[d] = autograd_symbol->symbol.d;
1956
      // Cannot find relevant backward exec symbols for f, it could be many.
1957
6.79k
    }
1958
6.79k
  }
1959
6.78k
}
1960
1961
void ccv_nnc_symbolic_graph_backward(ccv_nnc_symbolic_graph_t* const graph, const ccv_nnc_tensor_symbol_t* const f_symbols, const int f_symbol_size, const ccv_nnc_tensor_symbol_t* const wrt_symbols, const int wrt_symbol_size, const ccv_nnc_graph_exec_symbol_t* const sources, const int source_size, const ccv_nnc_graph_exec_symbol_t* const destinations, const int destination_size)
1962
6.77k
{
1963
6.77k
  int i;
1964
  // f symbols cannot be alias.
1965
13.5k
  for (i = 0; i < f_symbol_size; 
i++6.79k
)
1966
6.79k
    if (f_symbols[i].d >= 0)
1967
6.79k
    {
1968
6.79k
      assert(f_symbols[i].graph == graph); // f symbol has to be in the current graph.
1969
6.79k
      assert(!((ccv_nnc_tensor_symbol_info_t*)ccv_array_get(graph->tensor_symbol_info, f_symbols[i].d))->alias_ref);
1970
6.79k
    }
1971
16.3k
  
for (i = 0; 6.77k
i < wrt_symbol_size;
i++9.53k
)
1972
9.53k
    if (wrt_symbols[i].d >= 0)
1973
9.52k
    {
1974
9.52k
      assert(wrt_symbols[i].graph == graph);
1975
      // This is not an alias, or what it refers to is not an alias.
1976
9.52k
      assert(!((ccv_nnc_tensor_symbol_info_t*)ccv_array_get(graph->tensor_symbol_info, wrt_symbols[i].d))->alias_ref || !((ccv_nnc_tensor_symbol_info_t*)ccv_array_get(graph->tensor_symbol_info, ((ccv_nnc_tensor_symbol_info_t*)ccv_array_get(graph->tensor_symbol_info, wrt_symbols[i].d))->alias_ref - 1))->alias_ref);
1977
9.52k
    }
1978
6.77k
  const int exec_symbol_info_size = graph->exec_symbol_info->rnum;
1979
6.77k
  const int tensor_symbol_info_size = graph->tensor_symbol_info->rnum;
1980
6.77k
  assert(exec_symbol_info_size > 0);
1981
6.77k
  assert(tensor_symbol_info_size > 0);
1982
6.77k
  ccv_nnc_symbolic_graph_backward_prep_t backward_prep = _ccv_nnc_symbolic_graph_backward_prep(graph, sources, source_size, destinations, destination_size);
1983
6.77k
  _ccv_nnc_symbolic_graph_backward_prep_prune_ops(&backward_prep, f_symbols, f_symbol_size, wrt_symbols, wrt_symbol_size, sources, source_size, destinations, destination_size);
1984
6.77k
  _ccv_nnc_symbolic_graph_backward_prep_gen(&backward_prep, f_symbols, f_symbol_size, wrt_symbols, wrt_symbol_size, 0, sources, source_size, destinations, destination_size);
1985
6.77k
  _ccv_nnc_symbolic_graph_backward_gen(&backward_prep, f_symbols, f_symbol_size, wrt_symbols, wrt_symbol_size, graph, graph);
1986
6.77k
  _ccv_nnc_symbolic_graph_backward_prep_free(backward_prep);
1987
6.77k
}
1988
1989
ccv_nnc_tensor_symbol_t ccv_nnc_tensor_symbol_for_backward(const ccv_nnc_symbolic_graph_t* const graph, const ccv_nnc_tensor_symbol_t symbol)
1990
27.6k
{
1991
27.6k
  assert(symbol.d >= 0);
1992
27.6k
  assert(symbol.d < graph->backward.tensor_symbol_size);
1993
27.6k
  if (graph->backward.tensor_symbol_idx[symbol.d] < 0)
1994
10
    return NO_TENSOR_SYMBOL;
1995
27.6k
  ccv_nnc_tensor_symbol_t tensor = {
1996
27.6k
    .d = graph->backward.tensor_symbol_idx[symbol.d],
1997
27.6k
    .graph = graph,
1998
27.6k
  };
1999
27.6k
  return tensor;
2000
27.6k
}
2001
2002
ccv_nnc_graph_exec_symbol_t ccv_nnc_graph_exec_symbol_for_backward(const ccv_nnc_symbolic_graph_t* const graph, const ccv_nnc_tensor_symbol_t symbol)
2003
17.7k
{
2004
17.7k
  assert(symbol.d >= 0);
2005
17.7k
  assert(symbol.d < graph->tensor_symbol_info->rnum);
2006
17.7k
  int dd = symbol.d;
2007
  // Check if this is an alias. Use the original if it is.
2008
17.7k
  ccv_nnc_tensor_symbol_info_t* const symbol_info = (ccv_nnc_tensor_symbol_info_t*)ccv_array_get(graph->tensor_symbol_info, dd);
2009
17.7k
  if (symbol_info->alias_ref)
2010
2
    dd = symbol_info->alias_ref - 1;
2011
17.7k
  assert(dd >= 0);
2012
17.7k
  assert(dd < graph->backward.exec_symbol_size);
2013
17.7k
  if (graph->backward.exec_symbol_idx[dd] < 0)
2014
0
    return (ccv_nnc_graph_exec_symbol_t){
2015
0
      .graph = 0,
2016
0
      .d = CCV_NNC_NO_GRAPH_EXEC_SYMBOL
2017
0
    };
2018
17.7k
  ccv_nnc_graph_exec_symbol_t exec = {
2019
17.7k
    .d = graph->backward.exec_symbol_idx[dd],
2020
17.7k
    .graph = graph
2021
17.7k
  };
2022
17.7k
  return exec;
2023
17.7k
}