/home/liu/actions-runner/_work/ccv/ccv/lib/nnc/ccv_cnnp_model_gradient_checkpointing.c
Line | Count | Source (jump to first uncovered line) |
1 | | #include "ccv_nnc.h" |
2 | | #include "ccv_nnc_easy.h" |
3 | | #include "ccv_nnc_internal.h" |
4 | | #include "ccv_internal.h" |
5 | | #include "_ccv_cnnp_model.h" |
6 | | // This can be removed once we organized ccv_cnnp_apply_gradient_checkpoints better. |
7 | | #include "_ccv_nnc_symbolic_graph.h" |
8 | | |
9 | | typedef struct { |
10 | | ccv_array_t* outgoings; |
11 | | } ccv_nnc_graph_exec_symbol_reverse_t; |
12 | | |
13 | | typedef struct { |
14 | | ccv_array_t* tensor_symbols; |
15 | | void* old_tensor_symbol_new_hook_context; |
16 | | ccv_nnc_tensor_symbol_new_hook_f old_tensor_symbol_new_hook; |
17 | | void* old_tensor_symbol_alias_new_hook_context; |
18 | | ccv_nnc_tensor_symbol_alias_new_hook_f old_tensor_symbol_alias_new_hook; |
19 | | ccv_array_t* graph_exec_symbols; |
20 | | ccv_nnc_graph_exec_symbol_new_hook_f old_graph_exec_symbol_new_hook; |
21 | | void* old_graph_exec_symbol_new_hook_context; |
22 | | } ccv_cnnp_gradient_checkpoint_build_t; |
23 | | |
24 | | static void _ccv_cnnp_gradient_checkpoint_tensor_symbol_new_hook(void* context, const ccv_nnc_tensor_symbol_t symbol, const ccv_nnc_tensor_param_t info, const char* const name) |
25 | 18 | { |
26 | 18 | ccv_cnnp_gradient_checkpoint_build_t* const build_context = (ccv_cnnp_gradient_checkpoint_build_t*)context; |
27 | 18 | ccv_array_push(build_context->tensor_symbols, &symbol); |
28 | 18 | if (build_context->old_tensor_symbol_new_hook) |
29 | 18 | build_context->old_tensor_symbol_new_hook(build_context->old_tensor_symbol_new_hook_context, symbol, info, name); |
30 | 18 | } |
31 | | |
32 | | static void _ccv_cnnp_gradient_checkpoint_tensor_symbol_alias_new_hook(void* context, const ccv_nnc_tensor_symbol_t symbol, const ccv_nnc_tensor_symbol_t from_symbol, const int ofs[CCV_NNC_MAX_DIM_ALLOC], const int inc[CCV_NNC_MAX_DIM_ALLOC], const ccv_nnc_tensor_param_t info, const char* const name) |
33 | 0 | { |
34 | 0 | ccv_cnnp_gradient_checkpoint_build_t* const build_context = (ccv_cnnp_gradient_checkpoint_build_t*)context; |
35 | 0 | ccv_array_push(build_context->tensor_symbols, &symbol); |
36 | 0 | if (build_context->old_tensor_symbol_alias_new_hook) |
37 | 0 | build_context->old_tensor_symbol_alias_new_hook(build_context->old_tensor_symbol_alias_new_hook_context, symbol, from_symbol, ofs, inc, info, name); |
38 | 0 | } |
39 | | |
40 | | static void _ccv_cnnp_model_gradient_checkpoint_graph_exec_symbol_new_hook(void* context, const ccv_nnc_graph_exec_symbol_t symbol, const ccv_nnc_cmd_t cmd, const ccv_nnc_tensor_symbol_t* const inputs, const int input_size, const ccv_nnc_tensor_symbol_t* const outputs, const int output_size, const char* const name) |
41 | 18 | { |
42 | 18 | ccv_cnnp_gradient_checkpoint_build_t* const build = (ccv_cnnp_gradient_checkpoint_build_t*)context; |
43 | 18 | ccv_array_push(build->graph_exec_symbols, &symbol); |
44 | 18 | if (build->old_graph_exec_symbol_new_hook) |
45 | 18 | build->old_graph_exec_symbol_new_hook(build->old_graph_exec_symbol_new_hook_context, symbol, cmd, inputs, input_size, outputs, output_size, name); |
46 | 18 | } |
47 | | |
48 | | KHASH_MAP_INIT_INT(ccv_cnnp_tensor_symbol_map, int) |
49 | | |
50 | | void ccv_cnnp_model_apply_gradient_checkpoints(ccv_cnnp_compiled_data_t* const compiled_data, ccv_nnc_symbolic_graph_t* const graph) |
51 | 2.24k | { |
52 | 2.24k | ccv_array_t* const gradient_checkpoints = compiled_data->gradient_checkpoints; |
53 | 2.24k | if (!gradient_checkpoints || gradient_checkpoints->rnum == 02 ) // No saved gradient checkpoints, this is an easy way out. |
54 | 2.23k | return; |
55 | | // Otherwise, for each gradient checkpoint, there are 3 steps: |
56 | | // 1. Find currently, what execs exists from inputs to outputs. |
57 | | // 2. Find execs that generates the outputs, and their corresponding backward execs. |
58 | | // 3. Find all backward execs flow from outputs back to inputs. |
59 | | // 4. Generate new ops by calling build again with old inputs, record all new tensors / execs. |
60 | | // 5. Replace inputs in backward execs with the new tensors. |
61 | | // 6. Hook the execs takes inputs with edge from parents of backward execs in step 2. |
62 | | // 7. Delete newly generated execs that has no use (i.e. its outputs are not used by backward pass). |
63 | | // 8. Mark all new execs with DISABLE_OPT to avoid common sub-expression elimination pass. |
64 | 2 | int i, j, k, l; |
65 | 2 | ccv_array_t* input_execs = ccv_array_new(sizeof(ccv_nnc_graph_exec_symbol_t), 0, 0); |
66 | 2 | ccv_array_t* output_execs = ccv_array_new(sizeof(ccv_nnc_graph_exec_symbol_t), 0, 0); |
67 | 2 | ccv_array_t* input_gradient_execs = ccv_array_new(sizeof(ccv_nnc_graph_exec_symbol_t), 0, 0); |
68 | 2 | ccv_array_t* output_gradient_execs = ccv_array_new(sizeof(ccv_nnc_graph_exec_symbol_t), 0, 0); |
69 | 2 | ccv_array_t* visited_backward_execs = ccv_array_new(sizeof(int), 0, 0); |
70 | 2 | ccv_array_t* replaced_backward_execs = ccv_array_new(sizeof(int), 0, 0); |
71 | 2 | const int exec_rnum = graph->exec_symbol_info->rnum; |
72 | 2 | ccv_nnc_graph_exec_symbol_reverse_t* const reversed_nodes = cccalloc(exec_rnum, sizeof(ccv_nnc_graph_exec_symbol_reverse_t)); |
73 | 48 | for (i = 0; i < exec_rnum; i++46 ) |
74 | 46 | { |
75 | 46 | const int* tos = 0; |
76 | 46 | int to_size = 0; |
77 | 46 | ccv_nnc_graph_exec_symbol_to(graph, (ccv_nnc_graph_exec_symbol_t){ |
78 | 46 | .graph = graph, |
79 | 46 | .d = i |
80 | 46 | }, &tos, &to_size); |
81 | 46 | if (tos) |
82 | 99 | for (j = 0; 38 j < to_size; j++61 ) |
83 | 61 | { |
84 | 61 | if (!reversed_nodes[tos[j]].outgoings) |
85 | 40 | reversed_nodes[tos[j]].outgoings = ccv_array_new(sizeof(int), 1, 0); |
86 | 61 | ccv_array_add_unique_int(reversed_nodes[tos[j]].outgoings, i); |
87 | 61 | } |
88 | 46 | } |
89 | 2 | uint32_t* const maskbit = cccalloc((exec_rnum + 31) >> 5, sizeof(uint32_t)); |
90 | | // Temporary for build_data. |
91 | 2 | ccv_array_t* const parameters = ccv_array_new(sizeof(ccv_nnc_tensor_symbol_t), 0, 0); |
92 | 2 | ccv_array_t* const parameter_ids = ccv_array_new(sizeof(char*), 0, 0); |
93 | 2 | ccv_array_t* const parameter_trainables = ccv_array_new(sizeof(int), 0, 0); |
94 | 2 | ccv_array_t* const internals = ccv_array_new(sizeof(ccv_nnc_tensor_symbol_t), 0, 0); |
95 | 2 | ccv_array_t* const internal_ids = ccv_array_new(sizeof(char*), 0, 0); |
96 | 2 | ccv_array_t* const buf = ccv_array_new(sizeof(int), 0, 0); |
97 | 2 | int max_output_size = 0; |
98 | 4 | for (i = 0; i < gradient_checkpoints->rnum; i++2 ) |
99 | 2 | { |
100 | 2 | ccv_cnnp_model_gradient_checkpoint_t* const checkpoint = (ccv_cnnp_model_gradient_checkpoint_t*)ccv_array_get(gradient_checkpoints, i); |
101 | 2 | max_output_size = ccv_max(checkpoint->output_size, max_output_size); |
102 | 2 | } |
103 | 2 | ccv_nnc_tensor_symbol_t* max_outputs = ccmalloc(sizeof(ccv_nnc_tensor_symbol_t) * max_output_size); |
104 | 2 | ccv_array_t* newly_used_outputs = ccv_array_new(sizeof(int), 0, 0); |
105 | 4 | for (i = 0; i < gradient_checkpoints->rnum; i++2 ) |
106 | 2 | { |
107 | 2 | ccv_cnnp_model_gradient_checkpoint_t* const checkpoint = (ccv_cnnp_model_gradient_checkpoint_t*)ccv_array_get(gradient_checkpoints, i); |
108 | 2 | ccv_array_clear(input_execs); |
109 | 2 | ccv_array_clear(output_execs); |
110 | 2 | ccv_nnc_graph_exec_symbol_info_t* exec_info = (ccv_nnc_graph_exec_symbol_info_t*)ccv_array_get(graph->exec_symbol_info, 0); |
111 | 48 | for (j = 0; j < exec_rnum; j++46 ) |
112 | 46 | { |
113 | 46 | if (CCV_NNC_GRAPH_EXEC_IS_DEAD(exec_info[j].flags)) |
114 | 0 | continue; |
115 | 46 | const int* inputs = exec_info[j].inputs; |
116 | 46 | int input_size = exec_info[j].input_size; |
117 | 46 | const int* outputs = exec_info[j].outputs; |
118 | 46 | int output_size = exec_info[j].output_size; |
119 | 46 | if (input_size == 0 && output_size == 00 ) |
120 | 0 | continue; |
121 | | // Only go through forward pass. |
122 | 46 | if (ccv_nnc_cmd_is_backward(exec_info[j].cmd)) |
123 | 17 | continue; |
124 | 29 | const ccv_nnc_graph_exec_symbol_t symbol = { |
125 | 29 | .graph = graph, |
126 | 29 | .d = j |
127 | 29 | }; |
128 | 29 | int flag = 0; |
129 | 88 | for (k = 0; inputs && k < input_size && !flag65 ; k++59 ) |
130 | 59 | if (inputs[k] >= 0) |
131 | 118 | for (l = 0; 59 l < checkpoint->input_size && !flag59 ; l++59 ) |
132 | 59 | if (checkpoint->inputs[l].d >= 0 && inputs[k] == checkpoint->inputs[l].d) |
133 | 6 | flag = 1; |
134 | 29 | if (flag) |
135 | 6 | ccv_array_push(input_execs, &symbol); |
136 | 29 | flag = 0; |
137 | 66 | for (k = 0; outputs && k < output_size && !flag37 ; k++37 ) |
138 | 37 | if (outputs[k] >= 0) |
139 | 74 | for (l = 0; 37 l < checkpoint->output_size && !flag37 ; l++37 ) |
140 | 37 | if (checkpoint->outputs[l].d >= 0 && outputs[k] == checkpoint->outputs[l].d) |
141 | 2 | flag = 1; |
142 | 29 | if (flag) |
143 | 2 | ccv_array_push(output_execs, &symbol); |
144 | 29 | } |
145 | 2 | if (input_execs->rnum <= 0 || output_execs->rnum <= 0) |
146 | 0 | continue; |
147 | | // Fill in blanks (i.e. the backward ops that are not showing in above, but should be included to avoid excluding necessary ones). This is done by flowing gradients from outputs back all the way to inputs. |
148 | 2 | ccv_array_clear(input_gradient_execs); |
149 | 2 | ccv_array_clear(output_gradient_execs); |
150 | 8 | for (j = 0; j < input_execs->rnum; j++6 ) |
151 | 6 | { |
152 | 6 | const int d = ((ccv_nnc_graph_exec_symbol_t*)ccv_array_get(input_execs, j))->d; |
153 | 18 | for (k = 0; k < exec_info[d].input_size; k++12 ) |
154 | 12 | if (exec_info[d].inputs[k] >= 0) |
155 | 12 | { |
156 | 12 | const ccv_nnc_tensor_symbol_t gradient_symbol = ccv_nnc_tensor_symbol_for_backward(graph, (ccv_nnc_tensor_symbol_t){ |
157 | 12 | .graph = graph, |
158 | 12 | .d = exec_info[d].inputs[k] |
159 | 12 | }); |
160 | 12 | if (gradient_symbol.d < 0) |
161 | 9 | continue; |
162 | 3 | const ccv_nnc_graph_exec_symbol_t backward = ccv_nnc_graph_exec_symbol_for_backward(graph, gradient_symbol); |
163 | 3 | if (backward.d < 0) |
164 | 0 | continue; |
165 | 3 | if (CCV_NNC_GRAPH_EXEC_IS_DEAD(exec_info[backward.d].flags)) |
166 | 0 | continue; |
167 | 3 | int flag = 0; |
168 | 4 | for (l = 0; !flag && l < output_gradient_execs->rnum; l++1 ) |
169 | 1 | if (((ccv_nnc_graph_exec_symbol_t*)ccv_array_get(output_gradient_execs, l))->d == backward.d) |
170 | 0 | flag = 1; |
171 | 3 | if (!flag) |
172 | 3 | ccv_array_push(output_gradient_execs, &backward); |
173 | 3 | } |
174 | 6 | if (exec_info[d].outgoings && exec_info[d].outgoings->rnum > 0) |
175 | 15 | for (k = 0; 6 k < exec_info[d].outgoings->rnum; k++9 ) |
176 | 9 | { |
177 | 9 | const int to_d = *(int*)ccv_array_get(exec_info[d].outgoings, k); |
178 | 9 | if (!ccv_nnc_cmd_is_backward(exec_info[to_d].cmd)) |
179 | 6 | continue; |
180 | 3 | int flag = 0; |
181 | 7 | for (l = 0; !flag && l < output_gradient_execs->rnum4 ; l++4 ) |
182 | 4 | if (((ccv_nnc_graph_exec_symbol_t*)ccv_array_get(output_gradient_execs, l))->d == to_d) |
183 | 3 | flag = 1; |
184 | 3 | if (!flag) |
185 | 0 | { |
186 | 0 | const ccv_nnc_graph_exec_symbol_t backward = { |
187 | 0 | .graph = graph, |
188 | 0 | .d = to_d |
189 | 0 | }; |
190 | 0 | ccv_array_push(output_gradient_execs, &backward); |
191 | 0 | } |
192 | 3 | } |
193 | 6 | } |
194 | | // For output_gradient_execs, we can be opportunistic and use the wrt symbols (if exists) to find relevant bits. |
195 | | // For input_gradient_execs, there is no other way but to loop over all outgoings, find the ones are direct link as backward execs. |
196 | 4 | for (j = 0; j < output_execs->rnum; j++2 ) |
197 | 2 | { |
198 | 2 | const int d = ((ccv_nnc_graph_exec_symbol_t*)ccv_array_get(output_execs, j))->d; |
199 | 2 | if (exec_info[d].outgoings && exec_info[d].outgoings->rnum > 0) |
200 | 6 | for (k = 0; 2 k < exec_info[d].outgoings->rnum; k++4 ) |
201 | 4 | { |
202 | 4 | const int to_d = *(int*)ccv_array_get(exec_info[d].outgoings, k); |
203 | 4 | if (!ccv_nnc_cmd_is_backward(exec_info[to_d].cmd)) |
204 | 2 | continue; |
205 | 2 | int flag = 0; |
206 | 2 | for (l = 0; !flag && l < input_gradient_execs->rnum; l++0 ) |
207 | 0 | if (((ccv_nnc_graph_exec_symbol_t*)ccv_array_get(input_gradient_execs, l))->d == to_d) |
208 | 0 | flag = 1; |
209 | 2 | if (!flag) |
210 | 2 | { |
211 | 2 | const ccv_nnc_graph_exec_symbol_t backward = { |
212 | 2 | .graph = graph, |
213 | 2 | .d = to_d |
214 | 2 | }; |
215 | 2 | ccv_array_push(input_gradient_execs, &backward); |
216 | 2 | } |
217 | 2 | } |
218 | 2 | } |
219 | | // Note that we have to use up-to-date ones because the exec_info might have outgoings that is up-to-date. |
220 | 4 | ccv_nnc_graph_visit_t* const visit = ccv_nnc_graph_visit_new2 (graph, exec_info, graph->exec_symbol_info->rnum, (ccv_nnc_graph_exec_symbol_t*)ccv_array_get(input_gradient_execs, 0), input_gradient_execs->rnum, (ccv_nnc_graph_exec_symbol_t*)ccv_array_get(output_gradient_execs, 0), output_gradient_execs->rnum, 1); |
221 | 16 | ccv_nnc_graph_visit_for(visit, exec_info, node, idx) { |
222 | 16 | if (idx < exec_rnum && !CCV_NNC_GRAPH_EXEC_IS_DEAD(node->flags)) |
223 | 16 | maskbit[idx >> 5] |= (1u << (idx & 0x1f)); |
224 | 16 | } ccv_nnc_graph_visit_endfor |
225 | 4 | ccv_array_clear(visited_backward_execs); |
226 | | // Add more backward pass to the list. Note that we don't add everything, particularly there are new nodes created through gradient checkpointing are ignored. |
227 | 4 | #define visitor(node, idx, _) \ |
228 | 16 | if (idx < exec_rnum && !CCV_NNC_GRAPH_EXEC_IS_DEAD(exec_info[idx].flags) && maskbit[idx >> 5] & (1u << (idx & 0x1f))) \ |
229 | 16 | ccv_array_add_unique_int(visited_backward_execs, idx); |
230 | 16 | CCV_NNC_GRAPH_VISIT2 (graph, reversed_nodes, exec_rnum, (ccv_nnc_graph_exec_symbol_t*)ccv_array_get(output_gradient_execs, 0), output_gradient_execs->rnum, (ccv_nnc_graph_exec_symbol_t*)ccv_array_get(input_gradient_execs, 0), input_gradient_execs->rnum, 0, visitor); |
231 | 4 | for (j = 0; j < input_gradient_execs->rnum; j++2 ) |
232 | 2 | ccv_array_add_unique_int(visited_backward_execs, ((ccv_nnc_graph_exec_symbol_t*)ccv_array_get(input_gradient_execs, j))->d); |
233 | 2 | #undef visitor |
234 | 2 | ccv_cnnp_gradient_checkpoint_build_t build = { |
235 | 2 | .tensor_symbols = ccv_array_new(sizeof(ccv_nnc_tensor_symbol_t), 0, 0), |
236 | 2 | .graph_exec_symbols = ccv_array_new(sizeof(ccv_nnc_graph_exec_symbol_t), 0, 0), |
237 | 2 | }; |
238 | 2 | build.old_tensor_symbol_new_hook_context = ccv_nnc_tensor_symbol_new_hook(graph, _ccv_cnnp_gradient_checkpoint_tensor_symbol_new_hook, &build, &build.old_tensor_symbol_new_hook); |
239 | 2 | build.old_tensor_symbol_alias_new_hook_context = ccv_nnc_tensor_symbol_alias_new_hook(graph, _ccv_cnnp_gradient_checkpoint_tensor_symbol_alias_new_hook, &build, &build.old_tensor_symbol_alias_new_hook); |
240 | 2 | build.old_graph_exec_symbol_new_hook_context = ccv_nnc_graph_exec_symbol_new_hook(graph, _ccv_cnnp_model_gradient_checkpoint_graph_exec_symbol_new_hook, &build, &build.old_graph_exec_symbol_new_hook); |
241 | 2 | ccv_array_clear(parameters); |
242 | 2 | ccv_array_clear(parameter_ids); |
243 | 2 | ccv_array_clear(parameter_trainables); |
244 | 2 | ccv_array_clear(internals); |
245 | 2 | ccv_array_clear(internal_ids); |
246 | 2 | ccv_cnnp_model_sequence_t model_sequence = { |
247 | 2 | .bank = kh_init(ccv_cnnp_model_name_bank) |
248 | 2 | }; |
249 | 2 | ccv_cnnp_model_add_to_array_context_t add_to_parameter_context = { |
250 | 2 | .sequence = &model_sequence, |
251 | 2 | .prefix = 't', |
252 | 2 | .symbols = parameters, |
253 | 2 | .ids = parameter_ids, |
254 | 2 | .trainables = parameter_trainables, |
255 | 2 | }; |
256 | 2 | ccv_cnnp_model_add_to_array_context_t add_to_output_context = { |
257 | 2 | .sequence = &model_sequence, |
258 | 2 | .prefix = 'r', |
259 | 2 | .symbols = internals, |
260 | 2 | .ids = internal_ids, |
261 | 2 | .trainables = 0, |
262 | 2 | }; |
263 | 2 | ccv_cnnp_model_build_data_t build_data = { |
264 | 2 | .is_trainable = checkpoint->is_trainable, |
265 | 2 | .model_sequence = &model_sequence, |
266 | 2 | .add_to_array = ccv_cnnp_model_add_to_array, |
267 | 2 | .parameters = parameters, |
268 | 2 | .context = { |
269 | 2 | .add_to_parameter = &add_to_parameter_context, |
270 | 2 | .add_to_output = &add_to_output_context, |
271 | 2 | }, |
272 | 2 | .is_gradient_checkpointing = 1, // Mark this as true so we don't allocate gradient_checkpoints array or override the hooks. |
273 | 2 | .gradient_checkpoints = 0, |
274 | 2 | }; |
275 | 2 | checkpoint->model->data = &build_data; |
276 | 2 | checkpoint->build(checkpoint->model, graph, checkpoint->inputs, checkpoint->input_size, max_outputs, checkpoint->output_size); |
277 | 2 | checkpoint->model->data = 0; |
278 | 2 | kh_destroy(ccv_cnnp_model_name_bank, model_sequence.bank); |
279 | 2 | if (model_sequence.sequences) |
280 | 2 | ccv_array_free(model_sequence.sequences); |
281 | 2 | ccv_nnc_tensor_symbol_new_hook(graph, build.old_tensor_symbol_new_hook, build.old_tensor_symbol_new_hook_context, 0); |
282 | 2 | ccv_nnc_tensor_symbol_alias_new_hook(graph, build.old_tensor_symbol_alias_new_hook, build.old_tensor_symbol_alias_new_hook_context, 0); |
283 | 2 | ccv_nnc_graph_exec_symbol_autogen(graph, (ccv_nnc_graph_exec_symbol_t*)ccv_array_get(build.graph_exec_symbols, 0), build.graph_exec_symbols->rnum, 0); |
284 | 14 | for (j = 0; j < parameter_ids->rnum; j++12 ) |
285 | 12 | ccfree(*(char**)ccv_array_get(parameter_ids, j)); |
286 | 2 | for (j = 0; j < internal_ids->rnum; j++0 ) |
287 | 0 | ccfree(*(char**)ccv_array_get(internal_ids, j)); |
288 | | // Note that there is no graph optimization applied here. |
289 | 2 | exec_info = (ccv_nnc_graph_exec_symbol_info_t*)ccv_array_get(graph->exec_symbol_info, 0); |
290 | | // Reuse existing one. |
291 | 2 | ccv_array_t* const newly_input_execs = input_execs; |
292 | 2 | ccv_array_t* const newly_output_execs = output_execs; |
293 | 2 | ccv_array_clear(newly_input_execs); |
294 | 2 | ccv_array_clear(newly_output_execs); |
295 | 20 | for (j = 0; j < build.graph_exec_symbols->rnum; j++18 ) |
296 | 18 | { |
297 | 18 | const int idx = ((ccv_nnc_graph_exec_symbol_t*)ccv_array_get(build.graph_exec_symbols, j))->d; |
298 | 18 | if (idx < 0) |
299 | 0 | continue; |
300 | 18 | if (CCV_NNC_GRAPH_EXEC_IS_DEAD(exec_info[idx].flags)) |
301 | 0 | continue; |
302 | 18 | const ccv_nnc_graph_exec_symbol_t symbol = { |
303 | 18 | .graph = graph, |
304 | 18 | .d = idx |
305 | 18 | }; |
306 | 18 | const int* inputs = exec_info[idx].inputs; |
307 | 18 | int input_size = exec_info[idx].input_size; |
308 | | // Only go through forward pass. |
309 | 18 | assert(!ccv_nnc_cmd_is_backward(exec_info[idx].cmd)); |
310 | 18 | int flag = 0; |
311 | 47 | for (k = 0; inputs && k < input_size && !flag35 ; k++29 ) |
312 | 29 | if (inputs[k] >= 0) |
313 | 58 | for (l = 0; 29 l < checkpoint->input_size && !flag29 ; l++29 ) |
314 | 29 | if (checkpoint->inputs[l].d >= 0 && inputs[k] == checkpoint->inputs[l].d) |
315 | 6 | flag = 1; |
316 | 18 | if (flag) |
317 | 6 | ccv_array_push(newly_input_execs, &symbol); |
318 | 18 | flag = 0; |
319 | 18 | const int* outputs = exec_info[idx].outputs; |
320 | 18 | int output_size = exec_info[idx].output_size; |
321 | 36 | for (k = 0; inputs && k < output_size && !flag18 ; k++18 ) |
322 | 18 | if (outputs[k] >= 0) |
323 | 36 | for (l = 0; 18 l < checkpoint->output_size && !flag18 ; l++18 ) |
324 | 18 | if (max_outputs[l].d >= 0 && outputs[k] == max_outputs[l].d) |
325 | 2 | flag = 1; |
326 | 18 | if (flag) |
327 | 2 | ccv_array_push(newly_output_execs, &symbol); |
328 | 18 | } |
329 | 4 | for (j = 0; 2 j < checkpoint->input_size; j++2 ) |
330 | 2 | if (checkpoint->inputs[j].d >= 0) |
331 | 2 | ccv_array_push(parameters, checkpoint->inputs + j); |
332 | 2 | ccv_nnc_symbolic_graph_simplify(graph, |
333 | 2 | SYMBOLIC_GRAPH_PASSES(CCV_NNC_SIMPLIFY_COMMON_SUBEXPRESSION_ELIMINATION, |
334 | 2 | CCV_NNC_SIMPLIFY_DATA_TRANSFER_OPT, |
335 | 2 | CCV_NNC_SIMPLIFY_OPS_FUSION), |
336 | 2 | ccv_array_get(parameters, 0), parameters->rnum, |
337 | 2 | max_outputs, checkpoint->output_size, |
338 | 2 | ccv_array_get(newly_input_execs, 0), newly_input_execs->rnum, ccv_array_get(newly_output_execs, 0), newly_output_execs->rnum); |
339 | 2 | ccv_nnc_graph_exec_symbol_new_hook(graph, build.old_graph_exec_symbol_new_hook, build.old_graph_exec_symbol_new_hook_context, 0); |
340 | | // Need to autogen and redo source / destination. |
341 | 2 | ccv_nnc_graph_exec_symbol_autogen(graph, (ccv_nnc_graph_exec_symbol_t*)ccv_array_get(build.graph_exec_symbols, 0), build.graph_exec_symbols->rnum, 0); |
342 | 2 | exec_info = (ccv_nnc_graph_exec_symbol_info_t*)ccv_array_get(graph->exec_symbol_info, 0); |
343 | 2 | ccv_array_clear(newly_input_execs); |
344 | 20 | for (j = 0; j < build.graph_exec_symbols->rnum; j++18 ) |
345 | 18 | { |
346 | 18 | const int idx = ((ccv_nnc_graph_exec_symbol_t*)ccv_array_get(build.graph_exec_symbols, j))->d; |
347 | 18 | if (idx < 0) |
348 | 0 | continue; |
349 | 18 | if (CCV_NNC_GRAPH_EXEC_IS_DEAD(exec_info[idx].flags)) |
350 | 0 | continue; |
351 | 18 | const ccv_nnc_graph_exec_symbol_t symbol = { |
352 | 18 | .graph = graph, |
353 | 18 | .d = idx |
354 | 18 | }; |
355 | 18 | const int* inputs = exec_info[idx].inputs; |
356 | 18 | int input_size = exec_info[idx].input_size; |
357 | | // Only go through forward pass. |
358 | 18 | assert(!ccv_nnc_cmd_is_backward(exec_info[idx].cmd)); |
359 | 18 | int flag = 0; |
360 | 47 | for (k = 0; inputs && k < input_size && !flag35 ; k++29 ) |
361 | 29 | if (inputs[k] >= 0) |
362 | 58 | for (l = 0; 29 l < checkpoint->input_size && !flag29 ; l++29 ) |
363 | 29 | if (checkpoint->inputs[l].d >= 0 && inputs[k] == checkpoint->inputs[l].d) |
364 | 6 | flag = 1; |
365 | 18 | if (flag) |
366 | 6 | ccv_array_push(newly_input_execs, &symbol); |
367 | 18 | } |
368 | | // Build a map between old tensor symbols and new tensor symbols. |
369 | 2 | khash_t(ccv_cnnp_tensor_symbol_map)* symbol_map = kh_init(ccv_cnnp_tensor_symbol_map); |
370 | 2 | assert(build.tensor_symbols->rnum <= checkpoint->tensor_symbols->rnum); |
371 | | // Build a map to potentially map from old input to new input. |
372 | 32 | for (j = 0, k = 0; 2 j < build.tensor_symbols->rnum && k < checkpoint->tensor_symbols->rnum30 ;) |
373 | 30 | { |
374 | 30 | const int from_d = ((ccv_nnc_tensor_symbol_t*)ccv_array_get(checkpoint->tensor_symbols, k))->d; |
375 | 30 | assert(from_d >= 0); |
376 | 30 | const int to_d = ((ccv_nnc_tensor_symbol_t*)ccv_array_get(build.tensor_symbols, j))->d; |
377 | 30 | assert(to_d >= 0); |
378 | 30 | int from_flag = 0; |
379 | 30 | int to_flag = 0; |
380 | 288 | for (l = 0; (!from_flag || !to_flag63 ) && l < parameters->rnum; l++258 ) |
381 | 258 | { |
382 | 258 | const int d = ((ccv_nnc_tensor_symbol_t*)ccv_array_get(parameters, l))->d; |
383 | 258 | if (d == from_d) |
384 | 12 | from_flag = 1; |
385 | 258 | if (d == to_d) |
386 | 0 | to_flag = 1; |
387 | 258 | } |
388 | 30 | if (!from_flag || !to_flag12 ) |
389 | 30 | for (l = 0; (!from_flag || !to_flag12 ) && l < internals->rnum; l++0 ) |
390 | 0 | { |
391 | 0 | const int d = ((ccv_nnc_tensor_symbol_t*)ccv_array_get(internals, l))->d; |
392 | 0 | if (d == from_d) |
393 | 0 | from_flag = 1; |
394 | 0 | if (d == to_d) |
395 | 0 | to_flag = 1; |
396 | 0 | } |
397 | 30 | if (from_flag) |
398 | 12 | ++k; |
399 | 30 | if (to_flag) |
400 | 0 | ++j; |
401 | 30 | if (from_flag || to_flag18 ) |
402 | 12 | continue; |
403 | 18 | ++k; |
404 | 18 | ++j; |
405 | | // Skip if from_d is outputs. |
406 | 36 | for (l = 0; l < !from_flag && checkpoint->output_size18 ; l++18 ) |
407 | 18 | if (checkpoint->outputs[l].d == from_d) |
408 | 2 | from_flag = 1; |
409 | 18 | if (from_flag) |
410 | 2 | continue; |
411 | 16 | int ret = 0; |
412 | 16 | khiter_t h = kh_put(ccv_cnnp_tensor_symbol_map, symbol_map, from_d, &ret); |
413 | 16 | kh_val(symbol_map, h) = to_d; |
414 | 16 | } |
415 | | // Now go over all backward passes to replace inputs with the ones from symbol map. Record these that are used. |
416 | 2 | ccv_array_clear(newly_used_outputs); |
417 | 2 | ccv_array_clear(replaced_backward_execs); |
418 | 18 | for (j = 0; j < visited_backward_execs->rnum; j++16 ) |
419 | 16 | { |
420 | 16 | const int idx = *(int*)ccv_array_get(visited_backward_execs, j); |
421 | 16 | if (CCV_NNC_GRAPH_EXEC_IS_DEAD(exec_info[idx].flags)) |
422 | 0 | continue; |
423 | 16 | assert(idx >= 0); |
424 | 16 | assert(idx < exec_rnum); |
425 | 16 | if (!ccv_nnc_cmd_is_backward(exec_info[idx].cmd)) |
426 | 1 | continue; |
427 | 74 | for (k = 0; 15 k < exec_info[idx].input_size; k++59 ) |
428 | 59 | if (exec_info[idx].inputs[k] >= 0) |
429 | 32 | { |
430 | 32 | const khiter_t h = kh_get(ccv_cnnp_tensor_symbol_map, symbol_map, exec_info[idx].inputs[k]); |
431 | 32 | if (h != kh_end(symbol_map)) // Replacing it. |
432 | 8 | { |
433 | 8 | const int newly_created_output = kh_val(symbol_map, h); |
434 | 8 | exec_info[idx].inputs[k] = newly_created_output; |
435 | 8 | ccv_array_add_unique_int(newly_used_outputs, newly_created_output); |
436 | 8 | ccv_array_add_unique_int(replaced_backward_execs, idx); |
437 | 8 | } |
438 | 32 | } |
439 | 15 | } |
440 | 20 | for (j = 0; 2 j < build.graph_exec_symbols->rnum; j++18 ) |
441 | 18 | { |
442 | 18 | ccv_nnc_graph_exec_symbol_t* const symbol = (ccv_nnc_graph_exec_symbol_t*)ccv_array_get(build.graph_exec_symbols, j); |
443 | 18 | if (symbol->d < 0) |
444 | 0 | continue; |
445 | 18 | if (CCV_NNC_GRAPH_EXEC_IS_DEAD(exec_info[symbol->d].flags)) |
446 | 0 | continue; |
447 | 18 | int x, y; |
448 | 106 | for (k = 0; k < replaced_backward_execs->rnum; k++88 ) |
449 | 88 | { |
450 | 88 | const int idx = *(int*)ccv_array_get(replaced_backward_execs, k); |
451 | 88 | assert(idx >= 0); |
452 | 88 | assert(idx < exec_rnum); |
453 | 88 | assert(ccv_nnc_cmd_is_backward(exec_info[idx].cmd)); |
454 | 88 | int flag = 0; |
455 | 412 | for (x = 0; !flag && x < exec_info[idx].input_size404 ; x++324 ) |
456 | 648 | for (y = 0; 324 !flag && y < exec_info[symbol->d].output_size640 ; y++324 ) |
457 | 324 | if (exec_info[idx].inputs[x] == exec_info[symbol->d].outputs[y]) |
458 | 8 | flag = 1; |
459 | 88 | if (flag) |
460 | 8 | ccv_nnc_graph_exec_symbol_concat(graph, *symbol, (ccv_nnc_graph_exec_symbol_t){ |
461 | 8 | .graph = graph, |
462 | 8 | .d = idx |
463 | 8 | }); |
464 | 88 | } |
465 | 18 | } |
466 | | // Find parents to visited_backward_execs, and use that as the starting point of all newly added graph_exec_symbols. Use the visited backward execs as the source, use all its parents as destination, go through with graph visit. |
467 | 2 | ccv_sparse_matrix_t* const exec_dep = ccv_sparse_matrix_new(graph->exec_symbol_info->rnum, graph->exec_symbol_info->rnum, CCV_8U | CCV_C1, CCV_SPARSE_ROW_MAJOR, 0); |
468 | 2 | #define for_block(x, val) \ |
469 | 67 | do { \ |
470 | 67 | if (((uint8_t*)val)[0] != 0) \ |
471 | 67 | ccv_array_push(buf, &x); \ |
472 | 67 | } while (0) |
473 | 2 | const uint8_t one = 1; |
474 | | // Now go from outputs to inputs, unmark visited ones. |
475 | 16 | ccv_nnc_graph_visit_for(visit, exec_info, node, idx) { |
476 | 16 | if (idx < exec_rnum && !CCV_NNC_GRAPH_EXEC_IS_DEAD(node->flags) && maskbit[idx >> 5] & (1u << (idx & 0x1f))) |
477 | 16 | { |
478 | 16 | ccv_array_clear(buf); |
479 | 16 | ccv_sparse_matrix_vector_t* vector = ccv_get_sparse_matrix_vector(exec_dep, idx); |
480 | 16 | if (vector) |
481 | 67 | CCV_SPARSE_VECTOR_FOREACH14 (exec_dep, vector, for_block); |
482 | 16 | if (node->outgoings && node->outgoings->rnum > 0) |
483 | 16 | { |
484 | 16 | ccv_array_t* const outgoings = node->outgoings; |
485 | 39 | for (k = 0; k < outgoings->rnum; k++23 ) |
486 | 23 | { |
487 | 23 | const int outgoing_d = *(int*)ccv_array_get(outgoings, k); |
488 | 23 | if (outgoing_d >= exec_rnum) |
489 | 0 | continue; |
490 | 23 | int l; |
491 | | // We cannot avoid the ones that visited, because these may not contain all the deps. |
492 | 23 | ccv_set_sparse_matrix_cell(exec_dep, outgoing_d, idx, &one); |
493 | 114 | for (l = 0; l < buf->rnum; l++91 ) |
494 | 91 | ccv_set_sparse_matrix_cell(exec_dep, outgoing_d, *(int*)ccv_array_get(buf, l), &one); |
495 | 23 | } |
496 | 16 | } |
497 | 16 | } |
498 | 16 | } ccv_nnc_graph_visit_endfor |
499 | | // Now go from outputs to inputs, unmark visited ones. |
500 | 16 | ccv_nnc_graph_visit_for(visit, exec_info, node, idx) { |
501 | 16 | if (idx < exec_rnum) |
502 | 16 | maskbit[idx >> 5] &= ~(1u << (idx & 0x1f)); |
503 | 16 | } ccv_nnc_graph_visit_endfor |
504 | 2 | ccv_nnc_graph_visit_free(visit); |
505 | 2 | #undef for_block |
506 | | // Go through visited backward execs, remove the ones that has no dependency on any replaced backward execs. |
507 | 18 | for (j = 0; j < visited_backward_execs->rnum;) |
508 | 16 | { |
509 | 16 | const int idx = *(int*)ccv_array_get(visited_backward_execs, j); |
510 | 16 | if (ccv_array_contain_int(replaced_backward_execs, idx)) |
511 | 7 | { |
512 | 7 | ++j; |
513 | 7 | continue; |
514 | 7 | } |
515 | 9 | ccv_sparse_matrix_vector_t* vector = ccv_get_sparse_matrix_vector(exec_dep, idx); |
516 | 9 | int flag = 0; |
517 | 9 | #define for_block(x, val) \ |
518 | 37 | do { \ |
519 | 37 | if (((uint8_t*)val)[0] != 0) \ |
520 | 37 | if (ccv_array_contain_int(replaced_backward_execs, x)) \ |
521 | 37 | flag = 119 ; \ |
522 | 37 | } while (0) |
523 | 9 | if (vector) |
524 | 37 | CCV_SPARSE_VECTOR_FOREACH7 (exec_dep, vector, for_block); |
525 | 9 | #undef for_block |
526 | 9 | if (!flag) |
527 | 3 | { |
528 | 3 | if (j < visited_backward_execs->rnum - 1) |
529 | 2 | *(int*)ccv_array_get(visited_backward_execs, j) = *(int*)ccv_array_get(visited_backward_execs, visited_backward_execs->rnum - 1); |
530 | 3 | --visited_backward_execs->rnum; |
531 | 3 | continue; |
532 | 3 | } |
533 | 6 | ++j; |
534 | 6 | } |
535 | | // Now go through all replaced_backward_execs to find the ones has no dependencies in visited_backward_execs. |
536 | 9 | for (j = 0; 2 j < replaced_backward_execs->rnum; j++7 ) |
537 | 7 | { |
538 | 7 | const int idx = *(int*)ccv_array_get(replaced_backward_execs, j); |
539 | 7 | ccv_sparse_matrix_vector_t* vector = ccv_get_sparse_matrix_vector(exec_dep, idx); |
540 | 7 | int flag = 0; |
541 | 7 | #define for_block(x, val) \ |
542 | 30 | do { \ |
543 | 30 | if (((uint8_t*)val)[0] != 0) \ |
544 | 30 | if (ccv_array_contain_int(visited_backward_execs, x)) \ |
545 | 30 | flag = 119 ; \ |
546 | 30 | } while (0) |
547 | 7 | if (vector) |
548 | 30 | CCV_SPARSE_VECTOR_FOREACH7 (exec_dep, vector, for_block); |
549 | 7 | #undef for_block |
550 | | // If this one has no parents that is within the visited_backward_execs, it is a good place for us to add all its parents as dependency for input_execs. |
551 | 7 | if (!flag) |
552 | 2 | { |
553 | 2 | assert(idx < exec_rnum); |
554 | 2 | ccv_array_t* const outgoings = reversed_nodes[idx].outgoings; |
555 | 2 | assert(outgoings); |
556 | 6 | for (k = 0; 2 k < outgoings->rnum; k++4 ) |
557 | 4 | { |
558 | 4 | const int d = *(int*)ccv_array_get(outgoings, k); |
559 | 16 | for (l = 0; l < newly_input_execs->rnum; l++12 ) |
560 | 12 | { |
561 | 12 | ccv_nnc_graph_exec_symbol_concat(graph, (ccv_nnc_graph_exec_symbol_t){ |
562 | 12 | .graph = graph, |
563 | 12 | .d = d |
564 | 12 | }, *(ccv_nnc_graph_exec_symbol_t*)ccv_array_get(newly_input_execs, l)); |
565 | 12 | } |
566 | 4 | } |
567 | 2 | } |
568 | 7 | } |
569 | 2 | ccv_matrix_free(exec_dep); |
570 | | // Go through all exec, free ones that doesn't have output used. |
571 | | // Reuse this array because it is not useful any more. |
572 | 2 | ccv_array_t* forward_pass_inputs = visited_backward_execs; |
573 | 2 | int any_deleted; |
574 | 6 | do { |
575 | | // Build a map of still active inputs. |
576 | 6 | ccv_array_clear(forward_pass_inputs); |
577 | 60 | for (j = 0; j < build.graph_exec_symbols->rnum; j++54 ) |
578 | 54 | { |
579 | 54 | ccv_nnc_graph_exec_symbol_t* const symbol = (ccv_nnc_graph_exec_symbol_t*)ccv_array_get(build.graph_exec_symbols, j); |
580 | 54 | if (symbol->d < 0) |
581 | 8 | continue; |
582 | 46 | if (CCV_NNC_GRAPH_EXEC_IS_DEAD(exec_info[symbol->d].flags)) |
583 | 0 | continue; |
584 | 46 | int* const inputs = exec_info[symbol->d].inputs; |
585 | 46 | const int input_size = exec_info[symbol->d].input_size; |
586 | 135 | for (k = 0; k < input_size; k++89 ) |
587 | 89 | ccv_array_add_unique_int(forward_pass_inputs, inputs[k]); |
588 | 46 | } |
589 | 6 | any_deleted = 0; |
590 | 60 | for (j = 0; j < build.graph_exec_symbols->rnum; j++54 ) |
591 | 54 | { |
592 | 54 | ccv_nnc_graph_exec_symbol_t* const symbol = (ccv_nnc_graph_exec_symbol_t*)ccv_array_get(build.graph_exec_symbols, j); |
593 | 54 | if (symbol->d < 0) |
594 | 8 | continue; |
595 | 46 | if (CCV_NNC_GRAPH_EXEC_IS_DEAD(exec_info[symbol->d].flags)) |
596 | 0 | continue; |
597 | 46 | int* const outputs = exec_info[symbol->d].outputs; |
598 | 46 | const int output_size = exec_info[symbol->d].output_size; |
599 | 46 | int flag = 0; |
600 | 92 | for (k = 0; !flag && k < output_size52 ; k++46 ) |
601 | 46 | flag = ccv_array_contain_int(newly_used_outputs, outputs[k]) || ccv_array_contain_int(forward_pass_inputs, outputs[k])22 ; |
602 | 46 | if (flag) |
603 | 40 | continue; |
604 | 6 | ccv_nnc_graph_exec_symbol_free(graph, *symbol); |
605 | 6 | symbol->d = -1; |
606 | 6 | symbol->graph = 0; |
607 | 6 | any_deleted = 1; |
608 | 6 | } |
609 | 6 | } while (any_deleted); |
610 | 2 | ccv_array_clear(forward_pass_inputs); |
611 | 20 | for (j = 0; j < build.graph_exec_symbols->rnum; j++18 ) |
612 | 18 | { |
613 | 18 | ccv_nnc_graph_exec_symbol_t* const symbol = (ccv_nnc_graph_exec_symbol_t*)ccv_array_get(build.graph_exec_symbols, j); |
614 | 18 | if (symbol->d < 0) |
615 | 6 | continue; |
616 | 12 | if (CCV_NNC_GRAPH_EXEC_IS_DEAD(exec_info[symbol->d].flags)) |
617 | 0 | continue; |
618 | 12 | int* const inputs = exec_info[symbol->d].inputs; |
619 | 12 | const int input_size = exec_info[symbol->d].input_size; |
620 | 35 | for (k = 0; k < input_size; k++23 ) |
621 | 23 | ccv_array_add_unique_int(forward_pass_inputs, inputs[k]); |
622 | 12 | int* const outputs = exec_info[symbol->d].outputs; |
623 | 12 | const int output_size = exec_info[symbol->d].output_size; |
624 | 24 | for (k = 0; k < output_size; k++12 ) |
625 | 12 | ccv_array_add_unique_int(forward_pass_inputs, outputs[k]); |
626 | 12 | } |
627 | | // Free unused tensor symbols. |
628 | 20 | for (j = 0; j < build.tensor_symbols->rnum; j++18 ) |
629 | 18 | { |
630 | 18 | const ccv_nnc_tensor_symbol_t* symbol = ((ccv_nnc_tensor_symbol_t*)ccv_array_get(build.tensor_symbols, j)); |
631 | 18 | if (ccv_array_contain_int(newly_used_outputs, symbol->d) || ccv_array_contain_int(forward_pass_inputs, symbol->d)10 ) |
632 | 12 | continue; |
633 | 6 | ccv_nnc_tensor_symbol_free(graph, *symbol); |
634 | 6 | } |
635 | 20 | for (j = 0; j < build.graph_exec_symbols->rnum; j++18 ) |
636 | 18 | { |
637 | 18 | ccv_nnc_graph_exec_symbol_t* const symbol = (ccv_nnc_graph_exec_symbol_t*)ccv_array_get(build.graph_exec_symbols, j); |
638 | 18 | if (symbol->d < 0) |
639 | 6 | continue; |
640 | 12 | if (CCV_NNC_GRAPH_EXEC_IS_DEAD(exec_info[symbol->d].flags)) |
641 | 0 | continue; |
642 | 12 | ccv_nnc_graph_exec_symbol_set_flags(graph, *symbol, CCV_NNC_GRAPH_EXEC_DISABLE_OPT); |
643 | 12 | } |
644 | | // Free these newly created execs and tensor symbols. |
645 | 2 | ccv_array_free(build.tensor_symbols); |
646 | 2 | ccv_array_free(build.graph_exec_symbols); |
647 | 2 | kh_destroy(ccv_cnnp_tensor_symbol_map, symbol_map); |
648 | 2 | } |
649 | 2 | ccfree(max_outputs); |
650 | 2 | ccv_array_free(buf); |
651 | 2 | ccv_array_free(newly_used_outputs); |
652 | 2 | ccv_array_free(parameters); |
653 | 2 | ccv_array_free(parameter_ids); |
654 | 2 | ccv_array_free(parameter_trainables); |
655 | 2 | ccv_array_free(internals); |
656 | 2 | ccv_array_free(internal_ids); |
657 | 2 | ccfree(maskbit); |
658 | 2 | ccv_array_free(input_gradient_execs); |
659 | 2 | ccv_array_free(output_gradient_execs); |
660 | 2 | ccv_array_free(input_execs); |
661 | 2 | ccv_array_free(output_execs); |
662 | 2 | ccv_array_free(replaced_backward_execs); |
663 | 2 | ccv_array_free(visited_backward_execs); |
664 | 48 | for (i = 0; i < exec_rnum; i++46 ) |
665 | 46 | if (reversed_nodes[i].outgoings) |
666 | 40 | ccv_array_free(reversed_nodes[i].outgoings); |
667 | 2 | ccfree(reversed_nodes); |
668 | 2 | } |