Coverage Report

Created: 2025-02-24 17:43

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/home/liu/actions-runner/_work/ccv/ccv/lib/nnc/ccv_cnnp_model_gradient_checkpointing.c
Line
Count
Source
1
#include "ccv_nnc.h"
2
#include "ccv_nnc_easy.h"
3
#include "ccv_nnc_internal.h"
4
#include "ccv_internal.h"
5
#include "_ccv_cnnp_model.h"
6
// This can be removed once we organized ccv_cnnp_apply_gradient_checkpoints better.
7
#include "_ccv_nnc_symbolic_graph.h"
8
9
void ccv_cnnp_model_gradient_checkpoints_cleanup_after_build(ccv_cnnp_compiled_data_t* const compiled_data, ccv_nnc_symbolic_graph_t* const graph)
10
2.29k
{
11
2.29k
  ccv_array_t* const gradient_checkpoints = compiled_data->gradient_checkpoints;
12
2.29k
  if (!gradient_checkpoints || 
gradient_checkpoints->rnum == 02
) // No saved gradient checkpoints, this is an easy way out.
13
2.29k
    return;
14
2
  int i, j;
15
2
  const ccv_nnc_tensor_symbol_info_t* const tensor_symbol_info = (const ccv_nnc_tensor_symbol_info_t*)ccv_array_get(graph->tensor_symbol_info, 0);
16
  // Go through to check if any tensors that supposes in this map is removed.
17
4
  for (i = 0; i < gradient_checkpoints->rnum; 
i++2
)
18
2
  {
19
2
    ccv_cnnp_model_gradient_checkpoint_t* const checkpoint = (ccv_cnnp_model_gradient_checkpoint_t*)ccv_array_get(gradient_checkpoints, i);
20
32
    for (j = 0; j < checkpoint->tensor_symbols->rnum; 
j++30
)
21
30
    {
22
30
      ccv_nnc_tensor_symbol_t* const symbol = ((ccv_nnc_tensor_symbol_t*)ccv_array_get(checkpoint->tensor_symbols, j));
23
30
      if (symbol->d >= 0 && symbol->d < graph->tensor_symbol_info->rnum)
24
        // If it is dead, we need to remove this symbol.
25
30
        if (CCV_NNC_TENSOR_SYMBOL_IS_DEAD(tensor_symbol_info[symbol->d].flags))
26
0
        {
27
0
          symbol->d = -1;
28
0
          symbol->graph = 0;
29
0
        }
30
30
    }
31
2
  }
32
2
}
33
34
typedef struct {
35
  ccv_array_t* outgoings;
36
} ccv_nnc_graph_exec_symbol_reverse_t;
37
38
typedef struct {
39
  ccv_cnnp_model_gradient_checkpoint_build_context_t tensor_context;
40
  ccv_array_t* graph_exec_symbols;
41
  ccv_nnc_graph_exec_symbol_new_hook_f old_graph_exec_symbol_new_hook;
42
  void* old_graph_exec_symbol_new_hook_context;
43
  ccv_array_t* all_tensor_symbols;
44
} ccv_cnnp_gradient_checkpoint_build_t;
45
46
static void _ccv_cnnp_gradient_checkpoint_tensor_symbol_new_hook(void* context, const ccv_nnc_tensor_symbol_t symbol, const ccv_nnc_tensor_param_t info, const char* const name)
47
18
{
48
18
  ccv_cnnp_gradient_checkpoint_build_t* const build_context = (ccv_cnnp_gradient_checkpoint_build_t*)context;
49
18
  if (build_context->tensor_context.record)
50
18
    ccv_array_push(build_context->tensor_context.tensor_symbols, &symbol);
51
18
  ccv_array_push(build_context->all_tensor_symbols, &symbol);
52
18
  if (build_context->tensor_context.old_tensor_symbol_new_hook)
53
18
    build_context->tensor_context.old_tensor_symbol_new_hook(build_context->tensor_context.old_tensor_symbol_new_hook_context, symbol, info, name);
54
18
}
55
56
static void _ccv_cnnp_gradient_checkpoint_tensor_symbol_alias_new_hook(void* context, const ccv_nnc_tensor_symbol_t symbol, const ccv_nnc_tensor_symbol_t from_symbol, const int ofs[CCV_NNC_MAX_DIM_ALLOC], const int inc[CCV_NNC_MAX_DIM_ALLOC], const ccv_nnc_tensor_param_t info, const char* const name)
57
0
{
58
0
  ccv_cnnp_gradient_checkpoint_build_t* const build_context = (ccv_cnnp_gradient_checkpoint_build_t*)context;
59
0
  if (build_context->tensor_context.record)
60
0
    ccv_array_push(build_context->tensor_context.tensor_symbols, &symbol);
61
0
  ccv_array_push(build_context->all_tensor_symbols, &symbol);
62
0
  if (build_context->tensor_context.old_tensor_symbol_alias_new_hook)
63
0
    build_context->tensor_context.old_tensor_symbol_alias_new_hook(build_context->tensor_context.old_tensor_symbol_alias_new_hook_context, symbol, from_symbol, ofs, inc, info, name);
64
0
}
65
66
static void _ccv_cnnp_model_gradient_checkpoint_graph_exec_symbol_new_hook(void* context, const ccv_nnc_graph_exec_symbol_t symbol, const ccv_nnc_cmd_t cmd, const ccv_nnc_tensor_symbol_t* const inputs, const int input_size, const ccv_nnc_tensor_symbol_t* const outputs, const int output_size, const char* const name)
67
18
{
68
18
  ccv_cnnp_gradient_checkpoint_build_t* const build = (ccv_cnnp_gradient_checkpoint_build_t*)context;
69
18
  ccv_array_push(build->graph_exec_symbols, &symbol);
70
18
  if (build->old_graph_exec_symbol_new_hook)
71
18
    build->old_graph_exec_symbol_new_hook(build->old_graph_exec_symbol_new_hook_context, symbol, cmd, inputs, input_size, outputs, output_size, name);
72
18
}
73
74
KHASH_MAP_INIT_INT(ccv_cnnp_tensor_symbol_map, int)
75
KHASH_SET_INIT_INT(ccv_cnnp_tensor_symbol_set)
76
77
ccv_nnc_exec_dep_t _ccv_nnc_exec_dep_new(const ccv_nnc_symbolic_graph_t* const graph, const ccv_nnc_graph_visit_t* const visit, const int exec_rnum, uint32_t* const maskbit)
78
2
{
79
2
  int* chain_ids = ccmalloc(sizeof(int) * exec_rnum * 2);
80
2
  int* chain_pos = chain_ids + exec_rnum;
81
2
  int* buf = (int*)ccmalloc(sizeof(int) * exec_rnum);
82
2
  int* reversed_depth = buf;
83
2
  const ccv_nnc_graph_exec_symbol_info_t* const exec_symbol_info = (ccv_nnc_graph_exec_symbol_info_t*)ccv_array_get(graph->exec_symbol_info, 0);
84
2
  int i, j;
85
  // Go reverse order to generate the distance from sink.
86
16
  ccv_nnc_graph_visit_for_reversed(visit, exec_symbol_info, node, idx, term) {
87
16
    if (idx >= exec_rnum || CCV_NNC_GRAPH_EXEC_IS_DEAD(node->flags) || !(maskbit[idx >> 5] & (1u << (idx & 0x1f))))
88
0
      continue;
89
16
    chain_ids[idx] = -1;
90
16
    if (!node->outgoings || node->outgoings->rnum == 0)
91
0
    {
92
0
      reversed_depth[idx] = 0;
93
0
      continue;
94
0
    }
95
16
    int depth = -1;
96
39
    for (i = 0; i < node->outgoings->rnum; 
i++23
)
97
23
    {
98
23
      const int outgoing = *(int*)ccv_array_get(node->outgoings, i);
99
23
      if (outgoing >= exec_rnum)
100
0
        continue;
101
23
      depth = ccv_max(depth, reversed_depth[outgoing]);
102
23
    }
103
16
    reversed_depth[idx] = depth + 1;
104
16
  } ccv_nnc_graph_visit_endfor
105
  // Go in order to generate chain ids (if there are multiple exits, we use the reverse depth to break the tie).
106
  // Note that we cannot use depth so-far because then multiple exit nodes are equally good to "inherit" the chain selection.
107
2
  int chain_count = 0;
108
16
  ccv_nnc_graph_visit_for(visit, exec_symbol_info, node, idx, term) {
109
16
    if (idx >= exec_rnum || CCV_NNC_GRAPH_EXEC_IS_DEAD(node->flags) || !(maskbit[idx >> 5] & (1u << (idx & 0x1f))))
110
0
      continue;
111
16
    int chain_id = chain_ids[idx];
112
16
    if (chain_ids[idx] < 0)
113
4
    {
114
4
      chain_id = chain_count;
115
4
      chain_ids[idx] = chain_id;
116
4
      chain_pos[idx] = 1; // The first one in this chain. 1-based index because in sparse matrix, 0 is the default value.
117
4
      chain_count += 1;
118
4
    }
119
16
    if (!node->outgoings || node->outgoings->rnum == 0)
120
0
      continue;
121
16
    int depth = -1;
122
16
    int next_idx = -1;
123
39
    for (i = 0; i < node->outgoings->rnum; 
i++23
)
124
23
    {
125
23
      const int outgoing = *(int*)ccv_array_get(node->outgoings, i);
126
23
      if (outgoing >= exec_rnum)
127
0
        continue;
128
23
      if (chain_ids[outgoing] < 0 && 
reversed_depth[outgoing] > depth14
)
129
12
        depth = reversed_depth[outgoing], next_idx = outgoing;
130
23
    }
131
16
    if (next_idx >= 0)
132
12
    {
133
12
      chain_ids[next_idx] = chain_id;
134
12
      assert(reversed_depth[idx] - depth >= 1);
135
12
      chain_pos[next_idx] = chain_pos[idx] + (reversed_depth[idx] - depth);
136
12
    }
137
16
  } ccv_nnc_graph_visit_endfor
138
2
  if (exec_rnum < chain_count * 2) // Be more conservative on RAM usage.
139
0
    buf = ccrealloc(buf, sizeof(int) * chain_count * 2);
140
2
  ccv_sparse_matrix_t* deps = ccv_sparse_matrix_new(graph->exec_symbol_info->rnum, chain_count, CCV_32S | CCV_C1, CCV_SPARSE_ROW_MAJOR, 0);
141
  // It logs which pos on that chain we depend on. We can simply compare that with the chain_pos for a node to know if they are ancestors.
142
2
#define for_block(x, val) \
143
13
  do { \
144
13
    if (((int32_t*)val)[0] > 0) \
145
13
    { \
146
13
      buf[buf_size * 2] = x; \
147
13
      buf[buf_size * 2 + 1] = ((int32_t*)val)[0]; \
148
13
      ++buf_size; \
149
13
    } \
150
13
  } while (0)
151
2
  int buf_size;
152
16
  ccv_nnc_graph_visit_for(visit, exec_symbol_info, node, idx, term) {
153
16
    if (idx >= exec_rnum || CCV_NNC_GRAPH_EXEC_IS_DEAD(node->flags) || !(maskbit[idx >> 5] & (1u << (idx & 0x1f))))
154
0
      continue;
155
16
    buf_size = 0; /* save all its parent deps to this buffer */
156
16
    ccv_sparse_matrix_vector_t* vector = ccv_get_sparse_matrix_vector(deps, idx);
157
16
    if (vector)
158
13
      
CCV_SPARSE_VECTOR_FOREACH10
(deps, vector, for_block);
159
16
    if (!node->outgoings)
160
0
      continue;
161
16
    const int chain_id = chain_ids[idx];
162
16
    const int pos = chain_pos[idx];
163
39
    for (i = 0; i < node->outgoings->rnum; 
i++23
)
164
23
    {
165
23
      const int outgoing = *(int*)ccv_array_get(node->outgoings, i);
166
23
      if (outgoing >= exec_rnum)
167
0
        continue;
168
23
      const int outgoing_chain_id = chain_ids[outgoing];
169
23
      if (outgoing_chain_id != chain_id)
170
10
      {
171
10
        ccv_numeric_data_t cell = ccv_get_sparse_matrix_cell(deps, outgoing, chain_id);
172
        /* If not found, set, if the current node is the destination node, no need 
173
         * set itself as parent of subsequent nodes because its terminal nature. */
174
10
        if (!cell.i32 || 
cell.i32[0] == 01
||
cell.i32[0] < pos1
)
175
10
          ccv_set_sparse_matrix_cell(deps, outgoing, chain_id, &pos);
176
10
      }
177
23
      if (buf_size > 0)
178
13
      {
179
13
        ccv_sparse_matrix_vector_t* vector = ccv_get_sparse_matrix_vector(deps, outgoing);
180
30
        for (j = 0; j < buf_size; 
j++17
) /* set with all idx's dependencies as well */
181
17
        {
182
17
          if (outgoing_chain_id == buf[j * 2]) // We don't need to add as dependency for the same chain.
183
1
            continue;
184
16
          if (!vector)
185
8
          {
186
8
            ccv_set_sparse_matrix_cell(deps, outgoing, buf[j * 2], &buf[j * 2 + 1]);
187
8
            vector = ccv_get_sparse_matrix_vector(deps, outgoing);
188
8
            continue;
189
8
          }
190
8
          ccv_numeric_data_t cell = ccv_get_sparse_matrix_cell_from_vector(deps, vector, buf[j * 2]);
191
          /* If not found, set. Otherwise, set to the latest one only if it is later. */
192
8
          if (!cell.i32 || 
cell.i32[0] == 00
||
cell.i32[0] <= buf[j * 2 + 1]0
)
193
8
            ccv_set_sparse_matrix_cell_from_vector(deps, vector, buf[j * 2], &buf[j * 2 + 1]);
194
8
        }
195
13
      }
196
23
    }
197
16
  } ccv_nnc_graph_visit_endfor
198
2
#undef for_block
199
2
  ccfree(buf);
200
2
  ccv_nnc_exec_dep_t exec_dep = {
201
2
    .chain_ids = chain_ids,
202
2
    .chain_pos = chain_pos,
203
2
    .deps = deps
204
2
  };
205
2
  return exec_dep;
206
2
}
207
208
void ccv_cnnp_model_apply_gradient_checkpoints(ccv_cnnp_compiled_data_t* const compiled_data, ccv_nnc_symbolic_graph_t* const graph)
209
2.24k
{
210
2.24k
  ccv_array_t* const gradient_checkpoints = compiled_data->gradient_checkpoints;
211
2.24k
  if (!gradient_checkpoints || 
gradient_checkpoints->rnum == 02
) // No saved gradient checkpoints, this is an easy way out.
212
2.23k
    return;
213
  // Otherwise, for each gradient checkpoint, there are 3 steps:
214
  // 1. Find currently, what execs exists from inputs to outputs.
215
  // 2. Find execs that generates the outputs, and their corresponding backward execs.
216
  // 3. Find all backward execs flow from outputs back to inputs.
217
  // 4. Generate new ops by calling build again with old inputs, record all new tensors / execs.
218
  // 5. Replace inputs in backward execs with the new tensors.
219
  // 6. Hook the execs takes inputs with edge from parents of backward execs in step 2.
220
  // 7. Delete newly generated execs that has no use (i.e. its outputs are not used by backward pass).
221
  // 8. Mark all new execs with DISABLE_OPT to avoid common sub-expression elimination pass.
222
2
  int i, j, k, l;
223
2
  ccv_array_t* input_execs = ccv_array_new(sizeof(ccv_nnc_graph_exec_symbol_t), 0, 0);
224
2
  ccv_array_t* output_execs = ccv_array_new(sizeof(ccv_nnc_graph_exec_symbol_t), 0, 0);
225
2
  ccv_array_t* input_gradient_execs = ccv_array_new(sizeof(ccv_nnc_graph_exec_symbol_t), 0, 0);
226
2
  ccv_array_t* output_gradient_execs = ccv_array_new(sizeof(ccv_nnc_graph_exec_symbol_t), 0, 0);
227
2
  ccv_array_t* visited_backward_execs = ccv_array_new(sizeof(int), 0, 0);
228
2
  ccv_array_t* replaced_backward_execs = ccv_array_new(sizeof(int), 0, 0);
229
2
  const int exec_rnum = graph->exec_symbol_info->rnum;
230
2
  ccv_nnc_graph_exec_symbol_reverse_t* const reversed_nodes = cccalloc(exec_rnum, sizeof(ccv_nnc_graph_exec_symbol_reverse_t));
231
48
  for (i = 0; i < exec_rnum; 
i++46
)
232
46
  {
233
46
    const int* tos = 0;
234
46
    int to_size = 0;
235
46
    ccv_nnc_graph_exec_symbol_to(graph, (ccv_nnc_graph_exec_symbol_t){
236
46
      .graph = graph,
237
46
      .d = i
238
46
    }, &tos, &to_size);
239
46
    if (tos)
240
99
      
for (j = 0; 38
j < to_size;
j++61
)
241
61
      {
242
61
        if (!reversed_nodes[tos[j]].outgoings)
243
40
          reversed_nodes[tos[j]].outgoings = ccv_array_new(sizeof(int), 1, 0);
244
61
        ccv_array_add_unique_int(reversed_nodes[tos[j]].outgoings, i);
245
61
      }
246
46
  }
247
2
  uint32_t* const maskbit = cccalloc((exec_rnum + 31) >> 5, sizeof(uint32_t));
248
  // Temporary for build_data.
249
2
  ccv_array_t* const parameters = ccv_array_new(sizeof(ccv_nnc_tensor_symbol_t), 0, 0);
250
2
  ccv_array_t* const parameter_ids = ccv_array_new(sizeof(char*), 0, 0);
251
2
  ccv_array_t* const parameter_trainables = ccv_array_new(sizeof(int), 0, 0);
252
2
  ccv_array_t* const internals = ccv_array_new(sizeof(ccv_nnc_tensor_symbol_t), 0, 0);
253
2
  ccv_array_t* const internal_ids = ccv_array_new(sizeof(char*), 0, 0);
254
2
  int max_output_size = 0;
255
4
  for (i = 0; i < gradient_checkpoints->rnum; 
i++2
)
256
2
  {
257
2
    ccv_cnnp_model_gradient_checkpoint_t* const checkpoint = (ccv_cnnp_model_gradient_checkpoint_t*)ccv_array_get(gradient_checkpoints, i);
258
2
    max_output_size = ccv_max(checkpoint->output_size, max_output_size);
259
2
  }
260
2
  ccv_nnc_tensor_symbol_t* max_outputs = ccmalloc(sizeof(ccv_nnc_tensor_symbol_t) * max_output_size);
261
2
  ccv_array_t* newly_used_outputs = ccv_array_new(sizeof(int), 0, 0);
262
2
  khash_t(ccv_cnnp_tensor_symbol_set)* const parameters_or_internals = kh_init(ccv_cnnp_tensor_symbol_set);
263
14
  for (i = 0; i < compiled_data->parameters->rnum; 
i++12
)
264
12
  {
265
12
    const ccv_nnc_tensor_symbol_t* const symbol = (const ccv_nnc_tensor_symbol_t*)ccv_array_get(compiled_data->parameters, i);
266
12
    int ret;
267
12
    kh_put(ccv_cnnp_tensor_symbol_set, parameters_or_internals, symbol->d, &ret);
268
12
  }
269
2
  for (i = 0; i < compiled_data->internals->rnum; 
i++0
)
270
0
  {
271
0
    const ccv_nnc_tensor_symbol_t* const symbol = (const ccv_nnc_tensor_symbol_t*)ccv_array_get(compiled_data->parameters, i);
272
0
    int ret;
273
0
    kh_put(ccv_cnnp_tensor_symbol_set, parameters_or_internals, symbol->d, &ret);
274
0
  }
275
2
  khash_t(ccv_cnnp_tensor_symbol_set)* const newly_created_tensor_symbols = kh_init(ccv_cnnp_tensor_symbol_set);
276
2
  khash_t(ccv_cnnp_tensor_symbol_map)* symbol_map = kh_init(ccv_cnnp_tensor_symbol_map);
277
4
  for (i = 0; i < gradient_checkpoints->rnum; 
i++2
)
278
2
  {
279
2
    ccv_cnnp_model_gradient_checkpoint_t* const checkpoint = (ccv_cnnp_model_gradient_checkpoint_t*)ccv_array_get(gradient_checkpoints, i);
280
2
    kh_clear(ccv_cnnp_tensor_symbol_set, newly_created_tensor_symbols);
281
32
    for (j = 0; j < checkpoint->tensor_symbols->rnum; 
j++30
)
282
30
    {
283
30
      const int idx = ((ccv_nnc_tensor_symbol_t*)ccv_array_get(checkpoint->tensor_symbols, j))->d;
284
30
      if (idx < 0)
285
0
        continue;
286
      // Skip parameters or internals.
287
30
      if (kh_get(ccv_cnnp_tensor_symbol_set, parameters_or_internals, idx) != kh_end(parameters_or_internals))
288
12
        continue;
289
18
      int ret;
290
18
      kh_put(ccv_cnnp_tensor_symbol_set, newly_created_tensor_symbols, idx, &ret);
291
18
    }
292
2
    ccv_array_clear(input_execs);
293
2
    ccv_array_clear(output_execs);
294
2
    ccv_nnc_graph_exec_symbol_info_t* exec_info = (ccv_nnc_graph_exec_symbol_info_t*)ccv_array_get(graph->exec_symbol_info, 0);
295
48
    for (j = 0; j < exec_rnum; 
j++46
)
296
46
    {
297
46
      if (CCV_NNC_GRAPH_EXEC_IS_DEAD(exec_info[j].flags))
298
0
        continue;
299
46
      const int* inputs = exec_info[j].inputs;
300
46
      int input_size = exec_info[j].input_size;
301
46
      const int* outputs = exec_info[j].outputs;
302
46
      int output_size = exec_info[j].output_size;
303
46
      if (input_size == 0 && 
output_size == 00
)
304
0
        continue;
305
      // Only go through forward pass.
306
46
      if (ccv_nnc_cmd_is_backward(exec_info[j].cmd))
307
17
        continue;
308
29
      const ccv_nnc_graph_exec_symbol_t symbol = {
309
29
        .graph = graph,
310
29
        .d = j
311
29
      };
312
29
      int flag = 0;
313
88
      for (k = 0; inputs && k < input_size && 
!flag65
;
k++59
)
314
59
        if (inputs[k] >= 0)
315
118
          
for (l = 0; 59
l < checkpoint->input_size &&
!flag59
;
l++59
)
316
59
            if (checkpoint->inputs[l].d >= 0 && inputs[k] == checkpoint->inputs[l].d)
317
6
              flag = 1;
318
      // Input logic is different from output logic. We need to filter out these exec that contains inputs from within the graph.
319
41
      for (k = 0; inputs && k < input_size && 
flag35
;
k++12
)
320
12
        if (inputs[k] >= 0 && kh_get(ccv_cnnp_tensor_symbol_set, newly_created_tensor_symbols, inputs[k]) != kh_end(newly_created_tensor_symbols))
321
0
          flag = 0;
322
29
      if (flag)
323
6
        ccv_array_push(input_execs, &symbol);
324
29
      flag = 0;
325
66
      for (k = 0; outputs && k < output_size && 
!flag37
;
k++37
)
326
37
        if (outputs[k] >= 0)
327
74
          
for (l = 0; 37
l < checkpoint->output_size &&
!flag37
;
l++37
)
328
37
            if (checkpoint->outputs[l].d >= 0 && outputs[k] == checkpoint->outputs[l].d)
329
2
              flag = 1;
330
29
      if (flag)
331
2
        ccv_array_push(output_execs, &symbol);
332
29
    }
333
2
    if (input_execs->rnum <= 0 || output_execs->rnum <= 0)
334
0
      continue;
335
    // Fill in blanks (i.e. the backward ops that are not showing in above, but should be included to avoid excluding necessary ones). This is done by flowing gradients from outputs back all the way to inputs.
336
2
    ccv_array_clear(input_gradient_execs);
337
2
    ccv_array_clear(output_gradient_execs);
338
8
    for (j = 0; j < input_execs->rnum; 
j++6
)
339
6
    {
340
6
      const int d = ((ccv_nnc_graph_exec_symbol_t*)ccv_array_get(input_execs, j))->d;
341
18
      for (k = 0; k < exec_info[d].input_size; 
k++12
)
342
12
        if (exec_info[d].inputs[k] >= 0)
343
12
        {
344
12
          const ccv_nnc_tensor_symbol_t gradient_symbol = ccv_nnc_tensor_symbol_for_backward(graph, (ccv_nnc_tensor_symbol_t){
345
12
            .graph = graph,
346
12
            .d = exec_info[d].inputs[k]
347
12
          });
348
12
          if (gradient_symbol.d < 0)
349
9
            continue;
350
3
          const ccv_nnc_graph_exec_symbol_t backward = ccv_nnc_graph_exec_symbol_for_backward(graph, gradient_symbol);
351
3
          if (backward.d < 0)
352
0
            continue;
353
3
          if (CCV_NNC_GRAPH_EXEC_IS_DEAD(exec_info[backward.d].flags))
354
0
            continue;
355
3
          int flag = 0;
356
4
          for (l = 0; !flag && l < output_gradient_execs->rnum; 
l++1
)
357
1
            if (((ccv_nnc_graph_exec_symbol_t*)ccv_array_get(output_gradient_execs, l))->d == backward.d)
358
0
              flag = 1;
359
3
          if (!flag)
360
3
            ccv_array_push(output_gradient_execs, &backward);
361
3
        }
362
6
      if (exec_info[d].outgoings && exec_info[d].outgoings->rnum > 0)
363
15
        
for (k = 0; 6
k < exec_info[d].outgoings->rnum;
k++9
)
364
9
        {
365
9
          const int to_d = *(int*)ccv_array_get(exec_info[d].outgoings, k);
366
9
          if (!ccv_nnc_cmd_is_backward(exec_info[to_d].cmd))
367
6
            continue;
368
3
          int flag = 0;
369
7
          for (l = 0; !flag && 
l < output_gradient_execs->rnum4
;
l++4
)
370
4
            if (((ccv_nnc_graph_exec_symbol_t*)ccv_array_get(output_gradient_execs, l))->d == to_d)
371
3
              flag = 1;
372
3
          if (!flag)
373
0
          {
374
0
            const ccv_nnc_graph_exec_symbol_t backward = {
375
0
              .graph = graph,
376
0
              .d = to_d
377
0
            };
378
0
            ccv_array_push(output_gradient_execs, &backward);
379
0
          }
380
3
        }
381
6
    }
382
    // For output_gradient_execs, we can be opportunistic and use the wrt symbols (if exists) to find relevant bits.
383
    // For input_gradient_execs, there is no other way but to loop over all outgoings, find the ones are direct link as backward execs.
384
4
    for (j = 0; j < output_execs->rnum; 
j++2
)
385
2
    {
386
2
      const int d = ((ccv_nnc_graph_exec_symbol_t*)ccv_array_get(output_execs, j))->d;
387
2
      if (exec_info[d].outgoings && exec_info[d].outgoings->rnum > 0)
388
6
        
for (k = 0; 2
k < exec_info[d].outgoings->rnum;
k++4
)
389
4
        {
390
4
          const int to_d = *(int*)ccv_array_get(exec_info[d].outgoings, k);
391
4
          if (!ccv_nnc_cmd_is_backward(exec_info[to_d].cmd))
392
2
            continue;
393
2
          int flag = 0;
394
2
          for (l = 0; !flag && l < input_gradient_execs->rnum; 
l++0
)
395
0
            if (((ccv_nnc_graph_exec_symbol_t*)ccv_array_get(input_gradient_execs, l))->d == to_d)
396
0
              flag = 1;
397
2
          if (!flag)
398
2
          {
399
2
            const ccv_nnc_graph_exec_symbol_t backward = {
400
2
              .graph = graph,
401
2
              .d = to_d
402
2
            };
403
2
            ccv_array_push(input_gradient_execs, &backward);
404
2
          }
405
2
        }
406
2
    }
407
    // Note that we have to use up-to-date ones because the exec_info might have outgoings that is up-to-date.
408
4
    ccv_nnc_graph_visit_t* const visit = 
ccv_nnc_graph_visit_new2
(graph, exec_info, graph->exec_symbol_info->rnum, (ccv_nnc_graph_exec_symbol_t*)ccv_array_get(input_gradient_execs, 0), input_gradient_execs->rnum, (ccv_nnc_graph_exec_symbol_t*)ccv_array_get(output_gradient_execs, 0), output_gradient_execs->rnum, 1);
409
16
    ccv_nnc_graph_visit_for(visit, exec_info, node, idx) {
410
16
      if (idx < exec_rnum && !CCV_NNC_GRAPH_EXEC_IS_DEAD(node->flags))
411
16
        maskbit[idx >> 5] |= (1u << (idx & 0x1f));
412
16
    } ccv_nnc_graph_visit_endfor
413
4
    ccv_array_clear(visited_backward_execs);
414
    // Add more backward pass to the list. Note that we don't add everything, particularly there are new nodes created through gradient checkpointing are ignored.
415
4
#define visitor(node, idx, _) \
416
16
    if (idx < exec_rnum && !CCV_NNC_GRAPH_EXEC_IS_DEAD(exec_info[idx].flags) && maskbit[idx >> 5] & (1u << (idx & 0x1f))) \
417
16
      ccv_array_add_unique_int(visited_backward_execs, idx);
418
16
    
CCV_NNC_GRAPH_VISIT2
(graph, reversed_nodes, exec_rnum, (ccv_nnc_graph_exec_symbol_t*)ccv_array_get(output_gradient_execs, 0), output_gradient_execs->rnum, (ccv_nnc_graph_exec_symbol_t*)ccv_array_get(input_gradient_execs, 0), input_gradient_execs->rnum, 0, visitor);
419
4
    for (j = 0; j < input_gradient_execs->rnum; 
j++2
)
420
2
      ccv_array_add_unique_int(visited_backward_execs, ((ccv_nnc_graph_exec_symbol_t*)ccv_array_get(input_gradient_execs, j))->d);
421
2
#undef visitor
422
2
    ccv_cnnp_gradient_checkpoint_build_t build = {
423
2
      .tensor_context = {
424
2
        .record = 1,
425
2
        .tensor_symbols = ccv_array_new(sizeof(ccv_nnc_tensor_symbol_t), 0, 0),
426
2
      },
427
2
      .graph_exec_symbols = ccv_array_new(sizeof(ccv_nnc_graph_exec_symbol_t), 0, 0),
428
2
      .all_tensor_symbols = ccv_array_new(sizeof(ccv_nnc_tensor_symbol_t), 0, 0),
429
2
    };
430
2
    build.tensor_context.old_tensor_symbol_new_hook_context = ccv_nnc_tensor_symbol_new_hook(graph, _ccv_cnnp_gradient_checkpoint_tensor_symbol_new_hook, &build, &build.tensor_context.old_tensor_symbol_new_hook);
431
2
    build.tensor_context.old_tensor_symbol_alias_new_hook_context = ccv_nnc_tensor_symbol_alias_new_hook(graph, _ccv_cnnp_gradient_checkpoint_tensor_symbol_alias_new_hook, &build, &build.tensor_context.old_tensor_symbol_alias_new_hook);
432
2
    build.old_graph_exec_symbol_new_hook_context = ccv_nnc_graph_exec_symbol_new_hook(graph, _ccv_cnnp_model_gradient_checkpoint_graph_exec_symbol_new_hook, &build, &build.old_graph_exec_symbol_new_hook);
433
2
    ccv_array_clear(parameters);
434
2
    ccv_array_clear(parameter_ids);
435
2
    ccv_array_clear(parameter_trainables);
436
2
    ccv_array_clear(internals);
437
2
    ccv_array_clear(internal_ids);
438
2
    ccv_cnnp_model_sequence_t model_sequence = {
439
2
      .bank = kh_init(ccv_cnnp_model_name_bank)
440
2
    };
441
2
    ccv_cnnp_model_add_to_array_context_t add_to_parameter_context = {
442
2
      .add_parameter_indices = 0,
443
2
      .prefix = 't',
444
2
      .sequence = &model_sequence,
445
2
      .symbols = parameters,
446
2
      .ids = parameter_ids,
447
2
      .trainables = parameter_trainables,
448
2
    };
449
2
    ccv_cnnp_model_add_to_array_context_t add_to_output_context = {
450
2
      .add_parameter_indices = 0,
451
2
      .prefix = 'r',
452
2
      .sequence = &model_sequence,
453
2
      .symbols = internals,
454
2
      .ids = internal_ids,
455
2
      .trainables = 0,
456
2
    };
457
2
    ccv_cnnp_model_build_data_t build_data = {
458
2
      .is_trainable = checkpoint->is_trainable,
459
2
      .model_sequence = &model_sequence,
460
2
      .add_to_array = ccv_cnnp_model_add_to_array,
461
2
      .parameters = parameters,
462
2
      .context = {
463
2
        .add_to_parameter = &add_to_parameter_context,
464
2
        .add_to_output = &add_to_output_context,
465
2
      },
466
2
      .is_gradient_checkpointing = 1, // Mark this as true so we don't allocate gradient_checkpoints array or override the hooks.
467
2
      .gradient_checkpoints = 0,
468
2
    };
469
2
    checkpoint->model->data = &build_data;
470
2
    checkpoint->build(checkpoint->model, graph, checkpoint->inputs, checkpoint->input_size, max_outputs, checkpoint->output_size);
471
2
    checkpoint->model->data = 0;
472
2
    kh_destroy(ccv_cnnp_model_name_bank, model_sequence.bank);
473
2
    if (model_sequence.sequences)
474
2
      ccv_array_free(model_sequence.sequences);
475
2
    ccv_nnc_tensor_symbol_new_hook(graph, build.tensor_context.old_tensor_symbol_new_hook, build.tensor_context.old_tensor_symbol_new_hook_context, 0);
476
2
    ccv_nnc_tensor_symbol_alias_new_hook(graph, build.tensor_context.old_tensor_symbol_alias_new_hook, build.tensor_context.old_tensor_symbol_alias_new_hook_context, 0);
477
2
    ccv_nnc_graph_exec_symbol_autogen(graph, (ccv_nnc_graph_exec_symbol_t*)ccv_array_get(build.graph_exec_symbols, 0), build.graph_exec_symbols->rnum, 0);
478
14
    for (j = 0; j < parameter_ids->rnum; 
j++12
)
479
12
      ccfree(*(char**)ccv_array_get(parameter_ids, j));
480
2
    for (j = 0; j < internal_ids->rnum; 
j++0
)
481
0
      ccfree(*(char**)ccv_array_get(internal_ids, j));
482
    // Note that there is no graph optimization applied here.
483
2
    exec_info = (ccv_nnc_graph_exec_symbol_info_t*)ccv_array_get(graph->exec_symbol_info, 0);
484
    // Reuse existing one.
485
2
    kh_clear(ccv_cnnp_tensor_symbol_set, newly_created_tensor_symbols);
486
20
    for (j = 0; j < build.tensor_context.tensor_symbols->rnum; 
j++18
)
487
18
    {
488
18
      const int idx = ((ccv_nnc_tensor_symbol_t*)ccv_array_get(build.tensor_context.tensor_symbols, j))->d;
489
18
      if (idx < 0)
490
0
        continue;
491
18
      if (kh_get(ccv_cnnp_tensor_symbol_set, parameters_or_internals, idx) != kh_end(parameters_or_internals))
492
0
        continue;
493
18
      int ret;
494
18
      kh_put(ccv_cnnp_tensor_symbol_set, newly_created_tensor_symbols, idx, &ret);
495
18
    }
496
2
    ccv_array_t* const newly_input_execs = input_execs;
497
2
    ccv_array_t* const newly_output_execs = output_execs;
498
2
    ccv_array_clear(newly_input_execs);
499
2
    ccv_array_clear(newly_output_execs);
500
20
    for (j = 0; j < build.graph_exec_symbols->rnum; 
j++18
)
501
18
    {
502
18
      const int idx = ((ccv_nnc_graph_exec_symbol_t*)ccv_array_get(build.graph_exec_symbols, j))->d;
503
18
      if (idx < 0)
504
0
        continue;
505
18
      if (CCV_NNC_GRAPH_EXEC_IS_DEAD(exec_info[idx].flags))
506
0
        continue;
507
18
      const ccv_nnc_graph_exec_symbol_t symbol = {
508
18
        .graph = graph,
509
18
        .d = idx
510
18
      };
511
18
      const int* inputs = exec_info[idx].inputs;
512
18
      int input_size = exec_info[idx].input_size;
513
      // Only go through forward pass.
514
18
      assert(!ccv_nnc_cmd_is_backward(exec_info[idx].cmd));
515
18
      int flag = 0;
516
47
      for (k = 0; inputs && k < input_size && 
!flag35
;
k++29
)
517
29
        if (inputs[k] >= 0)
518
58
          
for (l = 0; 29
l < checkpoint->input_size &&
!flag29
;
l++29
)
519
29
            if (checkpoint->inputs[l].d >= 0 && inputs[k] == checkpoint->inputs[l].d)
520
6
              flag = 1;
521
      // Input logic is different from output logic. We need to filter out these exec that contains inputs from within the graph.
522
30
      for (k = 0; inputs && k < input_size && 
flag24
;
k++12
)
523
12
        if (inputs[k] >= 0 && kh_get(ccv_cnnp_tensor_symbol_set, newly_created_tensor_symbols, inputs[k]) != kh_end(newly_created_tensor_symbols))
524
0
          flag = 0;
525
18
      if (flag)
526
6
        ccv_array_push(newly_input_execs, &symbol);
527
18
      flag = 0;
528
18
      const int* outputs = exec_info[idx].outputs;
529
18
      int output_size = exec_info[idx].output_size;
530
36
      for (k = 0; inputs && k < output_size && 
!flag18
;
k++18
)
531
18
        if (outputs[k] >= 0)
532
36
          
for (l = 0; 18
l < checkpoint->output_size &&
!flag18
;
l++18
)
533
18
            if (max_outputs[l].d >= 0 && outputs[k] == max_outputs[l].d)
534
2
              flag = 1;
535
18
      if (flag)
536
2
        ccv_array_push(newly_output_execs, &symbol);
537
18
    }
538
4
    
for (j = 0; 2
j < checkpoint->input_size;
j++2
)
539
2
      if (checkpoint->inputs[j].d >= 0)
540
2
        ccv_array_push(parameters, checkpoint->inputs + j);
541
2
    ccv_nnc_symbolic_graph_simplify(graph,
542
2
      SYMBOLIC_GRAPH_PASSES(CCV_NNC_SIMPLIFY_COMMON_SUBEXPRESSION_ELIMINATION,
543
2
        CCV_NNC_SIMPLIFY_DATA_TRANSFER_OPT,
544
2
        CCV_NNC_SIMPLIFY_OPS_FUSION),
545
2
      ccv_array_get(parameters, 0), parameters->rnum,
546
2
      max_outputs, checkpoint->output_size,
547
2
      ccv_array_get(newly_input_execs, 0), newly_input_execs->rnum, ccv_array_get(newly_output_execs, 0), newly_output_execs->rnum);
548
2
    ccv_nnc_graph_exec_symbol_new_hook(graph, build.old_graph_exec_symbol_new_hook, build.old_graph_exec_symbol_new_hook_context, 0);
549
    // Need to autogen and redo source / destination.
550
2
    ccv_nnc_graph_exec_symbol_autogen(graph, (ccv_nnc_graph_exec_symbol_t*)ccv_array_get(build.graph_exec_symbols, 0), build.graph_exec_symbols->rnum, 0);
551
2
    ccv_nnc_tensor_symbol_info_t* const tensor_symbol_info = (ccv_nnc_tensor_symbol_info_t*)ccv_array_get(graph->tensor_symbol_info, 0);
552
2
    exec_info = (ccv_nnc_graph_exec_symbol_info_t*)ccv_array_get(graph->exec_symbol_info, 0);
553
2
    ccv_array_clear(newly_input_execs);
554
20
    for (j = 0; j < build.graph_exec_symbols->rnum; 
j++18
)
555
18
    {
556
18
      const int idx = ((ccv_nnc_graph_exec_symbol_t*)ccv_array_get(build.graph_exec_symbols, j))->d;
557
18
      if (idx < 0)
558
0
        continue;
559
18
      if (CCV_NNC_GRAPH_EXEC_IS_DEAD(exec_info[idx].flags))
560
0
        continue;
561
18
      const ccv_nnc_graph_exec_symbol_t symbol = {
562
18
        .graph = graph,
563
18
        .d = idx
564
18
      };
565
18
      const int* inputs = exec_info[idx].inputs;
566
18
      int input_size = exec_info[idx].input_size;
567
      // Only go through forward pass.
568
18
      assert(!ccv_nnc_cmd_is_backward(exec_info[idx].cmd));
569
18
      int flag = 0;
570
47
      for (k = 0; inputs && k < input_size && 
!flag35
;
k++29
)
571
29
        if (inputs[k] >= 0)
572
58
          
for (l = 0; 29
l < checkpoint->input_size &&
!flag29
;
l++29
)
573
29
            if (checkpoint->inputs[l].d >= 0 && inputs[k] == checkpoint->inputs[l].d)
574
6
              flag = 1;
575
30
      for (k = 0; inputs && k < input_size && 
flag24
;
k++12
)
576
12
        if (inputs[k] >= 0 && kh_get(ccv_cnnp_tensor_symbol_set, newly_created_tensor_symbols, inputs[k]) != kh_end(newly_created_tensor_symbols))
577
0
          flag = 0;
578
18
      if (flag)
579
6
        ccv_array_push(newly_input_execs, &symbol);
580
18
    }
581
    // Build a map between old tensor symbols and new tensor symbols.
582
2
    assert(build.tensor_context.tensor_symbols->rnum <= checkpoint->tensor_symbols->rnum);
583
    // Build a map to potentially map from old input to new input. 
584
2
    kh_clear(ccv_cnnp_tensor_symbol_map, symbol_map);
585
32
    for (j = 0, k = 0; j < build.tensor_context.tensor_symbols->rnum && 
k < checkpoint->tensor_symbols->rnum30
;)
586
30
    {
587
30
      const int from_d = ((ccv_nnc_tensor_symbol_t*)ccv_array_get(checkpoint->tensor_symbols, k))->d;
588
30
      if (from_d < 0) // This is removed, move to the next one.
589
0
      {
590
0
        ++j;
591
0
        ++k;
592
0
        continue;
593
0
      }
594
30
      const int to_d = ((ccv_nnc_tensor_symbol_t*)ccv_array_get(build.tensor_context.tensor_symbols, j))->d;
595
30
      assert(to_d >= 0);
596
30
      int from_flag = kh_get(ccv_cnnp_tensor_symbol_set, parameters_or_internals, from_d) != kh_end(parameters_or_internals);
597
30
      int to_flag = kh_get(ccv_cnnp_tensor_symbol_set, parameters_or_internals, to_d) != kh_end(parameters_or_internals);
598
30
      if (from_flag)
599
12
        ++k;
600
30
      if (to_flag)
601
0
        ++j;
602
30
      if (from_flag || 
to_flag18
)
603
12
        continue;
604
18
      ++k;
605
18
      ++j;
606
      // Skip if from_d is outputs.
607
36
      for (l = 0; l < !from_flag && 
checkpoint->output_size18
;
l++18
)
608
18
        if (checkpoint->outputs[l].d == from_d)
609
2
          from_flag = 1;
610
18
      if (from_flag)
611
2
        continue;
612
      // Skip if to_d is outputs.
613
32
      
for (l = 0; 16
l < !to_flag &&
checkpoint->output_size16
;
l++16
)
614
16
        if (checkpoint->outputs[l].d == to_d)
615
0
          to_flag = 1;
616
16
      if (to_flag)
617
0
        continue;
618
16
      int ret = 0;
619
16
      khiter_t h = kh_put(ccv_cnnp_tensor_symbol_map, symbol_map, from_d, &ret);
620
16
      kh_val(symbol_map, h) = to_d;
621
16
    }
622
    // Now go over all backward passes to replace inputs with the ones from symbol map. Record these that are used.
623
2
    ccv_array_clear(newly_used_outputs);
624
2
    ccv_array_clear(replaced_backward_execs);
625
18
    for (j = 0; j < visited_backward_execs->rnum; 
j++16
)
626
16
    {
627
16
      const int idx = *(int*)ccv_array_get(visited_backward_execs, j);
628
16
      if (CCV_NNC_GRAPH_EXEC_IS_DEAD(exec_info[idx].flags))
629
0
        continue;
630
16
      assert(idx >= 0);
631
16
      assert(idx < exec_rnum);
632
16
      if (!ccv_nnc_cmd_is_backward(exec_info[idx].cmd))
633
1
        continue;
634
74
      
for (k = 0; 15
k < exec_info[idx].input_size;
k++59
)
635
59
        if (exec_info[idx].inputs[k] >= 0)
636
32
        {
637
32
          const khiter_t h = kh_get(ccv_cnnp_tensor_symbol_map, symbol_map, exec_info[idx].inputs[k]);
638
32
          if (h != kh_end(symbol_map)) // Replacing it.
639
8
          {
640
8
            int newly_created_output = kh_val(symbol_map, h);
641
8
            exec_info[idx].inputs[k] = newly_created_output;
642
8
            ccv_array_add_unique_int(newly_used_outputs, newly_created_output);
643
8
            if (tensor_symbol_info[newly_created_output].alias_ref > 0)
644
0
            {
645
0
              newly_created_output = tensor_symbol_info[newly_created_output].alias_ref - 1;
646
0
              ccv_array_add_unique_int(newly_used_outputs, newly_created_output);
647
0
            }
648
8
            ccv_array_add_unique_int(replaced_backward_execs, idx);
649
8
          }
650
32
        }
651
15
    }
652
20
    
for (j = 0; 2
j < build.graph_exec_symbols->rnum;
j++18
)
653
18
    {
654
18
      ccv_nnc_graph_exec_symbol_t* const symbol = (ccv_nnc_graph_exec_symbol_t*)ccv_array_get(build.graph_exec_symbols, j);
655
18
      if (symbol->d < 0)
656
0
        continue;
657
18
      if (CCV_NNC_GRAPH_EXEC_IS_DEAD(exec_info[symbol->d].flags))
658
0
        continue;
659
18
      int x, y;
660
106
      for (k = 0; k < replaced_backward_execs->rnum; 
k++88
)
661
88
      {
662
88
        const int idx = *(int*)ccv_array_get(replaced_backward_execs, k);
663
88
        assert(idx >= 0);
664
88
        assert(idx < exec_rnum);
665
88
        assert(ccv_nnc_cmd_is_backward(exec_info[idx].cmd));
666
88
        int flag = 0;
667
412
        for (x = 0; !flag && 
x < exec_info[idx].input_size404
;
x++324
)
668
324
        {
669
324
          int x_d = exec_info[idx].inputs[x];
670
324
          if (x_d < 0)
671
80
            continue;
672
244
          if (tensor_symbol_info[x_d].alias_ref > 0)
673
0
            x_d = tensor_symbol_info[x_d].alias_ref - 1;
674
488
          for (y = 0; !flag && 
y < exec_info[symbol->d].output_size480
;
y++244
)
675
244
          {
676
244
            int y_d = exec_info[symbol->d].outputs[y];
677
244
            if (y_d < 0)
678
0
              continue;
679
244
            if (tensor_symbol_info[y_d].alias_ref > 0)
680
0
              y_d = tensor_symbol_info[y_d].alias_ref - 1;
681
244
            if (x_d == y_d)
682
8
              flag = 1;
683
244
          }
684
244
        }
685
88
        if (flag)
686
8
          ccv_nnc_graph_exec_symbol_concat(graph, *symbol, (ccv_nnc_graph_exec_symbol_t){
687
8
            .graph = graph,
688
8
            .d = idx
689
8
          });
690
88
      }
691
18
    }
692
    // Find parents to visited_backward_execs, and use that as the starting point of all newly added graph_exec_symbols. Use the visited backward execs as the source, use all its parents as destination, go through with graph visit.
693
2
    ccv_nnc_exec_dep_t exec_dep = _ccv_nnc_exec_dep_new(graph, visit, exec_rnum, maskbit);
694
    // Now go from outputs to inputs, unmark visited ones.
695
16
    ccv_nnc_graph_visit_for(visit, exec_info, node, idx) {
696
16
      if (idx < exec_rnum)
697
16
        maskbit[idx >> 5] &= ~(1u << (idx & 0x1f));
698
16
    } ccv_nnc_graph_visit_endfor
699
2
    ccv_nnc_graph_visit_free(visit);
700
    // Go through visited backward execs, remove the ones that has no dependency on any replaced backward execs.
701
18
    for (j = 0; j < visited_backward_execs->rnum;)
702
16
    {
703
16
      const int idx = *(int*)ccv_array_get(visited_backward_execs, j);
704
16
      if (ccv_array_contain_int(replaced_backward_execs, idx))
705
7
      {
706
7
        ++j;
707
7
        continue;
708
7
      }
709
9
      ccv_sparse_matrix_vector_t* vector = ccv_get_sparse_matrix_vector(exec_dep.deps, idx);
710
9
      if (!vector)
711
3
        vector = (ccv_sparse_matrix_vector_t*)1; // Mark it as we tried but cannot find.
712
9
      int flag = 0;
713
38
      for (k = 0; !flag && 
k < replaced_backward_execs->rnum32
;
k++29
)
714
29
      {
715
29
        const int d = *(int*)ccv_array_get(replaced_backward_execs, k);
716
29
        flag = ccv_nnc_exec_dep_check(exec_dep, idx, vector, d);
717
29
      }
718
9
      if (!flag)
719
3
      {
720
3
        if (j < visited_backward_execs->rnum - 1)
721
2
          *(int*)ccv_array_get(visited_backward_execs, j) = *(int*)ccv_array_get(visited_backward_execs, visited_backward_execs->rnum - 1);
722
3
        --visited_backward_execs->rnum;
723
3
        continue;
724
3
      }
725
6
      ++j;
726
6
    }
727
    // Now go through all replaced_backward_execs to find the ones has no dependencies in visited_backward_execs.
728
9
    for (j = 0; j < replaced_backward_execs->rnum; 
j++7
)
729
7
    {
730
7
      const int idx = *(int*)ccv_array_get(replaced_backward_execs, j);
731
7
      ccv_sparse_matrix_vector_t* vector = ccv_get_sparse_matrix_vector(exec_dep.deps, idx);
732
7
      if (!vector)
733
3
        vector = (ccv_sparse_matrix_vector_t*)1; // Mark it as we tried but cannot find.
734
7
      int flag = 0;
735
59
      for (k = 0; !flag && 
k < visited_backward_execs->rnum54
;
k++52
)
736
52
      {
737
52
        const int d = *(int*)ccv_array_get(visited_backward_execs, k);
738
52
        flag = ccv_nnc_exec_dep_check(exec_dep, idx, vector, d);
739
52
      }
740
      // If this one has no parents that is within the visited_backward_execs, it is a good place for us to add all its parents as dependency for input_execs.
741
7
      if (!flag)
742
2
      {
743
2
        assert(idx < exec_rnum);
744
2
        ccv_array_t* const outgoings = reversed_nodes[idx].outgoings;
745
2
        assert(outgoings);
746
6
        
for (k = 0; 2
k < outgoings->rnum;
k++4
)
747
4
        {
748
4
          const int d = *(int*)ccv_array_get(outgoings, k);
749
16
          for (l = 0; l < newly_input_execs->rnum; 
l++12
)
750
12
          {
751
12
            ccv_nnc_graph_exec_symbol_concat(graph, (ccv_nnc_graph_exec_symbol_t){
752
12
              .graph = graph,
753
12
              .d = d
754
12
            }, *(ccv_nnc_graph_exec_symbol_t*)ccv_array_get(newly_input_execs, l));
755
12
          }
756
4
        }
757
2
      }
758
7
    }
759
2
    ccv_nnc_exec_dep_free(exec_dep);
760
    // Go through all exec, free ones that doesn't have output used.
761
    // Reuse this array because it is not useful any more.
762
2
    ccv_array_t* forward_pass_inputs = visited_backward_execs;
763
2
    int any_deleted;
764
6
    do {
765
      // Build a map of still active inputs.
766
6
      ccv_array_clear(forward_pass_inputs);
767
60
      for (j = 0; j < build.graph_exec_symbols->rnum; 
j++54
)
768
54
      {
769
54
        ccv_nnc_graph_exec_symbol_t* const symbol = (ccv_nnc_graph_exec_symbol_t*)ccv_array_get(build.graph_exec_symbols, j);
770
54
        if (symbol->d < 0)
771
8
          continue;
772
46
        if (CCV_NNC_GRAPH_EXEC_IS_DEAD(exec_info[symbol->d].flags))
773
0
          continue;
774
46
        int* const inputs = exec_info[symbol->d].inputs;
775
46
        const int input_size = exec_info[symbol->d].input_size;
776
135
        for (k = 0; k < input_size; 
k++89
)
777
89
        {
778
89
          int d = inputs[k];
779
89
          if (d < 0)
780
0
            continue;
781
89
          ccv_array_add_unique_int(forward_pass_inputs, d);
782
89
          if (tensor_symbol_info[d].alias_ref > 0)
783
0
          {
784
0
            d = tensor_symbol_info[d].alias_ref - 1;
785
0
            ccv_array_add_unique_int(forward_pass_inputs, d);
786
0
          }
787
89
        }
788
46
      }
789
6
      any_deleted = 0;
790
60
      for (j = 0; j < build.graph_exec_symbols->rnum; 
j++54
)
791
54
      {
792
54
        ccv_nnc_graph_exec_symbol_t* const symbol = (ccv_nnc_graph_exec_symbol_t*)ccv_array_get(build.graph_exec_symbols, j);
793
54
        if (symbol->d < 0)
794
8
          continue;
795
46
        if (CCV_NNC_GRAPH_EXEC_IS_DEAD(exec_info[symbol->d].flags))
796
0
          continue;
797
46
        int* const outputs = exec_info[symbol->d].outputs;
798
46
        const int output_size = exec_info[symbol->d].output_size;
799
46
        int flag = 0;
800
92
        for (k = 0; !flag && 
k < output_size52
;
k++46
)
801
46
        {
802
46
          int d = outputs[k];
803
46
          if (d < 0)
804
0
            continue;
805
46
          flag = ccv_array_contain_int(newly_used_outputs, d) || 
ccv_array_contain_int(forward_pass_inputs, d)22
;
806
46
          if (!flag && 
tensor_symbol_info[d].alias_ref > 06
)
807
0
          {
808
0
            d = tensor_symbol_info[d].alias_ref - 1;
809
0
            flag = ccv_array_contain_int(newly_used_outputs, d) || ccv_array_contain_int(forward_pass_inputs, d);
810
0
          }
811
46
        }
812
46
        if (flag)
813
40
          continue;
814
6
        ccv_nnc_graph_exec_symbol_free(graph, *symbol);
815
6
        symbol->d = -1;
816
6
        symbol->graph = 0;
817
6
        any_deleted = 1;
818
6
      }
819
6
    } while (any_deleted);
820
2
    ccv_array_clear(forward_pass_inputs);
821
20
    for (j = 0; j < build.graph_exec_symbols->rnum; 
j++18
)
822
18
    {
823
18
      ccv_nnc_graph_exec_symbol_t* const symbol = (ccv_nnc_graph_exec_symbol_t*)ccv_array_get(build.graph_exec_symbols, j);
824
18
      if (symbol->d < 0)
825
6
        continue;
826
12
      if (CCV_NNC_GRAPH_EXEC_IS_DEAD(exec_info[symbol->d].flags))
827
0
        continue;
828
12
      int* const inputs = exec_info[symbol->d].inputs;
829
12
      const int input_size = exec_info[symbol->d].input_size;
830
35
      for (k = 0; k < input_size; 
k++23
)
831
23
      {
832
23
        if (inputs[k] < 0)
833
0
          continue;
834
23
        ccv_array_add_unique_int(forward_pass_inputs, inputs[k]);
835
23
        if (tensor_symbol_info[inputs[k]].alias_ref > 0)
836
0
          ccv_array_add_unique_int(forward_pass_inputs, tensor_symbol_info[inputs[k]].alias_ref - 1);
837
23
      }
838
12
      int* const outputs = exec_info[symbol->d].outputs;
839
12
      const int output_size = exec_info[symbol->d].output_size;
840
24
      for (k = 0; k < output_size; 
k++12
)
841
12
      {
842
12
        if (outputs[k] < 0)
843
0
          continue;
844
12
        ccv_array_add_unique_int(forward_pass_inputs, outputs[k]);
845
12
        if (tensor_symbol_info[outputs[k]].alias_ref > 0)
846
0
          ccv_array_add_unique_int(forward_pass_inputs, tensor_symbol_info[outputs[k]].alias_ref - 1);
847
12
      }
848
12
    }
849
    // Free unused tensor symbols.
850
20
    for (j = 0; j < build.all_tensor_symbols->rnum; 
j++18
)
851
18
    {
852
18
      const ccv_nnc_tensor_symbol_t* symbol = ((ccv_nnc_tensor_symbol_t*)ccv_array_get(build.all_tensor_symbols, j));
853
18
      if (ccv_array_contain_int(newly_used_outputs, symbol->d) || 
ccv_array_contain_int(forward_pass_inputs, symbol->d)10
)
854
12
        continue;
855
6
      if (tensor_symbol_info[symbol->d].alias_ref > 0)
856
0
      {
857
0
        const int d = tensor_symbol_info[symbol->d].alias_ref - 1;
858
0
        if (ccv_array_contain_int(newly_used_outputs, d) || ccv_array_contain_int(forward_pass_inputs, d))
859
0
          continue;
860
0
      }
861
6
      ccv_nnc_tensor_symbol_free(graph, *symbol);
862
6
    }
863
20
    for (j = 0; j < build.graph_exec_symbols->rnum; 
j++18
)
864
18
    {
865
18
      ccv_nnc_graph_exec_symbol_t* const symbol = (ccv_nnc_graph_exec_symbol_t*)ccv_array_get(build.graph_exec_symbols, j);
866
18
      if (symbol->d < 0)
867
6
        continue;
868
12
      if (CCV_NNC_GRAPH_EXEC_IS_DEAD(exec_info[symbol->d].flags))
869
0
        continue;
870
12
      ccv_nnc_graph_exec_symbol_set_flags(graph, *symbol, CCV_NNC_GRAPH_EXEC_DISABLE_OPT);
871
12
    }
872
    // Free these newly created execs and tensor symbols.
873
2
    ccv_array_free(build.tensor_context.tensor_symbols);
874
2
    ccv_array_free(build.graph_exec_symbols);
875
2
    ccv_array_free(build.all_tensor_symbols);
876
2
  }
877
2
  kh_destroy(ccv_cnnp_tensor_symbol_map, symbol_map);
878
2
  kh_destroy(ccv_cnnp_tensor_symbol_set, newly_created_tensor_symbols);
879
2
  kh_destroy(ccv_cnnp_tensor_symbol_set, parameters_or_internals);
880
2
  ccfree(max_outputs);
881
2
  ccv_array_free(newly_used_outputs);
882
2
  ccv_array_free(parameters);
883
2
  ccv_array_free(parameter_ids);
884
2
  ccv_array_free(parameter_trainables);
885
2
  ccv_array_free(internals);
886
2
  ccv_array_free(internal_ids);
887
2
  ccfree(maskbit);
888
2
  ccv_array_free(input_gradient_execs);
889
2
  ccv_array_free(output_gradient_execs);
890
2
  ccv_array_free(input_execs);
891
2
  ccv_array_free(output_execs);
892
2
  ccv_array_free(replaced_backward_execs);
893
2
  ccv_array_free(visited_backward_execs);
894
48
  for (i = 0; i < exec_rnum; 
i++46
)
895
46
    if (reversed_nodes[i].outgoings)
896
40
      ccv_array_free(reversed_nodes[i].outgoings);
897
2
  ccfree(reversed_nodes);
898
2
}