Coverage Report

Created: 2024-08-18 16:21

/home/liu/actions-runner/_work/ccv/ccv/lib/nnc/ccv_cnnp_model_gradient_checkpointing.c
Line
Count
Source (jump to first uncovered line)
1
#include "ccv_nnc.h"
2
#include "ccv_nnc_easy.h"
3
#include "ccv_nnc_internal.h"
4
#include "ccv_internal.h"
5
#include "_ccv_cnnp_model.h"
6
// This can be removed once we organized ccv_cnnp_apply_gradient_checkpoints better.
7
#include "_ccv_nnc_symbolic_graph.h"
8
9
typedef struct {
10
  ccv_array_t* outgoings;
11
} ccv_nnc_graph_exec_symbol_reverse_t;
12
13
typedef struct {
14
  ccv_array_t* tensor_symbols;
15
  void* old_tensor_symbol_new_hook_context;
16
  ccv_nnc_tensor_symbol_new_hook_f old_tensor_symbol_new_hook;
17
  void* old_tensor_symbol_alias_new_hook_context;
18
  ccv_nnc_tensor_symbol_alias_new_hook_f old_tensor_symbol_alias_new_hook;
19
  ccv_array_t* graph_exec_symbols;
20
  ccv_nnc_graph_exec_symbol_new_hook_f old_graph_exec_symbol_new_hook;
21
  void* old_graph_exec_symbol_new_hook_context;
22
} ccv_cnnp_gradient_checkpoint_build_t;
23
24
static void _ccv_cnnp_gradient_checkpoint_tensor_symbol_new_hook(void* context, const ccv_nnc_tensor_symbol_t symbol, const ccv_nnc_tensor_param_t info, const char* const name)
25
18
{
26
18
  ccv_cnnp_gradient_checkpoint_build_t* const build_context = (ccv_cnnp_gradient_checkpoint_build_t*)context;
27
18
  ccv_array_push(build_context->tensor_symbols, &symbol);
28
18
  if (build_context->old_tensor_symbol_new_hook)
29
18
    build_context->old_tensor_symbol_new_hook(build_context->old_tensor_symbol_new_hook_context, symbol, info, name);
30
18
}
31
32
static void _ccv_cnnp_gradient_checkpoint_tensor_symbol_alias_new_hook(void* context, const ccv_nnc_tensor_symbol_t symbol, const ccv_nnc_tensor_symbol_t from_symbol, const int ofs[CCV_NNC_MAX_DIM_ALLOC], const int inc[CCV_NNC_MAX_DIM_ALLOC], const ccv_nnc_tensor_param_t info, const char* const name)
33
0
{
34
0
  ccv_cnnp_gradient_checkpoint_build_t* const build_context = (ccv_cnnp_gradient_checkpoint_build_t*)context;
35
0
  ccv_array_push(build_context->tensor_symbols, &symbol);
36
0
  if (build_context->old_tensor_symbol_alias_new_hook)
37
0
    build_context->old_tensor_symbol_alias_new_hook(build_context->old_tensor_symbol_alias_new_hook_context, symbol, from_symbol, ofs, inc, info, name);
38
0
}
39
40
static void _ccv_cnnp_model_gradient_checkpoint_graph_exec_symbol_new_hook(void* context, const ccv_nnc_graph_exec_symbol_t symbol, const ccv_nnc_cmd_t cmd, const ccv_nnc_tensor_symbol_t* const inputs, const int input_size, const ccv_nnc_tensor_symbol_t* const outputs, const int output_size, const char* const name)
41
18
{
42
18
  ccv_cnnp_gradient_checkpoint_build_t* const build = (ccv_cnnp_gradient_checkpoint_build_t*)context;
43
18
  ccv_array_push(build->graph_exec_symbols, &symbol);
44
18
  if (build->old_graph_exec_symbol_new_hook)
45
18
    build->old_graph_exec_symbol_new_hook(build->old_graph_exec_symbol_new_hook_context, symbol, cmd, inputs, input_size, outputs, output_size, name);
46
18
}
47
48
KHASH_MAP_INIT_INT(ccv_cnnp_tensor_symbol_map, int)
49
50
void ccv_cnnp_model_apply_gradient_checkpoints(ccv_cnnp_compiled_data_t* const compiled_data, ccv_nnc_symbolic_graph_t* const graph)
51
2.24k
{
52
2.24k
  ccv_array_t* const gradient_checkpoints = compiled_data->gradient_checkpoints;
53
2.24k
  if (!gradient_checkpoints || 
gradient_checkpoints->rnum == 02
) // No saved gradient checkpoints, this is an easy way out.
54
2.23k
    return;
55
  // Otherwise, for each gradient checkpoint, there are 3 steps:
56
  // 1. Find currently, what execs exists from inputs to outputs.
57
  // 2. Find execs that generates the outputs, and their corresponding backward execs.
58
  // 3. Find all backward execs flow from outputs back to inputs.
59
  // 4. Generate new ops by calling build again with old inputs, record all new tensors / execs.
60
  // 5. Replace inputs in backward execs with the new tensors.
61
  // 6. Hook the execs takes inputs with edge from parents of backward execs in step 2.
62
  // 7. Delete newly generated execs that has no use (i.e. its outputs are not used by backward pass).
63
  // 8. Mark all new execs with DISABLE_OPT to avoid common sub-expression elimination pass.
64
2
  int i, j, k, l;
65
2
  ccv_array_t* input_execs = ccv_array_new(sizeof(ccv_nnc_graph_exec_symbol_t), 0, 0);
66
2
  ccv_array_t* output_execs = ccv_array_new(sizeof(ccv_nnc_graph_exec_symbol_t), 0, 0);
67
2
  ccv_array_t* input_gradient_execs = ccv_array_new(sizeof(ccv_nnc_graph_exec_symbol_t), 0, 0);
68
2
  ccv_array_t* output_gradient_execs = ccv_array_new(sizeof(ccv_nnc_graph_exec_symbol_t), 0, 0);
69
2
  ccv_array_t* visited_backward_execs = ccv_array_new(sizeof(int), 0, 0);
70
2
  ccv_array_t* replaced_backward_execs = ccv_array_new(sizeof(int), 0, 0);
71
2
  const int exec_rnum = graph->exec_symbol_info->rnum;
72
2
  ccv_nnc_graph_exec_symbol_reverse_t* const reversed_nodes = cccalloc(exec_rnum, sizeof(ccv_nnc_graph_exec_symbol_reverse_t));
73
48
  for (i = 0; i < exec_rnum; 
i++46
)
74
46
  {
75
46
    const int* tos = 0;
76
46
    int to_size = 0;
77
46
    ccv_nnc_graph_exec_symbol_to(graph, (ccv_nnc_graph_exec_symbol_t){
78
46
      .graph = graph,
79
46
      .d = i
80
46
    }, &tos, &to_size);
81
46
    if (tos)
82
99
      
for (j = 0; 38
j < to_size;
j++61
)
83
61
      {
84
61
        if (!reversed_nodes[tos[j]].outgoings)
85
40
          reversed_nodes[tos[j]].outgoings = ccv_array_new(sizeof(int), 1, 0);
86
61
        ccv_array_add_unique_int(reversed_nodes[tos[j]].outgoings, i);
87
61
      }
88
46
  }
89
2
  uint32_t* const maskbit = cccalloc((exec_rnum + 31) >> 5, sizeof(uint32_t));
90
  // Temporary for build_data.
91
2
  ccv_array_t* const parameters = ccv_array_new(sizeof(ccv_nnc_tensor_symbol_t), 0, 0);
92
2
  ccv_array_t* const parameter_ids = ccv_array_new(sizeof(char*), 0, 0);
93
2
  ccv_array_t* const parameter_trainables = ccv_array_new(sizeof(int), 0, 0);
94
2
  ccv_array_t* const internals = ccv_array_new(sizeof(ccv_nnc_tensor_symbol_t), 0, 0);
95
2
  ccv_array_t* const internal_ids = ccv_array_new(sizeof(char*), 0, 0);
96
2
  ccv_array_t* const buf = ccv_array_new(sizeof(int), 0, 0);
97
2
  int max_output_size = 0;
98
4
  for (i = 0; i < gradient_checkpoints->rnum; 
i++2
)
99
2
  {
100
2
    ccv_cnnp_model_gradient_checkpoint_t* const checkpoint = (ccv_cnnp_model_gradient_checkpoint_t*)ccv_array_get(gradient_checkpoints, i);
101
2
    max_output_size = ccv_max(checkpoint->output_size, max_output_size);
102
2
  }
103
2
  ccv_nnc_tensor_symbol_t* max_outputs = ccmalloc(sizeof(ccv_nnc_tensor_symbol_t) * max_output_size);
104
2
  ccv_array_t* newly_used_outputs = ccv_array_new(sizeof(int), 0, 0);
105
4
  for (i = 0; i < gradient_checkpoints->rnum; 
i++2
)
106
2
  {
107
2
    ccv_cnnp_model_gradient_checkpoint_t* const checkpoint = (ccv_cnnp_model_gradient_checkpoint_t*)ccv_array_get(gradient_checkpoints, i);
108
2
    ccv_array_clear(input_execs);
109
2
    ccv_array_clear(output_execs);
110
2
    ccv_nnc_graph_exec_symbol_info_t* exec_info = (ccv_nnc_graph_exec_symbol_info_t*)ccv_array_get(graph->exec_symbol_info, 0);
111
48
    for (j = 0; j < exec_rnum; 
j++46
)
112
46
    {
113
46
      if (CCV_NNC_GRAPH_EXEC_IS_DEAD(exec_info[j].flags))
114
0
        continue;
115
46
      const int* inputs = exec_info[j].inputs;
116
46
      int input_size = exec_info[j].input_size;
117
46
      const int* outputs = exec_info[j].outputs;
118
46
      int output_size = exec_info[j].output_size;
119
46
      if (input_size == 0 && 
output_size == 00
)
120
0
        continue;
121
      // Only go through forward pass.
122
46
      if (ccv_nnc_cmd_is_backward(exec_info[j].cmd))
123
17
        continue;
124
29
      const ccv_nnc_graph_exec_symbol_t symbol = {
125
29
        .graph = graph,
126
29
        .d = j
127
29
      };
128
29
      int flag = 0;
129
88
      for (k = 0; inputs && k < input_size && 
!flag65
;
k++59
)
130
59
        if (inputs[k] >= 0)
131
118
        
for (l = 0; 59
l < checkpoint->input_size &&
!flag59
;
l++59
)
132
59
          if (checkpoint->inputs[l].d >= 0 && inputs[k] == checkpoint->inputs[l].d)
133
6
            flag = 1;
134
29
      if (flag)
135
6
        ccv_array_push(input_execs, &symbol);
136
29
      flag = 0;
137
66
      for (k = 0; outputs && k < output_size && 
!flag37
;
k++37
)
138
37
        if (outputs[k] >= 0)
139
74
          
for (l = 0; 37
l < checkpoint->output_size &&
!flag37
;
l++37
)
140
37
            if (checkpoint->outputs[l].d >= 0 && outputs[k] == checkpoint->outputs[l].d)
141
2
              flag = 1;
142
29
      if (flag)
143
2
        ccv_array_push(output_execs, &symbol);
144
29
    }
145
2
    if (input_execs->rnum <= 0 || output_execs->rnum <= 0)
146
0
      continue;
147
    // Fill in blanks (i.e. the backward ops that are not showing in above, but should be included to avoid excluding necessary ones). This is done by flowing gradients from outputs back all the way to inputs.
148
2
    ccv_array_clear(input_gradient_execs);
149
2
    ccv_array_clear(output_gradient_execs);
150
8
    for (j = 0; j < input_execs->rnum; 
j++6
)
151
6
    {
152
6
      const int d = ((ccv_nnc_graph_exec_symbol_t*)ccv_array_get(input_execs, j))->d;
153
18
      for (k = 0; k < exec_info[d].input_size; 
k++12
)
154
12
        if (exec_info[d].inputs[k] >= 0)
155
12
        {
156
12
          const ccv_nnc_tensor_symbol_t gradient_symbol = ccv_nnc_tensor_symbol_for_backward(graph, (ccv_nnc_tensor_symbol_t){
157
12
            .graph = graph,
158
12
            .d = exec_info[d].inputs[k]
159
12
          });
160
12
          if (gradient_symbol.d < 0)
161
9
            continue;
162
3
          const ccv_nnc_graph_exec_symbol_t backward = ccv_nnc_graph_exec_symbol_for_backward(graph, gradient_symbol);
163
3
          if (backward.d < 0)
164
0
            continue;
165
3
          if (CCV_NNC_GRAPH_EXEC_IS_DEAD(exec_info[backward.d].flags))
166
0
            continue;
167
3
          int flag = 0;
168
4
          for (l = 0; !flag && l < output_gradient_execs->rnum; 
l++1
)
169
1
            if (((ccv_nnc_graph_exec_symbol_t*)ccv_array_get(output_gradient_execs, l))->d == backward.d)
170
0
              flag = 1;
171
3
          if (!flag)
172
3
            ccv_array_push(output_gradient_execs, &backward);
173
3
        }
174
6
      if (exec_info[d].outgoings && exec_info[d].outgoings->rnum > 0)
175
15
        
for (k = 0; 6
k < exec_info[d].outgoings->rnum;
k++9
)
176
9
        {
177
9
          const int to_d = *(int*)ccv_array_get(exec_info[d].outgoings, k);
178
9
          if (!ccv_nnc_cmd_is_backward(exec_info[to_d].cmd))
179
6
            continue;
180
3
          int flag = 0;
181
7
          for (l = 0; !flag && 
l < output_gradient_execs->rnum4
;
l++4
)
182
4
            if (((ccv_nnc_graph_exec_symbol_t*)ccv_array_get(output_gradient_execs, l))->d == to_d)
183
3
              flag = 1;
184
3
          if (!flag)
185
0
          {
186
0
            const ccv_nnc_graph_exec_symbol_t backward = {
187
0
              .graph = graph,
188
0
              .d = to_d
189
0
            };
190
0
            ccv_array_push(output_gradient_execs, &backward);
191
0
          }
192
3
        }
193
6
    }
194
    // For output_gradient_execs, we can be opportunistic and use the wrt symbols (if exists) to find relevant bits.
195
    // For input_gradient_execs, there is no other way but to loop over all outgoings, find the ones are direct link as backward execs.
196
4
    for (j = 0; j < output_execs->rnum; 
j++2
)
197
2
    {
198
2
      const int d = ((ccv_nnc_graph_exec_symbol_t*)ccv_array_get(output_execs, j))->d;
199
2
      if (exec_info[d].outgoings && exec_info[d].outgoings->rnum > 0)
200
6
        
for (k = 0; 2
k < exec_info[d].outgoings->rnum;
k++4
)
201
4
        {
202
4
          const int to_d = *(int*)ccv_array_get(exec_info[d].outgoings, k);
203
4
          if (!ccv_nnc_cmd_is_backward(exec_info[to_d].cmd))
204
2
            continue;
205
2
          int flag = 0;
206
2
          for (l = 0; !flag && l < input_gradient_execs->rnum; 
l++0
)
207
0
            if (((ccv_nnc_graph_exec_symbol_t*)ccv_array_get(input_gradient_execs, l))->d == to_d)
208
0
              flag = 1;
209
2
          if (!flag)
210
2
          {
211
2
            const ccv_nnc_graph_exec_symbol_t backward = {
212
2
              .graph = graph,
213
2
              .d = to_d
214
2
            };
215
2
            ccv_array_push(input_gradient_execs, &backward);
216
2
          }
217
2
        }
218
2
    }
219
    // Note that we have to use up-to-date ones because the exec_info might have outgoings that is up-to-date.
220
4
    ccv_nnc_graph_visit_t* const visit = 
ccv_nnc_graph_visit_new2
(graph, exec_info, graph->exec_symbol_info->rnum, (ccv_nnc_graph_exec_symbol_t*)ccv_array_get(input_gradient_execs, 0), input_gradient_execs->rnum, (ccv_nnc_graph_exec_symbol_t*)ccv_array_get(output_gradient_execs, 0), output_gradient_execs->rnum, 1);
221
16
    ccv_nnc_graph_visit_for(visit, exec_info, node, idx) {
222
16
      if (idx < exec_rnum && !CCV_NNC_GRAPH_EXEC_IS_DEAD(node->flags))
223
16
        maskbit[idx >> 5] |= (1u << (idx & 0x1f));
224
16
    } ccv_nnc_graph_visit_endfor
225
4
    ccv_array_clear(visited_backward_execs);
226
    // Add more backward pass to the list. Note that we don't add everything, particularly there are new nodes created through gradient checkpointing are ignored.
227
4
#define visitor(node, idx, _) \
228
16
    if (idx < exec_rnum && !CCV_NNC_GRAPH_EXEC_IS_DEAD(exec_info[idx].flags) && maskbit[idx >> 5] & (1u << (idx & 0x1f))) \
229
16
      ccv_array_add_unique_int(visited_backward_execs, idx);
230
16
    
CCV_NNC_GRAPH_VISIT2
(graph, reversed_nodes, exec_rnum, (ccv_nnc_graph_exec_symbol_t*)ccv_array_get(output_gradient_execs, 0), output_gradient_execs->rnum, (ccv_nnc_graph_exec_symbol_t*)ccv_array_get(input_gradient_execs, 0), input_gradient_execs->rnum, 0, visitor);
231
4
    for (j = 0; j < input_gradient_execs->rnum; 
j++2
)
232
2
      ccv_array_add_unique_int(visited_backward_execs, ((ccv_nnc_graph_exec_symbol_t*)ccv_array_get(input_gradient_execs, j))->d);
233
2
#undef visitor
234
2
    ccv_cnnp_gradient_checkpoint_build_t build = {
235
2
      .tensor_symbols = ccv_array_new(sizeof(ccv_nnc_tensor_symbol_t), 0, 0),
236
2
      .graph_exec_symbols = ccv_array_new(sizeof(ccv_nnc_graph_exec_symbol_t), 0, 0),
237
2
    };
238
2
    build.old_tensor_symbol_new_hook_context = ccv_nnc_tensor_symbol_new_hook(graph, _ccv_cnnp_gradient_checkpoint_tensor_symbol_new_hook, &build, &build.old_tensor_symbol_new_hook);
239
2
    build.old_tensor_symbol_alias_new_hook_context = ccv_nnc_tensor_symbol_alias_new_hook(graph, _ccv_cnnp_gradient_checkpoint_tensor_symbol_alias_new_hook, &build, &build.old_tensor_symbol_alias_new_hook);
240
2
    build.old_graph_exec_symbol_new_hook_context = ccv_nnc_graph_exec_symbol_new_hook(graph, _ccv_cnnp_model_gradient_checkpoint_graph_exec_symbol_new_hook, &build, &build.old_graph_exec_symbol_new_hook);
241
2
    ccv_array_clear(parameters);
242
2
    ccv_array_clear(parameter_ids);
243
2
    ccv_array_clear(parameter_trainables);
244
2
    ccv_array_clear(internals);
245
2
    ccv_array_clear(internal_ids);
246
2
    ccv_cnnp_model_sequence_t model_sequence = {
247
2
      .bank = kh_init(ccv_cnnp_model_name_bank)
248
2
    };
249
2
    ccv_cnnp_model_add_to_array_context_t add_to_parameter_context = {
250
2
      .sequence = &model_sequence,
251
2
      .prefix = 't',
252
2
      .symbols = parameters,
253
2
      .ids = parameter_ids,
254
2
      .trainables = parameter_trainables,
255
2
    };
256
2
    ccv_cnnp_model_add_to_array_context_t add_to_output_context = {
257
2
      .sequence = &model_sequence,
258
2
      .prefix = 'r',
259
2
      .symbols = internals,
260
2
      .ids = internal_ids,
261
2
      .trainables = 0,
262
2
    };
263
2
    ccv_cnnp_model_build_data_t build_data = {
264
2
      .is_trainable = checkpoint->is_trainable,
265
2
      .model_sequence = &model_sequence,
266
2
      .add_to_array = ccv_cnnp_model_add_to_array,
267
2
      .parameters = parameters,
268
2
      .context = {
269
2
        .add_to_parameter = &add_to_parameter_context,
270
2
        .add_to_output = &add_to_output_context,
271
2
      },
272
2
      .is_gradient_checkpointing = 1, // Mark this as true so we don't allocate gradient_checkpoints array or override the hooks.
273
2
      .gradient_checkpoints = 0,
274
2
    };
275
2
    checkpoint->model->data = &build_data;
276
2
    checkpoint->build(checkpoint->model, graph, checkpoint->inputs, checkpoint->input_size, max_outputs, checkpoint->output_size);
277
2
    checkpoint->model->data = 0;
278
2
    kh_destroy(ccv_cnnp_model_name_bank, model_sequence.bank);
279
2
    if (model_sequence.sequences)
280
2
      ccv_array_free(model_sequence.sequences);
281
2
    ccv_nnc_tensor_symbol_new_hook(graph, build.old_tensor_symbol_new_hook, build.old_tensor_symbol_new_hook_context, 0);
282
2
    ccv_nnc_tensor_symbol_alias_new_hook(graph, build.old_tensor_symbol_alias_new_hook, build.old_tensor_symbol_alias_new_hook_context, 0);
283
2
    ccv_nnc_graph_exec_symbol_autogen(graph, (ccv_nnc_graph_exec_symbol_t*)ccv_array_get(build.graph_exec_symbols, 0), build.graph_exec_symbols->rnum, 0);
284
14
    for (j = 0; j < parameter_ids->rnum; 
j++12
)
285
12
      ccfree(*(char**)ccv_array_get(parameter_ids, j));
286
2
    for (j = 0; j < internal_ids->rnum; 
j++0
)
287
0
      ccfree(*(char**)ccv_array_get(internal_ids, j));
288
    // Note that there is no graph optimization applied here.
289
2
    exec_info = (ccv_nnc_graph_exec_symbol_info_t*)ccv_array_get(graph->exec_symbol_info, 0);
290
    // Reuse existing one.
291
2
    ccv_array_t* const newly_input_execs = input_execs;
292
2
    ccv_array_t* const newly_output_execs = output_execs;
293
2
    ccv_array_clear(newly_input_execs);
294
2
    ccv_array_clear(newly_output_execs);
295
20
    for (j = 0; j < build.graph_exec_symbols->rnum; 
j++18
)
296
18
    {
297
18
      const int idx = ((ccv_nnc_graph_exec_symbol_t*)ccv_array_get(build.graph_exec_symbols, j))->d;
298
18
      if (idx < 0)
299
0
        continue;
300
18
      if (CCV_NNC_GRAPH_EXEC_IS_DEAD(exec_info[idx].flags))
301
0
        continue;
302
18
      const ccv_nnc_graph_exec_symbol_t symbol = {
303
18
        .graph = graph,
304
18
        .d = idx
305
18
      };
306
18
      const int* inputs = exec_info[idx].inputs;
307
18
      int input_size = exec_info[idx].input_size;
308
      // Only go through forward pass.
309
18
      assert(!ccv_nnc_cmd_is_backward(exec_info[idx].cmd));
310
18
      int flag = 0;
311
47
      for (k = 0; inputs && k < input_size && 
!flag35
;
k++29
)
312
29
        if (inputs[k] >= 0)
313
58
        
for (l = 0; 29
l < checkpoint->input_size &&
!flag29
;
l++29
)
314
29
          if (checkpoint->inputs[l].d >= 0 && inputs[k] == checkpoint->inputs[l].d)
315
6
            flag = 1;
316
18
      if (flag)
317
6
        ccv_array_push(newly_input_execs, &symbol);
318
18
      flag = 0;
319
18
      const int* outputs = exec_info[idx].outputs;
320
18
      int output_size = exec_info[idx].output_size;
321
36
      for (k = 0; inputs && k < output_size && 
!flag18
;
k++18
)
322
18
        if (outputs[k] >= 0)
323
36
        
for (l = 0; 18
l < checkpoint->output_size &&
!flag18
;
l++18
)
324
18
          if (max_outputs[l].d >= 0 && outputs[k] == max_outputs[l].d)
325
2
            flag = 1;
326
18
      if (flag)
327
2
        ccv_array_push(newly_output_execs, &symbol);
328
18
    }
329
4
    
for (j = 0; 2
j < checkpoint->input_size;
j++2
)
330
2
      if (checkpoint->inputs[j].d >= 0)
331
2
        ccv_array_push(parameters, checkpoint->inputs + j);
332
2
    ccv_nnc_symbolic_graph_simplify(graph,
333
2
      SYMBOLIC_GRAPH_PASSES(CCV_NNC_SIMPLIFY_COMMON_SUBEXPRESSION_ELIMINATION,
334
2
        CCV_NNC_SIMPLIFY_DATA_TRANSFER_OPT,
335
2
        CCV_NNC_SIMPLIFY_OPS_FUSION),
336
2
      ccv_array_get(parameters, 0), parameters->rnum,
337
2
      max_outputs, checkpoint->output_size,
338
2
      ccv_array_get(newly_input_execs, 0), newly_input_execs->rnum, ccv_array_get(newly_output_execs, 0), newly_output_execs->rnum);
339
2
    ccv_nnc_graph_exec_symbol_new_hook(graph, build.old_graph_exec_symbol_new_hook, build.old_graph_exec_symbol_new_hook_context, 0);
340
    // Need to autogen and redo source / destination.
341
2
    ccv_nnc_graph_exec_symbol_autogen(graph, (ccv_nnc_graph_exec_symbol_t*)ccv_array_get(build.graph_exec_symbols, 0), build.graph_exec_symbols->rnum, 0);
342
2
    exec_info = (ccv_nnc_graph_exec_symbol_info_t*)ccv_array_get(graph->exec_symbol_info, 0);
343
2
    ccv_array_clear(newly_input_execs);
344
20
    for (j = 0; j < build.graph_exec_symbols->rnum; 
j++18
)
345
18
    {
346
18
      const int idx = ((ccv_nnc_graph_exec_symbol_t*)ccv_array_get(build.graph_exec_symbols, j))->d;
347
18
      if (idx < 0)
348
0
        continue;
349
18
      if (CCV_NNC_GRAPH_EXEC_IS_DEAD(exec_info[idx].flags))
350
0
        continue;
351
18
      const ccv_nnc_graph_exec_symbol_t symbol = {
352
18
        .graph = graph,
353
18
        .d = idx
354
18
      };
355
18
      const int* inputs = exec_info[idx].inputs;
356
18
      int input_size = exec_info[idx].input_size;
357
      // Only go through forward pass.
358
18
      assert(!ccv_nnc_cmd_is_backward(exec_info[idx].cmd));
359
18
      int flag = 0;
360
47
      for (k = 0; inputs && k < input_size && 
!flag35
;
k++29
)
361
29
        if (inputs[k] >= 0)
362
58
        
for (l = 0; 29
l < checkpoint->input_size &&
!flag29
;
l++29
)
363
29
          if (checkpoint->inputs[l].d >= 0 && inputs[k] == checkpoint->inputs[l].d)
364
6
            flag = 1;
365
18
      if (flag)
366
6
        ccv_array_push(newly_input_execs, &symbol);
367
18
    }
368
    // Build a map between old tensor symbols and new tensor symbols.
369
2
    khash_t(ccv_cnnp_tensor_symbol_map)* symbol_map = kh_init(ccv_cnnp_tensor_symbol_map);
370
2
    assert(build.tensor_symbols->rnum <= checkpoint->tensor_symbols->rnum);
371
    // Build a map to potentially map from old input to new input. 
372
32
    
for (j = 0, k = 0; 2
j < build.tensor_symbols->rnum &&
k < checkpoint->tensor_symbols->rnum30
;)
373
30
    {
374
30
      const int from_d = ((ccv_nnc_tensor_symbol_t*)ccv_array_get(checkpoint->tensor_symbols, k))->d;
375
30
      assert(from_d >= 0);
376
30
      const int to_d = ((ccv_nnc_tensor_symbol_t*)ccv_array_get(build.tensor_symbols, j))->d;
377
30
      assert(to_d >= 0);
378
30
      int from_flag = 0;
379
30
      int to_flag = 0;
380
288
      for (l = 0; (!from_flag || 
!to_flag63
) && l < parameters->rnum;
l++258
)
381
258
      {
382
258
        const int d = ((ccv_nnc_tensor_symbol_t*)ccv_array_get(parameters, l))->d;
383
258
        if (d == from_d)
384
12
          from_flag = 1;
385
258
        if (d == to_d)
386
0
          to_flag = 1;
387
258
      }
388
30
      if (!from_flag || 
!to_flag12
)
389
30
        for (l = 0; (!from_flag || 
!to_flag12
) && l < internals->rnum;
l++0
)
390
0
        {
391
0
          const int d = ((ccv_nnc_tensor_symbol_t*)ccv_array_get(internals, l))->d;
392
0
          if (d == from_d)
393
0
            from_flag = 1;
394
0
          if (d == to_d)
395
0
            to_flag = 1;
396
0
        }
397
30
      if (from_flag)
398
12
        ++k;
399
30
      if (to_flag)
400
0
        ++j;
401
30
      if (from_flag || 
to_flag18
)
402
12
        continue;
403
18
      ++k;
404
18
      ++j;
405
      // Skip if from_d is outputs.
406
36
      for (l = 0; l < !from_flag && 
checkpoint->output_size18
;
l++18
)
407
18
        if (checkpoint->outputs[l].d == from_d)
408
2
          from_flag = 1;
409
18
      if (from_flag)
410
2
        continue;
411
16
      int ret = 0;
412
16
      khiter_t h = kh_put(ccv_cnnp_tensor_symbol_map, symbol_map, from_d, &ret);
413
16
      kh_val(symbol_map, h) = to_d;
414
16
    }
415
    // Now go over all backward passes to replace inputs with the ones from symbol map. Record these that are used.
416
2
    ccv_array_clear(newly_used_outputs);
417
2
    ccv_array_clear(replaced_backward_execs);
418
18
    for (j = 0; j < visited_backward_execs->rnum; 
j++16
)
419
16
    {
420
16
      const int idx = *(int*)ccv_array_get(visited_backward_execs, j);
421
16
      if (CCV_NNC_GRAPH_EXEC_IS_DEAD(exec_info[idx].flags))
422
0
        continue;
423
16
      assert(idx >= 0);
424
16
      assert(idx < exec_rnum);
425
16
      if (!ccv_nnc_cmd_is_backward(exec_info[idx].cmd))
426
1
        continue;
427
74
      
for (k = 0; 15
k < exec_info[idx].input_size;
k++59
)
428
59
        if (exec_info[idx].inputs[k] >= 0)
429
32
        {
430
32
          const khiter_t h = kh_get(ccv_cnnp_tensor_symbol_map, symbol_map, exec_info[idx].inputs[k]);
431
32
          if (h != kh_end(symbol_map)) // Replacing it.
432
8
          {
433
8
            const int newly_created_output = kh_val(symbol_map, h);
434
8
            exec_info[idx].inputs[k] = newly_created_output;
435
8
            ccv_array_add_unique_int(newly_used_outputs, newly_created_output);
436
8
            ccv_array_add_unique_int(replaced_backward_execs, idx);
437
8
          }
438
32
        }
439
15
    }
440
20
    
for (j = 0; 2
j < build.graph_exec_symbols->rnum;
j++18
)
441
18
    {
442
18
      ccv_nnc_graph_exec_symbol_t* const symbol = (ccv_nnc_graph_exec_symbol_t*)ccv_array_get(build.graph_exec_symbols, j);
443
18
      if (symbol->d < 0)
444
0
        continue;
445
18
      if (CCV_NNC_GRAPH_EXEC_IS_DEAD(exec_info[symbol->d].flags))
446
0
        continue;
447
18
      int x, y;
448
106
      for (k = 0; k < replaced_backward_execs->rnum; 
k++88
)
449
88
      {
450
88
        const int idx = *(int*)ccv_array_get(replaced_backward_execs, k);
451
88
        assert(idx >= 0);
452
88
        assert(idx < exec_rnum);
453
88
        assert(ccv_nnc_cmd_is_backward(exec_info[idx].cmd));
454
88
        int flag = 0;
455
412
        for (x = 0; !flag && 
x < exec_info[idx].input_size404
;
x++324
)
456
648
          
for (y = 0; 324
!flag &&
y < exec_info[symbol->d].output_size640
;
y++324
)
457
324
            if (exec_info[idx].inputs[x] == exec_info[symbol->d].outputs[y])
458
8
              flag = 1;
459
88
        if (flag)
460
8
          ccv_nnc_graph_exec_symbol_concat(graph, *symbol, (ccv_nnc_graph_exec_symbol_t){
461
8
            .graph = graph,
462
8
            .d = idx
463
8
          });
464
88
      }
465
18
    }
466
    // Find parents to visited_backward_execs, and use that as the starting point of all newly added graph_exec_symbols. Use the visited backward execs as the source, use all its parents as destination, go through with graph visit.
467
2
    ccv_sparse_matrix_t* const exec_dep = ccv_sparse_matrix_new(graph->exec_symbol_info->rnum, graph->exec_symbol_info->rnum, CCV_8U | CCV_C1, CCV_SPARSE_ROW_MAJOR, 0);
468
2
#define for_block(x, val) \
469
67
    do { \
470
67
      if (((uint8_t*)val)[0] != 0) \
471
67
        ccv_array_push(buf, &x); \
472
67
    } while (0)
473
2
    const uint8_t one = 1;
474
    // Now go from outputs to inputs, unmark visited ones.
475
16
    ccv_nnc_graph_visit_for(visit, exec_info, node, idx) {
476
16
      if (idx < exec_rnum && !CCV_NNC_GRAPH_EXEC_IS_DEAD(node->flags) && maskbit[idx >> 5] & (1u << (idx & 0x1f)))
477
16
      {
478
16
        ccv_array_clear(buf);
479
16
        ccv_sparse_matrix_vector_t* vector = ccv_get_sparse_matrix_vector(exec_dep, idx);
480
16
        if (vector)
481
67
          
CCV_SPARSE_VECTOR_FOREACH14
(exec_dep, vector, for_block);
482
16
        if (node->outgoings && node->outgoings->rnum > 0)
483
16
        {
484
16
          ccv_array_t* const outgoings = node->outgoings;
485
39
          for (k = 0; k < outgoings->rnum; 
k++23
)
486
23
          {
487
23
            const int outgoing_d = *(int*)ccv_array_get(outgoings, k);
488
23
            if (outgoing_d >= exec_rnum)
489
0
              continue;
490
23
            int l;
491
            // We cannot avoid the ones that visited, because these may not contain all the deps.
492
23
            ccv_set_sparse_matrix_cell(exec_dep, outgoing_d, idx, &one);
493
114
            for (l = 0; l < buf->rnum; 
l++91
)
494
91
              ccv_set_sparse_matrix_cell(exec_dep, outgoing_d, *(int*)ccv_array_get(buf, l), &one);
495
23
          }
496
16
        }
497
16
      }
498
16
    } ccv_nnc_graph_visit_endfor
499
    // Now go from outputs to inputs, unmark visited ones.
500
16
    ccv_nnc_graph_visit_for(visit, exec_info, node, idx) {
501
16
      if (idx < exec_rnum)
502
16
        maskbit[idx >> 5] &= ~(1u << (idx & 0x1f));
503
16
    } ccv_nnc_graph_visit_endfor
504
2
    ccv_nnc_graph_visit_free(visit);
505
2
#undef for_block
506
    // Go through visited backward execs, remove the ones that has no dependency on any replaced backward execs.
507
18
    for (j = 0; j < visited_backward_execs->rnum;)
508
16
    {
509
16
      const int idx = *(int*)ccv_array_get(visited_backward_execs, j);
510
16
      if (ccv_array_contain_int(replaced_backward_execs, idx))
511
7
      {
512
7
        ++j;
513
7
        continue;
514
7
      }
515
9
      ccv_sparse_matrix_vector_t* vector = ccv_get_sparse_matrix_vector(exec_dep, idx);
516
9
      int flag = 0;
517
9
#define for_block(x, val) \
518
37
      do { \
519
37
        if (((uint8_t*)val)[0] != 0) \
520
37
          if (ccv_array_contain_int(replaced_backward_execs, x)) \
521
37
            
flag = 119
; \
522
37
      } while (0)
523
9
      if (vector)
524
37
        
CCV_SPARSE_VECTOR_FOREACH7
(exec_dep, vector, for_block);
525
9
#undef for_block
526
9
      if (!flag)
527
3
      {
528
3
        if (j < visited_backward_execs->rnum - 1)
529
2
          *(int*)ccv_array_get(visited_backward_execs, j) = *(int*)ccv_array_get(visited_backward_execs, visited_backward_execs->rnum - 1);
530
3
        --visited_backward_execs->rnum;
531
3
        continue;
532
3
      }
533
6
      ++j;
534
6
    }
535
    // Now go through all replaced_backward_execs to find the ones has no dependencies in visited_backward_execs.
536
9
    
for (j = 0; 2
j < replaced_backward_execs->rnum;
j++7
)
537
7
    {
538
7
      const int idx = *(int*)ccv_array_get(replaced_backward_execs, j);
539
7
      ccv_sparse_matrix_vector_t* vector = ccv_get_sparse_matrix_vector(exec_dep, idx);
540
7
      int flag = 0;
541
7
#define for_block(x, val) \
542
30
      do { \
543
30
        if (((uint8_t*)val)[0] != 0) \
544
30
          if (ccv_array_contain_int(visited_backward_execs, x)) \
545
30
            
flag = 119
; \
546
30
      } while (0)
547
7
      if (vector)
548
30
        
CCV_SPARSE_VECTOR_FOREACH7
(exec_dep, vector, for_block);
549
7
#undef for_block
550
      // If this one has no parents that is within the visited_backward_execs, it is a good place for us to add all its parents as dependency for input_execs.
551
7
      if (!flag)
552
2
      {
553
2
        assert(idx < exec_rnum);
554
2
        ccv_array_t* const outgoings = reversed_nodes[idx].outgoings;
555
2
        assert(outgoings);
556
6
        
for (k = 0; 2
k < outgoings->rnum;
k++4
)
557
4
        {
558
4
          const int d = *(int*)ccv_array_get(outgoings, k);
559
16
          for (l = 0; l < newly_input_execs->rnum; 
l++12
)
560
12
          {
561
12
            ccv_nnc_graph_exec_symbol_concat(graph, (ccv_nnc_graph_exec_symbol_t){
562
12
              .graph = graph,
563
12
              .d = d
564
12
            }, *(ccv_nnc_graph_exec_symbol_t*)ccv_array_get(newly_input_execs, l));
565
12
          }
566
4
        }
567
2
      }
568
7
    }
569
2
    ccv_matrix_free(exec_dep);
570
    // Go through all exec, free ones that doesn't have output used.
571
    // Reuse this array because it is not useful any more.
572
2
    ccv_array_t* forward_pass_inputs = visited_backward_execs;
573
2
    int any_deleted;
574
6
    do {
575
      // Build a map of still active inputs.
576
6
      ccv_array_clear(forward_pass_inputs);
577
60
      for (j = 0; j < build.graph_exec_symbols->rnum; 
j++54
)
578
54
      {
579
54
        ccv_nnc_graph_exec_symbol_t* const symbol = (ccv_nnc_graph_exec_symbol_t*)ccv_array_get(build.graph_exec_symbols, j);
580
54
        if (symbol->d < 0)
581
8
          continue;
582
46
        if (CCV_NNC_GRAPH_EXEC_IS_DEAD(exec_info[symbol->d].flags))
583
0
          continue;
584
46
        int* const inputs = exec_info[symbol->d].inputs;
585
46
        const int input_size = exec_info[symbol->d].input_size;
586
135
        for (k = 0; k < input_size; 
k++89
)
587
89
          ccv_array_add_unique_int(forward_pass_inputs, inputs[k]);
588
46
      }
589
6
      any_deleted = 0;
590
60
      for (j = 0; j < build.graph_exec_symbols->rnum; 
j++54
)
591
54
      {
592
54
        ccv_nnc_graph_exec_symbol_t* const symbol = (ccv_nnc_graph_exec_symbol_t*)ccv_array_get(build.graph_exec_symbols, j);
593
54
        if (symbol->d < 0)
594
8
          continue;
595
46
        if (CCV_NNC_GRAPH_EXEC_IS_DEAD(exec_info[symbol->d].flags))
596
0
          continue;
597
46
        int* const outputs = exec_info[symbol->d].outputs;
598
46
        const int output_size = exec_info[symbol->d].output_size;
599
46
        int flag = 0;
600
92
        for (k = 0; !flag && 
k < output_size52
;
k++46
)
601
46
          flag = ccv_array_contain_int(newly_used_outputs, outputs[k]) || 
ccv_array_contain_int(forward_pass_inputs, outputs[k])22
;
602
46
        if (flag)
603
40
          continue;
604
6
        ccv_nnc_graph_exec_symbol_free(graph, *symbol);
605
6
        symbol->d = -1;
606
6
        symbol->graph = 0;
607
6
        any_deleted = 1;
608
6
      }
609
6
    } while (any_deleted);
610
2
    ccv_array_clear(forward_pass_inputs);
611
20
    for (j = 0; j < build.graph_exec_symbols->rnum; 
j++18
)
612
18
    {
613
18
      ccv_nnc_graph_exec_symbol_t* const symbol = (ccv_nnc_graph_exec_symbol_t*)ccv_array_get(build.graph_exec_symbols, j);
614
18
      if (symbol->d < 0)
615
6
        continue;
616
12
      if (CCV_NNC_GRAPH_EXEC_IS_DEAD(exec_info[symbol->d].flags))
617
0
        continue;
618
12
      int* const inputs = exec_info[symbol->d].inputs;
619
12
      const int input_size = exec_info[symbol->d].input_size;
620
35
      for (k = 0; k < input_size; 
k++23
)
621
23
        ccv_array_add_unique_int(forward_pass_inputs, inputs[k]);
622
12
      int* const outputs = exec_info[symbol->d].outputs;
623
12
      const int output_size = exec_info[symbol->d].output_size;
624
24
      for (k = 0; k < output_size; 
k++12
)
625
12
        ccv_array_add_unique_int(forward_pass_inputs, outputs[k]);
626
12
    }
627
    // Free unused tensor symbols.
628
20
    for (j = 0; j < build.tensor_symbols->rnum; 
j++18
)
629
18
    {
630
18
      const ccv_nnc_tensor_symbol_t* symbol = ((ccv_nnc_tensor_symbol_t*)ccv_array_get(build.tensor_symbols, j));
631
18
      if (ccv_array_contain_int(newly_used_outputs, symbol->d) || 
ccv_array_contain_int(forward_pass_inputs, symbol->d)10
)
632
12
        continue;
633
6
      ccv_nnc_tensor_symbol_free(graph, *symbol);
634
6
    }
635
20
    for (j = 0; j < build.graph_exec_symbols->rnum; 
j++18
)
636
18
    {
637
18
      ccv_nnc_graph_exec_symbol_t* const symbol = (ccv_nnc_graph_exec_symbol_t*)ccv_array_get(build.graph_exec_symbols, j);
638
18
      if (symbol->d < 0)
639
6
        continue;
640
12
      if (CCV_NNC_GRAPH_EXEC_IS_DEAD(exec_info[symbol->d].flags))
641
0
        continue;
642
12
      ccv_nnc_graph_exec_symbol_set_flags(graph, *symbol, CCV_NNC_GRAPH_EXEC_DISABLE_OPT);
643
12
    }
644
    // Free these newly created execs and tensor symbols.
645
2
    ccv_array_free(build.tensor_symbols);
646
2
    ccv_array_free(build.graph_exec_symbols);
647
2
    kh_destroy(ccv_cnnp_tensor_symbol_map, symbol_map);
648
2
  }
649
2
  ccfree(max_outputs);
650
2
  ccv_array_free(buf);
651
2
  ccv_array_free(newly_used_outputs);
652
2
  ccv_array_free(parameters);
653
2
  ccv_array_free(parameter_ids);
654
2
  ccv_array_free(parameter_trainables);
655
2
  ccv_array_free(internals);
656
2
  ccv_array_free(internal_ids);
657
2
  ccfree(maskbit);
658
2
  ccv_array_free(input_gradient_execs);
659
2
  ccv_array_free(output_gradient_execs);
660
2
  ccv_array_free(input_execs);
661
2
  ccv_array_free(output_execs);
662
2
  ccv_array_free(replaced_backward_execs);
663
2
  ccv_array_free(visited_backward_execs);
664
48
  for (i = 0; i < exec_rnum; 
i++46
)
665
46
    if (reversed_nodes[i].outgoings)
666
40
      ccv_array_free(reversed_nodes[i].outgoings);
667
2
  ccfree(reversed_nodes);
668
2
}