Coverage Report

Created: 2022-08-03 23:52

/home/liu/buildslave/linux-x64-runtests/build/lib/nnc/ccv_nnc_symbolic_graph_backward.c
Line
Count
Source (jump to first uncovered line)
1
#include "ccv_nnc.h"
2
#include "ccv_nnc_easy.h"
3
#include "ccv_nnc_internal.h"
4
#include "ccv_internal.h"
5
#include "_ccv_nnc_symbolic_graph.h"
6
7
// MARK - Level-3.5 API
8
9
typedef struct {
10
  int f_wrt; // Check if both f_symbols and wrt_symbols flow through this node.
11
  ccv_array_t* outgoings; // backward traverse nodes.
12
  uint64_t* input_bitmasks;
13
  int input_bitmask_size;
14
  uint64_t* output_bitmasks;
15
  int output_bitmask_size;
16
} ccv_nnc_graph_backward_info_t;
17
18
typedef struct {
19
  int input_size;
20
  int* inputs;
21
  int output;
22
  ccv_array_t* outgoings;
23
  float value;
24
  ccv_nnc_graph_exec_symbol_t symbol;
25
} ccv_nnc_sum_or_set_graph_exec_symbol_t;
26
27
typedef struct {
28
  int input_size;
29
  int output_size;
30
  int* inputs;
31
  int* outputs;
32
  ccv_array_t* outgoings;
33
  ccv_nnc_cmd_t cmd;
34
  ccv_nnc_graph_exec_symbol_t symbol;
35
} ccv_nnc_autograd_graph_exec_symbol_t;
36
37
typedef struct {
38
  int d; // The pointer to the forward level object.
39
  int alias_ref; // The alias ref to itself (autograd_tensor_symbols array).
40
  int flags; // Flags for this symbol.
41
  ccv_nnc_tensor_symbol_t symbol;
42
} ccv_nnc_autograd_tensor_symbol_t;
43
44
typedef struct {
45
  int d; // The tensor symbol ref.
46
  int x; // The exec symbol ref.
47
  ccv_array_t* exec_registry; // Additional exec symbol refs, similar to x, only useful for aliasing.
48
  ccv_array_t* alias_registry; // int point to all the alias (if this is not an alias). The alias is the object in autograd_tensor_symbols, you need another level of indirection to get the actual forward level alias.
49
} ccv_nnc_tensor_ref_t;
50
51
typedef struct {
52
  int c; // The start non-accumulated version.
53
  ccv_array_t* ref_version; // tensor ref point to the reverse tensor symbol.
54
} ccv_nnc_autograd_tensor_version_t;
55
56
typedef struct {
57
  int d;
58
  int alias_ref;
59
} ccv_nnc_sum_variable_t;
60
61
// This method tries to figure out if a set of aliases can cover the whole tensor dim.
62
// This is not a precise implementation though. The requirement is to answer this question
63
// with a given memory constraint, therefore, only allow up to 65536 different tensor locations.
64
// If you have more than that, it will assume that it doesn't have fully assigned aliases,
65
// and will return 0.
66
67
// Return 1 if inserted successfully.
68
static inline int _ccv_nnc_try_mix(int* const md, const int ins, const int c)
69
43
{
70
43
  if (!c)
71
25
  {
72
25
    md[0] = ins;
73
25
    return 1;
74
25
  }
75
18
  int ll = 0, uu = c - 1;
76
18
  int mm;
77
20
  do {
78
20
    mm = ll + ((uu - ll) >> 1);
79
20
    if (ins == md[mm])
80
16
      return 0;
81
4
    else if (ins < md[mm])
82
2
      uu = mm - 1;
83
2
    else if (ins > md[mm])
84
2
      ll = mm + 1;
85
20
  } while (
ll <= uu4
);
86
2
  if (ll < c)
87
2
    memmove(md + ll + 1, md + ll, sizeof(int) * (c - ll));
88
2
  md[ll] = ins;
89
2
  return 1;
90
18
}
91
92
static inline int _ccv_nnc_mix_idx(const int* const md, const int ins, const int c)
93
26
{
94
26
  if (c <= 1)
95
18
    return 0;
96
8
  int ll = 0, uu = c - 1;
97
8
  int mm;
98
14
  do {
99
14
    mm = ll + ((uu - ll) >> 1);
100
14
    if (ins == md[mm])
101
8
      return mm;
102
6
    else if (ins < md[mm])
103
0
      uu = mm - 1;
104
6
    else if (ins > md[mm])
105
6
      ll = mm + 1;
106
14
  } while (
ll <= uu6
);
107
0
  assert(0 && "Shouldn't reach here");
108
0
  return -1;
109
0
}
110
111
static inline void _ccv_nnc_try_set_pix_0(const int* const ofs, const int* const dim, const int* const tensor_dim, int* const* const scmd, const int* const cube_dim, const int* const cube_step, uint32_t* const cube, int offset)
112
6
{
113
6
  const int s = (ofs[0] == 0) ? 
03
:
_ccv_nnc_mix_idx(scmd[0], ofs[0], cube_dim[0]) + 13
;
114
6
  const int d = ((ofs[0] + dim[0] == tensor_dim[0]) ? 
cube_dim[0]3
:
_ccv_nnc_mix_idx(scmd[0], ofs[0] + 3
ccv_max3
(1, dim[0]), cube_dim[0])) + 1;
115
6
  assert(s >= 0 && d > s);
116
6
  int i;
117
12
  for (i = s; i < d; 
i++6
)
118
    // Fill this pix. I can make this faster by loop through full ones (divided by 8), but too lazy.
119
6
    cube[(offset + i) >> 5] |= (1u << ((offset + i) & 0x1f));
120
6
}
121
122
static inline void _ccv_nnc_try_set_pix_1(const int* const ofs, const int* const dim, const int* const tensor_dim, int* const* const scmd, const int* const cube_dim, const int* const cube_step, uint32_t* const cube, int offset)
123
14
{
124
14
  const int s0 = (ofs[0] == 0) ? 
012
:
_ccv_nnc_mix_idx(scmd[0], ofs[0], cube_dim[0]) + 12
;
125
14
  const int d0 = ((ofs[0] + dim[0] == tensor_dim[0]) ? 
cube_dim[0]12
:
_ccv_nnc_mix_idx(scmd[0], ofs[0] + 2
ccv_max2
(1, dim[0]), cube_dim[0])) + 1;
126
14
  assert(s0 >= 0 && d0 > s0);
127
14
  const int s1 = (ofs[1] == 0) ? 
08
:
_ccv_nnc_mix_idx(scmd[1], ofs[1], cube_dim[1]) + 16
;
128
14
  const int d1 = ((ofs[1] + dim[1] == tensor_dim[1]) ? 
cube_dim[1]8
:
_ccv_nnc_mix_idx(scmd[1], ofs[1] + 6
ccv_max6
(1, dim[1]), cube_dim[1])) + 1;
129
14
  assert(s1 >= 0 && d1 > s1);
130
14
  int i, j;
131
14
  const int step1 = cube_step[1];
132
14
  if (step1 == d0 - s0)
133
10
  {
134
    // Faster one, we can simply loop through.
135
22
    for (i = s1 * step1; i < d1 * step1; 
i++12
)
136
12
      cube[(offset + i) >> 5] |= (1u << ((offset + i) & 0x1f));
137
10
  } else {
138
4
    offset += s1 * step1;
139
    // There are gaps, slow one.
140
8
    for (i = s1; i < d1; 
i++, offset += step14
)
141
8
      
for (j = s0; 4
j < d0;
j++4
)
142
4
        cube[(offset + j) >> 5] |= (1u << ((offset + j) & 0x1f));
143
4
  }
144
14
}
145
146
static inline void _ccv_nnc_try_set_pix(const int* const ofs, const int* const dim, const int* const tensor_dim, int* const* const scmd, const int* const cube_dim, const int* const cube_step, uint32_t* const cube, int offset, const int dim_idx)
147
24
{
148
24
  switch (dim_idx)
149
24
  {
150
14
    case 1:
151
14
      _ccv_nnc_try_set_pix_1(ofs, dim, tensor_dim, scmd, cube_dim, cube_step, cube, offset);
152
14
      return;
153
6
    case 0:
154
6
      _ccv_nnc_try_set_pix_0(ofs, dim, tensor_dim, scmd, cube_dim, cube_step, cube, offset);
155
6
      return;
156
24
  }
157
4
  int i;
158
4
  const int s = (ofs[dim_idx] == 0) ? 
02
:
_ccv_nnc_mix_idx(scmd[dim_idx], ofs[dim_idx], cube_dim[dim_idx]) + 12
;
159
4
  const int d = ((ofs[dim_idx] + dim[dim_idx] == tensor_dim[dim_idx]) ? 
cube_dim[dim_idx]2
:
_ccv_nnc_mix_idx(scmd[dim_idx], ofs[dim_idx] + 2
ccv_max2
(1, dim[dim_idx]), cube_dim[dim_idx])) + 1;
160
4
  assert(s >= 0 && d > s);
161
8
  
for (i = s; 4
i < d;
i++4
)
162
4
    _ccv_nnc_try_set_pix(ofs, dim, tensor_dim, scmd, cube_dim, cube_step, cube, offset + i * cube_step[dim_idx], dim_idx - 1);
163
4
}
164
165
static int _ccv_nnc_tensor_ref_fully_assigned_with_aliases(const ccv_nnc_tensor_ref_t* const tensor_ref, const ccv_array_t* const autograd_tensor_symbols, const ccv_nnc_tensor_symbol_info_t* const tensor_symbol_info)
166
1.15k
{
167
  // Only work with tensor_ref of aliases.
168
1.15k
  assert(tensor_ref->alias_registry);
169
1.15k
  const ccv_nnc_autograd_tensor_symbol_t* autograd = (ccv_nnc_autograd_tensor_symbol_t*)ccv_array_get(autograd_tensor_symbols, tensor_ref->d);
170
1.15k
  assert(tensor_symbol_info[autograd->d].alias_ref == 0);
171
1.15k
  const int* tensor_dim = tensor_symbol_info[autograd->d].info.dim;
172
1.15k
  const int tensor_count = ccv_nnc_dimension_count(tensor_dim);
173
1.15k
  int i, j;
174
1.18k
  for (i = 0; i < tensor_ref->alias_registry->rnum; 
i++37
)
175
1.16k
  {
176
1.16k
    const int d = *(int*)ccv_array_get(tensor_ref->alias_registry, i);
177
1.16k
    assert(d < autograd_tensor_symbols->rnum);
178
1.16k
    const ccv_nnc_autograd_tensor_symbol_t* autograd = (ccv_nnc_autograd_tensor_symbol_t*)ccv_array_get(autograd_tensor_symbols, d);
179
1.16k
    assert(tensor_symbol_info[autograd->d].alias_ref);
180
1.16k
    const int* inc = tensor_symbol_info[autograd->d].inc;
181
    // If this is just reshaped (i.e., dimension is the same, and inc covers the whole). We have fully assigned.
182
1.16k
    if (memcmp(inc, tensor_symbol_info[autograd->d].info.dim, sizeof(int) * CCV_NNC_MAX_DIM_ALLOC) == 0 &&
183
1.16k
      
ccv_nnc_dimension_count(inc) == tensor_count1.12k
)
184
1.12k
      return 1;
185
    // Otherwise if inc doesn't match original dim, it is not covered.
186
37
    if (memcmp(inc, tensor_dim, sizeof(int) * CCV_NNC_MAX_DIM_ALLOC) != 0)
187
0
      return 0;
188
37
  }
189
  /* We need a solid cube (potentially hyper dimensional) to compute if there are overlaps.
190
   * To make this cube as small as possible, we need to map the actual tensor dimension
191
   * (therefore, we don't actually allocate the whole tensor to compute overlaps) to a smaller
192
   * cube given the ofs and dim size of its aliases.
193
   *
194
   * The following code generated the dimension mapping (using scratch space) with binary search + insertion
195
   * and then we fill the cube with a given tensor alias's dimensional information (ofs, dim).
196
   * Afterwards, we simply need to check if the cube is totally filled up to know if this tensor
197
   * is fully assigned with its aliases (if that is the case, we can skip zeroing for this tensor).
198
   *
199
   * There are several restrictions though to make this faster: 1). I cannot handle any cube that all side
200
   * lengths combined larger than 1023 (scm only have 1024 scratch space). 2). I cannot handle any cube
201
   * that the total volume is larger than 2048 * 8 (I only allocate 2K on stack for this).
202
   * */
203
24
  int scm[1024]; // Having 1024 int scratch space for mapping dimensions. (Or sparse coordinate mapping).
204
24
  int cube_dim[CCV_NNC_MAX_DIM_ALLOC] = {}; // Mapping dimension size.
205
24
  int cube_size = 1;
206
24
  int* scmptr = scm;
207
40
  for (i = 0; i < CCV_NNC_MAX_DIM_ALLOC && tensor_dim[i]; 
i++16
)
208
32
  {
209
32
    int head = 0, tail = 0; // Note that we touched both the head and tail (otherwise this dimension is not fully covered).
210
32
    int len = 0;
211
89
    for (j = 0; j < tensor_ref->alias_registry->rnum; 
j++57
)
212
57
    {
213
57
      const int d = *(int*)ccv_array_get(tensor_ref->alias_registry, j);
214
57
      assert(d < autograd_tensor_symbols->rnum);
215
57
      const ccv_nnc_autograd_tensor_symbol_t* autograd = (ccv_nnc_autograd_tensor_symbol_t*)ccv_array_get(autograd_tensor_symbols, d);
216
57
      assert(tensor_symbol_info[autograd->d].alias_ref);
217
57
      const int* ofs = tensor_symbol_info[autograd->d].ofs;
218
57
      const int* dim = tensor_symbol_info[autograd->d].info.dim;
219
57
      head = head || 
(ofs[i] == 0)36
;
220
57
      tail = tail || 
(ofs[i] + 39
ccv_max39
(1, dim[i]) == tensor_dim[i]);
221
57
      if (ofs[i] != 0)
222
14
        len += _ccv_nnc_try_mix(scmptr, ofs[i], len);
223
57
      if (scmptr - scm + len >= 1024) // Cannot handle that much, abort.
224
0
        return 0;
225
57
      if (ofs[i] + ccv_max(1, dim[i]) < tensor_dim[i])
226
29
        len += _ccv_nnc_try_mix(scmptr, ofs[i] + ccv_max(1, dim[i]), len);
227
57
      if (scmptr - scm + len >= 1024) // Cannot handle that much, abort.
228
0
        return 0;
229
57
    }
230
32
    if (!head || 
!tail31
)
231
16
      return 0;
232
16
    cube_size *= (len + 1);
233
16
    cube_dim[i] = len;
234
16
    scmptr += len; // Moving to next level.
235
16
  }
236
  // The cube map is too large, cannot do the computation, assume it is not fully assigned.
237
8
  if (cube_size > 2048 * 8)
238
0
    return 0;
239
  // binary map to see if it fills up.
240
8
  uint32_t cube[(cube_size + 31) >> 5];
241
8
  memset(cube, 0, sizeof(uint32_t) * ((cube_size + 31) >> 5));
242
8
  int* scmd[CCV_NNC_MAX_DIM_ALLOC] = {}; // Sparse coordinate map at dimension x.
243
8
  int cube_step[CCV_NNC_MAX_DIM_ALLOC] = {};
244
22
  for (i = 0; i < CCV_NNC_MAX_DIM_ALLOC && tensor_dim[i]; 
i++14
)
245
14
  {
246
14
    cube_step[i] = (i > 0) ? 
cube_step[i - 1] * (cube_dim[i - 1] + 1)6
:
18
;
247
14
    scmd[i] = (i > 0) ? 
scmd[i - 1] + cube_dim[i - 1]6
:
scm8
;
248
14
  }
249
8
  const int max_dim = i;
250
28
  for (i = 0; i < tensor_ref->alias_registry->rnum; 
i++20
)
251
20
  {
252
20
    const int d = *(int*)ccv_array_get(tensor_ref->alias_registry, i);
253
20
    assert(d < autograd_tensor_symbols->rnum);
254
20
    const ccv_nnc_autograd_tensor_symbol_t* autograd = (ccv_nnc_autograd_tensor_symbol_t*)ccv_array_get(autograd_tensor_symbols, d);
255
20
    assert(tensor_symbol_info[autograd->d].alias_ref);
256
20
    const int* ofs = tensor_symbol_info[autograd->d].ofs;
257
20
    const int* dim = tensor_symbol_info[autograd->d].info.dim;
258
20
    _ccv_nnc_try_set_pix(ofs, dim, tensor_dim, scmd, cube_dim, cube_step, cube, 0, max_dim - 1);
259
20
  }
260
  // Compare to see now if the binary map filled up. If it filled up, we know it is fully assigned.
261
8
  for (i = 0; i < (cube_size >> 5); 
i++0
)
262
0
    if (cube[i] < 0xffffffff)
263
0
      return 0;
264
8
  if ((cube_size & 0x1f) > 0)
265
8
  {
266
    // Fetch the rest.
267
8
    uint32_t r = 0;
268
28
    for (i = 0; i < (cube_size & 0x1f); 
i++20
)
269
20
      r |= (1u << i);
270
8
    assert(cube[((cube_size + 31) >> 5) - 1] <= r);
271
8
    if (cube[((cube_size + 31) >> 5) - 1] < r)
272
0
      return 0;
273
8
  }
274
8
  return 1;
275
8
}
276
277
static int _ccv_nnc_tensor_ref_version_find_init(const ccv_nnc_autograd_tensor_version_t* const tensor_ver)
278
5
{
279
5
  int i;
280
10
  for (i = 0; i < tensor_ver->ref_version->rnum; 
i++5
)
281
7
    if (((ccv_nnc_tensor_ref_t*)ccv_array_get(tensor_ver->ref_version, i))->x < 0)
282
2
      return i;
283
3
  return -1;
284
5
}
285
286
static void _ccv_nnc_graph_sum_autograd_tensor_versions(const int idx, const int d, const int exec_symbol_info_size, const ccv_nnc_tensor_symbol_info_t* const tensor_symbol_info, ccv_nnc_autograd_tensor_version_t* const tensor_ver, ccv_nnc_autograd_graph_exec_symbol_t* const autograd_execs, ccv_array_t* const autograd_tensor_symbols, ccv_array_t* const sum_or_set_execs)
287
4.26k
{
288
4.26k
  int i, j;
289
4.26k
  assert(tensor_ver->c < tensor_ver->ref_version->rnum);
290
4.26k
  const int input_size = tensor_ver->ref_version->rnum - tensor_ver->c;
291
4.26k
  int* inputs = (int*)ccmalloc(sizeof(int) * input_size);
292
12.8k
  for (i = tensor_ver->c; i < tensor_ver->ref_version->rnum; 
i++8.55k
)
293
8.55k
    inputs[i] = ((ccv_nnc_tensor_ref_t*)ccv_array_get(tensor_ver->ref_version, i))->d;
294
4.26k
  const ccv_nnc_autograd_tensor_symbol_t tensor_sym = {
295
4.26k
    .d = d
296
4.26k
  };
297
4.26k
  ccv_array_push(autograd_tensor_symbols, &tensor_sym);
298
4.26k
  ccv_nnc_sum_or_set_graph_exec_symbol_t sum_exec = {
299
4.26k
    .input_size = input_size,
300
4.26k
    .inputs = inputs,
301
4.26k
    .output = autograd_tensor_symbols->rnum - 1
302
4.26k
  };
303
4.26k
  if (idx >= 0)
304
4.24k
  {
305
4.24k
    sum_exec.outgoings = ccv_array_new(sizeof(int), 1, 0);
306
4.24k
    ccv_array_push(sum_exec.outgoings, &idx);
307
4.24k
  }
308
4.26k
  ccv_array_push(sum_or_set_execs, &sum_exec);
309
4.26k
  const int outgoing = exec_symbol_info_size + sum_or_set_execs->rnum - 1;
310
12.8k
  for (i = tensor_ver->c; i < tensor_ver->ref_version->rnum; 
i++8.55k
)
311
8.55k
  {
312
8.55k
    const ccv_nnc_tensor_ref_t* tensor_ref = (ccv_nnc_tensor_ref_t*)ccv_array_get(tensor_ver->ref_version, i);
313
8.55k
    const int x = tensor_ref->x;
314
8.55k
    if (x < 0) /* This is initialization tensor, it has to be occurred before the execution anyway. */
315
1
    {
316
      // No alias.
317
1
      assert(!tensor_ref->alias_registry);
318
      // No associated additional execs.
319
1
      assert(!tensor_ref->exec_registry);
320
1
      continue;
321
1
    }
322
8.55k
    if (x < exec_symbol_info_size)
323
8.55k
    {
324
8.55k
      ccv_nnc_autograd_graph_exec_symbol_t* back_exec = autograd_execs + x;
325
8.55k
      if (!back_exec->outgoings)
326
31
        back_exec->outgoings = ccv_array_new(sizeof(int), 1, 0);
327
8.55k
      ccv_array_replace_unique_int(back_exec->outgoings, idx, outgoing);
328
8.55k
    } else {
329
      // This tensor_ref is generated by the sum operation.
330
0
      ccv_nnc_sum_or_set_graph_exec_symbol_t* sum_or_set = (ccv_nnc_sum_or_set_graph_exec_symbol_t*)ccv_array_get(sum_or_set_execs, x - exec_symbol_info_size);
331
0
      ccv_array_replace_unique_int(sum_or_set->outgoings, idx, outgoing);
332
0
    }
333
    // If this tensor have associated alias, we need to init it to zeros when it is allocated (we only need to set a flag here)
334
    // it is handled at compilation phase.
335
8.55k
    if (tensor_ref->alias_registry &&
336
      // Loop over to see if this tensor is fully occupied to avoid extra zero step.
337
8.55k
      
!_ccv_nnc_tensor_ref_fully_assigned_with_aliases(tensor_ref, autograd_tensor_symbols, tensor_symbol_info)28
)
338
8
    {
339
8
      ccv_nnc_autograd_tensor_symbol_t* tensor_sym = (ccv_nnc_autograd_tensor_symbol_t*)ccv_array_get(autograd_tensor_symbols, tensor_ref->d);
340
      // By having alias_registry, what this symbol represents must not by an alias.
341
8
      assert(tensor_sym->alias_ref == 0);
342
8
      tensor_sym->flags = CCV_NNC_TENSOR_SYMBOL_INIT_ZEROS;
343
8
    }
344
8.55k
    if (tensor_ref->exec_registry)
345
4
      
for (j = 0; 2
j < tensor_ref->exec_registry->rnum;
j++2
)
346
2
      {
347
2
        const int x = *(int*)ccv_array_get(tensor_ref->exec_registry, j);
348
2
        assert(x >= 0);
349
        // The exec_registry can only be generated by alias registry, therefore, it cannot reference to a sum operation.
350
2
        assert(x < exec_symbol_info_size);
351
2
        ccv_nnc_autograd_graph_exec_symbol_t* back_exec = autograd_execs + x;
352
2
        if (!back_exec->outgoings)
353
1
          back_exec->outgoings = ccv_array_new(sizeof(int), 1, 0);
354
2
        ccv_array_replace_unique_int(back_exec->outgoings, idx, outgoing);
355
2
      }
356
8.55k
  }
357
4.26k
  const ccv_nnc_tensor_ref_t tensor_ref = {
358
4.26k
    .d = autograd_tensor_symbols->rnum - 1,
359
4.26k
    .x = outgoing
360
4.26k
  };
361
4.26k
  ccv_array_push(tensor_ver->ref_version, &tensor_ref);
362
  /* Move the c pointer up to the latest summed result. */
363
4.26k
  tensor_ver->c = tensor_ver->ref_version->rnum - 1;
364
4.26k
}
365
366
static int _ccv_nnc_tensor_ref_version_involve_alias(const ccv_nnc_tensor_ref_t* const tensor_ref, const ccv_array_t* const autograd_tensor_symbols, const ccv_nnc_tensor_symbol_info_t* const tensor_symbol_info, const ccv_nnc_tensor_symbol_info_t* const alias)
367
80
{
368
80
  assert(alias->alias_ref > 0);
369
  // No alias_registry, must conflict (owns the whole band).
370
80
  if (!tensor_ref->alias_registry)
371
28
    return 1;
372
52
  int i;
373
73
  for (i = 0; i < tensor_ref->alias_registry->rnum; 
i++21
)
374
64
  {
375
64
    const int d = *(int*)ccv_array_get(tensor_ref->alias_registry, i);
376
64
    assert(d < autograd_tensor_symbols->rnum);
377
64
    ccv_nnc_autograd_tensor_symbol_t* autograd = (ccv_nnc_autograd_tensor_symbol_t*)ccv_array_get(autograd_tensor_symbols, d);
378
64
    if (ccv_nnc_over_tensor_symbol_aliases(tensor_symbol_info + autograd->d, alias))
379
43
      return 1;
380
64
  }
381
  // All aliases referenced by this ref_version doesn't overlap with the provided one, thus, there is no conflict at all.
382
9
  return 0;
383
52
}
384
385
static int _ccv_nnc_tensor_ref_version_find_alias(const ccv_nnc_tensor_ref_t* const tensor_ref, const ccv_array_t* const autograd_tensor_symbols, const ccv_nnc_tensor_symbol_info_t* const tensor_symbol_info, const ccv_nnc_tensor_symbol_info_t* const alias)
386
27
{
387
27
  assert(alias->alias_ref > 0);
388
  // No alias_registry, thus, cannot find the exact matched alias.
389
27
  if (!tensor_ref->alias_registry)
390
8
    return -1;
391
19
  int i;
392
34
  for (i = 0; i < tensor_ref->alias_registry->rnum; 
i++15
)
393
26
  {
394
26
    const int d = *(int*)ccv_array_get(tensor_ref->alias_registry, i);
395
26
    assert(d < autograd_tensor_symbols->rnum);
396
26
    ccv_nnc_autograd_tensor_symbol_t* autograd = (ccv_nnc_autograd_tensor_symbol_t*)ccv_array_get(autograd_tensor_symbols, d);
397
    // This must reference to an alias.
398
26
    assert(tensor_symbol_info[autograd->d].alias_ref);
399
26
    const int* inc = tensor_symbol_info[autograd->d].inc;
400
26
    const int* ofs = tensor_symbol_info[autograd->d].ofs;
401
26
    const int* dim = tensor_symbol_info[autograd->d].info.dim;
402
    // If everything matches, this is the required alias.
403
26
    if (memcmp(inc, alias->inc, sizeof(int) * CCV_NNC_MAX_DIM_ALLOC) == 0 &&
404
26
      memcmp(ofs, alias->ofs, sizeof(int) * CCV_NNC_MAX_DIM_ALLOC) == 0 &&
405
26
      
memcmp(dim, alias->info.dim, sizeof(int) * 11
CCV_NNC_MAX_DIM_ALLOC11
) == 0)
406
11
      return d;
407
26
  }
408
8
  return -1;
409
19
}
410
411
static int _ccv_nnc_tensor_ref_version_has_this_alias_exclusively(const ccv_nnc_tensor_ref_t* const tensor_ref, const ccv_array_t* const autograd_tensor_symbols, const ccv_nnc_tensor_symbol_info_t* const tensor_symbol_info, const ccv_nnc_tensor_symbol_info_t* const alias)
412
4
{
413
4
  assert(alias->alias_ref > 0);
414
  // No alias_registry, thus, cannot find the exact matched alias.
415
4
  if (!tensor_ref->alias_registry)
416
0
    return 0;
417
4
  int i;
418
8
  for (i = 0; i < tensor_ref->alias_registry->rnum; 
i++4
)
419
5
  {
420
5
    const int d = *(int*)ccv_array_get(tensor_ref->alias_registry, i);
421
5
    assert(d < autograd_tensor_symbols->rnum);
422
5
    ccv_nnc_autograd_tensor_symbol_t* autograd = (ccv_nnc_autograd_tensor_symbol_t*)ccv_array_get(autograd_tensor_symbols, d);
423
    // This must reference to an alias.
424
5
    assert(tensor_symbol_info[autograd->d].alias_ref);
425
5
    const int* inc = tensor_symbol_info[autograd->d].inc;
426
5
    const int* ofs = tensor_symbol_info[autograd->d].ofs;
427
5
    const int* dim = tensor_symbol_info[autograd->d].info.dim;
428
5
    if (memcmp(inc, alias->inc, sizeof(int) * CCV_NNC_MAX_DIM_ALLOC) != 0 ||
429
5
      memcmp(ofs, alias->ofs, sizeof(int) * CCV_NNC_MAX_DIM_ALLOC) != 0 ||
430
5
      
memcmp(dim, alias->info.dim, sizeof(int) * 4
CCV_NNC_MAX_DIM_ALLOC4
) != 0)
431
1
      return 0;
432
5
  }
433
  // If everything matches for every alias in registry, we can use any of the alias directly.
434
3
  return 1;
435
4
}
436
437
static int _ccv_nnc_graph_sum_autograd_tensor_versions_alias(const int idx, const int d, const ccv_nnc_tensor_symbol_info_t* const tensor_symbol_info, const int exec_symbol_info_size, const ccv_nnc_tensor_symbol_info_t* const alias, ccv_nnc_autograd_tensor_version_t* const tensor_ver, ccv_nnc_autograd_graph_exec_symbol_t* const autograd_execs, ccv_array_t* const autograd_tensor_symbols, ccv_array_t* const sum_or_set_execs)
438
19
{
439
19
  assert(tensor_ver->c < tensor_ver->ref_version->rnum);
440
19
  int i, j = 0;
441
19
  struct {
442
19
    int k;
443
19
    int i;
444
19
  } kd[tensor_ver->ref_version->rnum - tensor_ver->c];
445
46
  for (i = tensor_ver->c; i < tensor_ver->ref_version->rnum; 
i++27
)
446
27
  {
447
27
    ccv_nnc_tensor_ref_t* tensor_ref = (ccv_nnc_tensor_ref_t*)ccv_array_get(tensor_ver->ref_version, i);
448
27
    const int k = _ccv_nnc_tensor_ref_version_find_alias(tensor_ref, autograd_tensor_symbols, tensor_symbol_info, alias);
449
27
    if (k >= 0)
450
11
      kd[j++] = (typeof(kd[0])){
451
11
        .k = k, .i = i
452
11
      };
453
16
    else if (_ccv_nnc_tensor_ref_version_involve_alias(tensor_ref, autograd_tensor_symbols, tensor_symbol_info, alias))
454
16
      kd[j++] = (typeof(kd[0])) {
455
16
        .k = -1, .i = i // It has dependency to the original tensor (non-alias) now, label this with highest bit.
456
16
      };
457
27
  }
458
  // Can only find one. This is the easy case, we can simply return that symbol (or its alias).
459
19
  if (j == 1)
460
14
  {
461
14
    if (kd[0].k >= 0)
462
4
      return kd[0].k; // Only can find one alias, that is the one.
463
    // Otherwise, need to create a new alias.
464
10
    ccv_nnc_tensor_ref_t* tensor_ref = (ccv_nnc_tensor_ref_t*)ccv_array_get(tensor_ver->ref_version, kd[0].i);
465
10
    ccv_nnc_autograd_tensor_symbol_t* ref = (ccv_nnc_autograd_tensor_symbol_t*)ccv_array_get(autograd_tensor_symbols, tensor_ref->d);
466
    // Since we create new alias, we need to set the referenced one to be allocated with 0s.
467
10
    if (ref->alias_ref) // If this is an alias, it has to be zero initialized.
468
0
    {
469
0
      ref = (ccv_nnc_autograd_tensor_symbol_t*)ccv_array_get(autograd_tensor_symbols, ref->alias_ref - 1);
470
0
      assert(ref->alias_ref == 0); // This is original.
471
0
      ref->flags = CCV_NNC_TENSOR_SYMBOL_INIT_ZEROS;
472
10
    } else if (tensor_ref->alias_registry && // Otherwise, to see if this symbol is fully occupied.
473
        // Loop over to see if this tensor is fully occupied to avoid extra zero step.
474
10
        
!_ccv_nnc_tensor_ref_fully_assigned_with_aliases(tensor_ref, autograd_tensor_symbols, tensor_symbol_info)3
) {
475
1
      ref->flags = CCV_NNC_TENSOR_SYMBOL_INIT_ZEROS;
476
1
    }
477
10
    ccv_nnc_autograd_tensor_symbol_t tensor_sym = {
478
10
      .d = d,
479
10
      .alias_ref = tensor_ref->d + 1
480
10
    };
481
10
    ccv_array_push(autograd_tensor_symbols, &tensor_sym);
482
10
    const int ad = autograd_tensor_symbols->rnum - 1;
483
10
    if (tensor_ref->alias_registry) // Only push this when it has an alias registry (otherwise it already conflict with everyone).
484
3
      ccv_array_push(tensor_ref->alias_registry, &ad);
485
    // The newly inserted tensor symbol.
486
10
    return ad;
487
10
  }
488
  // Otherwise, we need to create the sum operation out of these.
489
5
  const int input_size = j;
490
5
  int has_this_alias_exclusively = 1;
491
5
  int* inputs = input_size > 0 ? (int*)ccmalloc(sizeof(int) * input_size) : 
00
;
492
18
  for (i = 0; i < input_size; 
i++13
)
493
13
  {
494
13
    ccv_nnc_tensor_ref_t* tensor_ref = (ccv_nnc_tensor_ref_t*)ccv_array_get(tensor_ver->ref_version, kd[i].i);
495
    // Can take a fast path if every ref involved has the same alias, our sum operation can be faster (using alias directly).
496
13
    if (has_this_alias_exclusively && 
kd[i].k >= 07
&&
_ccv_nnc_tensor_ref_version_has_this_alias_exclusively(tensor_ref, autograd_tensor_symbols, tensor_symbol_info, alias)4
)
497
3
      inputs[i] = *(int*)ccv_array_get(tensor_ref->alias_registry, 0); // Assigning the alias.
498
10
    else {
499
10
      if (has_this_alias_exclusively)
500
4
      {
501
4
        has_this_alias_exclusively = 0;
502
4
        for (j = 0; j < i; 
j++0
)
503
0
          inputs[j] = ((ccv_nnc_tensor_ref_t*)ccv_array_get(tensor_ver->ref_version, kd[j].i))->d;
504
4
      }
505
10
      inputs[i] = tensor_ref->d;
506
10
    }
507
13
  }
508
5
  ccv_nnc_autograd_tensor_symbol_t tensor_sym = {
509
5
    .d = alias->alias_ref - 1
510
5
  };
511
5
  ccv_array_push(autograd_tensor_symbols, &tensor_sym);
512
5
  const int tensor_ref_d = autograd_tensor_symbols->rnum - 1;
513
5
  tensor_sym.d = d;
514
5
  tensor_sym.alias_ref = tensor_ref_d + 1;
515
5
  ccv_array_push(autograd_tensor_symbols, &tensor_sym);
516
5
  const int ad = autograd_tensor_symbols->rnum - 1;
517
5
  ccv_nnc_sum_or_set_graph_exec_symbol_t sum_exec = {
518
5
    .input_size = input_size,
519
5
    .inputs = inputs,
520
5
    .output = has_this_alias_exclusively ? 
ad1
:
tensor_ref_d4
/* If has this alias exclusively, the output should be alias as well. Otherwise the output is the real tensor. */
521
5
  };
522
5
  if (idx >= 0)
523
5
  {
524
5
    sum_exec.outgoings = ccv_array_new(sizeof(int), 1, 0);
525
5
    ccv_array_push(sum_exec.outgoings, &idx);
526
5
  }
527
5
  ccv_array_push(sum_or_set_execs, &sum_exec);
528
5
  const int outgoing = exec_symbol_info_size + sum_or_set_execs->rnum - 1;
529
5
  int no_alias_registry = 0;
530
18
  for (i = 0; i < input_size; 
i++13
)
531
13
  {
532
13
    ccv_nnc_tensor_ref_t* tensor_ref = (ccv_nnc_tensor_ref_t*)ccv_array_get(tensor_ver->ref_version, kd[i].i);
533
13
    if (!has_this_alias_exclusively)
534
10
    {
535
      // If the sum operation is not operating on one alias. I need to zero this tensor out when it is first
536
      // allocated (see discussions around the flags I use).
537
10
      ccv_nnc_autograd_tensor_symbol_t* tensor_sym = (ccv_nnc_autograd_tensor_symbol_t*)ccv_array_get(autograd_tensor_symbols, tensor_ref->d);
538
10
      if (tensor_sym->alias_ref)
539
0
      {
540
        // Find the original tensor_sym and set its flags (I prefer to set flags on its original).
541
0
        ccv_nnc_autograd_tensor_symbol_t* ref = (ccv_nnc_autograd_tensor_symbol_t*)ccv_array_get(autograd_tensor_symbols, tensor_sym->alias_ref - 1);
542
0
        assert(ref->alias_ref == 0); // This is original.
543
0
        ref->flags = CCV_NNC_TENSOR_SYMBOL_INIT_ZEROS;
544
10
      } else if (tensor_ref->alias_registry && // Otherwise, to see if this symbol is fully occupied.
545
          // Loop over to see if this tensor is fully occupied to avoid extra zero step.
546
10
          
!_ccv_nnc_tensor_ref_fully_assigned_with_aliases(tensor_ref, autograd_tensor_symbols, tensor_symbol_info)9
) {
547
6
        tensor_sym->flags = CCV_NNC_TENSOR_SYMBOL_INIT_ZEROS;
548
6
      }
549
10
    }
550
    // Check to see if any of these tensors doesn't have alias.
551
13
    no_alias_registry |= (!tensor_ref->alias_registry);
552
13
    const int x = tensor_ref->x;
553
13
    assert(x >= 0); /* Otherwise, this is initialization tensor, which is impossible to be summed up by. */
554
13
    if (x < exec_symbol_info_size)
555
13
    {
556
13
      ccv_nnc_autograd_graph_exec_symbol_t* back_exec = autograd_execs + x;
557
13
      if (!back_exec->outgoings)
558
0
        back_exec->outgoings = ccv_array_new(sizeof(int), 1, 0);
559
13
      ccv_array_push(back_exec->outgoings, &outgoing);
560
13
    } else {
561
0
      ccv_nnc_sum_or_set_graph_exec_symbol_t* sum_or_set = (ccv_nnc_sum_or_set_graph_exec_symbol_t*)ccv_array_get(sum_or_set_execs, x - exec_symbol_info_size);
562
0
      ccv_array_push(sum_or_set->outgoings, &outgoing);
563
0
    }
564
13
    if (tensor_ref->exec_registry)
565
6
      
for (j = 0; 3
j < tensor_ref->exec_registry->rnum;
j++3
)
566
3
      {
567
3
        const int x = *(int*)ccv_array_get(tensor_ref->exec_registry, j);
568
3
        assert(x >= 0); /* Otherwise, this is initialization tensor, which is impossible to be summed up by. */
569
3
        assert(x < exec_symbol_info_size); // exec_registry is only used by alias_registry, it simply cannot reference to a sum operation.
570
3
        ccv_nnc_autograd_graph_exec_symbol_t* back_exec = autograd_execs + x;
571
3
        if (!back_exec->outgoings)
572
0
          back_exec->outgoings = ccv_array_new(sizeof(int), 1, 0);
573
3
        ccv_array_push(back_exec->outgoings, &outgoing);
574
3
      }
575
13
  }
576
5
  const ccv_nnc_tensor_ref_t tensor_ref = {
577
5
    .d = tensor_ref_d,
578
5
    .x = outgoing,
579
5
    .exec_registry = 0, // I don't need to take execution dependencies because this tensor is generated by sum, therefore, we already take that dependency.
580
5
    .alias_registry = !no_alias_registry || 
has_this_alias_exclusively1
?
ccv_array_new(sizeof(int), 1, 0)4
:
01
581
5
  };
582
  // If there is no alias registry, then we take the whole tensor ref as one.
583
5
  if (!no_alias_registry || 
has_this_alias_exclusively1
)
584
4
  {
585
    // If this tensor ref contains multiple different types of alias, have to add them together (otherwise
586
    // the computation for if there is an empty slot in this tensor ref is not correct without all the
587
    // occupancy availability information).
588
4
    if (!has_this_alias_exclusively)
589
10
      
for (i = 0; 3
i < input_size;
i++7
)
590
7
      {
591
7
        ccv_nnc_tensor_ref_t* ref = (ccv_nnc_tensor_ref_t*)ccv_array_get(tensor_ver->ref_version, kd[i].i);
592
7
        assert(ref->alias_registry);
593
        // It may get duplicates. But whatever, won't matter the computation.
594
19
        
for (j = 0; 7
j < ref->alias_registry->rnum;
j++12
)
595
12
          ccv_array_push(tensor_ref.alias_registry, ccv_array_get(ref->alias_registry, j));
596
7
      }
597
4
    ccv_array_push(tensor_ref.alias_registry, &ad);
598
4
  }
599
5
  assert(input_size <= tensor_ver->ref_version->rnum - tensor_ver->c);
600
5
  ccv_nnc_tensor_ref_t x;
601
18
  for (i = 0; i < input_size; 
i++13
)
602
    // If the current one (i + tensor_ver->c) is smaller than the one referenced to, exchange.
603
13
    if (kd[i].i > i + tensor_ver->c)
604
0
      CCV_SWAP(*(ccv_nnc_tensor_ref_t*)ccv_array_get(tensor_ver->ref_version, i + tensor_ver->c), *(ccv_nnc_tensor_ref_t*)ccv_array_get(tensor_ver->ref_version, kd[i].i), x);
605
5
  ccv_array_push(tensor_ver->ref_version, &tensor_ref);
606
  // We've consumed input_size tensor refs, now move c up to the pointer of non-consumed tensors.
607
5
  tensor_ver->c += input_size;
608
5
  return ad;
609
5
}
610
611
typedef struct ccv_nnc_symbolic_graph_backward_prep_s {
612
  int exec_symbol_info_size; // Number of graph exec symbols before adding any new symbols related to automatic differentiation.
613
  int tensor_symbol_info_size; // Number of tensor symbols before adding anything new.
614
  int sub_prep_size;
615
  ccv_nnc_graph_exec_symbol_info_t* exec_symbol_info;
616
  ccv_nnc_tensor_symbol_info_t* tensor_symbol_info;
617
  ccv_nnc_graph_backward_info_t* backward_info; // Corresponding to forward graph exec symbol info, it is exactly in reverse.
618
  ccv_nnc_graph_visit_t* forward_visit; // The visitor structure (top sorted index) when doing traversal.
619
  ccv_nnc_graph_visit_t* backward_visit; // The visitor structure (top sorted index) when doing reverse traversal.
620
  ccv_nnc_autograd_graph_exec_symbol_t* autograd_execs; // The graph exec symbols we need for automatic differentiation. This is a 1:1 mapping for forward graph exec symbols, however, unlike backward_info, its outgoings may be more complex (may contain outgoing flows to sum nodes).
621
  ccv_nnc_autograd_tensor_version_t* autograd_tensor_versions; // Corresponding to forward tensor symbols, each may contain multiple versions (due to multi-write).
622
  ccv_array_t* autograd_tensor_symbols; // The tensor symbols we need for automatic differentiation (it may not be 1:1 mapping).
623
  ccv_array_t* sum_or_set_execs; // The sum nodes, because in reverse mode, a tensor could have multiple versions, we need to sum them up before use.
624
  struct ccv_nnc_symbolic_graph_backward_prep_s* sub_preps; // The preps of its sub-graphs.
625
  // Pointers not managed by this struct
626
  ccv_nnc_symbolic_graph_t* graph;
627
} ccv_nnc_symbolic_graph_backward_prep_t;
628
629
static ccv_nnc_symbolic_graph_backward_prep_t _ccv_nnc_symbolic_graph_backward_prep(const ccv_nnc_symbolic_graph_t* const graph, const ccv_nnc_graph_exec_symbol_t* const sources, const int source_size, const ccv_nnc_graph_exec_symbol_t* const destinations, const int destination_size)
630
6.72k
{
631
6.72k
  const int exec_symbol_info_size = graph->exec_symbol_info->rnum;
632
6.72k
  assert(exec_symbol_info_size > 0);
633
6.72k
  const int tensor_symbol_info_size = graph->tensor_symbol_info->rnum;
634
6.72k
  assert(tensor_symbol_info_size > 0);
635
6.72k
  ccv_nnc_graph_exec_symbol_info_t* exec_symbol_info = (ccv_nnc_graph_exec_symbol_info_t*)ccmalloc(sizeof(ccv_nnc_graph_exec_symbol_info_t) * exec_symbol_info_size);
636
6.72k
  ccv_nnc_tensor_symbol_info_t* tensor_symbol_info = (ccv_nnc_tensor_symbol_info_t*)ccmalloc(sizeof(ccv_nnc_tensor_symbol_info_t) * tensor_symbol_info_size);
637
13.4k
  ccv_nnc_graph_visit_t* forward_visit = 
ccv_nnc_graph_visit_new6.72k
(graph, (ccv_nnc_graph_exec_symbol_info_t*)ccv_array_get(graph->exec_symbol_info, 0), exec_symbol_info_size, sources, source_size, destinations, destination_size, 0);
638
0
  ccv_nnc_symbolic_graph_symbol_infer(graph, forward_visit, sources, source_size, destinations, destination_size, 0, 0, tensor_symbol_info, exec_symbol_info);
639
13.4k
  int i;
640
  // Now, for each one of these, find a reverse graph.
641
13.4k
  ccv_nnc_graph_backward_info_t* backward_info = (ccv_nnc_graph_backward_info_t*)
cccalloc6.72k
(exec_symbol_info_size, sizeof(ccv_nnc_graph_backward_info_t));
642
18.0k
  ccv_nnc_graph_visit_for(forward_visit, exec_symbol_info, node, idx) {
643
18.0k
    assert(ccv_nnc_cmd_is_forward(node->cmd) || node->cmd.cmd == CCV_NNC_NOOP);
644
18.0k
    if (node->outgoings)
645
22.6k
      
for (i = 0; 11.3k
i < node->outgoings->rnum;
i++11.3k
)
646
11.3k
      {
647
11.3k
        int d = *(int*)ccv_array_get(node->outgoings, i);
648
11.3k
        if (backward_info[d].outgoings == 0)
649
11.2k
          backward_info[d].outgoings = ccv_array_new(sizeof(int32_t), 1, 0);
650
11.3k
        ccv_array_push(backward_info[d].outgoings, &idx);
651
11.3k
      }
652
18.0k
  } ccv_nnc_graph_visit_endfor
653
  // Also mark only the output bits that we use.
654
24.8k
  
for (i = 0; 6.72k
i < exec_symbol_info_size;
i++18.1k
)
655
18.1k
  {
656
18.1k
    backward_info[i].input_bitmask_size = ((exec_symbol_info[i].output_size * 2 + exec_symbol_info[i].input_size + 63) >> 6);
657
18.1k
    backward_info[i].output_bitmask_size = ((exec_symbol_info[i].input_size + 63) >> 6);
658
    // Allocate input / output bitmasks
659
18.1k
    if (backward_info[i].input_bitmask_size + backward_info[i].output_bitmask_size > 0)
660
18.0k
    {
661
18.0k
      backward_info[i].input_bitmasks = (uint64_t*)cccalloc(backward_info[i].input_bitmask_size + backward_info[i].output_bitmask_size, sizeof(uint64_t));
662
18.0k
      if (backward_info[i].output_bitmask_size)
663
18.0k
        backward_info[i].output_bitmasks = backward_info[i].input_bitmasks + backward_info[i].input_bitmask_size;
664
18.0k
    }
665
18.1k
  }
666
6.72k
  ccv_nnc_graph_visit_t* backward_visit = ccv_nnc_graph_visit_new(graph, backward_info, exec_symbol_info_size, destinations, destination_size, sources, source_size, 0);
667
6.72k
  const int sub_prep_size = graph->sub_graphs ? 
graph->sub_graphs->rnum2
:
06.72k
;
668
6.72k
  ccv_nnc_symbolic_graph_backward_prep_t* sub_preps = sub_prep_size > 0 ? 
(ccv_nnc_symbolic_graph_backward_prep_t*)2
cccalloc2
(sub_prep_size, sizeof(ccv_nnc_symbolic_graph_backward_prep_t)) :
06.72k
;
669
6.73k
  for (i = 0; i < sub_prep_size; 
i++4
)
670
4
  {
671
4
    const ccv_nnc_symbolic_graph_t* const sub_graph = *(ccv_nnc_symbolic_graph_t**)ccv_array_get(graph->sub_graphs, i);
672
4
    sub_preps[i] = _ccv_nnc_symbolic_graph_backward_prep(sub_graph, ccv_nnc_symbolic_graph_sources(sub_graph), ccv_nnc_symbolic_graph_source_size(sub_graph), ccv_nnc_symbolic_graph_destinations(sub_graph), ccv_nnc_symbolic_graph_destination_size(sub_graph));
673
4
  }
674
6.72k
  return (ccv_nnc_symbolic_graph_backward_prep_t){
675
6.72k
    .exec_symbol_info_size = exec_symbol_info_size,
676
6.72k
    .tensor_symbol_info_size = tensor_symbol_info_size,
677
6.72k
    .sub_prep_size = sub_prep_size,
678
6.72k
    .exec_symbol_info = exec_symbol_info,
679
6.72k
    .tensor_symbol_info = tensor_symbol_info,
680
6.72k
    .backward_info = backward_info,
681
6.72k
    .forward_visit = forward_visit,
682
6.72k
    .backward_visit = backward_visit,
683
6.72k
    .sub_preps = sub_preps,
684
6.72k
    .graph = (ccv_nnc_symbolic_graph_t*)graph,
685
6.72k
  };
686
6.72k
}
687
688
static void _ccv_nnc_symbolic_graph_backward_exec_io(const ccv_nnc_graph_exec_symbol_info_t* const node, int** const back_input_map, int** const back_output_map, int* const back_input_size, int* const back_output_size)
689
18.0k
{
690
18.0k
  int i;
691
18.0k
  if (node->flags & CCV_NNC_GRAPH_EXEC_CASE_OF)
692
7
  {
693
7
    *back_input_map = node->outputs;
694
7
    *back_input_size = node->output_size;
695
14
    for (i = 0; i < node->case_of.argument.offset; 
i++7
)
696
7
      (*back_output_map)[i] = node->inputs[i];
697
7
    const int argument_offset = node->case_of.argument.offset;
698
7
    const int argument_size = node->case_of.argument.size;
699
    // Skip the argument range.
700
7
    for (i = argument_offset + argument_size; i < node->input_size; 
i++0
)
701
0
      (*back_output_map)[i - argument_size] = node->inputs[i];
702
7
    *back_output_size = node->input_size - node->case_of.argument.size;
703
18.0k
  } else { // if (node->flags & CCV_NNC_GRAPH_EXEC_P_WHILE) {
704
18.0k
    *back_input_map = node->outputs;
705
18.0k
    *back_input_size = node->output_size;
706
18.0k
    *back_output_map = node->inputs;
707
18.0k
    *back_output_size = node->input_size;
708
18.0k
  }
709
18.0k
}
710
711
static void _ccv_nnc_symbolic_graph_backward_prep_sub_f_wrt_symbols(const ccv_nnc_graph_exec_symbol_info_t* const forw_exec, const ccv_nnc_symbolic_graph_t* const sub_graph, const int graph_ref, const ccv_nnc_tensor_symbol_info_t* const tensor_symbol_info, const uint64_t* const input_bitmasks, const uint64_t* const output_bitmasks, ccv_array_t* const sub_f_symbols, ccv_array_t* const sub_wrt_symbols)
712
8
{
713
8
  int i, j;
714
8
  ccv_array_clear(sub_wrt_symbols);
715
8
  int forw_outputs[ccv_max(1, forw_exec->output_size)];
716
8
  int forw_inputs[ccv_max(1, forw_exec->input_size)];
717
8
  int* back_input_map = forw_outputs;
718
8
  int* back_output_map = forw_inputs;
719
8
  int back_input_size, back_output_size;
720
8
  _ccv_nnc_symbolic_graph_backward_exec_io(forw_exec, &back_input_map, &back_output_map, &back_input_size, &back_output_size);
721
18
  for (i = 0; i < back_output_size; 
i++10
)
722
10
    if (output_bitmasks[i >> 6] & ((uint64_t)1 << (i & 63)))
723
8
    {
724
8
      const int d = back_output_map[i];
725
8
      const ccv_array_t* const s_refs = tensor_symbol_info[d].s_ref;
726
8
      const int s_ref = s_refs && s_refs->rnum > graph_ref ? 
*(int*)7
ccv_array_get7
(s_refs, graph_ref) - 1 :
-11
;
727
8
      if (s_ref >= 0)
728
4
      {
729
4
        ccv_nnc_tensor_symbol_t sub_wrt_symbol = {
730
4
          .d = s_ref,
731
4
          .graph = sub_graph,
732
4
        };
733
4
        ccv_array_push(sub_wrt_symbols, &sub_wrt_symbol);
734
4
      } else
735
4
        ccv_array_push(sub_wrt_symbols, &NO_TENSOR_SYMBOL);
736
8
    }
737
8
  ccv_array_clear(sub_f_symbols);
738
16
  for (i = 0; i < back_input_size; 
i++8
)
739
8
    if (input_bitmasks[i >> 6] & ((uint64_t)1 << (i & 63)))
740
8
    {
741
8
      const int d = back_input_map[i];
742
8
      ccv_nnc_tensor_symbol_t sub_f_symbol = {
743
8
        .d = *(int*)ccv_array_get(tensor_symbol_info[d].s_ref, graph_ref) - 1,
744
8
        .graph = sub_graph,
745
8
      };
746
8
      ccv_array_push(sub_f_symbols, &sub_f_symbol);
747
8
    }
748
  // Go through all its assignments (parameterized loop), making them either wrt or f.
749
  // The reason is these must flow through the graph, otherwise we cannot form a full
750
  // enclosed loop. Also because they are the additional f / wrt symbols, there is
751
  // no case that we cannot find their corresponding gradients in the backward sub graphs
752
  // (these gradients have to be parameterized to form an enclosed loop as well).
753
30
  for (i = 0; i < sub_graph->tensor_symbol_info->rnum; 
i++22
)
754
22
  {
755
22
    const ccv_nnc_tensor_symbol_info_t* const tensor_symbol_info = (ccv_nnc_tensor_symbol_info_t*)ccv_array_get(sub_graph->tensor_symbol_info, i);
756
22
    if (tensor_symbol_info->assign_ref)
757
2
    {
758
2
      const int assign_ref = tensor_symbol_info->assign_ref - 1;
759
      // i is the wrt, assign_ref is the f.
760
2
      int flag = 0;
761
4
      for (j = 0; !flag && j < sub_wrt_symbols->rnum; 
j++2
)
762
2
        flag = (((ccv_nnc_tensor_symbol_t*)ccv_array_get(sub_wrt_symbols, j))->d == i);
763
2
      if (!flag)
764
2
      {
765
2
        ccv_nnc_tensor_symbol_t sub_wrt_symbol = {
766
2
          .d = i,
767
2
          .graph = sub_graph,
768
2
        };
769
2
        ccv_array_push(sub_wrt_symbols, &sub_wrt_symbol);
770
2
      }
771
2
      flag = 0;
772
4
      for (j = 0; !flag && 
j < sub_f_symbols->rnum2
;
j++2
)
773
2
        flag = (((ccv_nnc_tensor_symbol_t*)ccv_array_get(sub_f_symbols, j))->d == assign_ref);
774
2
      if (!flag)
775
0
      {
776
0
        ccv_nnc_tensor_symbol_t sub_f_symbol = {
777
0
          .d = assign_ref,
778
0
          .graph = sub_graph,
779
0
        };
780
0
        ccv_array_push(sub_f_symbols, &sub_f_symbol);
781
0
      }
782
2
    }
783
22
  }
784
8
}
785
786
// Check whether for a given f_symbol, we can compute wrt_symbols at all, if we can, tag the minimal io and ops (some ops can be replaced with noop) required to do so.
787
static int _ccv_nnc_symbolic_graph_backward_prep_prune_ops(const ccv_nnc_symbolic_graph_backward_prep_t* const backward_prep, const ccv_nnc_tensor_symbol_t* const f_symbols, const int f_symbol_size, const ccv_nnc_tensor_symbol_t* const wrt_symbols, const int wrt_symbol_size, const ccv_nnc_graph_exec_symbol_t* const sources, const int source_size, const ccv_nnc_graph_exec_symbol_t* const destinations, const int destination_size)
788
6.72k
{
789
6.72k
  int i, j, p;
790
6.72k
  const int tensor_symbol_info_size = backward_prep->tensor_symbol_info_size;
791
6.72k
  const ccv_nnc_graph_exec_symbol_info_t* const exec_symbol_info = backward_prep->exec_symbol_info;
792
6.72k
  const ccv_nnc_tensor_symbol_info_t* const tensor_symbol_info =backward_prep->tensor_symbol_info;
793
6.72k
  const ccv_nnc_graph_visit_t* const forward_visit = backward_prep->forward_visit;
794
  // Now, for each one of these, find a reverse graph.
795
6.72k
  ccv_nnc_graph_backward_info_t* const backward_info = backward_prep->backward_info;
796
6.72k
  const ccv_nnc_graph_visit_t* const backward_visit = backward_prep->backward_visit;
797
  // Find the f_symbols, and tag its flows.
798
18.0k
  ccv_nnc_graph_visit_for(backward_visit, backward_info, node, idx) {
799
18.0k
    int f = node->f_wrt & 0x1;
800
25.0k
    for (i = 0; i < exec_symbol_info[idx].output_size && 
!f18.3k
;
i++6.99k
)
801
6.99k
    {
802
6.99k
      int d = exec_symbol_info[idx].outputs[i];
803
6.99k
      if (d < 0)
804
206
        continue;
805
6.78k
      
while (6.78k
tensor_symbol_info[d].alias_ref)
806
3
        d = tensor_symbol_info[d].alias_ref - 1;
807
13.5k
      for (j = 0; j < f_symbol_size && 
!f6.81k
;
j++6.80k
)
808
6.80k
        if (d == f_symbols[j].d)
809
6.74k
          f = 1;
810
6.78k
    }
811
18.0k
    if (f)
812
18.0k
    {
813
18.0k
      node->f_wrt |= f;
814
18.0k
      if (node->outgoings)
815
22.6k
        
for (i = 0; 11.2k
i < node->outgoings->rnum;
i++11.3k
)
816
11.3k
        {
817
11.3k
          int d = *(int*)ccv_array_get(node->outgoings, i);
818
11.3k
          backward_info[d].f_wrt |= f;
819
11.3k
        }
820
18.0k
    }
821
18.0k
  } ccv_nnc_graph_visit_endfor
822
  // Find the wrt_symbols, and tag its flows.
823
18.0k
  ccv_nnc_graph_visit_for(forward_visit, exec_symbol_info, node, idx) {
824
18.0k
    int wrt = backward_info[idx].f_wrt & 0x2;
825
28.9k
    for (i = 0; i < node->input_size && 
!wrt26.4k
;
i++10.8k
)
826
10.8k
    {
827
10.8k
      int d = node->inputs[i];
828
10.8k
      if (d < 0)
829
0
        continue;
830
10.8k
      
while (10.8k
tensor_symbol_info[d].alias_ref)
831
6
        d = tensor_symbol_info[d].alias_ref - 1;
832
24.4k
      for (j = 0; j < wrt_symbol_size && 
!wrt13.6k
;
j++13.6k
)
833
13.6k
      {
834
13.6k
        int wrt_d = wrt_symbols[j].d;
835
        // Find the root of this tensor alias.
836
13.6k
        if (tensor_symbol_info[wrt_d].alias_ref)
837
2
          wrt_d = tensor_symbol_info[wrt_d].alias_ref - 1;
838
13.6k
        if (d == wrt_d)
839
6.79k
          wrt = 0x2;
840
13.6k
      }
841
10.8k
    }
842
18.0k
    if (wrt)
843
18.0k
    {
844
18.0k
      backward_info[idx].f_wrt |= wrt;
845
18.0k
      if (node->outgoings)
846
22.6k
        
for (i = 0; 11.3k
i < node->outgoings->rnum;
i++11.3k
)
847
11.3k
        {
848
11.3k
          int d = *(int*)ccv_array_get(node->outgoings, i);
849
11.3k
          backward_info[d].f_wrt |= wrt;
850
11.3k
        }
851
18.0k
    }
852
18.0k
  } ccv_nnc_graph_visit_endfor
853
6.72k
  enum {
854
6.72k
    WRT_SYMBOL_USE = 1,
855
6.72k
    F_SYMBOL_USE = 2
856
6.72k
  };
857
6.72k
  uint8_t* used_grad = (uint8_t*)cccalloc(tensor_symbol_info_size, sizeof(uint8_t));
858
  // First, all f_symbols and wrt_symbols are used.
859
13.4k
  for (i = 0; i < f_symbol_size; 
i++6.74k
)
860
6.74k
    if (f_symbols[i].d >= 0)
861
6.74k
      used_grad[tensor_symbol_info[f_symbols[i].d].alias_ref ? 
tensor_symbol_info[f_symbols[i].d].alias_ref - 10
: f_symbols[i].d] |= F_SYMBOL_USE;
862
16.2k
  for (i = 0; i < wrt_symbol_size; 
i++9.47k
)
863
9.47k
    if (wrt_symbols[i].d >= 0)
864
9.47k
      used_grad[tensor_symbol_info[wrt_symbols[i].d].alias_ref ? 
tensor_symbol_info[wrt_symbols[i].d].alias_ref - 11
:
wrt_symbols[i].d9.47k
] |= WRT_SYMBOL_USE;
865
  // Do optimistic assumption, and then compute used_grad
866
18.0k
  ccv_nnc_graph_visit_for(forward_visit, exec_symbol_info, _, idx) {
867
18.0k
    ccv_nnc_graph_backward_info_t* node = backward_info + idx;
868
    /* Only interested in the ones on the f / wrt flow */
869
18.0k
    if ((node->f_wrt & 0x3) == 0x3)
870
18.0k
    {
871
18.0k
      const ccv_nnc_graph_exec_symbol_info_t* forw_exec = exec_symbol_info + idx;
872
18.0k
      ccv_nnc_cmd_t cmd = forw_exec->cmd;
873
18.0k
      if (cmd.cmd != CCV_NNC_NOOP)
874
18.0k
        cmd.cmd += 1; /* Backward command is the one after forward command. */
875
18.0k
      assert(ccv_nnc_cmd_is_backward(cmd) || cmd.cmd == CCV_NNC_NOOP);
876
88.5k
      
for (i = 0; 18.0k
i < forw_exec->output_size * 2 + forw_exec->input_size;
i++70.5k
)
877
70.5k
        if (!(i >= forw_exec->output_size && 
i < forw_exec->output_size + forw_exec->input_size52.1k
&&
878
70.5k
          
forw_exec->inputs[i - forw_exec->output_size] < 033.6k
) && // If the input is empty, no need.
879
70.5k
          
!(70.5k
i >= forw_exec->output_size + forw_exec->input_size70.5k
&&
i < forw_exec->output_size * 2 + forw_exec->input_size18.4k
&&
880
70.5k
          
forw_exec->outputs[i - forw_exec->output_size - forw_exec->input_size] < 018.4k
) && // If the output is empty, no need.
881
70.5k
          
!(70.3k
i < forw_exec->output_size70.3k
&&
forw_exec->outputs[i] < 018.4k
)) // If the output is empty for gradient, no need.
882
70.0k
          node->input_bitmasks[i >> 6] |= ((uint64_t)1 << (i & 63));
883
51.7k
      for (i = 0; i < forw_exec->input_size; 
i++33.6k
)
884
33.6k
        if (!(forw_exec->inputs[i] < 0)) // If the inputs is empty, no need.
885
33.6k
          node->output_bitmasks[i >> 6] |= ((uint64_t)1 << (i & 63));
886
18.0k
      int maybe_noop = 1;
887
22.0k
      for (i = 0; i < forw_exec->input_size; 
i++4.03k
)
888
        /* See if it is used as wrt, if not, no need to run this node at all. */
889
22.0k
        if (forw_exec->inputs[i] >= 0 && 
used_grad[22.0k
tensor_symbol_info[forw_exec->inputs[i]].alias_ref22.0k
?
tensor_symbol_info[forw_exec->inputs[i]].alias_ref - 11.12k
:
forw_exec->inputs[i]20.9k
] & WRT_SYMBOL_USE)
890
18.0k
        {
891
18.0k
          maybe_noop = 0;
892
18.0k
          break;
893
18.0k
        }
894
18.0k
      if (maybe_noop)
895
0
      {
896
0
        for (i = 0; i < node->input_bitmask_size; i++)
897
0
          node->input_bitmasks[i] = 0;
898
0
        for (i = 0; i < node->output_bitmask_size; i++)
899
0
          node->output_bitmasks[i] = 0;
900
0
        node->output_bitmask_size = 0;
901
18.0k
      } else if (cmd.cmd == CCV_NNC_GRAPH_FORWARD || cmd.cmd == CCV_NNC_GRAPH_BACKWARD) {
902
        // Clear out all potential outputs if we think it is not a wrt symbols.
903
6
        for (i = 0; i < forw_exec->input_size; 
i++4
)
904
4
          if ((node->output_bitmasks[i >> 6] & ((uint64_t)1 << (i & 63))) &&
905
4
            !(used_grad[tensor_symbol_info[forw_exec->inputs[i]].alias_ref ? 
tensor_symbol_info[forw_exec->inputs[i]].alias_ref - 10
: forw_exec->inputs[i]] & WRT_SYMBOL_USE))
906
1
            node->output_bitmasks[i >> 6] &= ~((uint64_t)1 << (i & 63));
907
        // But for now, assuming we need all input gradients.
908
        // Clear out all inputs / outputs from forward op.
909
8
        for (i = forw_exec->output_size; i < forw_exec->output_size * 2 + forw_exec->input_size; 
i++6
)
910
6
          node->input_bitmasks[i >> 6] &= ~((uint64_t)1 << (i & 63));
911
18.0k
      } else if (ccv_nnc_cmd_bitmask(cmd, forw_exec->output_size * 2 + forw_exec->input_size, forw_exec->input_size, node->input_bitmasks, node->input_bitmask_size, node->output_bitmasks, node->output_bitmask_size)) {
912
15.6k
        int flag; /* Only continue if it changed */
913
29.0k
        do {
914
29.0k
          flag = 0;
915
          /* Check if the output first */
916
87.0k
          for (i = 0; i < forw_exec->input_size; 
i++58.0k
)
917
            /* Only try to eliminate the one that is not used. */
918
58.0k
            if ((node->output_bitmasks[i >> 6] & ((uint64_t)1 << (i & 63))) &&
919
58.0k
              
!(used_grad[51.4k
tensor_symbol_info[forw_exec->inputs[i]].alias_ref51.4k
?
tensor_symbol_info[forw_exec->inputs[i]].alias_ref - 1327
:
forw_exec->inputs[i]51.1k
] & WRT_SYMBOL_USE))
920
10.6k
            {
921
10.6k
              node->output_bitmasks[i >> 6] &= ~((uint64_t)1 << (i & 63));
922
              /* If it worked, mark it as flagged. */
923
10.6k
              if (ccv_nnc_cmd_bitmask(cmd, forw_exec->output_size * 2 + forw_exec->input_size, forw_exec->input_size, node->input_bitmasks, node->input_bitmask_size, node->output_bitmasks, node->output_bitmask_size))
924
6.54k
                flag = 1;
925
4.06k
              else /* Refit this with the bit back again. */
926
4.06k
                node->output_bitmasks[i >> 6] |= ((uint64_t)1 << (i & 63));
927
10.6k
            }
928
146k
          for (i = 0; i < forw_exec->output_size * 2 + forw_exec->input_size; 
i++117k
)
929
117k
            if ((node->input_bitmasks[i >> 6] & ((uint64_t)1 << (i & 63))) &&
930
117k
              
(92.0k
i >= forw_exec->output_size92.0k
||
931
92.0k
               
!(used_grad[29.2k
tensor_symbol_info[forw_exec->outputs[i]].alias_ref29.2k
?
tensor_symbol_info[forw_exec->outputs[i]].alias_ref - 137
:
forw_exec->outputs[i]29.1k
] & F_SYMBOL_USE)))
932
78.7k
            { /* Try to eliminate one of the input. */
933
78.7k
              node->input_bitmasks[i >> 6] &= ~((uint64_t)1 << (i & 63));
934
              /* If it worked, mark it as flagged. */
935
78.7k
              if (ccv_nnc_cmd_bitmask(cmd, forw_exec->output_size * 2 + forw_exec->input_size, forw_exec->input_size, node->input_bitmasks, node->input_bitmask_size, node->output_bitmasks, node->output_bitmask_size))
936
24.7k
                flag = 1;
937
53.9k
              else /* Refit this with the bit back again. */
938
53.9k
                node->input_bitmasks[i >> 6] |= ((uint64_t)1 << (i & 63));
939
78.7k
            }
940
29.0k
        } while (flag);
941
15.6k
      }
942
36.4k
      for (i = 0; i < forw_exec->output_size; 
i++18.4k
)
943
18.4k
        if (node->input_bitmasks[i >> 6] & ((uint64_t)1 << (i & 63)))
944
          /* Mark it is used as wrt. */
945
18.0k
          used_grad[tensor_symbol_info[forw_exec->outputs[i]].alias_ref ? 
tensor_symbol_info[forw_exec->outputs[i]].alias_ref - 119
:
forw_exec->outputs[i]18.0k
] |= WRT_SYMBOL_USE;
946
51.7k
      for (i = 0; i < forw_exec->input_size; 
i++33.6k
)
947
          /* Mark it is used as f. */
948
33.6k
        if (node->output_bitmasks[i >> 6] & ((uint64_t)1 << (i & 63)))
949
27.1k
          used_grad[tensor_symbol_info[forw_exec->inputs[i]].alias_ref ? 
tensor_symbol_info[forw_exec->inputs[i]].alias_ref - 11.16k
:
forw_exec->inputs[i]25.9k
] |= F_SYMBOL_USE;
950
18.0k
    }
951
18.0k
  } ccv_nnc_graph_visit_endfor
952
6.72k
  ccv_array_t* sub_f_symbols = 0;
953
6.72k
  ccv_array_t* sub_wrt_symbols = 0;
954
18.0k
  ccv_nnc_graph_visit_for(forward_visit, exec_symbol_info, _, idx) {
955
18.0k
    ccv_nnc_graph_backward_info_t* node = backward_info + idx;
956
18.0k
    const ccv_nnc_graph_exec_symbol_info_t* forw_exec = exec_symbol_info + idx;
957
    /* Only interested in the ones on the f / wrt flow */
958
18.0k
    if ((node->f_wrt & 0x3) == 0x3 && 
forw_exec->graph_ref_size > 018.0k
)
959
2
    {
960
2
      uint64_t stack_input_bitmasks1[node->input_bitmask_size];
961
2
      uint64_t stack_input_bitmasks2[node->input_bitmask_size];
962
2
      uint64_t* const input_bitmasks = forw_exec->graph_ref_size > 1 ? 
stack_input_bitmasks11
:
node->input_bitmasks1
;
963
      // We collect input masks into this location.
964
2
      if (forw_exec->graph_ref_size > 1)
965
1
        memset(stack_input_bitmasks2, 0, sizeof(uint64_t) * node->input_bitmask_size);
966
6
      for (p = 0; p < forw_exec->graph_ref_size; 
p++4
)
967
4
      {
968
        // Reset the stack input bitmasks.
969
4
        if (forw_exec->graph_ref_size > 1)
970
3
          memcpy(stack_input_bitmasks1, node->input_bitmasks, sizeof(uint64_t) * node->input_bitmask_size);
971
        // Now calling it recursively until we are sure no f_symbols can be removed.
972
4
        const int graph_ref = CCV_NNC_GRAPH_REF(forw_exec)[p] - 1;
973
4
        ccv_nnc_symbolic_graph_backward_prep_t* const sub_prep = backward_prep->sub_preps + graph_ref;
974
4
        if (!sub_wrt_symbols)
975
2
          sub_wrt_symbols = ccv_array_new(sizeof(ccv_nnc_tensor_symbol_t), 0, 0);
976
2
        else
977
2
          ccv_array_clear(sub_wrt_symbols);
978
12
        for (i = 0; i < forw_exec->input_size; 
i++8
)
979
8
          if (node->output_bitmasks[i >> 6] & ((uint64_t)1 << (i & 63)))
980
7
          {
981
7
            const ccv_array_t* const s_refs = tensor_symbol_info[forw_exec->inputs[i]].s_ref;
982
7
            const int s_ref = s_refs && s_refs->rnum > graph_ref ? 
*(int*)5
ccv_array_get5
(s_refs, graph_ref) - 1 :
-12
;
983
7
            if (s_ref >= 0)
984
3
            {
985
3
              ccv_nnc_tensor_symbol_t sub_wrt_symbol = {
986
3
                .d = s_ref,
987
3
                .graph = sub_prep->graph,
988
3
              };
989
3
              ccv_array_push(sub_wrt_symbols, &sub_wrt_symbol);
990
3
            }
991
7
          }
992
4
        int flag; // Only continue if it changed */
993
4
        do {
994
4
          flag = 0;
995
8
          for (i = 0; i < forw_exec->output_size; 
i++4
)
996
            // Try to reduce number of inputs for the backward graph. If it is not tagged as F_SYMBOL_USE, we can reduce it.
997
            // It is reducible because this sub graph may have multiple computation paths, therefore, some of these may not
998
            // involve our wrt symbols at all.
999
4
            if (!(used_grad[tensor_symbol_info[forw_exec->outputs[i]].alias_ref ? 
tensor_symbol_info[forw_exec->outputs[i]].alias_ref - 10
: forw_exec->outputs[i]] & F_SYMBOL_USE) &&
1000
4
              
input_bitmasks[i >> 6] & ((uint64_t)1 << (i & 63))0
)
1001
0
            { /* Try to eliminate one of the input. */
1002
0
              input_bitmasks[i >> 6] &= ~((uint64_t)1 << (i & 63));
1003
0
              if (!sub_f_symbols)
1004
0
                sub_f_symbols = ccv_array_new(sizeof(ccv_nnc_tensor_symbol_t), 0, 0);
1005
0
              else
1006
0
                ccv_array_clear(sub_f_symbols);
1007
0
              for (j = 0; j < forw_exec->output_size; j++)
1008
0
                if (node->input_bitmasks[j >> 6] & ((uint64_t)1 << (j & 63)))
1009
0
                {
1010
0
                  const int s_ref = *(int*)ccv_array_get(tensor_symbol_info[forw_exec->outputs[j]].s_ref, graph_ref) - 1;
1011
0
                  assert(s_ref >= 0);
1012
0
                  ccv_nnc_tensor_symbol_t sub_f_symbol = {
1013
0
                    .d = s_ref,
1014
0
                    .graph = sub_prep->graph,
1015
0
                  };
1016
0
                  ccv_array_push(sub_f_symbols, &sub_f_symbol);
1017
0
                }
1018
0
              if (_ccv_nnc_symbolic_graph_backward_prep_prune_ops(sub_prep, (ccv_nnc_tensor_symbol_t*)ccv_array_get(sub_f_symbols, 0), sub_f_symbols->rnum, (ccv_nnc_tensor_symbol_t*)ccv_array_get(sub_wrt_symbols, 0), sub_wrt_symbols->rnum, ccv_nnc_symbolic_graph_sources(sub_prep->graph), ccv_nnc_symbolic_graph_source_size(sub_prep->graph), ccv_nnc_symbolic_graph_destinations(sub_prep->graph), ccv_nnc_symbolic_graph_destination_size(sub_prep->graph)))
1019
0
                flag = 1;
1020
0
              else /* Refit this with the bit back again. */
1021
0
                input_bitmasks[i >> 6] |= ((uint64_t)1 << (i & 63));
1022
0
            }
1023
4
        } while (flag);
1024
        // I am done, need to redo above for sub_prep, and it has to be successful now.
1025
4
        if (!sub_f_symbols)
1026
2
          sub_f_symbols = ccv_array_new(sizeof(ccv_nnc_tensor_symbol_t), 0, 0);
1027
2
        else
1028
2
          ccv_array_clear(sub_f_symbols);
1029
8
        for (i = 0; i < forw_exec->output_size; 
i++4
)
1030
4
          if (input_bitmasks[i >> 6] & ((uint64_t)1 << (i & 63)))
1031
4
          {
1032
4
            const int s_ref = *(int*)ccv_array_get(tensor_symbol_info[forw_exec->outputs[i]].s_ref, graph_ref) - 1;
1033
4
            assert(s_ref >= 0);
1034
4
            ccv_nnc_tensor_symbol_t sub_f_symbol = {
1035
4
              .d = s_ref,
1036
4
              .graph = sub_prep->graph,
1037
4
            };
1038
4
            ccv_array_push(sub_f_symbols, &sub_f_symbol);
1039
4
          }
1040
4
        _ccv_nnc_symbolic_graph_backward_prep_prune_ops(sub_prep, (ccv_nnc_tensor_symbol_t*)ccv_array_get(sub_f_symbols, 0), sub_f_symbols->rnum, (ccv_nnc_tensor_symbol_t*)ccv_array_get(sub_wrt_symbols, 0), sub_wrt_symbols->rnum, ccv_nnc_symbolic_graph_sources(sub_prep->graph), ccv_nnc_symbolic_graph_source_size(sub_prep->graph), ccv_nnc_symbolic_graph_destinations(sub_prep->graph), ccv_nnc_symbolic_graph_destination_size(sub_prep->graph));
1041
4
        if (forw_exec->graph_ref_size > 1)
1042
6
          
for (i = 0; 3
i < node->input_bitmask_size;
i++3
)
1043
3
            stack_input_bitmasks2[i] |= input_bitmasks[i];
1044
4
      }
1045
2
      if (forw_exec->graph_ref_size > 1)
1046
1
        memcpy(node->input_bitmasks, stack_input_bitmasks2, sizeof(uint64_t) * node->input_bitmask_size);
1047
2
    }
1048
18.0k
  } ccv_nnc_graph_visit_endfor
1049
6.72k
  if (sub_f_symbols)
1050
2
    ccv_array_free(sub_f_symbols);
1051
6.72k
  if (sub_wrt_symbols)
1052
2
    ccv_array_free(sub_wrt_symbols);
1053
6.72k
  int flag = 1;
1054
13.4k
  for (i = 0; i < f_symbol_size && 
flag6.74k
;
i++6.74k
)
1055
6.74k
    flag = (used_grad[tensor_symbol_info[f_symbols[i].d].alias_ref ? 
tensor_symbol_info[f_symbols[i].d].alias_ref - 10
: f_symbols[i].d] & WRT_SYMBOL_USE);
1056
6.72k
  ccfree(used_grad);
1057
6.72k
  return flag;
1058
6.72k
}
1059
1060
static void _ccv_nnc_symbolic_graph_backward_prep_gen(ccv_nnc_symbolic_graph_backward_prep_t* const backward_prep, const ccv_nnc_tensor_symbol_t* const f_symbols, const int f_symbol_size, const ccv_nnc_tensor_symbol_t* const wrt_symbols, const int wrt_symbol_size, const int is_while, const ccv_nnc_graph_exec_symbol_t* const sources, const int source_size, const ccv_nnc_graph_exec_symbol_t* const destinations, const int destination_size)
1061
6.72k
{
1062
6.72k
  const int exec_symbol_info_size = backward_prep->exec_symbol_info_size;
1063
6.72k
  const int tensor_symbol_info_size = backward_prep->tensor_symbol_info_size;
1064
6.72k
  const ccv_nnc_graph_exec_symbol_info_t* const exec_symbol_info = backward_prep->exec_symbol_info;
1065
6.72k
  const ccv_nnc_tensor_symbol_info_t* const tensor_symbol_info =backward_prep->tensor_symbol_info;
1066
6.72k
  const ccv_nnc_graph_visit_t* const forward_visit = backward_prep->forward_visit;
1067
  // Now, for each one of these, find a reverse graph.
1068
6.72k
  ccv_nnc_graph_backward_info_t* const backward_info = backward_prep->backward_info;
1069
6.72k
  const ccv_nnc_graph_visit_t* const backward_visit = backward_prep->backward_visit;
1070
6.72k
  int i, j;
1071
  // Now, only the flow from f_symbols back to wrt_symbols are interested to us.
1072
  // Visit the graph in reverse order, build the AD nodes.
1073
6.72k
  ccv_nnc_autograd_graph_exec_symbol_t* const autograd_execs = (ccv_nnc_autograd_graph_exec_symbol_t*)cccalloc(exec_symbol_info_size, sizeof(ccv_nnc_autograd_graph_exec_symbol_t));
1074
6.72k
  int max_forw_input_size = 0, max_forw_output_size = 0;
1075
24.8k
  for (i = 0; i < exec_symbol_info_size; 
i++18.1k
)
1076
18.1k
    if ((backward_info[i].f_wrt & 0x3) == 0x3)
1077
18.0k
    {
1078
18.0k
      max_forw_input_size = ccv_max(max_forw_input_size, exec_symbol_info[i].input_size);
1079
18.0k
      max_forw_output_size = ccv_max(max_forw_output_size, exec_symbol_info[i].output_size);
1080
18.0k
      if (backward_info[i].outgoings)
1081
11.2k
      {
1082
        // Copy over the outgoing bits.
1083
11.2k
        autograd_execs[i].outgoings = ccv_array_new(sizeof(int), backward_info[i].outgoings->rnum, 0);
1084
22.6k
        for (j = 0; j < backward_info[i].outgoings->rnum; 
j++11.3k
)
1085
11.3k
        {
1086
11.3k
          const int d = *(int*)ccv_array_get(backward_info[i].outgoings, j);
1087
          // Only push the outgoing node if it is in the f_wrt path.
1088
11.3k
          if ((backward_info[d].f_wrt & 0x3) == 0x3)
1089
11.3k
            ccv_array_push(autograd_execs[i].outgoings, &d);
1090
11.3k
        }
1091
11.2k
      }
1092
18.0k
    }
1093
6.72k
  int max_forw_inputs[ccv_max(1, max_forw_input_size)];
1094
6.72k
  int max_forw_outputs[ccv_max(1, max_forw_output_size)];
1095
6.72k
  ccv_nnc_autograd_tensor_version_t* const autograd_tensor_versions = (ccv_nnc_autograd_tensor_version_t*)cccalloc(tensor_symbol_info_size, sizeof(ccv_nnc_autograd_tensor_version_t));
1096
6.72k
  ccv_array_t* autograd_tensor_symbols = ccv_array_new(sizeof(ccv_nnc_autograd_tensor_symbol_t), tensor_symbol_info_size, 0);
1097
6.72k
  ccv_array_t* sum_or_set_execs = ccv_array_new(sizeof(ccv_nnc_sum_or_set_graph_exec_symbol_t), 0, 0);
1098
18.0k
  ccv_nnc_graph_visit_for(backward_visit, backward_info, back_info_node, idx) {
1099
    /* This is required by both f flow and wrt flow, therefore, an interest to us */
1100
18.0k
    if ((back_info_node->f_wrt & 0x3) == 0x3)
1101
18.0k
    {
1102
18.0k
      const ccv_nnc_graph_exec_symbol_info_t* forw_exec = exec_symbol_info + idx;
1103
18.0k
      ccv_nnc_autograd_graph_exec_symbol_t* back_exec = autograd_execs + idx;
1104
18.0k
      back_exec->cmd = forw_exec->cmd;
1105
18.0k
      if (back_exec->cmd.cmd != CCV_NNC_NOOP)
1106
18.0k
        back_exec->cmd.cmd += 1; /* Backward command is the one after forward command. */
1107
18.0k
      assert(ccv_nnc_cmd_is_backward(back_exec->cmd) || back_exec->cmd.cmd == CCV_NNC_NOOP);
1108
18.0k
      if (!back_info_node->output_bitmask_size) /* This has no output, can be a noop. */
1109
0
        back_exec->cmd.cmd = CCV_NNC_NOOP;
1110
18.0k
      else {
1111
18.0k
        int* back_input_map = max_forw_outputs;
1112
18.0k
        int* back_output_map = max_forw_inputs;
1113
18.0k
        _ccv_nnc_symbolic_graph_backward_exec_io(forw_exec, &back_input_map, &back_output_map, &back_exec->input_size, &back_exec->output_size);
1114
18.0k
        back_exec->inputs = ccmalloc(sizeof(int) * (back_exec->input_size + back_exec->output_size));
1115
18.0k
        back_exec->outputs = back_exec->inputs + back_exec->input_size;
1116
        /* Need to compute input before we compute output */
1117
36.4k
        for (i = 0; i < back_exec->input_size; 
i++18.4k
)
1118
18.4k
        {
1119
          /* If we can skip this input, do that. */
1120
18.4k
          if (!(back_info_node->input_bitmasks[i >> 6] & ((uint64_t)1 << i)))
1121
397
            continue;
1122
18.0k
          const int d = back_input_map[i];
1123
18.0k
          const int alias_ref = tensor_symbol_info[d].alias_ref;
1124
18.0k
          ccv_nnc_autograd_tensor_version_t* tensor_ver = alias_ref ? 
autograd_tensor_versions + (alias_ref - 1)19
:
autograd_tensor_versions + d18.0k
;
1125
          /* Initialization tensor, should corresponding to f symbols */
1126
18.0k
          if (!tensor_ver->ref_version)
1127
6.73k
          {
1128
6.73k
            ccv_nnc_autograd_tensor_symbol_t tensor_sym = {};
1129
6.73k
            if (!alias_ref)
1130
6.73k
            {
1131
6.73k
              tensor_sym.d = d;
1132
6.73k
              ccv_array_push(autograd_tensor_symbols, &tensor_sym);
1133
6.73k
              const ccv_nnc_tensor_ref_t tensor_ref = {
1134
6.73k
                .d = autograd_tensor_symbols->rnum - 1,
1135
6.73k
                .x = idx,
1136
6.73k
                .alias_registry = 0
1137
6.73k
              };
1138
6.73k
              tensor_ver->ref_version = ccv_array_new(sizeof(ccv_nnc_tensor_ref_t), 1, 0);
1139
6.73k
              ccv_array_push(tensor_ver->ref_version, &tensor_ref);
1140
6.73k
            } else {
1141
2
              tensor_sym.d = alias_ref - 1;
1142
2
              ccv_array_push(autograd_tensor_symbols, &tensor_sym);
1143
2
              const ccv_nnc_tensor_ref_t tensor_ref = {
1144
2
                .d = autograd_tensor_symbols->rnum - 1,
1145
2
                .x = idx,
1146
2
                .alias_registry = ccv_array_new(sizeof(int), 1, 0)
1147
2
              };
1148
2
              tensor_ver->ref_version = ccv_array_new(sizeof(ccv_nnc_tensor_ref_t), 1, 0);
1149
2
              ccv_array_push(tensor_ver->ref_version, &tensor_ref);
1150
2
              tensor_sym.d = d; /* set back */
1151
2
              tensor_sym.alias_ref = tensor_ref.d + 1;
1152
2
              ccv_array_push(autograd_tensor_symbols, &tensor_sym);
1153
2
              const int ad = autograd_tensor_symbols->rnum - 1;
1154
2
              ccv_array_push(tensor_ref.alias_registry, &ad);
1155
2
            }
1156
6.73k
          }
1157
          /* The simplest case (most common), it is not an alias. */
1158
18.0k
          if (!alias_ref)
1159
18.0k
          {
1160
            /* Even simpler, this only have one reference tensor, thus, pass this as input. */
1161
18.0k
            if (tensor_ver->c == tensor_ver->ref_version->rnum - 1)
1162
13.7k
            {
1163
13.7k
              ccv_nnc_tensor_ref_t* tensor_ref = (ccv_nnc_tensor_ref_t*)ccv_array_get(tensor_ver->ref_version, tensor_ver->c);
1164
              /* There are alias associated with this tensor ref, zero it out when this tensor is allocated. */
1165
              /* This is is required. Consider the case that we have an alias of this tensor used somehwere */
1166
              /* on forward pass, when we compute backward, we have that alias computed first, however, its */
1167
              /* underlying tensor is not zero initialized, and we will end up with garbage values here. */
1168
13.7k
              if (tensor_ref->alias_registry &&
1169
                /* Loop over to see if this tensor is fully occupied to avoid extra zero step. */
1170
13.7k
                
!_ccv_nnc_tensor_ref_fully_assigned_with_aliases(tensor_ref, autograd_tensor_symbols, tensor_symbol_info)1.11k
)
1171
1
              {
1172
1
                ccv_nnc_autograd_tensor_symbol_t* tensor_sym = (ccv_nnc_autograd_tensor_symbol_t*)ccv_array_get(autograd_tensor_symbols, tensor_ref->d);
1173
1
                assert(tensor_sym->alias_ref == 0);
1174
1
                tensor_sym->flags = CCV_NNC_TENSOR_SYMBOL_INIT_ZEROS;
1175
1
              }
1176
13.7k
              back_exec->inputs[i] = tensor_ref->d;
1177
13.7k
            } else {
1178
              /* Otherwise, we need to sum them up, and then pass the summed result to the computation. */
1179
4.24k
              _ccv_nnc_graph_sum_autograd_tensor_versions(idx, d, exec_symbol_info_size, tensor_symbol_info, tensor_ver, autograd_execs, autograd_tensor_symbols, sum_or_set_execs);
1180
4.24k
              ccv_nnc_tensor_ref_t* tensor_ref = (ccv_nnc_tensor_ref_t*)ccv_array_get(tensor_ver->ref_version, tensor_ver->c);
1181
4.24k
              back_exec->inputs[i] = tensor_ref->d;
1182
4.24k
            }
1183
18.0k
          } else
1184
            /* If this is an alias, go through all available tensor ref versions */
1185
19
            back_exec->inputs[i] = _ccv_nnc_graph_sum_autograd_tensor_versions_alias(idx, d, tensor_symbol_info, exec_symbol_info_size, tensor_symbol_info + d, tensor_ver, autograd_execs, autograd_tensor_symbols, sum_or_set_execs);
1186
18.0k
        }
1187
51.7k
        
for (i = 0; 18.0k
i < back_exec->output_size;
i++33.6k
)
1188
33.6k
        {
1189
          /* If we can skip this output, do that. */
1190
33.6k
          if (!(back_info_node->output_bitmasks[i >> 6] & ((uint64_t)1 << i)))
1191
6.56k
            continue;
1192
27.1k
          const int d = back_output_map[i];
1193
27.1k
          const int alias_ref = tensor_symbol_info[d].alias_ref;
1194
27.1k
          ccv_nnc_autograd_tensor_symbol_t tensor_sym = {
1195
27.1k
            .d = d
1196
27.1k
          };
1197
          /* The simplest case (most common), it is not an alias. */
1198
27.1k
          if (!alias_ref)
1199
25.9k
          {
1200
25.9k
            ccv_array_push(autograd_tensor_symbols, &tensor_sym);
1201
25.9k
            const ccv_nnc_tensor_ref_t tensor_ref = {
1202
25.9k
              .d = autograd_tensor_symbols->rnum - 1,
1203
25.9k
              .x = idx,
1204
25.9k
              .exec_registry = 0,
1205
25.9k
              .alias_registry = 0
1206
25.9k
            };
1207
25.9k
            ccv_nnc_autograd_tensor_version_t* tensor_ver = autograd_tensor_versions + d;
1208
25.9k
            if (!tensor_ver->ref_version)
1209
21.6k
              tensor_ver->ref_version = ccv_array_new(sizeof(ccv_nnc_tensor_ref_t), 1, 0);
1210
25.9k
            ccv_array_push(tensor_ver->ref_version, &tensor_ref);
1211
25.9k
            back_exec->outputs[i] = tensor_ref.d;
1212
25.9k
          } else {
1213
            /* Otherwise, in case that this is an alias, we try to find the existing one (in tensor_ver
1214
             * see if can meet the need (thus, for the tensor info / ofs, it fits). */
1215
1.16k
            ccv_nnc_autograd_tensor_version_t* tensor_ver = autograd_tensor_versions + (alias_ref - 1);
1216
1.16k
            if (!tensor_ver->ref_version)
1217
1.12k
              tensor_ver->ref_version = ccv_array_new(sizeof(ccv_nnc_tensor_ref_t), 1, 0);
1218
            /* If already exists a ref version, check if any of these not-sealed tensors have free space. */
1219
1.16k
            int found = 0;
1220
1.22k
            for (j = tensor_ver->c; !found && 
j < tensor_ver->ref_version->rnum1.21k
;
j++64
)
1221
64
            {
1222
64
              ccv_nnc_tensor_ref_t* tensor_ref = (ccv_nnc_tensor_ref_t*)ccv_array_get(tensor_ver->ref_version, j);
1223
64
              if (!_ccv_nnc_tensor_ref_version_involve_alias(tensor_ref, autograd_tensor_symbols, tensor_symbol_info, tensor_symbol_info + d))
1224
9
              {
1225
9
                tensor_sym.alias_ref = tensor_ref->d + 1;
1226
9
                ccv_array_push(autograd_tensor_symbols, &tensor_sym);
1227
9
                const int ad = autograd_tensor_symbols->rnum - 1;
1228
9
                ccv_array_push(tensor_ref->alias_registry, &ad);
1229
9
                if (!tensor_ref->exec_registry)
1230
7
                  tensor_ref->exec_registry = ccv_array_new(sizeof(int), 1, 0);
1231
9
                ccv_array_push(tensor_ref->exec_registry, &idx);
1232
9
                back_exec->outputs[i] = ad;
1233
9
                found = 1;
1234
9
              }
1235
64
            }
1236
1.16k
            if (!found) /* Cannot find an tensor ref to insert, create one first */
1237
1.15k
            {
1238
1.15k
              tensor_sym.d = alias_ref - 1; /* Reference back to the non-alias. */
1239
1.15k
              ccv_array_push(autograd_tensor_symbols, &tensor_sym);
1240
1.15k
              const ccv_nnc_tensor_ref_t tensor_ref = {
1241
1.15k
                .d = autograd_tensor_symbols->rnum - 1,
1242
1.15k
                .x = idx,
1243
1.15k
                .exec_registry = 0,
1244
1.15k
                .alias_registry = ccv_array_new(sizeof(int), 1, 0)
1245
1.15k
              };
1246
1.15k
              ccv_array_push(tensor_ver->ref_version, &tensor_ref);
1247
1.15k
              tensor_sym.d = d; /* set back */
1248
1.15k
              tensor_sym.alias_ref = tensor_ref.d + 1;
1249
1.15k
              ccv_array_push(autograd_tensor_symbols, &tensor_sym);
1250
1.15k
              const int ad = autograd_tensor_symbols->rnum - 1;
1251
1.15k
              ccv_array_push(tensor_ref.alias_registry, &ad);
1252
1.15k
              back_exec->outputs[i] = ad;
1253
1.15k
            }
1254
1.16k
          }
1255
27.1k
        }
1256
18.0k
      }
1257
18.0k
    }
1258
18.0k
  } ccv_nnc_graph_visit_endfor
1259
  // Find all relevant wrt symbols, generate sum for them if needed.
1260
16.2k
  
for (i = 0; 6.72k
i < wrt_symbol_size;
i++9.48k
)
1261
9.48k
  {
1262
9.48k
    const int d = wrt_symbols[i].d;
1263
9.48k
    if (d < 0)
1264
2
      continue;
1265
9.47k
    const int ref_d = (!tensor_symbol_info[d].alias_ref) ? 
d9.47k
:
tensor_symbol_info[d].alias_ref - 11
;
1266
9.47k
    ccv_nnc_autograd_tensor_version_t* tensor_ver = autograd_tensor_versions + ref_d;
1267
9.47k
    if (!tensor_ver->ref_version)
1268
1
    {
1269
      // This wrt symbol is not available at all, for this case, we set its flag to init zero.
1270
1
      const ccv_nnc_autograd_tensor_symbol_t tensor_sym = {
1271
1
        .d = ref_d
1272
1
      };
1273
1
      ccv_array_push(autograd_tensor_symbols, &tensor_sym);
1274
1
      ccv_nnc_sum_or_set_graph_exec_symbol_t set_exec = {
1275
1
        .value = 0,
1276
1
        .output = autograd_tensor_symbols->rnum - 1,
1277
1
      };
1278
1
      ccv_array_push(sum_or_set_execs, &set_exec);
1279
      // Insert the one to be set to zero.
1280
1
      const ccv_nnc_tensor_ref_t tensor_ref = {
1281
1
        .d = autograd_tensor_symbols->rnum - 1,
1282
1
        .x = exec_symbol_info_size + sum_or_set_execs->rnum - 1,
1283
1
      };
1284
1
      tensor_ver->ref_version = ccv_array_new(sizeof(ccv_nnc_tensor_ref_t), 1, 0);
1285
1
      ccv_array_push(tensor_ver->ref_version, &tensor_ref);
1286
1
      continue;
1287
1
    }
1288
    // If it is a while loop, we need to insert an accumulator to the graph (this is expressed as a initialization tensor summed with existing results).
1289
    // First, insert the initialization tensor if this wrt results is not used directly in next while loop (thus, it participates the computation, therefore, no need to accumulate).
1290
9.47k
    if (is_while && 
!tensor_symbol_info[ref_d].assign_ref2
&&
1291
9.47k
      
_ccv_nnc_tensor_ref_version_find_init(tensor_ver) < 01
) // If the initialization tensor is not inserted yet.
1292
1
    {
1293
1
      const ccv_nnc_autograd_tensor_symbol_t tensor_sym = {
1294
1
        .d = ref_d
1295
1
      };
1296
1
      ccv_array_push(autograd_tensor_symbols, &tensor_sym);
1297
      // Insert the one to be summed.
1298
1
      const ccv_nnc_tensor_ref_t tensor_ref = {
1299
1
        .d = autograd_tensor_symbols->rnum - 1,
1300
1
        .x = -1, // This denotes it is an initialization vector.
1301
1
      };
1302
1
      ccv_array_push(tensor_ver->ref_version, &tensor_ref);
1303
1
    }
1304
    // If there are more than one tensor in the list, it is possible to sum them up.
1305
9.47k
    if (tensor_ver->c < tensor_ver->ref_version->rnum - 1)
1306
23
      _ccv_nnc_graph_sum_autograd_tensor_versions(-1, ref_d, exec_symbol_info_size, tensor_symbol_info, tensor_ver, autograd_execs, autograd_tensor_symbols, sum_or_set_execs);
1307
    // The tensor version should have ref_version, and only one now (after sum up).
1308
9.47k
    assert(tensor_ver->c == tensor_ver->ref_version->rnum - 1);
1309
9.47k
  }
1310
  // Adding additional fields to backward_prep now.
1311
6.72k
  backward_prep->autograd_execs = autograd_execs;
1312
6.72k
  backward_prep->autograd_tensor_versions = autograd_tensor_versions;
1313
6.72k
  backward_prep->autograd_tensor_symbols = autograd_tensor_symbols;
1314
6.72k
  backward_prep->sum_or_set_execs = sum_or_set_execs;
1315
6.72k
  ccv_array_t* sub_f_symbols = 0;
1316
6.72k
  ccv_array_t* sub_wrt_symbols = 0;
1317
18.0k
  ccv_nnc_graph_visit_for(forward_visit, exec_symbol_info, _, idx) {
1318
18.0k
    ccv_nnc_graph_backward_info_t* node = backward_info + idx;
1319
18.0k
    const ccv_nnc_graph_exec_symbol_info_t* forw_exec = exec_symbol_info + idx;
1320
    /* Only interested in the ones on the f / wrt flow */
1321
18.0k
    if ((node->f_wrt & 0x3) == 0x3)
1322
18.0k
    {
1323
18.0k
      const int is_while = (forw_exec->flags & CCV_NNC_GRAPH_EXEC_P_WHILE);
1324
18.0k
      for (i = 0; i < forw_exec->graph_ref_size; 
i++4
)
1325
4
      {
1326
        // Now calling it recursively until we are sure no f_symbols can be removed.
1327
4
        const int graph_ref = CCV_NNC_GRAPH_REF(forw_exec)[i] - 1;
1328
4
        ccv_nnc_symbolic_graph_backward_prep_t* const sub_prep = backward_prep->sub_preps + graph_ref;
1329
4
        if (!sub_wrt_symbols)
1330
2
          sub_wrt_symbols = ccv_array_new(sizeof(ccv_nnc_tensor_symbol_t), 0, 0);
1331
4
        if (!sub_f_symbols)
1332
2
          sub_f_symbols = ccv_array_new(sizeof(ccv_nnc_tensor_symbol_t), 0, 0);
1333
4
        _ccv_nnc_symbolic_graph_backward_prep_sub_f_wrt_symbols(forw_exec, sub_prep->graph, graph_ref, tensor_symbol_info, node->input_bitmasks, node->output_bitmasks, sub_f_symbols, sub_wrt_symbols);
1334
4
        _ccv_nnc_symbolic_graph_backward_prep_gen(sub_prep, (ccv_nnc_tensor_symbol_t*)ccv_array_get(sub_f_symbols, 0), sub_f_symbols->rnum, (ccv_nnc_tensor_symbol_t*)ccv_array_get(sub_wrt_symbols, 0), sub_wrt_symbols->rnum, is_while, ccv_nnc_symbolic_graph_sources(sub_prep->graph), ccv_nnc_symbolic_graph_source_size(sub_prep->graph), ccv_nnc_symbolic_graph_destinations(sub_prep->graph), ccv_nnc_symbolic_graph_destination_size(sub_prep->graph));
1335
4
      }
1336
18.0k
    }
1337
18.0k
  } ccv_nnc_graph_visit_endfor
1338
6.72k
  if (sub_f_symbols)
1339
2
    ccv_array_free(sub_f_symbols);
1340
6.72k
  if (sub_wrt_symbols)
1341
2
    ccv_array_free(sub_wrt_symbols);
1342
6.72k
}
1343
1344
static void _ccv_nnc_symbolic_graph_backward_prep_free(const ccv_nnc_symbolic_graph_backward_prep_t backward_prep)
1345
6.72k
{
1346
6.72k
  int i, j;
1347
6.72k
  const int exec_symbol_info_size = backward_prep.exec_symbol_info_size;
1348
6.72k
  const int tensor_symbol_info_size = backward_prep.tensor_symbol_info_size;
1349
6.72k
  ccv_nnc_autograd_graph_exec_symbol_t* const autograd_execs = backward_prep.autograd_execs;
1350
6.72k
  if (autograd_execs)
1351
6.72k
  {
1352
24.8k
    for (i = 0; i < exec_symbol_info_size; 
i++18.1k
)
1353
18.1k
    {
1354
18.1k
      if (autograd_execs[i].inputs)
1355
18.0k
        ccfree(autograd_execs[i].inputs);
1356
18.1k
      if (autograd_execs[i].outgoings)
1357
11.3k
        ccv_array_free(autograd_execs[i].outgoings);
1358
18.1k
    }
1359
6.72k
    ccfree(autograd_execs);
1360
6.72k
  }
1361
6.72k
  ccv_nnc_autograd_tensor_version_t* const autograd_tensor_versions = backward_prep.autograd_tensor_versions;
1362
6.72k
  if (autograd_tensor_versions)
1363
6.72k
  {
1364
45.2k
    for (i = 0; i < tensor_symbol_info_size; 
i++38.5k
)
1365
38.5k
    {
1366
38.5k
      if (autograd_tensor_versions[i].ref_version)
1367
29.5k
      {
1368
67.6k
        for (j = 0; j < autograd_tensor_versions[i].ref_version->rnum; 
j++38.1k
)
1369
38.1k
        {
1370
38.1k
          ccv_nnc_tensor_ref_t* ref_version = (ccv_nnc_tensor_ref_t*)ccv_array_get(autograd_tensor_versions[i].ref_version, j);
1371
38.1k
          if (ref_version->exec_registry)
1372
7
            ccv_array_free(ref_version->exec_registry);
1373
38.1k
          if (ref_version->alias_registry)
1374
1.16k
            ccv_array_free(ref_version->alias_registry);
1375
38.1k
        }
1376
29.5k
        ccv_array_free(autograd_tensor_versions[i].ref_version);
1377
29.5k
      }
1378
38.5k
    }
1379
6.72k
    ccfree(autograd_tensor_versions);
1380
6.72k
  }
1381
6.72k
  if (backward_prep.autograd_tensor_symbols)
1382
6.72k
    ccv_array_free(backward_prep.autograd_tensor_symbols);
1383
6.72k
  ccv_array_t* const sum_or_set_execs = backward_prep.sum_or_set_execs;
1384
6.72k
  if (sum_or_set_execs)
1385
6.72k
  {
1386
11.0k
    for (i = 0; i < sum_or_set_execs->rnum; 
i++4.27k
)
1387
4.27k
    {
1388
4.27k
      ccv_nnc_sum_or_set_graph_exec_symbol_t* sum_or_set = (ccv_nnc_sum_or_set_graph_exec_symbol_t*)ccv_array_get(sum_or_set_execs, i);
1389
4.27k
      if (sum_or_set->inputs)
1390
4.27k
        ccfree(sum_or_set->inputs);
1391
4.27k
      if (sum_or_set->outgoings)
1392
4.24k
        ccv_array_free(sum_or_set->outgoings);
1393
4.27k
    }
1394
6.72k
    ccv_array_free(sum_or_set_execs);
1395
6.72k
  }
1396
  // Now afterwards, these are mandatory.
1397
6.72k
  ccv_nnc_graph_backward_info_t* const backward_info = backward_prep.backward_info;
1398
24.8k
  for (i = 0; i < exec_symbol_info_size; 
i++18.1k
)
1399
18.1k
  {
1400
18.1k
    if (backward_info[i].outgoings)
1401
11.2k
      ccv_array_free(backward_info[i].outgoings);
1402
18.1k
    if (backward_info[i].input_bitmasks)
1403
18.0k
      ccfree(backward_info[i].input_bitmasks);
1404
18.1k
  }
1405
6.72k
  ccfree(backward_info);
1406
6.72k
  ccv_nnc_graph_visit_free(backward_prep.backward_visit);
1407
6.72k
  ccv_nnc_graph_visit_free(backward_prep.forward_visit);
1408
6.72k
  ccfree(backward_prep.exec_symbol_info);
1409
6.72k
  ccfree(backward_prep.tensor_symbol_info);
1410
6.73k
  for (i = 0; i < backward_prep.sub_prep_size; 
i++4
)
1411
4
    _ccv_nnc_symbolic_graph_backward_prep_free(backward_prep.sub_preps[i]);
1412
6.72k
  if (backward_prep.sub_preps)
1413
2
    ccfree(backward_prep.sub_preps);
1414
6.72k
}
1415
1416
static void _ccv_nnc_add_backward_breakpoint_for_symbol(const ccv_nnc_symbolic_graph_backward_prep_t* const backward_prep, const ccv_nnc_graph_exec_symbol_t breakpoint, ccv_nnc_symbolic_graph_t* const graph, ccv_array_t* const sub_breakpoints)
1417
1
{
1418
1
  const ccv_nnc_graph_exec_symbol_t noop = ccv_nnc_graph_exec_symbol_new(graph, ccv_nnc_cmd(CCV_NNC_NOOP, 0, CMD_GENERIC(), 0), 0, 0, 0, 0, 0);
1419
1
  ccv_array_push(sub_breakpoints, &noop);
1420
  // Now need to hook this up to the graph.
1421
1
  const ccv_nnc_graph_exec_symbol_info_t* const exec_symbol_info = backward_prep->exec_symbol_info;
1422
1
  const ccv_nnc_graph_visit_t* const forward_visit = backward_prep->forward_visit;
1423
  // Now, for each one of these, find a reverse graph.
1424
1
  ccv_nnc_graph_backward_info_t* const backward_info = backward_prep->backward_info;
1425
1
  int i;
1426
  // Clean up the high bit.
1427
4
  for (i = 0; i < backward_prep->exec_symbol_info_size; 
i++3
)
1428
3
    backward_info[i].f_wrt &= ~0x4;
1429
1
  assert((backward_info[breakpoint.d].f_wrt & 0x3) != 0x3);
1430
1
  backward_info[breakpoint.d].f_wrt |= 0x4;
1431
1
  const ccv_nnc_graph_visit_t* const backward_visit = backward_prep->backward_visit;
1432
1
  const ccv_nnc_autograd_graph_exec_symbol_t* const autograd_execs = backward_prep->autograd_execs;
1433
  // Going forward to find whether this breakpoint is a source node to some f_wrt nodes.
1434
3
  ccv_nnc_graph_visit_for(forward_visit, exec_symbol_info, forw_exec, idx) {
1435
3
    ccv_nnc_graph_backward_info_t* const node = backward_info + idx;
1436
    // If it is tagged on breakpoint flow, but not as both f or wrt, flow through it.
1437
3
    if ((node->f_wrt & 0x4) && 
(node->f_wrt & 0x3) != 0x31
)
1438
1
      for (i = 0; forw_exec->outgoings && 
i < forw_exec->outgoings->rnum0
;
i++0
)
1439
0
      {
1440
0
        const int outgoing_idx = *(int*)ccv_array_get(forw_exec->outgoings, i);
1441
0
        ccv_nnc_graph_backward_info_t* const outgoing_node = backward_info + outgoing_idx;
1442
        // If this is a f_wrt node. Concatenate.
1443
0
        if (!(outgoing_node->f_wrt & 0x4) && (outgoing_node->f_wrt & 0x3) == 0x3)
1444
0
            ccv_nnc_graph_exec_symbol_concat(graph, autograd_execs[outgoing_idx].symbol, noop);
1445
0
        outgoing_node->f_wrt |= 0x4;
1446
0
      }
1447
3
  } ccv_nnc_graph_visit_endfor
1448
  // Going backward to find whether this breakpoint is a destination node for some f_wrt_nodes.
1449
3
  ccv_nnc_graph_visit_for(backward_visit, backward_info, node, idx) {
1450
3
    if ((node->f_wrt & 0x4) && 
(node->f_wrt & 0x3) != 0x32
)
1451
2
      
for (i = 0; 1
node->outgoings && i < node->outgoings->rnum;
i++1
)
1452
1
      {
1453
1
        const int outgoing_idx = *(int*)ccv_array_get(node->outgoings, i);
1454
1
        ccv_nnc_graph_backward_info_t* const outgoing_node = backward_info + outgoing_idx;
1455
        // If this is a f_wrt node. Concatenate.
1456
1
        if (!(outgoing_node->f_wrt & 0x4) && (outgoing_node->f_wrt & 0x3) == 0x3)
1457
1
            ccv_nnc_graph_exec_symbol_concat(graph, noop, autograd_execs[outgoing_idx].symbol);
1458
1
        outgoing_node->f_wrt |= 0x4;
1459
1
      }
1460
3
  } ccv_nnc_graph_visit_endfor
1461
1
}
1462
1463
static ccv_nnc_autograd_tensor_symbol_t* _ccv_nnc_autograd_tensor_symbol_from_tensor_version(ccv_array_t* const autograd_tensor_symbols, const ccv_nnc_autograd_tensor_version_t* const tensor_ver)
1464
7
{
1465
7
  assert(tensor_ver->ref_version);
1466
7
  const ccv_nnc_tensor_ref_t* const tensor_ref = (ccv_nnc_tensor_ref_t*)ccv_array_get(tensor_ver->ref_version, tensor_ver->c);
1467
7
  return (ccv_nnc_autograd_tensor_symbol_t*)ccv_array_get(autograd_tensor_symbols, tensor_ref->d);
1468
7
}
1469
1470
static void _ccv_nnc_symbolic_graph_set_backward_carry_overs(const ccv_nnc_symbolic_graph_backward_prep_t* const backward_prep, const ccv_nnc_tensor_symbol_t* const wrt_symbols, const int wrt_symbol_size, ccv_nnc_symbolic_graph_t* const graph)
1471
1
{
1472
1
  int i;
1473
5
  for (i = 0; i < backward_prep->graph->tensor_symbol_info->rnum; 
i++4
)
1474
4
  {
1475
4
    const ccv_nnc_tensor_symbol_info_t* const tensor_symbol_info = backward_prep->tensor_symbol_info + i;
1476
4
    if (tensor_symbol_info->assign_ref)
1477
1
    {
1478
1
      const int assign_ref = tensor_symbol_info->assign_ref - 1;
1479
1
      ccv_nnc_autograd_tensor_symbol_t* const destination_autograd_symbol = _ccv_nnc_autograd_tensor_symbol_from_tensor_version(backward_prep->autograd_tensor_symbols, backward_prep->autograd_tensor_versions + assign_ref);
1480
1
      ccv_nnc_autograd_tensor_symbol_t* const source_autograd_symbol = _ccv_nnc_autograd_tensor_symbol_from_tensor_version(backward_prep->autograd_tensor_symbols, backward_prep->autograd_tensor_versions + i);
1481
1
      ccv_nnc_symbolic_graph_set_carry_overs(graph, (ccv_nnc_tensor_symbol_map_t []){
1482
1
        { .source = source_autograd_symbol->symbol, .destination = destination_autograd_symbol->symbol }
1483
1
      }, 1);
1484
1
    }
1485
4
  }
1486
3
  for (i = 0; i < wrt_symbol_size; 
i++2
)
1487
2
  {
1488
2
    const int d = wrt_symbols[i].d;
1489
2
    if (d < 0)
1490
0
      continue;
1491
2
    const int ref_d = (!backward_prep->tensor_symbol_info[d].alias_ref) ? d : 
backward_prep->tensor_symbol_info[d].alias_ref - 10
;
1492
2
    const ccv_nnc_autograd_tensor_version_t* const tensor_ver = backward_prep->autograd_tensor_versions + ref_d;
1493
2
    const int init_ref_ver = _ccv_nnc_tensor_ref_version_find_init(tensor_ver);
1494
2
    if (init_ref_ver >= 0)
1495
1
    {
1496
1
      const int init_d = ((ccv_nnc_tensor_ref_t*)ccv_array_get(tensor_ver->ref_version, init_ref_ver))->d;
1497
1
      ccv_nnc_autograd_tensor_symbol_t* const destination_autograd_symbol = (ccv_nnc_autograd_tensor_symbol_t*)ccv_array_get(backward_prep->autograd_tensor_symbols, init_d);
1498
1
      ccv_nnc_autograd_tensor_symbol_t* const source_autograd_symbol = _ccv_nnc_autograd_tensor_symbol_from_tensor_version(backward_prep->autograd_tensor_symbols, backward_prep->autograd_tensor_versions + ref_d);
1499
1
      ccv_nnc_symbolic_graph_set_carry_overs(graph, (ccv_nnc_tensor_symbol_map_t []){
1500
1
        { .source = source_autograd_symbol->symbol, .destination = destination_autograd_symbol->symbol }
1501
1
      }, 1);
1502
1
    }
1503
2
  }
1504
1
}
1505
1506
static void _ccv_nnc_symbolic_graph_add_init_zeros(const ccv_nnc_symbolic_graph_backward_prep_t* const sub_prep, const ccv_nnc_tensor_symbol_t* const wrt_symbols, const int wrt_symbol_size, ccv_nnc_symbolic_graph_t* const graph, ccv_nnc_symbolic_graph_t* const sub_graph, ccv_array_t* const symbols)
1507
1
{
1508
1
  int i;
1509
3
  for (i = 0; i < wrt_symbol_size; 
i++2
)
1510
2
  {
1511
2
    const int d = wrt_symbols[i].d;
1512
2
    if (d < 0)
1513
0
      continue;
1514
2
    const int ref_d = (!sub_prep->tensor_symbol_info[d].alias_ref) ? d : 
sub_prep->tensor_symbol_info[d].alias_ref - 10
;
1515
2
    const ccv_nnc_autograd_tensor_version_t* const tensor_ver = sub_prep->autograd_tensor_versions + ref_d;
1516
2
    const int init_ref_ver = _ccv_nnc_tensor_ref_version_find_init(tensor_ver);
1517
2
    if (init_ref_ver >= 0)
1518
1
    {
1519
      // Need de-dup logic.
1520
1
      const int init_d = ((ccv_nnc_tensor_ref_t*)ccv_array_get(tensor_ver->ref_version, init_ref_ver))->d;
1521
1
      ccv_nnc_autograd_tensor_symbol_t* const init_autograd_symbol = (ccv_nnc_autograd_tensor_symbol_t*)ccv_array_get(sub_prep->autograd_tensor_symbols, init_d);
1522
1
      const ccv_nnc_tensor_symbol_info_t* const sub_init_symbol_info = (ccv_nnc_tensor_symbol_info_t*)ccv_array_get(sub_graph->tensor_symbol_info, init_autograd_symbol->symbol.d);
1523
      // If it doesn't have a parent ref yet, create one.
1524
1
      if (!sub_init_symbol_info->p_ref)
1525
1
      {
1526
1
        ccv_nnc_tensor_symbol_t new_symbol = ccv_nnc_tensor_symbol_new(graph, sub_prep->tensor_symbol_info[ref_d].info, 0);
1527
1
        ccv_nnc_tensor_symbol_set_flags(graph, new_symbol, CCV_NNC_TENSOR_SYMBOL_INIT_ZEROS);
1528
1
        ccv_array_push(symbols, &new_symbol);
1529
1
        ccv_nnc_tensor_symbol_hookup(graph, sub_graph, new_symbol, init_autograd_symbol->symbol);
1530
1
      }
1531
1
    }
1532
2
  }
1533
1
}
1534
1535
static void _ccv_nnc_symbolic_graph_add_tape_vars(const ccv_nnc_symbolic_graph_backward_prep_t* const sub_prep, ccv_nnc_symbolic_graph_t* const root, ccv_nnc_symbolic_graph_t* const graph, ccv_nnc_symbolic_graph_t* const sub_graph, ccv_array_t* const symbols)
1536
4
{
1537
4
  int i;
1538
24
  for (i = 0; i < sub_graph->tensor_symbol_info->rnum; 
i++20
)
1539
20
  {
1540
20
    const ccv_nnc_tensor_symbol_info_t* const symbol_info = (ccv_nnc_tensor_symbol_info_t*)ccv_array_get(sub_graph->tensor_symbol_info, i);
1541
20
    if ((symbol_info->flags & CCV_NNC_TENSOR_SYMBOL_TAPE_VAR) && 
symbol_info->pair_ref7
)
1542
7
    {
1543
7
      const int pair_ref = symbol_info->pair_ref - 1;
1544
7
      const ccv_nnc_tensor_symbol_t root_symbol = ccv_nnc_tensor_symbol_resolve(root, (ccv_nnc_tensor_symbol_t){
1545
7
        .d = pair_ref,
1546
7
        .graph = sub_prep->graph,
1547
7
      });
1548
7
      if (root_symbol.d >= 0)
1549
3
      {
1550
3
        ccv_nnc_tensor_symbol_hookup(root, sub_graph, root_symbol, (ccv_nnc_tensor_symbol_t){
1551
3
          .d = i,
1552
3
          .graph = sub_graph,
1553
3
        });
1554
3
        if (symbols)
1555
2
        {
1556
2
          const ccv_nnc_tensor_symbol_t p_symbol = ccv_nnc_tensor_symbol_resolve(graph, (ccv_nnc_tensor_symbol_t){
1557
2
            .d = i,
1558
2
            .graph = sub_graph,
1559
2
          });
1560
2
          ccv_array_push(symbols, &p_symbol);
1561
2
        }
1562
3
      }
1563
7
    }
1564
20
  }
1565
4
}
1566
1567
static void _ccv_nnc_symbolic_graph_backward_gen(const ccv_nnc_symbolic_graph_backward_prep_t* const backward_prep, const ccv_nnc_tensor_symbol_t* const f_symbols, const int f_symbol_size, const ccv_nnc_tensor_symbol_t* const wrt_symbols, const int wrt_symbol_size, ccv_nnc_symbolic_graph_t* const graph, ccv_nnc_symbolic_graph_t* const root)
1568
6.72k
{
1569
6.72k
  assert(graph == backward_prep->graph || graph->pair == backward_prep->graph);
1570
6.72k
  const int exec_symbol_info_size = backward_prep->exec_symbol_info_size;
1571
6.72k
  const int tensor_symbol_info_size = backward_prep->tensor_symbol_info_size;
1572
6.72k
  const ccv_nnc_graph_exec_symbol_info_t* const exec_symbol_info = backward_prep->exec_symbol_info;
1573
6.72k
  const ccv_nnc_tensor_symbol_info_t* const tensor_symbol_info = backward_prep->tensor_symbol_info;
1574
6.72k
  int i, j, k, p;
1575
6.72k
  ccv_array_t* const autograd_tensor_symbols = backward_prep->autograd_tensor_symbols;
1576
  // Generate required symbols based on the information gathered above.
1577
46.0k
  for (i = 0; i < autograd_tensor_symbols->rnum; 
i++39.2k
)
1578
39.2k
  {
1579
39.2k
    ccv_nnc_autograd_tensor_symbol_t* symbol = (ccv_nnc_autograd_tensor_symbol_t*)ccv_array_get(autograd_tensor_symbols, i);
1580
39.2k
    assert(symbol->d >= 0);
1581
39.2k
    assert(symbol->d < tensor_symbol_info_size);
1582
39.2k
    const ccv_nnc_tensor_symbol_info_t* const forw_symbol = tensor_symbol_info + symbol->d;
1583
39.2k
    if (!symbol->alias_ref)
1584
38.1k
    {
1585
38.1k
      assert(!forw_symbol->alias_ref);
1586
38.1k
      symbol->symbol = ccv_nnc_tensor_symbol_new(graph, forw_symbol->info, 0);
1587
38.1k
      ccv_nnc_tensor_symbol_set_flags(graph, symbol->symbol, symbol->flags);
1588
38.1k
    } else {
1589
1.18k
      assert(forw_symbol->alias_ref);
1590
1.18k
      assert(symbol->flags == 0); // We don't set flags on alias.
1591
      // Due to our generation order, this must be after the original symbol is created.
1592
1.18k
      ccv_nnc_autograd_tensor_symbol_t* ref = (ccv_nnc_autograd_tensor_symbol_t*)ccv_array_get(autograd_tensor_symbols, symbol->alias_ref - 1);
1593
1.18k
      symbol->symbol = ccv_nnc_tensor_symbol_alias_new(graph, ref->symbol, forw_symbol->ofs, forw_symbol->inc, forw_symbol->info, 0);
1594
1.18k
    }
1595
39.2k
  }
1596
6.72k
  ccv_nnc_graph_backward_info_t* const backward_info = backward_prep->backward_info;
1597
6.72k
  ccv_nnc_autograd_graph_exec_symbol_t* const autograd_execs = backward_prep->autograd_execs;
1598
6.72k
  ccv_array_t* symbols = ccv_array_new(sizeof(ccv_nnc_tensor_symbol_t), 0, 0);
1599
6.72k
  ccv_array_t* symbol_map = ccv_array_new(sizeof(ccv_nnc_tensor_symbol_map_t), 0, 0);
1600
6.72k
  ccv_array_t* sub_f_symbols = 0;
1601
6.72k
  ccv_array_t* sub_wrt_symbols = 0;
1602
6.72k
  ccv_array_t* sub_execs = 0;
1603
24.8k
  for (i = 0; i < exec_symbol_info_size; 
i++18.1k
)
1604
18.1k
  {
1605
    // This is not going to be an interesting node. Skip.
1606
18.1k
    if ((backward_info[i].f_wrt & 0x3) != 0x3)
1607
76
      continue;
1608
18.0k
    ccv_nnc_graph_backward_info_t* const back_info = backward_info + i;
1609
18.0k
    ccv_nnc_autograd_graph_exec_symbol_t* const back_exec = autograd_execs + i;
1610
18.0k
    if (back_exec->cmd.cmd == CCV_NNC_NOOP)
1611
1
    {
1612
1
      back_exec->symbol = ccv_nnc_graph_exec_symbol_new(graph, back_exec->cmd, 0, 0, 0, 0, 0);
1613
1
      continue;
1614
1
    }
1615
18.0k
    const ccv_nnc_graph_exec_symbol_info_t* const forw_exec = exec_symbol_info + i;
1616
18.0k
    if (forw_exec->flags & CCV_NNC_GRAPH_EXEC_P_WHILE)
1617
1
    {
1618
1
      ccv_array_clear(symbols);
1619
1
      const int graph_ref = CCV_NNC_GRAPH_REF(forw_exec)[0] - 1;
1620
1
      ccv_nnc_symbolic_graph_backward_prep_t* sub_prep = backward_prep->sub_preps + graph_ref;
1621
1
      ccv_nnc_symbolic_graph_t* sub_graph = ccv_nnc_symbolic_graph_new();
1622
1
      sub_graph->pair = sub_prep->graph;
1623
1
      if (!sub_wrt_symbols)
1624
1
        sub_wrt_symbols = ccv_array_new(sizeof(ccv_nnc_tensor_symbol_t), 0, 0);
1625
      // I am done, need to redo above for sub_prep, and it has to be successful now.
1626
1
      if (!sub_f_symbols)
1627
1
        sub_f_symbols = ccv_array_new(sizeof(ccv_nnc_tensor_symbol_t), 0, 0);
1628
1
      _ccv_nnc_symbolic_graph_backward_prep_sub_f_wrt_symbols(forw_exec, sub_prep->graph, graph_ref, tensor_symbol_info, back_info->input_bitmasks, back_info->output_bitmasks, sub_f_symbols, sub_wrt_symbols);
1629
1
      _ccv_nnc_symbolic_graph_backward_gen(sub_prep, (ccv_nnc_tensor_symbol_t*)ccv_array_get(sub_f_symbols, 0), sub_f_symbols->rnum, (ccv_nnc_tensor_symbol_t*)ccv_array_get(sub_wrt_symbols, 0), sub_wrt_symbols->rnum, sub_graph, root);
1630
1
      back_exec->symbol = ccv_nnc_symbolic_graph_while(graph, back_exec->cmd.cmd, sub_graph, forw_exec->name);
1631
1
      if (!sub_execs)
1632
1
        sub_execs = ccv_array_new(sizeof(ccv_nnc_graph_exec_symbol_t), 0, 0);
1633
1
      ccv_array_clear(sub_execs);
1634
      // Find the breakpoints in forward graph, creating the reverse one.
1635
2
      for (j = 0; j < sub_prep->graph->breakpoint_size; 
j++1
)
1636
1
      {
1637
1
        const int d = sub_prep->graph->breakpoints[j].d;
1638
1
        if (sub_prep->autograd_execs[d].symbol.graph)
1639
0
          ccv_array_push(sub_execs, &sub_prep->autograd_execs[d].symbol);
1640
1
        else
1641
1
          _ccv_nnc_add_backward_breakpoint_for_symbol(sub_prep, sub_prep->graph->breakpoints[j], sub_graph, sub_execs);
1642
1
      }
1643
1
      ccv_nnc_symbolic_graph_set_while_expr(sub_graph, NOOP_GRAPH_WHILE_EXPR, 0, 0, 0, (ccv_nnc_graph_exec_symbol_t*)ccv_array_get(sub_execs, 0), sub_execs->rnum);
1644
1
      ccv_nnc_graph_exec_symbol_autogen(sub_graph, 0, 0, CCV_NNC_AUTOGEN_SOURCES_AND_DESTINATIONS);
1645
1
      _ccv_nnc_symbolic_graph_set_backward_carry_overs(sub_prep, (ccv_nnc_tensor_symbol_t*)ccv_array_get(sub_wrt_symbols, 0), sub_wrt_symbols->rnum, sub_graph);
1646
2
      for (j = 0; j < back_exec->input_size; 
j++1
)
1647
1
        if (back_info->input_bitmasks[j >> 6] & ((uint64_t)1 << j))
1648
1
          ccv_array_push(symbols, &(((ccv_nnc_autograd_tensor_symbol_t*)ccv_array_get(autograd_tensor_symbols, back_exec->inputs[j]))->symbol));
1649
      // Find whether in the wrt symbols, anything we need to init to zero, if there are, these need to be inputs here too.
1650
1
      _ccv_nnc_symbolic_graph_add_init_zeros(sub_prep, (ccv_nnc_tensor_symbol_t*)ccv_array_get(sub_wrt_symbols, 0), sub_wrt_symbols->rnum, graph, sub_graph, symbols);
1651
1
      _ccv_nnc_symbolic_graph_add_tape_vars(sub_prep, root, graph, sub_graph, symbols);
1652
      // input_size at this point, may be different from the back_exec->input_size, the reason is because we may added zeroing tensors as input tensors.
1653
1
      const int input_size = symbols->rnum;
1654
3
      for (j = 0; j < back_exec->output_size; 
j++2
)
1655
2
        if (back_info->output_bitmasks[j >> 6] & ((uint64_t)1 << j))
1656
1
          ccv_array_push(symbols, &(((ccv_nnc_autograd_tensor_symbol_t*)ccv_array_get(autograd_tensor_symbols, back_exec->outputs[j]))->symbol));
1657
1
      const int output_size = symbols->rnum - input_size;
1658
1
      const int p_idx = sub_prep->graph->p_idx - 1;
1659
1
      assert(back_exec->input_size == forw_exec->output_size);
1660
1
      k = 0;
1661
2
      for (j = 0; j < back_exec->input_size; 
j++1
)
1662
1
        if (back_info->input_bitmasks[j >> 6] & ((uint64_t)1 << j))
1663
1
        {
1664
1
          const ccv_nnc_tensor_symbol_info_t* const info = tensor_symbol_info + forw_exec->outputs[j];
1665
1
          const int s_idx = *(int*)ccv_array_get(info->s_ref, p_idx) - 1;
1666
1
          assert(s_idx >= 0);
1667
1
          const ccv_nnc_autograd_tensor_symbol_t* const autograd_symbol = _ccv_nnc_autograd_tensor_symbol_from_tensor_version(sub_prep->autograd_tensor_symbols, sub_prep->autograd_tensor_versions + s_idx);
1668
1
          ccv_nnc_tensor_symbol_hookup(graph, sub_graph, *(ccv_nnc_tensor_symbol_t*)ccv_array_get(symbols, k), autograd_symbol->symbol);
1669
1
          ++k;
1670
1
        }
1671
1
      k = input_size; // Reset k, the symbol pass already set up by add_init_zeros.
1672
1
      assert(back_exec->output_size == forw_exec->input_size);
1673
3
      
for (j = 0; 1
j < back_exec->output_size;
j++2
)
1674
2
        if (back_info->output_bitmasks[j >> 6] & ((uint64_t)1 << j))
1675
1
        {
1676
1
          const ccv_nnc_tensor_symbol_info_t* const info = tensor_symbol_info + forw_exec->inputs[j];
1677
1
          const int s_idx = *(int*)ccv_array_get(info->s_ref, p_idx) - 1;
1678
1
          assert(s_idx >= 0);
1679
1
          const ccv_nnc_autograd_tensor_symbol_t* const autograd_symbol = _ccv_nnc_autograd_tensor_symbol_from_tensor_version(sub_prep->autograd_tensor_symbols, sub_prep->autograd_tensor_versions + s_idx);
1680
1
          ccv_nnc_tensor_symbol_hookup(graph, sub_graph, *(ccv_nnc_tensor_symbol_t*)ccv_array_get(symbols, k), autograd_symbol->symbol);
1681
1
          ++k;
1682
1
        }
1683
1
      ccv_nnc_graph_exec_symbol_set_io(graph, back_exec->symbol, ccv_array_get(symbols, 0), input_size, ccv_array_get(symbols, input_size), output_size);
1684
18.0k
    } else if (forw_exec->flags & CCV_NNC_GRAPH_EXEC_CASE_OF) {
1685
1
      ccv_array_clear(symbol_map);
1686
2
      for (j = 0; j < back_exec->output_size; 
j++1
)
1687
1
        if (back_info->output_bitmasks[j >> 6] & ((uint64_t)1 << j))
1688
1
        {
1689
1
          ccv_nnc_tensor_symbol_map_t symbol = {
1690
1
            .source = ((ccv_nnc_autograd_tensor_symbol_t*)ccv_array_get(autograd_tensor_symbols, back_exec->inputs[j]))->symbol,
1691
1
            .destination = ((ccv_nnc_autograd_tensor_symbol_t*)ccv_array_get(autograd_tensor_symbols, back_exec->outputs[j]))->symbol,
1692
1
          };
1693
1
          ccv_array_push(symbol_map, &symbol);
1694
1
        }
1695
1
      const int symbol_map_size = symbol_map->rnum;
1696
1
      back_exec->symbol = ccv_nnc_symbolic_graph_case_of_new(graph, back_exec->cmd.cmd, 0, 0, ccv_array_get(symbol_map, 0), symbol_map_size, forw_exec->name);
1697
1
      ccv_nnc_symbolic_graph_set_case_of_expr(graph, back_exec->symbol, NOOP_GRAPH_CASE_OF_EXPR, 0);
1698
4
      for (p = 0; p < forw_exec->graph_ref_size; 
p++3
)
1699
3
      {
1700
3
        const int graph_ref = CCV_NNC_GRAPH_REF(forw_exec)[p] - 1;
1701
3
        ccv_nnc_symbolic_graph_backward_prep_t* sub_prep = backward_prep->sub_preps + graph_ref;
1702
3
        ccv_nnc_symbolic_graph_t* sub_graph = ccv_nnc_symbolic_graph_new();
1703
3
        sub_graph->pair = sub_prep->graph;
1704
3
        if (!sub_wrt_symbols)
1705
1
          sub_wrt_symbols = ccv_array_new(sizeof(ccv_nnc_tensor_symbol_t), 0, 0);
1706
        // I am done, need to redo above for sub_prep, and it has to be successful now.
1707
3
        if (!sub_f_symbols)
1708
1
          sub_f_symbols = ccv_array_new(sizeof(ccv_nnc_tensor_symbol_t), 0, 0);
1709
3
        _ccv_nnc_symbolic_graph_backward_prep_sub_f_wrt_symbols(forw_exec, sub_prep->graph, graph_ref, tensor_symbol_info, back_info->input_bitmasks, back_info->output_bitmasks, sub_f_symbols, sub_wrt_symbols);
1710
3
        _ccv_nnc_symbolic_graph_backward_gen(sub_prep, (ccv_nnc_tensor_symbol_t*)ccv_array_get(sub_f_symbols, 0), sub_f_symbols->rnum, (ccv_nnc_tensor_symbol_t*)ccv_array_get(sub_wrt_symbols, 0), sub_wrt_symbols->rnum, sub_graph, root);
1711
3
        ccv_array_clear(symbol_map);
1712
3
        k = 0;
1713
6
        for (j = 0; j < back_exec->output_size; 
j++3
)
1714
3
          if (back_info->output_bitmasks[j >> 6] & ((uint64_t)1 << j))
1715
3
          {
1716
3
            const int d = ((ccv_nnc_tensor_symbol_t*)ccv_array_get(sub_wrt_symbols, k))->d;
1717
3
            if (d >= 0)
1718
1
            {
1719
1
              const ccv_nnc_autograd_tensor_symbol_t* const autograd_symbol = _ccv_nnc_autograd_tensor_symbol_from_tensor_version(sub_prep->autograd_tensor_symbols, sub_prep->autograd_tensor_versions + d);
1720
1
              ccv_nnc_tensor_symbol_map_t symbol = {
1721
1
                .source = autograd_symbol->symbol,
1722
1
                .destination = ((ccv_nnc_autograd_tensor_symbol_t*)ccv_array_get(autograd_tensor_symbols, back_exec->outputs[j]))->symbol,
1723
1
              };
1724
1
              ccv_array_push(symbol_map, &symbol);
1725
2
            } else {
1726
              // Create a new tensor in sub-graph and set it to be 0.
1727
2
              const ccv_nnc_autograd_tensor_symbol_t* const autograd_symbol = (ccv_nnc_autograd_tensor_symbol_t*)ccv_array_get(autograd_tensor_symbols, back_exec->outputs[j]);
1728
              // autograd_symbol->d points to the corresponding forward tensor.
1729
2
              ccv_nnc_tensor_symbol_t zero_symbol = ccv_nnc_tensor_symbol_new(sub_graph, tensor_symbol_info[autograd_symbol->d].info, 0);
1730
2
              ccv_nnc_graph_exec_symbol_new(sub_graph, CMD_SET_FORWARD(0), 0, 0, &zero_symbol, 1, 0);
1731
2
              ccv_nnc_tensor_symbol_map_t symbol = {
1732
2
                .source = zero_symbol,
1733
2
                .destination = autograd_symbol->symbol,
1734
2
              };
1735
2
              ccv_array_push(symbol_map, &symbol);
1736
2
            }
1737
3
            ++k;
1738
3
          }
1739
3
        ccv_nnc_graph_exec_symbol_autogen(sub_graph, 0, 0, CCV_NNC_AUTOGEN_SOURCES_AND_DESTINATIONS);
1740
3
        const int symbol_map_size = symbol_map->rnum;
1741
3
        ccv_nnc_symbolic_graph_set_case_of(graph, back_exec->symbol, sub_graph, p, ccv_array_get(symbol_map, 0), symbol_map_size);
1742
        // Hookup input only after this becomes a sub graph of the graph.
1743
3
        k = 0;
1744
6
        for (j = 0; j < back_exec->input_size; 
j++3
)
1745
3
          if (back_info->input_bitmasks[j >> 6] & ((uint64_t)1 << j))
1746
3
          {
1747
3
            const int d = ((ccv_nnc_tensor_symbol_t*)ccv_array_get(sub_f_symbols, k))->d;
1748
3
            assert(d >= 0);
1749
            // No corresponding sub tensors allocated. Skip.
1750
3
            if (!sub_prep->autograd_tensor_versions[d].ref_version ||
1751
3
              
!sub_prep->autograd_tensor_versions[d].ref_version->rnum1
)
1752
2
              continue;
1753
1
            const ccv_nnc_autograd_tensor_symbol_t* const autograd_symbol = _ccv_nnc_autograd_tensor_symbol_from_tensor_version(sub_prep->autograd_tensor_symbols, sub_prep->autograd_tensor_versions + d);
1754
1
            ccv_nnc_tensor_symbol_hookup(graph, sub_graph, ((ccv_nnc_autograd_tensor_symbol_t*)ccv_array_get(autograd_tensor_symbols, back_exec->inputs[j]))->symbol, autograd_symbol->symbol);
1755
1
            ++k;
1756
1
          }
1757
        // Need to make sure tape vars are hooked up.
1758
3
        _ccv_nnc_symbolic_graph_add_tape_vars(sub_prep, root, graph, sub_graph, 0);
1759
3
      }
1760
18.0k
    } else {
1761
18.0k
      ccv_array_clear(symbols);
1762
      // Gradient inputs.
1763
36.4k
      for (j = 0; j < back_exec->input_size; 
j++18.4k
)
1764
18.4k
        if (back_info->input_bitmasks[j >> 6] & ((uint64_t)1 << j))
1765
18.0k
          ccv_array_push(symbols, &(((ccv_nnc_autograd_tensor_symbol_t*)ccv_array_get(autograd_tensor_symbols, back_exec->inputs[j]))->symbol));
1766
397
        else
1767
397
          ccv_array_push(symbols, &NO_TENSOR_SYMBOL);
1768
      // Inputs from forward function.
1769
51.6k
      for (j = 0; j < forw_exec->input_size; 
j++33.6k
)
1770
33.6k
        if (!(back_info->input_bitmasks[(j + back_exec->input_size) >> 6] & ((uint64_t)1 << (j + back_exec->input_size))))
1771
11.4k
          ccv_array_push(symbols, &NO_TENSOR_SYMBOL);
1772
22.2k
        else {
1773
22.2k
          const ccv_nnc_tensor_symbol_t symbol = {
1774
22.2k
            .d = forw_exec->inputs[j],
1775
22.2k
            .graph = backward_prep->graph
1776
22.2k
          };
1777
22.2k
          if (graph == backward_prep->graph)
1778
22.2k
            ccv_array_push(symbols, &symbol);
1779
5
          else { // Otherwise, create a new symbol, and set its pair to the old symbol.
1780
5
            const ccv_nnc_tensor_symbol_t new_symbol = ccv_nnc_tensor_symbol_new(graph, tensor_symbol_info[forw_exec->inputs[j]].info, tensor_symbol_info[forw_exec->inputs[j]].name);
1781
5
            ccv_nnc_tensor_symbol_pair_with(graph, new_symbol, symbol);
1782
5
            const int flags = ccv_nnc_tensor_symbol_flags(backward_prep->graph, symbol) | CCV_NNC_TENSOR_SYMBOL_TAPE_VAR;
1783
5
            ccv_nnc_tensor_symbol_set_flags(graph, new_symbol, flags);
1784
5
            ccv_nnc_tensor_symbol_set_flags(backward_prep->graph, symbol, flags);
1785
5
            ccv_array_push(symbols, &new_symbol);
1786
5
          }
1787
22.2k
        }
1788
      // Outputs from forward function.
1789
36.4k
      for (j = 0; j < forw_exec->output_size; 
j++18.4k
)
1790
18.4k
        if (!(back_info->input_bitmasks[(j + back_exec->input_size + forw_exec->input_size) >> 6] & ((uint64_t)1 << (j + back_exec->input_size + forw_exec->input_size))))
1791
13.3k
          ccv_array_push(symbols, &NO_TENSOR_SYMBOL);
1792
5.09k
        else {
1793
5.09k
          const ccv_nnc_tensor_symbol_t symbol = {
1794
5.09k
            .d = forw_exec->outputs[j],
1795
5.09k
            .graph = backward_prep->graph
1796
5.09k
          };
1797
5.09k
          if (graph == backward_prep->graph)
1798
5.09k
            ccv_array_push(symbols, &symbol);
1799
2
          else { // Otherwise, create a new symbol, and set its pair to the old symbol.
1800
2
            const ccv_nnc_tensor_symbol_t new_symbol = ccv_nnc_tensor_symbol_new(graph, tensor_symbol_info[forw_exec->outputs[j]].info, tensor_symbol_info[forw_exec->outputs[j]].name);
1801
2
            ccv_nnc_tensor_symbol_pair_with(graph, new_symbol, symbol);
1802
2
            const int flags = ccv_nnc_tensor_symbol_flags(backward_prep->graph, symbol) | CCV_NNC_TENSOR_SYMBOL_TAPE_VAR;
1803
2
            ccv_nnc_tensor_symbol_set_flags(graph, new_symbol, flags);
1804
2
            ccv_nnc_tensor_symbol_set_flags(backward_prep->graph, symbol, flags);
1805
2
            ccv_array_push(symbols, &new_symbol);
1806
2
          }
1807
5.09k
        }
1808
51.6k
      for (j = 0; j < back_exec->output_size; 
j++33.6k
)
1809
33.6k
        if (back_info->output_bitmasks[j >> 6] & ((uint64_t)1 << j))
1810
27.1k
          ccv_array_push(symbols, &(((ccv_nnc_autograd_tensor_symbol_t*)ccv_array_get(autograd_tensor_symbols, back_exec->outputs[j]))->symbol));
1811
6.56k
        else
1812
6.56k
          ccv_array_push(symbols, &NO_TENSOR_SYMBOL);
1813
18.0k
      back_exec->symbol = ccv_nnc_graph_exec_symbol_new(graph, back_exec->cmd, ccv_array_get(symbols, 0), back_exec->input_size + forw_exec->input_size + forw_exec->output_size, ccv_array_get(symbols, back_exec->input_size + forw_exec->input_size + forw_exec->output_size), back_exec->output_size, 0);
1814
18.0k
      ccv_nnc_graph_exec_symbol_set_hint(graph, back_exec->symbol, exec_symbol_info[i].hint);
1815
18.0k
      ccv_nnc_graph_exec_symbol_pair_with(graph, back_exec->symbol, (ccv_nnc_graph_exec_symbol_t){
1816
18.0k
        .d = i,
1817
18.0k
        .graph = backward_prep->graph,
1818
18.0k
      });
1819
18.0k
    }
1820
18.0k
  }
1821
6.72k
  if (sub_f_symbols)
1822
2
    ccv_array_free(sub_f_symbols);
1823
6.72k
  if (sub_wrt_symbols)
1824
2
    ccv_array_free(sub_wrt_symbols);
1825
6.72k
  if (sub_execs)
1826
1
    ccv_array_free(sub_execs);
1827
6.72k
  ccv_array_t* const sum_or_set_execs = backward_prep->sum_or_set_execs;
1828
11.0k
  for (i = 0; i < sum_or_set_execs->rnum; 
i++4.27k
)
1829
4.27k
  {
1830
4.27k
    ccv_nnc_sum_or_set_graph_exec_symbol_t* sum_or_set_exec = (ccv_nnc_sum_or_set_graph_exec_symbol_t*)ccv_array_get(sum_or_set_execs, i);
1831
    // It is sum, set don't have inputs.
1832
4.27k
    if (sum_or_set_exec->input_size)
1833
4.27k
    {
1834
4.27k
      ccv_array_clear(symbols);
1835
      // This is to sum.
1836
12.8k
      for (j = 0; j < sum_or_set_exec->input_size; 
j++8.57k
)
1837
8.57k
        ccv_array_push(symbols, &(((ccv_nnc_autograd_tensor_symbol_t*)ccv_array_get(autograd_tensor_symbols, sum_or_set_exec->inputs[j]))->symbol));
1838
4.27k
      ccv_nnc_cmd_t cmd = ccv_nnc_cmd(CCV_NNC_EWSUM_FORWARD, 0, CMD_GENERIC(), 0);
1839
4.27k
      sum_or_set_exec->symbol = ccv_nnc_graph_exec_symbol_new(graph, cmd, ccv_array_get(symbols, 0), sum_or_set_exec->input_size, &(((ccv_nnc_autograd_tensor_symbol_t*)ccv_array_get(autograd_tensor_symbols, sum_or_set_exec->output))->symbol), 1, 0);
1840
4.27k
    } else
1841
1
      sum_or_set_exec->symbol = ccv_nnc_graph_exec_symbol_new(graph, CMD_SET_FORWARD(sum_or_set_exec->value), 0, 0, &(((ccv_nnc_autograd_tensor_symbol_t*)ccv_array_get(autograd_tensor_symbols, sum_or_set_exec->output))->symbol), 1, 0);
1842
4.27k
  }
1843
6.72k
  ccv_array_free(symbol_map);
1844
6.72k
  ccv_array_free(symbols);
1845
24.8k
  for (i = 0; i < exec_symbol_info_size; 
i++18.1k
)
1846
18.1k
  {
1847
    // This is not going to be an interesting node. Skip.
1848
18.1k
    if ((backward_info[i].f_wrt & 0x3) != 0x3)
1849
76
      continue;
1850
18.0k
    ccv_nnc_autograd_graph_exec_symbol_t* const back_exec = autograd_execs + i;
1851
    // If on the same graph, we cannot decide whether it is before or after the forw_exec, enforcing it is after forw_exec.
1852
18.0k
    if (graph == backward_prep->graph)
1853
18.0k
      ccv_nnc_graph_exec_symbol_concat(graph, (ccv_nnc_graph_exec_symbol_t){
1854
18.0k
        .d = i,
1855
18.0k
        .graph = graph
1856
18.0k
      }, back_exec->symbol);
1857
18.0k
    if (back_exec->outgoings)
1858
22.7k
      
for (j = 0; 11.3k
j < back_exec->outgoings->rnum;
j++11.4k
)
1859
11.4k
      {
1860
11.4k
        int d = *(int*)ccv_array_get(back_exec->outgoings, j);
1861
11.4k
        if (d < exec_symbol_info_size)
1862
7.05k
          ccv_nnc_graph_exec_symbol_concat(graph, back_exec->symbol, autograd_execs[d].symbol);
1863
4.34k
        else
1864
4.34k
          ccv_nnc_graph_exec_symbol_concat(graph, back_exec->symbol, ((ccv_nnc_sum_or_set_graph_exec_symbol_t*)ccv_array_get(sum_or_set_execs, d - exec_symbol_info_size))->symbol);
1865
11.4k
      }
1866
18.0k
  }
1867
11.0k
  for (i = 0; i < sum_or_set_execs->rnum; 
i++4.27k
)
1868
4.27k
  {
1869
4.27k
    ccv_nnc_sum_or_set_graph_exec_symbol_t* exec = (ccv_nnc_sum_or_set_graph_exec_symbol_t*)ccv_array_get(sum_or_set_execs, i);
1870
4.27k
    if (exec->outgoings)
1871
8.49k
      
for (j = 0; 4.24k
j < exec->outgoings->rnum;
j++4.24k
)
1872
4.24k
      {
1873
4.24k
        int d = *(int*)ccv_array_get(exec->outgoings, j);
1874
4.24k
        if (d < exec_symbol_info_size)
1875
4.24k
          ccv_nnc_graph_exec_symbol_concat(graph, exec->symbol, autograd_execs[d].symbol);
1876
0
        else
1877
0
          ccv_nnc_graph_exec_symbol_concat(graph, exec->symbol, ((ccv_nnc_sum_or_set_graph_exec_symbol_t*)ccv_array_get(sum_or_set_execs, d - exec_symbol_info_size))->symbol);
1878
4.24k
      }
1879
4.27k
  }
1880
  // Now, everything is done, set the metadata on graph so that we can lookup later for backward symbols
1881
6.72k
  if (graph->backward.tensor_symbol_idx)
1882
4.40k
    graph->backward.tensor_symbol_idx = (int*)ccrealloc(graph->backward.tensor_symbol_idx, sizeof(int) * (graph->tensor_symbol_info->rnum + tensor_symbol_info_size));
1883
2.32k
  else
1884
2.32k
    graph->backward.tensor_symbol_idx = (int*)ccmalloc(sizeof(int) * (graph->tensor_symbol_info->rnum + tensor_symbol_info_size));
1885
6.72k
  graph->backward.tensor_symbol_size = tensor_symbol_info_size;
1886
6.72k
  graph->backward.exec_symbol_idx = graph->backward.tensor_symbol_idx + tensor_symbol_info_size;
1887
6.72k
  graph->backward.exec_symbol_size = graph->tensor_symbol_info->rnum;
1888
45.2k
  for (i = 0; i < tensor_symbol_info_size; 
i++38.5k
)
1889
38.5k
    graph->backward.tensor_symbol_idx[i] = -1;
1890
84.5k
  for (i = 0; i < graph->backward.exec_symbol_size; 
i++77.8k
)
1891
77.8k
    graph->backward.exec_symbol_idx[i] = -1;
1892
6.72k
  ccv_nnc_autograd_tensor_version_t* const autograd_tensor_versions = backward_prep->autograd_tensor_versions;
1893
  // Assigning for wrt symbols.
1894
16.2k
  for (i = 0; i < wrt_symbol_size; 
i++9.48k
)
1895
9.48k
  {
1896
9.48k
    const int d = wrt_symbols[i].d;
1897
9.48k
    if (d < 0)
1898
2
      continue;
1899
9.47k
    assert(d < tensor_symbol_info_size);
1900
9.47k
    const ccv_nnc_tensor_symbol_info_t* const forw_symbol = tensor_symbol_info + d;
1901
9.47k
    ccv_nnc_autograd_tensor_version_t* const tensor_ver = autograd_tensor_versions + ((!forw_symbol->alias_ref) ? 
d9.47k
:
forw_symbol->alias_ref - 11
);
1902
9.47k
    assert(tensor_ver->ref_version);
1903
9.47k
    ccv_nnc_tensor_ref_t* const tensor_ref = (ccv_nnc_tensor_ref_t*)ccv_array_get(tensor_ver->ref_version, tensor_ver->c);
1904
9.47k
    ccv_nnc_autograd_tensor_symbol_t* const autograd_symbol = (ccv_nnc_autograd_tensor_symbol_t*)ccv_array_get(autograd_tensor_symbols, tensor_ref->d);
1905
    // If this wrt symbol is an alias, create extra alias for this.
1906
9.47k
    if (!forw_symbol->alias_ref)
1907
9.47k
      graph->backward.tensor_symbol_idx[d] = autograd_symbol->symbol.d;
1908
1
    else // We create new alias, and this cannot be referenced from exec_symbol_idx because its size limited to previous tensor symbol size.
1909
1
      graph->backward.tensor_symbol_idx[d] = ccv_nnc_tensor_symbol_alias_new(graph, autograd_symbol->symbol, forw_symbol->ofs, forw_symbol->inc, forw_symbol->info, 0).d;
1910
9.47k
    const int dd = autograd_symbol->symbol.d;
1911
9.47k
    const int x = tensor_ref->x;
1912
9.47k
    if (tensor_ref->exec_registry && 
tensor_ref->exec_registry->rnum2
) // Create no-op node.
1913
2
    {
1914
2
      ccv_nnc_graph_exec_symbol_t noop = ccv_nnc_graph_exec_symbol_new(graph, ccv_nnc_cmd(CCV_NNC_NOOP, 0, CMD_GENERIC(), 0), 0, 0, 0, 0, 0);
1915
2
      if (x < exec_symbol_info_size)
1916
2
        ccv_nnc_graph_exec_symbol_concat(graph, autograd_execs[x].symbol, noop);
1917
0
      else
1918
0
        ccv_nnc_graph_exec_symbol_concat(graph, ((ccv_nnc_sum_or_set_graph_exec_symbol_t*)ccv_array_get(sum_or_set_execs, x - exec_symbol_info_size))->symbol, noop);
1919
6
      for (j = 0; j < tensor_ref->exec_registry->rnum; 
j++4
)
1920
4
      {
1921
4
        const int x = *(int*)ccv_array_get(tensor_ref->exec_registry, j);
1922
4
        assert(x >= 0); /* Otherwise, this is initialization tensor, which is impossible to be summed up by. */
1923
4
        assert(x < exec_symbol_info_size); // exec_registry is only used by alias_registry, it simply cannot reference to a sum operation.
1924
4
        ccv_nnc_graph_exec_symbol_concat(graph, autograd_execs[x].symbol, noop);
1925
4
      }
1926
2
      graph->backward.exec_symbol_idx[dd] = noop.d;
1927
9.47k
    } else {
1928
9.47k
      if (x < exec_symbol_info_size)
1929
9.45k
        graph->backward.exec_symbol_idx[dd] = autograd_execs[x].symbol.d;
1930
26
      else
1931
26
        graph->backward.exec_symbol_idx[dd] = ((ccv_nnc_sum_or_set_graph_exec_symbol_t*)ccv_array_get(sum_or_set_execs, x - exec_symbol_info_size))->symbol.d;
1932
9.47k
    }
1933
9.47k
  }
1934
  // Assigning for f symbols.
1935
13.4k
  
for (i = 0; 6.72k
i < f_symbol_size;
i++6.74k
)
1936
6.74k
  {
1937
6.74k
    const int d = f_symbols[i].d;
1938
6.74k
    assert(d >= 0);
1939
6.74k
    assert(d < tensor_symbol_info_size);
1940
6.74k
    const ccv_nnc_autograd_tensor_version_t* const tensor_ver = autograd_tensor_versions + d;
1941
6.74k
    if (tensor_ver->ref_version)
1942
6.73k
    {
1943
      // We don't use _ccv_nnc_autograd_tensor_symbol_from_tensor_version because that select the last version, but for us, we need the first version.
1944
6.73k
      const ccv_nnc_tensor_ref_t* const tensor_ref = (ccv_nnc_tensor_ref_t*)ccv_array_get(tensor_ver->ref_version, 0);
1945
6.73k
      const ccv_nnc_autograd_tensor_symbol_t* const autograd_symbol = (ccv_nnc_autograd_tensor_symbol_t*)ccv_array_get(autograd_tensor_symbols, tensor_ref->d);
1946
6.73k
      graph->backward.tensor_symbol_idx[d] = autograd_symbol->symbol.d;
1947
      // Cannot find relevant backward exec symbols for f, it could be many.
1948
6.73k
    }
1949
6.74k
  }
1950
6.72k
}
1951
1952
void ccv_nnc_symbolic_graph_backward(ccv_nnc_symbolic_graph_t* const graph, const ccv_nnc_tensor_symbol_t* const f_symbols, const int f_symbol_size, const ccv_nnc_tensor_symbol_t* const wrt_symbols, const int wrt_symbol_size, const ccv_nnc_graph_exec_symbol_t* const sources, const int source_size, const ccv_nnc_graph_exec_symbol_t* const destinations, const int destination_size)
1953
6.72k
{
1954
6.72k
  int i;
1955
  // f symbols cannot be alias.
1956
13.4k
  for (i = 0; i < f_symbol_size; 
i++6.73k
)
1957
6.73k
  {
1958
6.73k
    assert(f_symbols[i].graph == graph); // f symbol has to be in the current graph.
1959
6.73k
    assert(!((ccv_nnc_tensor_symbol_info_t*)ccv_array_get(graph->tensor_symbol_info, f_symbols[i].d))->alias_ref);
1960
6.73k
  }
1961
16.2k
  
for (i = 0; 6.72k
i < wrt_symbol_size;
i++9.47k
)
1962
9.47k
  {
1963
9.47k
    assert(wrt_symbols[i].graph == graph);
1964
    // This is not an alias, or what it refers to is not an alias.
1965
9.47k
    assert(!((ccv_nnc_tensor_symbol_info_t*)ccv_array_get(graph->tensor_symbol_info, wrt_symbols[i].d))->alias_ref || !((ccv_nnc_tensor_symbol_info_t*)ccv_array_get(graph->tensor_symbol_info, ((ccv_nnc_tensor_symbol_info_t*)ccv_array_get(graph->tensor_symbol_info, wrt_symbols[i].d))->alias_ref - 1))->alias_ref);
1966
9.47k
  }
1967
6.72k
  const int exec_symbol_info_size = graph->exec_symbol_info->rnum;
1968
6.72k
  const int tensor_symbol_info_size = graph->tensor_symbol_info->rnum;
1969
6.72k
  assert(exec_symbol_info_size > 0);
1970
6.72k
  assert(tensor_symbol_info_size > 0);
1971
6.72k
  ccv_nnc_symbolic_graph_backward_prep_t backward_prep = _ccv_nnc_symbolic_graph_backward_prep(graph, sources, source_size, destinations, destination_size);
1972
6.72k
  _ccv_nnc_symbolic_graph_backward_prep_prune_ops(&backward_prep, f_symbols, f_symbol_size, wrt_symbols, wrt_symbol_size, sources, source_size, destinations, destination_size);
1973
6.72k
  _ccv_nnc_symbolic_graph_backward_prep_gen(&backward_prep, f_symbols, f_symbol_size, wrt_symbols, wrt_symbol_size, 0, sources, source_size, destinations, destination_size);
1974
6.72k
  _ccv_nnc_symbolic_graph_backward_gen(&backward_prep, f_symbols, f_symbol_size, wrt_symbols, wrt_symbol_size, graph, graph);
1975
6.72k
  _ccv_nnc_symbolic_graph_backward_prep_free(backward_prep);
1976
6.72k
}
1977
1978
ccv_nnc_tensor_symbol_t ccv_nnc_tensor_symbol_for_backward(const ccv_nnc_symbolic_graph_t* const graph, const ccv_nnc_tensor_symbol_t symbol)
1979
27.5k
{
1980
27.5k
  assert(symbol.d >= 0);
1981
27.5k
  assert(symbol.d < graph->backward.tensor_symbol_size);
1982
27.5k
  if (graph->backward.tensor_symbol_idx[symbol.d] < 0)
1983
1
    return NO_TENSOR_SYMBOL;
1984
27.5k
  ccv_nnc_tensor_symbol_t tensor = {
1985
27.5k
    .d = graph->backward.tensor_symbol_idx[symbol.d],
1986
27.5k
    .graph = graph,
1987
27.5k
  };
1988
27.5k
  return tensor;
1989
27.5k
}
1990
1991
ccv_nnc_graph_exec_symbol_t ccv_nnc_graph_exec_symbol_for_backward(const ccv_nnc_symbolic_graph_t* const graph, const ccv_nnc_tensor_symbol_t symbol)
1992
17.7k
{
1993
17.7k
  assert(symbol.d >= 0);
1994
17.7k
  assert(symbol.d < graph->tensor_symbol_info->rnum);
1995
17.7k
  int dd = symbol.d;
1996
  // Check if this is an alias. Use the original if it is.
1997
17.7k
  ccv_nnc_tensor_symbol_info_t* const symbol_info = (ccv_nnc_tensor_symbol_info_t*)ccv_array_get(graph->tensor_symbol_info, dd);
1998
17.7k
  if (symbol_info->alias_ref)
1999
2
    dd = symbol_info->alias_ref - 1;
2000
17.7k
  assert(dd >= 0);
2001
17.7k
  assert(dd < graph->backward.exec_symbol_size);
2002
17.7k
  assert(graph->backward.exec_symbol_idx[dd] >= 0);
2003
17.7k
  ccv_nnc_graph_exec_symbol_t exec = {
2004
17.7k
    .d = graph->backward.exec_symbol_idx[dd],
2005
17.7k
    .graph = graph
2006
17.7k
  };
2007
17.7k
  return exec;
2008
17.7k
}