/home/liu/actions-runner/_work/ccv/ccv/lib/nnc/ccv_cnnp_model_gradient_checkpointing.c
Line | Count | Source |
1 | | #include "ccv_nnc.h" |
2 | | #include "ccv_nnc_easy.h" |
3 | | #include "ccv_nnc_internal.h" |
4 | | #include "ccv_internal.h" |
5 | | #include "_ccv_cnnp_model.h" |
6 | | // This can be removed once we organized ccv_cnnp_apply_gradient_checkpoints better. |
7 | | #include "_ccv_nnc_symbolic_graph.h" |
8 | | |
9 | | void ccv_cnnp_model_gradient_checkpoints_cleanup_after_build(ccv_cnnp_compiled_data_t* const compiled_data, ccv_nnc_symbolic_graph_t* const graph) |
10 | 2.29k | { |
11 | 2.29k | ccv_array_t* const gradient_checkpoints = compiled_data->gradient_checkpoints; |
12 | 2.29k | if (!gradient_checkpoints || gradient_checkpoints->rnum == 02 ) // No saved gradient checkpoints, this is an easy way out. |
13 | 2.29k | return; |
14 | 2 | int i, j; |
15 | 2 | const ccv_nnc_tensor_symbol_info_t* const tensor_symbol_info = (const ccv_nnc_tensor_symbol_info_t*)ccv_array_get(graph->tensor_symbol_info, 0); |
16 | | // Go through to check if any tensors that supposes in this map is removed. |
17 | 4 | for (i = 0; i < gradient_checkpoints->rnum; i++2 ) |
18 | 2 | { |
19 | 2 | ccv_cnnp_model_gradient_checkpoint_t* const checkpoint = (ccv_cnnp_model_gradient_checkpoint_t*)ccv_array_get(gradient_checkpoints, i); |
20 | 32 | for (j = 0; j < checkpoint->tensor_symbols->rnum; j++30 ) |
21 | 30 | { |
22 | 30 | ccv_nnc_tensor_symbol_t* const symbol = ((ccv_nnc_tensor_symbol_t*)ccv_array_get(checkpoint->tensor_symbols, j)); |
23 | 30 | if (symbol->d >= 0 && symbol->d < graph->tensor_symbol_info->rnum) |
24 | | // If it is dead, we need to remove this symbol. |
25 | 30 | if (CCV_NNC_TENSOR_SYMBOL_IS_DEAD(tensor_symbol_info[symbol->d].flags)) |
26 | 0 | { |
27 | 0 | symbol->d = -1; |
28 | 0 | symbol->graph = 0; |
29 | 0 | } |
30 | 30 | } |
31 | 2 | } |
32 | 2 | } |
33 | | |
34 | | typedef struct { |
35 | | ccv_array_t* outgoings; |
36 | | } ccv_nnc_graph_exec_symbol_reverse_t; |
37 | | |
38 | | typedef struct { |
39 | | ccv_cnnp_model_gradient_checkpoint_build_context_t tensor_context; |
40 | | ccv_array_t* graph_exec_symbols; |
41 | | ccv_nnc_graph_exec_symbol_new_hook_f old_graph_exec_symbol_new_hook; |
42 | | void* old_graph_exec_symbol_new_hook_context; |
43 | | ccv_array_t* all_tensor_symbols; |
44 | | } ccv_cnnp_gradient_checkpoint_build_t; |
45 | | |
46 | | static void _ccv_cnnp_gradient_checkpoint_tensor_symbol_new_hook(void* context, const ccv_nnc_tensor_symbol_t symbol, const ccv_nnc_tensor_param_t info, const char* const name) |
47 | 18 | { |
48 | 18 | ccv_cnnp_gradient_checkpoint_build_t* const build_context = (ccv_cnnp_gradient_checkpoint_build_t*)context; |
49 | 18 | if (build_context->tensor_context.record) |
50 | 18 | ccv_array_push(build_context->tensor_context.tensor_symbols, &symbol); |
51 | 18 | ccv_array_push(build_context->all_tensor_symbols, &symbol); |
52 | 18 | if (build_context->tensor_context.old_tensor_symbol_new_hook) |
53 | 18 | build_context->tensor_context.old_tensor_symbol_new_hook(build_context->tensor_context.old_tensor_symbol_new_hook_context, symbol, info, name); |
54 | 18 | } |
55 | | |
56 | | static void _ccv_cnnp_gradient_checkpoint_tensor_symbol_alias_new_hook(void* context, const ccv_nnc_tensor_symbol_t symbol, const ccv_nnc_tensor_symbol_t from_symbol, const int ofs[CCV_NNC_MAX_DIM_ALLOC], const int inc[CCV_NNC_MAX_DIM_ALLOC], const ccv_nnc_tensor_param_t info, const char* const name) |
57 | 0 | { |
58 | 0 | ccv_cnnp_gradient_checkpoint_build_t* const build_context = (ccv_cnnp_gradient_checkpoint_build_t*)context; |
59 | 0 | if (build_context->tensor_context.record) |
60 | 0 | ccv_array_push(build_context->tensor_context.tensor_symbols, &symbol); |
61 | 0 | ccv_array_push(build_context->all_tensor_symbols, &symbol); |
62 | 0 | if (build_context->tensor_context.old_tensor_symbol_alias_new_hook) |
63 | 0 | build_context->tensor_context.old_tensor_symbol_alias_new_hook(build_context->tensor_context.old_tensor_symbol_alias_new_hook_context, symbol, from_symbol, ofs, inc, info, name); |
64 | 0 | } |
65 | | |
66 | | static void _ccv_cnnp_model_gradient_checkpoint_graph_exec_symbol_new_hook(void* context, const ccv_nnc_graph_exec_symbol_t symbol, const ccv_nnc_cmd_t cmd, const ccv_nnc_tensor_symbol_t* const inputs, const int input_size, const ccv_nnc_tensor_symbol_t* const outputs, const int output_size, const char* const name) |
67 | 18 | { |
68 | 18 | ccv_cnnp_gradient_checkpoint_build_t* const build = (ccv_cnnp_gradient_checkpoint_build_t*)context; |
69 | 18 | ccv_array_push(build->graph_exec_symbols, &symbol); |
70 | 18 | if (build->old_graph_exec_symbol_new_hook) |
71 | 18 | build->old_graph_exec_symbol_new_hook(build->old_graph_exec_symbol_new_hook_context, symbol, cmd, inputs, input_size, outputs, output_size, name); |
72 | 18 | } |
73 | | |
74 | | KHASH_MAP_INIT_INT(ccv_cnnp_tensor_symbol_map, int) |
75 | | KHASH_SET_INIT_INT(ccv_cnnp_tensor_symbol_set) |
76 | | |
77 | | ccv_nnc_exec_dep_t _ccv_nnc_exec_dep_new(const ccv_nnc_symbolic_graph_t* const graph, const ccv_nnc_graph_visit_t* const visit, const int exec_rnum, uint32_t* const maskbit) |
78 | 2 | { |
79 | 2 | int* chain_ids = ccmalloc(sizeof(int) * exec_rnum * 2); |
80 | 2 | int* chain_pos = chain_ids + exec_rnum; |
81 | 2 | int* buf = (int*)ccmalloc(sizeof(int) * exec_rnum); |
82 | 2 | int* reversed_depth = buf; |
83 | 2 | const ccv_nnc_graph_exec_symbol_info_t* const exec_symbol_info = (ccv_nnc_graph_exec_symbol_info_t*)ccv_array_get(graph->exec_symbol_info, 0); |
84 | 2 | int i, j; |
85 | | // Go reverse order to generate the distance from sink. |
86 | 16 | ccv_nnc_graph_visit_for_reversed(visit, exec_symbol_info, node, idx, term) { |
87 | 16 | if (idx >= exec_rnum || CCV_NNC_GRAPH_EXEC_IS_DEAD(node->flags) || !(maskbit[idx >> 5] & (1u << (idx & 0x1f)))) |
88 | 0 | continue; |
89 | 16 | chain_ids[idx] = -1; |
90 | 16 | if (!node->outgoings || node->outgoings->rnum == 0) |
91 | 0 | { |
92 | 0 | reversed_depth[idx] = 0; |
93 | 0 | continue; |
94 | 0 | } |
95 | 16 | int depth = -1; |
96 | 39 | for (i = 0; i < node->outgoings->rnum; i++23 ) |
97 | 23 | { |
98 | 23 | const int outgoing = *(int*)ccv_array_get(node->outgoings, i); |
99 | 23 | if (outgoing >= exec_rnum) |
100 | 0 | continue; |
101 | 23 | depth = ccv_max(depth, reversed_depth[outgoing]); |
102 | 23 | } |
103 | 16 | reversed_depth[idx] = depth + 1; |
104 | 16 | } ccv_nnc_graph_visit_endfor |
105 | | // Go in order to generate chain ids (if there are multiple exits, we use the reverse depth to break the tie). |
106 | | // Note that we cannot use depth so-far because then multiple exit nodes are equally good to "inherit" the chain selection. |
107 | 2 | int chain_count = 0; |
108 | 16 | ccv_nnc_graph_visit_for(visit, exec_symbol_info, node, idx, term) { |
109 | 16 | if (idx >= exec_rnum || CCV_NNC_GRAPH_EXEC_IS_DEAD(node->flags) || !(maskbit[idx >> 5] & (1u << (idx & 0x1f)))) |
110 | 0 | continue; |
111 | 16 | int chain_id = chain_ids[idx]; |
112 | 16 | if (chain_ids[idx] < 0) |
113 | 4 | { |
114 | 4 | chain_id = chain_count; |
115 | 4 | chain_ids[idx] = chain_id; |
116 | 4 | chain_pos[idx] = 1; // The first one in this chain. 1-based index because in sparse matrix, 0 is the default value. |
117 | 4 | chain_count += 1; |
118 | 4 | } |
119 | 16 | if (!node->outgoings || node->outgoings->rnum == 0) |
120 | 0 | continue; |
121 | 16 | int depth = -1; |
122 | 16 | int next_idx = -1; |
123 | 39 | for (i = 0; i < node->outgoings->rnum; i++23 ) |
124 | 23 | { |
125 | 23 | const int outgoing = *(int*)ccv_array_get(node->outgoings, i); |
126 | 23 | if (outgoing >= exec_rnum) |
127 | 0 | continue; |
128 | 23 | if (chain_ids[outgoing] < 0 && reversed_depth[outgoing] > depth14 ) |
129 | 12 | depth = reversed_depth[outgoing], next_idx = outgoing; |
130 | 23 | } |
131 | 16 | if (next_idx >= 0) |
132 | 12 | { |
133 | 12 | chain_ids[next_idx] = chain_id; |
134 | 12 | assert(reversed_depth[idx] - depth >= 1); |
135 | 12 | chain_pos[next_idx] = chain_pos[idx] + (reversed_depth[idx] - depth); |
136 | 12 | } |
137 | 16 | } ccv_nnc_graph_visit_endfor |
138 | 2 | if (exec_rnum < chain_count * 2) // Be more conservative on RAM usage. |
139 | 0 | buf = ccrealloc(buf, sizeof(int) * chain_count * 2); |
140 | 2 | ccv_sparse_matrix_t* deps = ccv_sparse_matrix_new(graph->exec_symbol_info->rnum, chain_count, CCV_32S | CCV_C1, CCV_SPARSE_ROW_MAJOR, 0); |
141 | | // It logs which pos on that chain we depend on. We can simply compare that with the chain_pos for a node to know if they are ancestors. |
142 | 2 | #define for_block(x, val) \ |
143 | 13 | do { \ |
144 | 13 | if (((int32_t*)val)[0] > 0) \ |
145 | 13 | { \ |
146 | 13 | buf[buf_size * 2] = x; \ |
147 | 13 | buf[buf_size * 2 + 1] = ((int32_t*)val)[0]; \ |
148 | 13 | ++buf_size; \ |
149 | 13 | } \ |
150 | 13 | } while (0) |
151 | 2 | int buf_size; |
152 | 16 | ccv_nnc_graph_visit_for(visit, exec_symbol_info, node, idx, term) { |
153 | 16 | if (idx >= exec_rnum || CCV_NNC_GRAPH_EXEC_IS_DEAD(node->flags) || !(maskbit[idx >> 5] & (1u << (idx & 0x1f)))) |
154 | 0 | continue; |
155 | 16 | buf_size = 0; /* save all its parent deps to this buffer */ |
156 | 16 | ccv_sparse_matrix_vector_t* vector = ccv_get_sparse_matrix_vector(deps, idx); |
157 | 16 | if (vector) |
158 | 13 | CCV_SPARSE_VECTOR_FOREACH10 (deps, vector, for_block); |
159 | 16 | if (!node->outgoings) |
160 | 0 | continue; |
161 | 16 | const int chain_id = chain_ids[idx]; |
162 | 16 | const int pos = chain_pos[idx]; |
163 | 39 | for (i = 0; i < node->outgoings->rnum; i++23 ) |
164 | 23 | { |
165 | 23 | const int outgoing = *(int*)ccv_array_get(node->outgoings, i); |
166 | 23 | if (outgoing >= exec_rnum) |
167 | 0 | continue; |
168 | 23 | const int outgoing_chain_id = chain_ids[outgoing]; |
169 | 23 | if (outgoing_chain_id != chain_id) |
170 | 10 | { |
171 | 10 | ccv_numeric_data_t cell = ccv_get_sparse_matrix_cell(deps, outgoing, chain_id); |
172 | | /* If not found, set, if the current node is the destination node, no need |
173 | | * set itself as parent of subsequent nodes because its terminal nature. */ |
174 | 10 | if (!cell.i32 || cell.i32[0] == 01 || cell.i32[0] < pos1 ) |
175 | 10 | ccv_set_sparse_matrix_cell(deps, outgoing, chain_id, &pos); |
176 | 10 | } |
177 | 23 | if (buf_size > 0) |
178 | 13 | { |
179 | 13 | ccv_sparse_matrix_vector_t* vector = ccv_get_sparse_matrix_vector(deps, outgoing); |
180 | 30 | for (j = 0; j < buf_size; j++17 ) /* set with all idx's dependencies as well */ |
181 | 17 | { |
182 | 17 | if (outgoing_chain_id == buf[j * 2]) // We don't need to add as dependency for the same chain. |
183 | 1 | continue; |
184 | 16 | if (!vector) |
185 | 8 | { |
186 | 8 | ccv_set_sparse_matrix_cell(deps, outgoing, buf[j * 2], &buf[j * 2 + 1]); |
187 | 8 | vector = ccv_get_sparse_matrix_vector(deps, outgoing); |
188 | 8 | continue; |
189 | 8 | } |
190 | 8 | ccv_numeric_data_t cell = ccv_get_sparse_matrix_cell_from_vector(deps, vector, buf[j * 2]); |
191 | | /* If not found, set. Otherwise, set to the latest one only if it is later. */ |
192 | 8 | if (!cell.i32 || cell.i32[0] == 00 || cell.i32[0] <= buf[j * 2 + 1]0 ) |
193 | 8 | ccv_set_sparse_matrix_cell_from_vector(deps, vector, buf[j * 2], &buf[j * 2 + 1]); |
194 | 8 | } |
195 | 13 | } |
196 | 23 | } |
197 | 16 | } ccv_nnc_graph_visit_endfor |
198 | 2 | #undef for_block |
199 | 2 | ccfree(buf); |
200 | 2 | ccv_nnc_exec_dep_t exec_dep = { |
201 | 2 | .chain_ids = chain_ids, |
202 | 2 | .chain_pos = chain_pos, |
203 | 2 | .deps = deps |
204 | 2 | }; |
205 | 2 | return exec_dep; |
206 | 2 | } |
207 | | |
208 | | void ccv_cnnp_model_apply_gradient_checkpoints(ccv_cnnp_compiled_data_t* const compiled_data, ccv_nnc_symbolic_graph_t* const graph) |
209 | 2.24k | { |
210 | 2.24k | ccv_array_t* const gradient_checkpoints = compiled_data->gradient_checkpoints; |
211 | 2.24k | if (!gradient_checkpoints || gradient_checkpoints->rnum == 02 ) // No saved gradient checkpoints, this is an easy way out. |
212 | 2.23k | return; |
213 | | // Otherwise, for each gradient checkpoint, there are 3 steps: |
214 | | // 1. Find currently, what execs exists from inputs to outputs. |
215 | | // 2. Find execs that generates the outputs, and their corresponding backward execs. |
216 | | // 3. Find all backward execs flow from outputs back to inputs. |
217 | | // 4. Generate new ops by calling build again with old inputs, record all new tensors / execs. |
218 | | // 5. Replace inputs in backward execs with the new tensors. |
219 | | // 6. Hook the execs takes inputs with edge from parents of backward execs in step 2. |
220 | | // 7. Delete newly generated execs that has no use (i.e. its outputs are not used by backward pass). |
221 | | // 8. Mark all new execs with DISABLE_OPT to avoid common sub-expression elimination pass. |
222 | 2 | int i, j, k, l; |
223 | 2 | ccv_array_t* input_execs = ccv_array_new(sizeof(ccv_nnc_graph_exec_symbol_t), 0, 0); |
224 | 2 | ccv_array_t* output_execs = ccv_array_new(sizeof(ccv_nnc_graph_exec_symbol_t), 0, 0); |
225 | 2 | ccv_array_t* input_gradient_execs = ccv_array_new(sizeof(ccv_nnc_graph_exec_symbol_t), 0, 0); |
226 | 2 | ccv_array_t* output_gradient_execs = ccv_array_new(sizeof(ccv_nnc_graph_exec_symbol_t), 0, 0); |
227 | 2 | ccv_array_t* visited_backward_execs = ccv_array_new(sizeof(int), 0, 0); |
228 | 2 | ccv_array_t* replaced_backward_execs = ccv_array_new(sizeof(int), 0, 0); |
229 | 2 | const int exec_rnum = graph->exec_symbol_info->rnum; |
230 | 2 | ccv_nnc_graph_exec_symbol_reverse_t* const reversed_nodes = cccalloc(exec_rnum, sizeof(ccv_nnc_graph_exec_symbol_reverse_t)); |
231 | 48 | for (i = 0; i < exec_rnum; i++46 ) |
232 | 46 | { |
233 | 46 | const int* tos = 0; |
234 | 46 | int to_size = 0; |
235 | 46 | ccv_nnc_graph_exec_symbol_to(graph, (ccv_nnc_graph_exec_symbol_t){ |
236 | 46 | .graph = graph, |
237 | 46 | .d = i |
238 | 46 | }, &tos, &to_size); |
239 | 46 | if (tos) |
240 | 99 | for (j = 0; 38 j < to_size; j++61 ) |
241 | 61 | { |
242 | 61 | if (!reversed_nodes[tos[j]].outgoings) |
243 | 40 | reversed_nodes[tos[j]].outgoings = ccv_array_new(sizeof(int), 1, 0); |
244 | 61 | ccv_array_add_unique_int(reversed_nodes[tos[j]].outgoings, i); |
245 | 61 | } |
246 | 46 | } |
247 | 2 | uint32_t* const maskbit = cccalloc((exec_rnum + 31) >> 5, sizeof(uint32_t)); |
248 | | // Temporary for build_data. |
249 | 2 | ccv_array_t* const parameters = ccv_array_new(sizeof(ccv_nnc_tensor_symbol_t), 0, 0); |
250 | 2 | ccv_array_t* const parameter_ids = ccv_array_new(sizeof(char*), 0, 0); |
251 | 2 | ccv_array_t* const parameter_trainables = ccv_array_new(sizeof(int), 0, 0); |
252 | 2 | ccv_array_t* const internals = ccv_array_new(sizeof(ccv_nnc_tensor_symbol_t), 0, 0); |
253 | 2 | ccv_array_t* const internal_ids = ccv_array_new(sizeof(char*), 0, 0); |
254 | 2 | int max_output_size = 0; |
255 | 4 | for (i = 0; i < gradient_checkpoints->rnum; i++2 ) |
256 | 2 | { |
257 | 2 | ccv_cnnp_model_gradient_checkpoint_t* const checkpoint = (ccv_cnnp_model_gradient_checkpoint_t*)ccv_array_get(gradient_checkpoints, i); |
258 | 2 | max_output_size = ccv_max(checkpoint->output_size, max_output_size); |
259 | 2 | } |
260 | 2 | ccv_nnc_tensor_symbol_t* max_outputs = ccmalloc(sizeof(ccv_nnc_tensor_symbol_t) * max_output_size); |
261 | 2 | ccv_array_t* newly_used_outputs = ccv_array_new(sizeof(int), 0, 0); |
262 | 2 | khash_t(ccv_cnnp_tensor_symbol_set)* const parameters_or_internals = kh_init(ccv_cnnp_tensor_symbol_set); |
263 | 14 | for (i = 0; i < compiled_data->parameters->rnum; i++12 ) |
264 | 12 | { |
265 | 12 | const ccv_nnc_tensor_symbol_t* const symbol = (const ccv_nnc_tensor_symbol_t*)ccv_array_get(compiled_data->parameters, i); |
266 | 12 | int ret; |
267 | 12 | kh_put(ccv_cnnp_tensor_symbol_set, parameters_or_internals, symbol->d, &ret); |
268 | 12 | } |
269 | 2 | for (i = 0; i < compiled_data->internals->rnum; i++0 ) |
270 | 0 | { |
271 | 0 | const ccv_nnc_tensor_symbol_t* const symbol = (const ccv_nnc_tensor_symbol_t*)ccv_array_get(compiled_data->parameters, i); |
272 | 0 | int ret; |
273 | 0 | kh_put(ccv_cnnp_tensor_symbol_set, parameters_or_internals, symbol->d, &ret); |
274 | 0 | } |
275 | 2 | khash_t(ccv_cnnp_tensor_symbol_set)* const newly_created_tensor_symbols = kh_init(ccv_cnnp_tensor_symbol_set); |
276 | 2 | khash_t(ccv_cnnp_tensor_symbol_map)* symbol_map = kh_init(ccv_cnnp_tensor_symbol_map); |
277 | 4 | for (i = 0; i < gradient_checkpoints->rnum; i++2 ) |
278 | 2 | { |
279 | 2 | ccv_cnnp_model_gradient_checkpoint_t* const checkpoint = (ccv_cnnp_model_gradient_checkpoint_t*)ccv_array_get(gradient_checkpoints, i); |
280 | 2 | kh_clear(ccv_cnnp_tensor_symbol_set, newly_created_tensor_symbols); |
281 | 32 | for (j = 0; j < checkpoint->tensor_symbols->rnum; j++30 ) |
282 | 30 | { |
283 | 30 | const int idx = ((ccv_nnc_tensor_symbol_t*)ccv_array_get(checkpoint->tensor_symbols, j))->d; |
284 | 30 | if (idx < 0) |
285 | 0 | continue; |
286 | | // Skip parameters or internals. |
287 | 30 | if (kh_get(ccv_cnnp_tensor_symbol_set, parameters_or_internals, idx) != kh_end(parameters_or_internals)) |
288 | 12 | continue; |
289 | 18 | int ret; |
290 | 18 | kh_put(ccv_cnnp_tensor_symbol_set, newly_created_tensor_symbols, idx, &ret); |
291 | 18 | } |
292 | 2 | ccv_array_clear(input_execs); |
293 | 2 | ccv_array_clear(output_execs); |
294 | 2 | ccv_nnc_graph_exec_symbol_info_t* exec_info = (ccv_nnc_graph_exec_symbol_info_t*)ccv_array_get(graph->exec_symbol_info, 0); |
295 | 48 | for (j = 0; j < exec_rnum; j++46 ) |
296 | 46 | { |
297 | 46 | if (CCV_NNC_GRAPH_EXEC_IS_DEAD(exec_info[j].flags)) |
298 | 0 | continue; |
299 | 46 | const int* inputs = exec_info[j].inputs; |
300 | 46 | int input_size = exec_info[j].input_size; |
301 | 46 | const int* outputs = exec_info[j].outputs; |
302 | 46 | int output_size = exec_info[j].output_size; |
303 | 46 | if (input_size == 0 && output_size == 00 ) |
304 | 0 | continue; |
305 | | // Only go through forward pass. |
306 | 46 | if (ccv_nnc_cmd_is_backward(exec_info[j].cmd)) |
307 | 17 | continue; |
308 | 29 | const ccv_nnc_graph_exec_symbol_t symbol = { |
309 | 29 | .graph = graph, |
310 | 29 | .d = j |
311 | 29 | }; |
312 | 29 | int flag = 0; |
313 | 88 | for (k = 0; inputs && k < input_size && !flag65 ; k++59 ) |
314 | 59 | if (inputs[k] >= 0) |
315 | 118 | for (l = 0; 59 l < checkpoint->input_size && !flag59 ; l++59 ) |
316 | 59 | if (checkpoint->inputs[l].d >= 0 && inputs[k] == checkpoint->inputs[l].d) |
317 | 6 | flag = 1; |
318 | | // Input logic is different from output logic. We need to filter out these exec that contains inputs from within the graph. |
319 | 41 | for (k = 0; inputs && k < input_size && flag35 ; k++12 ) |
320 | 12 | if (inputs[k] >= 0 && kh_get(ccv_cnnp_tensor_symbol_set, newly_created_tensor_symbols, inputs[k]) != kh_end(newly_created_tensor_symbols)) |
321 | 0 | flag = 0; |
322 | 29 | if (flag) |
323 | 6 | ccv_array_push(input_execs, &symbol); |
324 | 29 | flag = 0; |
325 | 66 | for (k = 0; outputs && k < output_size && !flag37 ; k++37 ) |
326 | 37 | if (outputs[k] >= 0) |
327 | 74 | for (l = 0; 37 l < checkpoint->output_size && !flag37 ; l++37 ) |
328 | 37 | if (checkpoint->outputs[l].d >= 0 && outputs[k] == checkpoint->outputs[l].d) |
329 | 2 | flag = 1; |
330 | 29 | if (flag) |
331 | 2 | ccv_array_push(output_execs, &symbol); |
332 | 29 | } |
333 | 2 | if (input_execs->rnum <= 0 || output_execs->rnum <= 0) |
334 | 0 | continue; |
335 | | // Fill in blanks (i.e. the backward ops that are not showing in above, but should be included to avoid excluding necessary ones). This is done by flowing gradients from outputs back all the way to inputs. |
336 | 2 | ccv_array_clear(input_gradient_execs); |
337 | 2 | ccv_array_clear(output_gradient_execs); |
338 | 8 | for (j = 0; j < input_execs->rnum; j++6 ) |
339 | 6 | { |
340 | 6 | const int d = ((ccv_nnc_graph_exec_symbol_t*)ccv_array_get(input_execs, j))->d; |
341 | 18 | for (k = 0; k < exec_info[d].input_size; k++12 ) |
342 | 12 | if (exec_info[d].inputs[k] >= 0) |
343 | 12 | { |
344 | 12 | const ccv_nnc_tensor_symbol_t gradient_symbol = ccv_nnc_tensor_symbol_for_backward(graph, (ccv_nnc_tensor_symbol_t){ |
345 | 12 | .graph = graph, |
346 | 12 | .d = exec_info[d].inputs[k] |
347 | 12 | }); |
348 | 12 | if (gradient_symbol.d < 0) |
349 | 9 | continue; |
350 | 3 | const ccv_nnc_graph_exec_symbol_t backward = ccv_nnc_graph_exec_symbol_for_backward(graph, gradient_symbol); |
351 | 3 | if (backward.d < 0) |
352 | 0 | continue; |
353 | 3 | if (CCV_NNC_GRAPH_EXEC_IS_DEAD(exec_info[backward.d].flags)) |
354 | 0 | continue; |
355 | 3 | int flag = 0; |
356 | 4 | for (l = 0; !flag && l < output_gradient_execs->rnum; l++1 ) |
357 | 1 | if (((ccv_nnc_graph_exec_symbol_t*)ccv_array_get(output_gradient_execs, l))->d == backward.d) |
358 | 0 | flag = 1; |
359 | 3 | if (!flag) |
360 | 3 | ccv_array_push(output_gradient_execs, &backward); |
361 | 3 | } |
362 | 6 | if (exec_info[d].outgoings && exec_info[d].outgoings->rnum > 0) |
363 | 15 | for (k = 0; 6 k < exec_info[d].outgoings->rnum; k++9 ) |
364 | 9 | { |
365 | 9 | const int to_d = *(int*)ccv_array_get(exec_info[d].outgoings, k); |
366 | 9 | if (!ccv_nnc_cmd_is_backward(exec_info[to_d].cmd)) |
367 | 6 | continue; |
368 | 3 | int flag = 0; |
369 | 7 | for (l = 0; !flag && l < output_gradient_execs->rnum4 ; l++4 ) |
370 | 4 | if (((ccv_nnc_graph_exec_symbol_t*)ccv_array_get(output_gradient_execs, l))->d == to_d) |
371 | 3 | flag = 1; |
372 | 3 | if (!flag) |
373 | 0 | { |
374 | 0 | const ccv_nnc_graph_exec_symbol_t backward = { |
375 | 0 | .graph = graph, |
376 | 0 | .d = to_d |
377 | 0 | }; |
378 | 0 | ccv_array_push(output_gradient_execs, &backward); |
379 | 0 | } |
380 | 3 | } |
381 | 6 | } |
382 | | // For output_gradient_execs, we can be opportunistic and use the wrt symbols (if exists) to find relevant bits. |
383 | | // For input_gradient_execs, there is no other way but to loop over all outgoings, find the ones are direct link as backward execs. |
384 | 4 | for (j = 0; j < output_execs->rnum; j++2 ) |
385 | 2 | { |
386 | 2 | const int d = ((ccv_nnc_graph_exec_symbol_t*)ccv_array_get(output_execs, j))->d; |
387 | 2 | if (exec_info[d].outgoings && exec_info[d].outgoings->rnum > 0) |
388 | 6 | for (k = 0; 2 k < exec_info[d].outgoings->rnum; k++4 ) |
389 | 4 | { |
390 | 4 | const int to_d = *(int*)ccv_array_get(exec_info[d].outgoings, k); |
391 | 4 | if (!ccv_nnc_cmd_is_backward(exec_info[to_d].cmd)) |
392 | 2 | continue; |
393 | 2 | int flag = 0; |
394 | 2 | for (l = 0; !flag && l < input_gradient_execs->rnum; l++0 ) |
395 | 0 | if (((ccv_nnc_graph_exec_symbol_t*)ccv_array_get(input_gradient_execs, l))->d == to_d) |
396 | 0 | flag = 1; |
397 | 2 | if (!flag) |
398 | 2 | { |
399 | 2 | const ccv_nnc_graph_exec_symbol_t backward = { |
400 | 2 | .graph = graph, |
401 | 2 | .d = to_d |
402 | 2 | }; |
403 | 2 | ccv_array_push(input_gradient_execs, &backward); |
404 | 2 | } |
405 | 2 | } |
406 | 2 | } |
407 | | // Note that we have to use up-to-date ones because the exec_info might have outgoings that is up-to-date. |
408 | 4 | ccv_nnc_graph_visit_t* const visit = ccv_nnc_graph_visit_new2 (graph, exec_info, graph->exec_symbol_info->rnum, (ccv_nnc_graph_exec_symbol_t*)ccv_array_get(input_gradient_execs, 0), input_gradient_execs->rnum, (ccv_nnc_graph_exec_symbol_t*)ccv_array_get(output_gradient_execs, 0), output_gradient_execs->rnum, 1); |
409 | 16 | ccv_nnc_graph_visit_for(visit, exec_info, node, idx) { |
410 | 16 | if (idx < exec_rnum && !CCV_NNC_GRAPH_EXEC_IS_DEAD(node->flags)) |
411 | 16 | maskbit[idx >> 5] |= (1u << (idx & 0x1f)); |
412 | 16 | } ccv_nnc_graph_visit_endfor |
413 | 4 | ccv_array_clear(visited_backward_execs); |
414 | | // Add more backward pass to the list. Note that we don't add everything, particularly there are new nodes created through gradient checkpointing are ignored. |
415 | 4 | #define visitor(node, idx, _) \ |
416 | 16 | if (idx < exec_rnum && !CCV_NNC_GRAPH_EXEC_IS_DEAD(exec_info[idx].flags) && maskbit[idx >> 5] & (1u << (idx & 0x1f))) \ |
417 | 16 | ccv_array_add_unique_int(visited_backward_execs, idx); |
418 | 16 | CCV_NNC_GRAPH_VISIT2 (graph, reversed_nodes, exec_rnum, (ccv_nnc_graph_exec_symbol_t*)ccv_array_get(output_gradient_execs, 0), output_gradient_execs->rnum, (ccv_nnc_graph_exec_symbol_t*)ccv_array_get(input_gradient_execs, 0), input_gradient_execs->rnum, 0, visitor); |
419 | 4 | for (j = 0; j < input_gradient_execs->rnum; j++2 ) |
420 | 2 | ccv_array_add_unique_int(visited_backward_execs, ((ccv_nnc_graph_exec_symbol_t*)ccv_array_get(input_gradient_execs, j))->d); |
421 | 2 | #undef visitor |
422 | 2 | ccv_cnnp_gradient_checkpoint_build_t build = { |
423 | 2 | .tensor_context = { |
424 | 2 | .record = 1, |
425 | 2 | .tensor_symbols = ccv_array_new(sizeof(ccv_nnc_tensor_symbol_t), 0, 0), |
426 | 2 | }, |
427 | 2 | .graph_exec_symbols = ccv_array_new(sizeof(ccv_nnc_graph_exec_symbol_t), 0, 0), |
428 | 2 | .all_tensor_symbols = ccv_array_new(sizeof(ccv_nnc_tensor_symbol_t), 0, 0), |
429 | 2 | }; |
430 | 2 | build.tensor_context.old_tensor_symbol_new_hook_context = ccv_nnc_tensor_symbol_new_hook(graph, _ccv_cnnp_gradient_checkpoint_tensor_symbol_new_hook, &build, &build.tensor_context.old_tensor_symbol_new_hook); |
431 | 2 | build.tensor_context.old_tensor_symbol_alias_new_hook_context = ccv_nnc_tensor_symbol_alias_new_hook(graph, _ccv_cnnp_gradient_checkpoint_tensor_symbol_alias_new_hook, &build, &build.tensor_context.old_tensor_symbol_alias_new_hook); |
432 | 2 | build.old_graph_exec_symbol_new_hook_context = ccv_nnc_graph_exec_symbol_new_hook(graph, _ccv_cnnp_model_gradient_checkpoint_graph_exec_symbol_new_hook, &build, &build.old_graph_exec_symbol_new_hook); |
433 | 2 | ccv_array_clear(parameters); |
434 | 2 | ccv_array_clear(parameter_ids); |
435 | 2 | ccv_array_clear(parameter_trainables); |
436 | 2 | ccv_array_clear(internals); |
437 | 2 | ccv_array_clear(internal_ids); |
438 | 2 | ccv_cnnp_model_sequence_t model_sequence = { |
439 | 2 | .bank = kh_init(ccv_cnnp_model_name_bank) |
440 | 2 | }; |
441 | 2 | ccv_cnnp_model_add_to_array_context_t add_to_parameter_context = { |
442 | 2 | .add_parameter_indices = 0, |
443 | 2 | .prefix = 't', |
444 | 2 | .sequence = &model_sequence, |
445 | 2 | .symbols = parameters, |
446 | 2 | .ids = parameter_ids, |
447 | 2 | .trainables = parameter_trainables, |
448 | 2 | }; |
449 | 2 | ccv_cnnp_model_add_to_array_context_t add_to_output_context = { |
450 | 2 | .add_parameter_indices = 0, |
451 | 2 | .prefix = 'r', |
452 | 2 | .sequence = &model_sequence, |
453 | 2 | .symbols = internals, |
454 | 2 | .ids = internal_ids, |
455 | 2 | .trainables = 0, |
456 | 2 | }; |
457 | 2 | ccv_cnnp_model_build_data_t build_data = { |
458 | 2 | .is_trainable = checkpoint->is_trainable, |
459 | 2 | .model_sequence = &model_sequence, |
460 | 2 | .add_to_array = ccv_cnnp_model_add_to_array, |
461 | 2 | .parameters = parameters, |
462 | 2 | .context = { |
463 | 2 | .add_to_parameter = &add_to_parameter_context, |
464 | 2 | .add_to_output = &add_to_output_context, |
465 | 2 | }, |
466 | 2 | .is_gradient_checkpointing = 1, // Mark this as true so we don't allocate gradient_checkpoints array or override the hooks. |
467 | 2 | .gradient_checkpoints = 0, |
468 | 2 | }; |
469 | 2 | checkpoint->model->data = &build_data; |
470 | 2 | checkpoint->build(checkpoint->model, graph, checkpoint->inputs, checkpoint->input_size, max_outputs, checkpoint->output_size); |
471 | 2 | checkpoint->model->data = 0; |
472 | 2 | kh_destroy(ccv_cnnp_model_name_bank, model_sequence.bank); |
473 | 2 | if (model_sequence.sequences) |
474 | 2 | ccv_array_free(model_sequence.sequences); |
475 | 2 | ccv_nnc_tensor_symbol_new_hook(graph, build.tensor_context.old_tensor_symbol_new_hook, build.tensor_context.old_tensor_symbol_new_hook_context, 0); |
476 | 2 | ccv_nnc_tensor_symbol_alias_new_hook(graph, build.tensor_context.old_tensor_symbol_alias_new_hook, build.tensor_context.old_tensor_symbol_alias_new_hook_context, 0); |
477 | 2 | ccv_nnc_graph_exec_symbol_autogen(graph, (ccv_nnc_graph_exec_symbol_t*)ccv_array_get(build.graph_exec_symbols, 0), build.graph_exec_symbols->rnum, 0); |
478 | 14 | for (j = 0; j < parameter_ids->rnum; j++12 ) |
479 | 12 | ccfree(*(char**)ccv_array_get(parameter_ids, j)); |
480 | 2 | for (j = 0; j < internal_ids->rnum; j++0 ) |
481 | 0 | ccfree(*(char**)ccv_array_get(internal_ids, j)); |
482 | | // Note that there is no graph optimization applied here. |
483 | 2 | exec_info = (ccv_nnc_graph_exec_symbol_info_t*)ccv_array_get(graph->exec_symbol_info, 0); |
484 | | // Reuse existing one. |
485 | 2 | kh_clear(ccv_cnnp_tensor_symbol_set, newly_created_tensor_symbols); |
486 | 20 | for (j = 0; j < build.tensor_context.tensor_symbols->rnum; j++18 ) |
487 | 18 | { |
488 | 18 | const int idx = ((ccv_nnc_tensor_symbol_t*)ccv_array_get(build.tensor_context.tensor_symbols, j))->d; |
489 | 18 | if (idx < 0) |
490 | 0 | continue; |
491 | 18 | if (kh_get(ccv_cnnp_tensor_symbol_set, parameters_or_internals, idx) != kh_end(parameters_or_internals)) |
492 | 0 | continue; |
493 | 18 | int ret; |
494 | 18 | kh_put(ccv_cnnp_tensor_symbol_set, newly_created_tensor_symbols, idx, &ret); |
495 | 18 | } |
496 | 2 | ccv_array_t* const newly_input_execs = input_execs; |
497 | 2 | ccv_array_t* const newly_output_execs = output_execs; |
498 | 2 | ccv_array_clear(newly_input_execs); |
499 | 2 | ccv_array_clear(newly_output_execs); |
500 | 20 | for (j = 0; j < build.graph_exec_symbols->rnum; j++18 ) |
501 | 18 | { |
502 | 18 | const int idx = ((ccv_nnc_graph_exec_symbol_t*)ccv_array_get(build.graph_exec_symbols, j))->d; |
503 | 18 | if (idx < 0) |
504 | 0 | continue; |
505 | 18 | if (CCV_NNC_GRAPH_EXEC_IS_DEAD(exec_info[idx].flags)) |
506 | 0 | continue; |
507 | 18 | const ccv_nnc_graph_exec_symbol_t symbol = { |
508 | 18 | .graph = graph, |
509 | 18 | .d = idx |
510 | 18 | }; |
511 | 18 | const int* inputs = exec_info[idx].inputs; |
512 | 18 | int input_size = exec_info[idx].input_size; |
513 | | // Only go through forward pass. |
514 | 18 | assert(!ccv_nnc_cmd_is_backward(exec_info[idx].cmd)); |
515 | 18 | int flag = 0; |
516 | 47 | for (k = 0; inputs && k < input_size && !flag35 ; k++29 ) |
517 | 29 | if (inputs[k] >= 0) |
518 | 58 | for (l = 0; 29 l < checkpoint->input_size && !flag29 ; l++29 ) |
519 | 29 | if (checkpoint->inputs[l].d >= 0 && inputs[k] == checkpoint->inputs[l].d) |
520 | 6 | flag = 1; |
521 | | // Input logic is different from output logic. We need to filter out these exec that contains inputs from within the graph. |
522 | 30 | for (k = 0; inputs && k < input_size && flag24 ; k++12 ) |
523 | 12 | if (inputs[k] >= 0 && kh_get(ccv_cnnp_tensor_symbol_set, newly_created_tensor_symbols, inputs[k]) != kh_end(newly_created_tensor_symbols)) |
524 | 0 | flag = 0; |
525 | 18 | if (flag) |
526 | 6 | ccv_array_push(newly_input_execs, &symbol); |
527 | 18 | flag = 0; |
528 | 18 | const int* outputs = exec_info[idx].outputs; |
529 | 18 | int output_size = exec_info[idx].output_size; |
530 | 36 | for (k = 0; inputs && k < output_size && !flag18 ; k++18 ) |
531 | 18 | if (outputs[k] >= 0) |
532 | 36 | for (l = 0; 18 l < checkpoint->output_size && !flag18 ; l++18 ) |
533 | 18 | if (max_outputs[l].d >= 0 && outputs[k] == max_outputs[l].d) |
534 | 2 | flag = 1; |
535 | 18 | if (flag) |
536 | 2 | ccv_array_push(newly_output_execs, &symbol); |
537 | 18 | } |
538 | 4 | for (j = 0; 2 j < checkpoint->input_size; j++2 ) |
539 | 2 | if (checkpoint->inputs[j].d >= 0) |
540 | 2 | ccv_array_push(parameters, checkpoint->inputs + j); |
541 | 2 | ccv_nnc_symbolic_graph_simplify(graph, |
542 | 2 | SYMBOLIC_GRAPH_PASSES(CCV_NNC_SIMPLIFY_COMMON_SUBEXPRESSION_ELIMINATION, |
543 | 2 | CCV_NNC_SIMPLIFY_DATA_TRANSFER_OPT, |
544 | 2 | CCV_NNC_SIMPLIFY_OPS_FUSION), |
545 | 2 | ccv_array_get(parameters, 0), parameters->rnum, |
546 | 2 | max_outputs, checkpoint->output_size, |
547 | 2 | ccv_array_get(newly_input_execs, 0), newly_input_execs->rnum, ccv_array_get(newly_output_execs, 0), newly_output_execs->rnum); |
548 | 2 | ccv_nnc_graph_exec_symbol_new_hook(graph, build.old_graph_exec_symbol_new_hook, build.old_graph_exec_symbol_new_hook_context, 0); |
549 | | // Need to autogen and redo source / destination. |
550 | 2 | ccv_nnc_graph_exec_symbol_autogen(graph, (ccv_nnc_graph_exec_symbol_t*)ccv_array_get(build.graph_exec_symbols, 0), build.graph_exec_symbols->rnum, 0); |
551 | 2 | ccv_nnc_tensor_symbol_info_t* const tensor_symbol_info = (ccv_nnc_tensor_symbol_info_t*)ccv_array_get(graph->tensor_symbol_info, 0); |
552 | 2 | exec_info = (ccv_nnc_graph_exec_symbol_info_t*)ccv_array_get(graph->exec_symbol_info, 0); |
553 | 2 | ccv_array_clear(newly_input_execs); |
554 | 20 | for (j = 0; j < build.graph_exec_symbols->rnum; j++18 ) |
555 | 18 | { |
556 | 18 | const int idx = ((ccv_nnc_graph_exec_symbol_t*)ccv_array_get(build.graph_exec_symbols, j))->d; |
557 | 18 | if (idx < 0) |
558 | 0 | continue; |
559 | 18 | if (CCV_NNC_GRAPH_EXEC_IS_DEAD(exec_info[idx].flags)) |
560 | 0 | continue; |
561 | 18 | const ccv_nnc_graph_exec_symbol_t symbol = { |
562 | 18 | .graph = graph, |
563 | 18 | .d = idx |
564 | 18 | }; |
565 | 18 | const int* inputs = exec_info[idx].inputs; |
566 | 18 | int input_size = exec_info[idx].input_size; |
567 | | // Only go through forward pass. |
568 | 18 | assert(!ccv_nnc_cmd_is_backward(exec_info[idx].cmd)); |
569 | 18 | int flag = 0; |
570 | 47 | for (k = 0; inputs && k < input_size && !flag35 ; k++29 ) |
571 | 29 | if (inputs[k] >= 0) |
572 | 58 | for (l = 0; 29 l < checkpoint->input_size && !flag29 ; l++29 ) |
573 | 29 | if (checkpoint->inputs[l].d >= 0 && inputs[k] == checkpoint->inputs[l].d) |
574 | 6 | flag = 1; |
575 | 30 | for (k = 0; inputs && k < input_size && flag24 ; k++12 ) |
576 | 12 | if (inputs[k] >= 0 && kh_get(ccv_cnnp_tensor_symbol_set, newly_created_tensor_symbols, inputs[k]) != kh_end(newly_created_tensor_symbols)) |
577 | 0 | flag = 0; |
578 | 18 | if (flag) |
579 | 6 | ccv_array_push(newly_input_execs, &symbol); |
580 | 18 | } |
581 | | // Build a map between old tensor symbols and new tensor symbols. |
582 | 2 | assert(build.tensor_context.tensor_symbols->rnum <= checkpoint->tensor_symbols->rnum); |
583 | | // Build a map to potentially map from old input to new input. |
584 | 2 | kh_clear(ccv_cnnp_tensor_symbol_map, symbol_map); |
585 | 32 | for (j = 0, k = 0; j < build.tensor_context.tensor_symbols->rnum && k < checkpoint->tensor_symbols->rnum30 ;) |
586 | 30 | { |
587 | 30 | const int from_d = ((ccv_nnc_tensor_symbol_t*)ccv_array_get(checkpoint->tensor_symbols, k))->d; |
588 | 30 | if (from_d < 0) // This is removed, move to the next one. |
589 | 0 | { |
590 | 0 | ++j; |
591 | 0 | ++k; |
592 | 0 | continue; |
593 | 0 | } |
594 | 30 | const int to_d = ((ccv_nnc_tensor_symbol_t*)ccv_array_get(build.tensor_context.tensor_symbols, j))->d; |
595 | 30 | assert(to_d >= 0); |
596 | 30 | int from_flag = kh_get(ccv_cnnp_tensor_symbol_set, parameters_or_internals, from_d) != kh_end(parameters_or_internals); |
597 | 30 | int to_flag = kh_get(ccv_cnnp_tensor_symbol_set, parameters_or_internals, to_d) != kh_end(parameters_or_internals); |
598 | 30 | if (from_flag) |
599 | 12 | ++k; |
600 | 30 | if (to_flag) |
601 | 0 | ++j; |
602 | 30 | if (from_flag || to_flag18 ) |
603 | 12 | continue; |
604 | 18 | ++k; |
605 | 18 | ++j; |
606 | | // Skip if from_d is outputs. |
607 | 36 | for (l = 0; l < !from_flag && checkpoint->output_size18 ; l++18 ) |
608 | 18 | if (checkpoint->outputs[l].d == from_d) |
609 | 2 | from_flag = 1; |
610 | 18 | if (from_flag) |
611 | 2 | continue; |
612 | | // Skip if to_d is outputs. |
613 | 32 | for (l = 0; 16 l < !to_flag && checkpoint->output_size16 ; l++16 ) |
614 | 16 | if (checkpoint->outputs[l].d == to_d) |
615 | 0 | to_flag = 1; |
616 | 16 | if (to_flag) |
617 | 0 | continue; |
618 | 16 | int ret = 0; |
619 | 16 | khiter_t h = kh_put(ccv_cnnp_tensor_symbol_map, symbol_map, from_d, &ret); |
620 | 16 | kh_val(symbol_map, h) = to_d; |
621 | 16 | } |
622 | | // Now go over all backward passes to replace inputs with the ones from symbol map. Record these that are used. |
623 | 2 | ccv_array_clear(newly_used_outputs); |
624 | 2 | ccv_array_clear(replaced_backward_execs); |
625 | 18 | for (j = 0; j < visited_backward_execs->rnum; j++16 ) |
626 | 16 | { |
627 | 16 | const int idx = *(int*)ccv_array_get(visited_backward_execs, j); |
628 | 16 | if (CCV_NNC_GRAPH_EXEC_IS_DEAD(exec_info[idx].flags)) |
629 | 0 | continue; |
630 | 16 | assert(idx >= 0); |
631 | 16 | assert(idx < exec_rnum); |
632 | 16 | if (!ccv_nnc_cmd_is_backward(exec_info[idx].cmd)) |
633 | 1 | continue; |
634 | 74 | for (k = 0; 15 k < exec_info[idx].input_size; k++59 ) |
635 | 59 | if (exec_info[idx].inputs[k] >= 0) |
636 | 32 | { |
637 | 32 | const khiter_t h = kh_get(ccv_cnnp_tensor_symbol_map, symbol_map, exec_info[idx].inputs[k]); |
638 | 32 | if (h != kh_end(symbol_map)) // Replacing it. |
639 | 8 | { |
640 | 8 | int newly_created_output = kh_val(symbol_map, h); |
641 | 8 | exec_info[idx].inputs[k] = newly_created_output; |
642 | 8 | ccv_array_add_unique_int(newly_used_outputs, newly_created_output); |
643 | 8 | if (tensor_symbol_info[newly_created_output].alias_ref > 0) |
644 | 0 | { |
645 | 0 | newly_created_output = tensor_symbol_info[newly_created_output].alias_ref - 1; |
646 | 0 | ccv_array_add_unique_int(newly_used_outputs, newly_created_output); |
647 | 0 | } |
648 | 8 | ccv_array_add_unique_int(replaced_backward_execs, idx); |
649 | 8 | } |
650 | 32 | } |
651 | 15 | } |
652 | 20 | for (j = 0; 2 j < build.graph_exec_symbols->rnum; j++18 ) |
653 | 18 | { |
654 | 18 | ccv_nnc_graph_exec_symbol_t* const symbol = (ccv_nnc_graph_exec_symbol_t*)ccv_array_get(build.graph_exec_symbols, j); |
655 | 18 | if (symbol->d < 0) |
656 | 0 | continue; |
657 | 18 | if (CCV_NNC_GRAPH_EXEC_IS_DEAD(exec_info[symbol->d].flags)) |
658 | 0 | continue; |
659 | 18 | int x, y; |
660 | 106 | for (k = 0; k < replaced_backward_execs->rnum; k++88 ) |
661 | 88 | { |
662 | 88 | const int idx = *(int*)ccv_array_get(replaced_backward_execs, k); |
663 | 88 | assert(idx >= 0); |
664 | 88 | assert(idx < exec_rnum); |
665 | 88 | assert(ccv_nnc_cmd_is_backward(exec_info[idx].cmd)); |
666 | 88 | int flag = 0; |
667 | 412 | for (x = 0; !flag && x < exec_info[idx].input_size404 ; x++324 ) |
668 | 324 | { |
669 | 324 | int x_d = exec_info[idx].inputs[x]; |
670 | 324 | if (x_d < 0) |
671 | 80 | continue; |
672 | 244 | if (tensor_symbol_info[x_d].alias_ref > 0) |
673 | 0 | x_d = tensor_symbol_info[x_d].alias_ref - 1; |
674 | 488 | for (y = 0; !flag && y < exec_info[symbol->d].output_size480 ; y++244 ) |
675 | 244 | { |
676 | 244 | int y_d = exec_info[symbol->d].outputs[y]; |
677 | 244 | if (y_d < 0) |
678 | 0 | continue; |
679 | 244 | if (tensor_symbol_info[y_d].alias_ref > 0) |
680 | 0 | y_d = tensor_symbol_info[y_d].alias_ref - 1; |
681 | 244 | if (x_d == y_d) |
682 | 8 | flag = 1; |
683 | 244 | } |
684 | 244 | } |
685 | 88 | if (flag) |
686 | 8 | ccv_nnc_graph_exec_symbol_concat(graph, *symbol, (ccv_nnc_graph_exec_symbol_t){ |
687 | 8 | .graph = graph, |
688 | 8 | .d = idx |
689 | 8 | }); |
690 | 88 | } |
691 | 18 | } |
692 | | // Find parents to visited_backward_execs, and use that as the starting point of all newly added graph_exec_symbols. Use the visited backward execs as the source, use all its parents as destination, go through with graph visit. |
693 | 2 | ccv_nnc_exec_dep_t exec_dep = _ccv_nnc_exec_dep_new(graph, visit, exec_rnum, maskbit); |
694 | | // Now go from outputs to inputs, unmark visited ones. |
695 | 16 | ccv_nnc_graph_visit_for(visit, exec_info, node, idx) { |
696 | 16 | if (idx < exec_rnum) |
697 | 16 | maskbit[idx >> 5] &= ~(1u << (idx & 0x1f)); |
698 | 16 | } ccv_nnc_graph_visit_endfor |
699 | 2 | ccv_nnc_graph_visit_free(visit); |
700 | | // Go through visited backward execs, remove the ones that has no dependency on any replaced backward execs. |
701 | 18 | for (j = 0; j < visited_backward_execs->rnum;) |
702 | 16 | { |
703 | 16 | const int idx = *(int*)ccv_array_get(visited_backward_execs, j); |
704 | 16 | if (ccv_array_contain_int(replaced_backward_execs, idx)) |
705 | 7 | { |
706 | 7 | ++j; |
707 | 7 | continue; |
708 | 7 | } |
709 | 9 | ccv_sparse_matrix_vector_t* vector = ccv_get_sparse_matrix_vector(exec_dep.deps, idx); |
710 | 9 | if (!vector) |
711 | 3 | vector = (ccv_sparse_matrix_vector_t*)1; // Mark it as we tried but cannot find. |
712 | 9 | int flag = 0; |
713 | 38 | for (k = 0; !flag && k < replaced_backward_execs->rnum32 ; k++29 ) |
714 | 29 | { |
715 | 29 | const int d = *(int*)ccv_array_get(replaced_backward_execs, k); |
716 | 29 | flag = ccv_nnc_exec_dep_check(exec_dep, idx, vector, d); |
717 | 29 | } |
718 | 9 | if (!flag) |
719 | 3 | { |
720 | 3 | if (j < visited_backward_execs->rnum - 1) |
721 | 2 | *(int*)ccv_array_get(visited_backward_execs, j) = *(int*)ccv_array_get(visited_backward_execs, visited_backward_execs->rnum - 1); |
722 | 3 | --visited_backward_execs->rnum; |
723 | 3 | continue; |
724 | 3 | } |
725 | 6 | ++j; |
726 | 6 | } |
727 | | // Now go through all replaced_backward_execs to find the ones has no dependencies in visited_backward_execs. |
728 | 9 | for (j = 0; j < replaced_backward_execs->rnum; j++7 ) |
729 | 7 | { |
730 | 7 | const int idx = *(int*)ccv_array_get(replaced_backward_execs, j); |
731 | 7 | ccv_sparse_matrix_vector_t* vector = ccv_get_sparse_matrix_vector(exec_dep.deps, idx); |
732 | 7 | if (!vector) |
733 | 3 | vector = (ccv_sparse_matrix_vector_t*)1; // Mark it as we tried but cannot find. |
734 | 7 | int flag = 0; |
735 | 59 | for (k = 0; !flag && k < visited_backward_execs->rnum54 ; k++52 ) |
736 | 52 | { |
737 | 52 | const int d = *(int*)ccv_array_get(visited_backward_execs, k); |
738 | 52 | flag = ccv_nnc_exec_dep_check(exec_dep, idx, vector, d); |
739 | 52 | } |
740 | | // If this one has no parents that is within the visited_backward_execs, it is a good place for us to add all its parents as dependency for input_execs. |
741 | 7 | if (!flag) |
742 | 2 | { |
743 | 2 | assert(idx < exec_rnum); |
744 | 2 | ccv_array_t* const outgoings = reversed_nodes[idx].outgoings; |
745 | 2 | assert(outgoings); |
746 | 6 | for (k = 0; 2 k < outgoings->rnum; k++4 ) |
747 | 4 | { |
748 | 4 | const int d = *(int*)ccv_array_get(outgoings, k); |
749 | 16 | for (l = 0; l < newly_input_execs->rnum; l++12 ) |
750 | 12 | { |
751 | 12 | ccv_nnc_graph_exec_symbol_concat(graph, (ccv_nnc_graph_exec_symbol_t){ |
752 | 12 | .graph = graph, |
753 | 12 | .d = d |
754 | 12 | }, *(ccv_nnc_graph_exec_symbol_t*)ccv_array_get(newly_input_execs, l)); |
755 | 12 | } |
756 | 4 | } |
757 | 2 | } |
758 | 7 | } |
759 | 2 | ccv_nnc_exec_dep_free(exec_dep); |
760 | | // Go through all exec, free ones that doesn't have output used. |
761 | | // Reuse this array because it is not useful any more. |
762 | 2 | ccv_array_t* forward_pass_inputs = visited_backward_execs; |
763 | 2 | int any_deleted; |
764 | 6 | do { |
765 | | // Build a map of still active inputs. |
766 | 6 | ccv_array_clear(forward_pass_inputs); |
767 | 60 | for (j = 0; j < build.graph_exec_symbols->rnum; j++54 ) |
768 | 54 | { |
769 | 54 | ccv_nnc_graph_exec_symbol_t* const symbol = (ccv_nnc_graph_exec_symbol_t*)ccv_array_get(build.graph_exec_symbols, j); |
770 | 54 | if (symbol->d < 0) |
771 | 8 | continue; |
772 | 46 | if (CCV_NNC_GRAPH_EXEC_IS_DEAD(exec_info[symbol->d].flags)) |
773 | 0 | continue; |
774 | 46 | int* const inputs = exec_info[symbol->d].inputs; |
775 | 46 | const int input_size = exec_info[symbol->d].input_size; |
776 | 135 | for (k = 0; k < input_size; k++89 ) |
777 | 89 | { |
778 | 89 | int d = inputs[k]; |
779 | 89 | if (d < 0) |
780 | 0 | continue; |
781 | 89 | ccv_array_add_unique_int(forward_pass_inputs, d); |
782 | 89 | if (tensor_symbol_info[d].alias_ref > 0) |
783 | 0 | { |
784 | 0 | d = tensor_symbol_info[d].alias_ref - 1; |
785 | 0 | ccv_array_add_unique_int(forward_pass_inputs, d); |
786 | 0 | } |
787 | 89 | } |
788 | 46 | } |
789 | 6 | any_deleted = 0; |
790 | 60 | for (j = 0; j < build.graph_exec_symbols->rnum; j++54 ) |
791 | 54 | { |
792 | 54 | ccv_nnc_graph_exec_symbol_t* const symbol = (ccv_nnc_graph_exec_symbol_t*)ccv_array_get(build.graph_exec_symbols, j); |
793 | 54 | if (symbol->d < 0) |
794 | 8 | continue; |
795 | 46 | if (CCV_NNC_GRAPH_EXEC_IS_DEAD(exec_info[symbol->d].flags)) |
796 | 0 | continue; |
797 | 46 | int* const outputs = exec_info[symbol->d].outputs; |
798 | 46 | const int output_size = exec_info[symbol->d].output_size; |
799 | 46 | int flag = 0; |
800 | 92 | for (k = 0; !flag && k < output_size52 ; k++46 ) |
801 | 46 | { |
802 | 46 | int d = outputs[k]; |
803 | 46 | if (d < 0) |
804 | 0 | continue; |
805 | 46 | flag = ccv_array_contain_int(newly_used_outputs, d) || ccv_array_contain_int(forward_pass_inputs, d)22 ; |
806 | 46 | if (!flag && tensor_symbol_info[d].alias_ref > 06 ) |
807 | 0 | { |
808 | 0 | d = tensor_symbol_info[d].alias_ref - 1; |
809 | 0 | flag = ccv_array_contain_int(newly_used_outputs, d) || ccv_array_contain_int(forward_pass_inputs, d); |
810 | 0 | } |
811 | 46 | } |
812 | 46 | if (flag) |
813 | 40 | continue; |
814 | 6 | ccv_nnc_graph_exec_symbol_free(graph, *symbol); |
815 | 6 | symbol->d = -1; |
816 | 6 | symbol->graph = 0; |
817 | 6 | any_deleted = 1; |
818 | 6 | } |
819 | 6 | } while (any_deleted); |
820 | 2 | ccv_array_clear(forward_pass_inputs); |
821 | 20 | for (j = 0; j < build.graph_exec_symbols->rnum; j++18 ) |
822 | 18 | { |
823 | 18 | ccv_nnc_graph_exec_symbol_t* const symbol = (ccv_nnc_graph_exec_symbol_t*)ccv_array_get(build.graph_exec_symbols, j); |
824 | 18 | if (symbol->d < 0) |
825 | 6 | continue; |
826 | 12 | if (CCV_NNC_GRAPH_EXEC_IS_DEAD(exec_info[symbol->d].flags)) |
827 | 0 | continue; |
828 | 12 | int* const inputs = exec_info[symbol->d].inputs; |
829 | 12 | const int input_size = exec_info[symbol->d].input_size; |
830 | 35 | for (k = 0; k < input_size; k++23 ) |
831 | 23 | { |
832 | 23 | if (inputs[k] < 0) |
833 | 0 | continue; |
834 | 23 | ccv_array_add_unique_int(forward_pass_inputs, inputs[k]); |
835 | 23 | if (tensor_symbol_info[inputs[k]].alias_ref > 0) |
836 | 0 | ccv_array_add_unique_int(forward_pass_inputs, tensor_symbol_info[inputs[k]].alias_ref - 1); |
837 | 23 | } |
838 | 12 | int* const outputs = exec_info[symbol->d].outputs; |
839 | 12 | const int output_size = exec_info[symbol->d].output_size; |
840 | 24 | for (k = 0; k < output_size; k++12 ) |
841 | 12 | { |
842 | 12 | if (outputs[k] < 0) |
843 | 0 | continue; |
844 | 12 | ccv_array_add_unique_int(forward_pass_inputs, outputs[k]); |
845 | 12 | if (tensor_symbol_info[outputs[k]].alias_ref > 0) |
846 | 0 | ccv_array_add_unique_int(forward_pass_inputs, tensor_symbol_info[outputs[k]].alias_ref - 1); |
847 | 12 | } |
848 | 12 | } |
849 | | // Free unused tensor symbols. |
850 | 20 | for (j = 0; j < build.all_tensor_symbols->rnum; j++18 ) |
851 | 18 | { |
852 | 18 | const ccv_nnc_tensor_symbol_t* symbol = ((ccv_nnc_tensor_symbol_t*)ccv_array_get(build.all_tensor_symbols, j)); |
853 | 18 | if (ccv_array_contain_int(newly_used_outputs, symbol->d) || ccv_array_contain_int(forward_pass_inputs, symbol->d)10 ) |
854 | 12 | continue; |
855 | 6 | if (tensor_symbol_info[symbol->d].alias_ref > 0) |
856 | 0 | { |
857 | 0 | const int d = tensor_symbol_info[symbol->d].alias_ref - 1; |
858 | 0 | if (ccv_array_contain_int(newly_used_outputs, d) || ccv_array_contain_int(forward_pass_inputs, d)) |
859 | 0 | continue; |
860 | 0 | } |
861 | 6 | ccv_nnc_tensor_symbol_free(graph, *symbol); |
862 | 6 | } |
863 | 20 | for (j = 0; j < build.graph_exec_symbols->rnum; j++18 ) |
864 | 18 | { |
865 | 18 | ccv_nnc_graph_exec_symbol_t* const symbol = (ccv_nnc_graph_exec_symbol_t*)ccv_array_get(build.graph_exec_symbols, j); |
866 | 18 | if (symbol->d < 0) |
867 | 6 | continue; |
868 | 12 | if (CCV_NNC_GRAPH_EXEC_IS_DEAD(exec_info[symbol->d].flags)) |
869 | 0 | continue; |
870 | 12 | ccv_nnc_graph_exec_symbol_set_flags(graph, *symbol, CCV_NNC_GRAPH_EXEC_DISABLE_OPT); |
871 | 12 | } |
872 | | // Free these newly created execs and tensor symbols. |
873 | 2 | ccv_array_free(build.tensor_context.tensor_symbols); |
874 | 2 | ccv_array_free(build.graph_exec_symbols); |
875 | 2 | ccv_array_free(build.all_tensor_symbols); |
876 | 2 | } |
877 | 2 | kh_destroy(ccv_cnnp_tensor_symbol_map, symbol_map); |
878 | 2 | kh_destroy(ccv_cnnp_tensor_symbol_set, newly_created_tensor_symbols); |
879 | 2 | kh_destroy(ccv_cnnp_tensor_symbol_set, parameters_or_internals); |
880 | 2 | ccfree(max_outputs); |
881 | 2 | ccv_array_free(newly_used_outputs); |
882 | 2 | ccv_array_free(parameters); |
883 | 2 | ccv_array_free(parameter_ids); |
884 | 2 | ccv_array_free(parameter_trainables); |
885 | 2 | ccv_array_free(internals); |
886 | 2 | ccv_array_free(internal_ids); |
887 | 2 | ccfree(maskbit); |
888 | 2 | ccv_array_free(input_gradient_execs); |
889 | 2 | ccv_array_free(output_gradient_execs); |
890 | 2 | ccv_array_free(input_execs); |
891 | 2 | ccv_array_free(output_execs); |
892 | 2 | ccv_array_free(replaced_backward_execs); |
893 | 2 | ccv_array_free(visited_backward_execs); |
894 | 48 | for (i = 0; i < exec_rnum; i++46 ) |
895 | 46 | if (reversed_nodes[i].outgoings) |
896 | 40 | ccv_array_free(reversed_nodes[i].outgoings); |
897 | 2 | ccfree(reversed_nodes); |
898 | 2 | } |