/home/liu/actions-runner/_work/ccv/ccv/lib/nnc/ccv_nnc_symbolic_graph.c
Line | Count | Source |
1 | | #include "ccv_nnc.h" |
2 | | #include "ccv_nnc_easy.h" |
3 | | #include "ccv_nnc_internal.h" |
4 | | #include "ccv_internal.h" |
5 | | #include "_ccv_nnc_symbolic_graph.h" |
6 | | |
7 | | // MARK - Level-3 API |
8 | | |
9 | | const ccv_nnc_tensor_param_t ccv_nnc_tensor_auto = {}; |
10 | | |
11 | | int ccv_nnc_is_tensor_auto(const ccv_nnc_tensor_param_t params) |
12 | 714k | { |
13 | 714k | return (memcmp(¶ms, &ccv_nnc_tensor_auto, sizeof(ccv_nnc_tensor_param_t)) == 0); |
14 | 714k | } |
15 | | |
16 | | ccv_nnc_symbolic_graph_t* ccv_nnc_symbolic_graph_new(void) |
17 | 2.63k | { |
18 | 2.63k | ccv_nnc_symbolic_graph_t* graph = cccalloc(1, sizeof(ccv_nnc_symbolic_graph_t)); |
19 | 2.63k | graph->tensor_symbol_info = ccv_array_new(sizeof(ccv_nnc_tensor_symbol_info_t), 5, 0); |
20 | 2.63k | graph->exec_symbol_info = ccv_array_new(sizeof(ccv_nnc_graph_exec_symbol_info_t), 5, 0); |
21 | 2.63k | graph->reuse.exec = -1; |
22 | 2.63k | graph->reuse.tensor = -1; |
23 | 2.63k | return graph; |
24 | 2.63k | } |
25 | | |
26 | | ccv_nnc_symbolic_graph_t* ccv_nnc_symbolic_graph_dup(const ccv_nnc_symbolic_graph_t* const graph, ccv_nnc_symbolic_graph_subst_f subst) |
27 | 13 | { |
28 | 13 | ccv_nnc_symbolic_graph_t* new_graph = ccmalloc(sizeof(ccv_nnc_symbolic_graph_t)); |
29 | 13 | memcpy(new_graph, graph, sizeof(ccv_nnc_symbolic_graph_t)); |
30 | 13 | new_graph->tensor_symbol_info = ccv_array_new(sizeof(ccv_nnc_tensor_symbol_info_t), graph->tensor_symbol_info->rnum, 0); |
31 | 13 | new_graph->tensor_symbol_info->rnum = graph->tensor_symbol_info->rnum; |
32 | 13 | memcpy(ccv_array_get(new_graph->tensor_symbol_info, 0), ccv_array_get(graph->tensor_symbol_info, 0), sizeof(ccv_nnc_tensor_symbol_info_t) * graph->tensor_symbol_info->rnum); |
33 | 13 | int i; |
34 | 91 | for (i = 0; i < new_graph->tensor_symbol_info->rnum; i++78 ) |
35 | 78 | { |
36 | 78 | ccv_nnc_tensor_symbol_info_t* symbol_info = (ccv_nnc_tensor_symbol_info_t*)ccv_array_get(new_graph->tensor_symbol_info, i); |
37 | 78 | if (symbol_info->name) |
38 | 70 | { |
39 | 70 | char* const name = symbol_info->name; |
40 | 70 | const size_t len = strnlen(name, 63); |
41 | 70 | const size_t n = len + 1; |
42 | 70 | symbol_info->name = (char*)ccmalloc(n); |
43 | | // Don't use strndup because this way I can have custom allocator (for ccmalloc). |
44 | 70 | memcpy(symbol_info->name, name, n); |
45 | 70 | symbol_info->name[len] = 0; |
46 | 70 | } |
47 | 78 | if (symbol_info->s_ref) |
48 | 6 | { |
49 | 6 | ccv_array_t* const s_ref = symbol_info->s_ref; |
50 | 6 | symbol_info->s_ref = ccv_array_new(sizeof(int), s_ref->rnum, 0); |
51 | 6 | symbol_info->s_ref->rnum = s_ref->rnum; |
52 | 6 | memcpy(ccv_array_get(symbol_info->s_ref, 0), ccv_array_get(s_ref, 0), sizeof(int) * s_ref->rnum); |
53 | 6 | } |
54 | 78 | } |
55 | 13 | new_graph->exec_symbol_info = ccv_array_new(sizeof(ccv_nnc_graph_exec_symbol_info_t), graph->exec_symbol_info->rnum, 0); |
56 | 13 | new_graph->exec_symbol_info->rnum = graph->exec_symbol_info->rnum; |
57 | 13 | memcpy(ccv_array_get(new_graph->exec_symbol_info, 0), ccv_array_get(graph->exec_symbol_info, 0), sizeof(ccv_nnc_graph_exec_symbol_info_t) * graph->exec_symbol_info->rnum); |
58 | 48 | for (i = 0; i < new_graph->exec_symbol_info->rnum; i++35 ) |
59 | 35 | { |
60 | 35 | ccv_nnc_graph_exec_symbol_info_t* symbol_info = (ccv_nnc_graph_exec_symbol_info_t*)ccv_array_get(new_graph->exec_symbol_info, i); |
61 | 35 | if (symbol_info->name) |
62 | 25 | { |
63 | 25 | char* const name = symbol_info->name; |
64 | 25 | const size_t len = strnlen(name, 63); |
65 | 25 | const size_t n = len + 1; |
66 | 25 | symbol_info->name = (char*)ccmalloc(n); |
67 | | // Don't use strndup because this way I can have custom allocator (for ccmalloc). |
68 | 25 | memcpy(symbol_info->name, name, n); |
69 | 25 | symbol_info->name[len] = 0; |
70 | 25 | } |
71 | 35 | if (symbol_info->outgoings) |
72 | 20 | { |
73 | 20 | ccv_array_t* const outgoings = symbol_info->outgoings; |
74 | 20 | symbol_info->outgoings = ccv_array_new(sizeof(int), outgoings->rnum, 0); |
75 | 20 | symbol_info->outgoings->rnum = outgoings->rnum; |
76 | 20 | memcpy(ccv_array_get(symbol_info->outgoings, 0), ccv_array_get(outgoings, 0), sizeof(int) * outgoings->rnum); |
77 | 20 | } |
78 | 35 | if (symbol_info->inputs) |
79 | 22 | { |
80 | 22 | int* const inputs = symbol_info->inputs; |
81 | 22 | symbol_info->inputs = (int*)ccmalloc(sizeof(int) * (symbol_info->input_size + symbol_info->output_size)); |
82 | 22 | symbol_info->outputs = symbol_info->inputs + symbol_info->input_size; |
83 | 22 | memcpy(symbol_info->inputs, inputs, sizeof(int) * (symbol_info->input_size + symbol_info->output_size)); |
84 | 22 | } |
85 | 35 | if (symbol_info->_heap_graph_ref) |
86 | 2 | { |
87 | 2 | int* const heap_graph_ref = symbol_info->_heap_graph_ref; |
88 | 2 | symbol_info->_heap_graph_ref = (int*)ccmalloc(sizeof(int) * symbol_info->graph_ref_size); |
89 | 2 | memcpy(symbol_info->_heap_graph_ref, heap_graph_ref, sizeof(int) * symbol_info->graph_ref_size); |
90 | 2 | } |
91 | 35 | if ((symbol_info->flags & CCV_NNC_GRAPH_EXEC_P_WHILE) && symbol_info->input_size > 01 ) |
92 | 1 | { |
93 | 1 | int* const inputs = symbol_info->p_while.inputs; |
94 | 1 | symbol_info->p_while.inputs = (int*)ccmalloc(sizeof(int) * symbol_info->p_while.input_size); |
95 | 1 | memcpy(symbol_info->p_while.inputs, inputs, sizeof(int) * symbol_info->p_while.input_size); |
96 | 1 | } |
97 | 35 | } |
98 | 13 | if (graph->sources) |
99 | 13 | { |
100 | 13 | new_graph->sources = ccv_array_new(sizeof(ccv_nnc_graph_exec_symbol_t), graph->sources->rnum, 0); |
101 | 13 | new_graph->sources->rnum = graph->sources->rnum; |
102 | 13 | memcpy(ccv_array_get(new_graph->sources, 0), ccv_array_get(graph->sources, 0), sizeof(ccv_nnc_graph_exec_symbol_t) * graph->sources->rnum); |
103 | 26 | for (i = 0; i < new_graph->sources->rnum; i++13 ) |
104 | 13 | ((ccv_nnc_graph_exec_symbol_t*)ccv_array_get(new_graph->sources, i))->graph = new_graph; |
105 | 13 | } |
106 | 13 | if (graph->destinations) |
107 | 13 | { |
108 | 13 | new_graph->destinations = ccv_array_new(sizeof(ccv_nnc_graph_exec_symbol_t), graph->destinations->rnum, 0); |
109 | 13 | new_graph->destinations->rnum = graph->destinations->rnum; |
110 | 13 | memcpy(ccv_array_get(new_graph->destinations, 0), ccv_array_get(graph->destinations, 0), sizeof(ccv_nnc_graph_exec_symbol_t) * graph->destinations->rnum); |
111 | 26 | for (i = 0; i < new_graph->destinations->rnum; i++13 ) |
112 | 13 | ((ccv_nnc_graph_exec_symbol_t*)ccv_array_get(new_graph->destinations, i))->graph = new_graph; |
113 | 13 | } |
114 | 13 | if (graph->breakpoints) |
115 | 13 | { |
116 | 13 | new_graph->breakpoints = (ccv_nnc_graph_exec_symbol_t*)ccmalloc(sizeof(ccv_nnc_graph_exec_symbol_t) * graph->breakpoint_size); |
117 | 13 | memcpy(new_graph->breakpoints, graph->breakpoints, sizeof(ccv_nnc_graph_exec_symbol_t) * graph->breakpoint_size); |
118 | 26 | for (i = 0; i < graph->breakpoint_size; i++13 ) |
119 | 13 | new_graph->breakpoints[i].graph = new_graph; |
120 | 13 | } |
121 | 13 | if (graph->backward.tensor_symbol_idx) |
122 | 1 | { |
123 | 1 | new_graph->backward.tensor_symbol_idx = (int*)ccmalloc(sizeof(int) * (new_graph->backward.tensor_symbol_size + new_graph->backward.exec_symbol_size)); |
124 | 1 | if (new_graph->backward.tensor_symbol_size > 0) |
125 | 1 | memcpy(new_graph->backward.tensor_symbol_idx, graph->backward.tensor_symbol_idx, sizeof(int) * new_graph->backward.tensor_symbol_size); |
126 | 1 | new_graph->backward.exec_symbol_idx = new_graph->backward.tensor_symbol_idx + new_graph->backward.tensor_symbol_size; |
127 | 1 | if (new_graph->backward.exec_symbol_size > 0) |
128 | 1 | memcpy(new_graph->backward.exec_symbol_idx, graph->backward.exec_symbol_idx, sizeof(int) * new_graph->backward.exec_symbol_size); |
129 | 1 | } |
130 | 13 | if (subst) |
131 | 13 | { |
132 | 48 | for (i = 0; i < new_graph->exec_symbol_info->rnum; i++35 ) |
133 | 35 | { |
134 | 35 | ccv_nnc_graph_exec_symbol_info_t* const symbol_info = (ccv_nnc_graph_exec_symbol_info_t*)ccv_array_get(new_graph->exec_symbol_info, i); |
135 | 35 | if (!CCV_NNC_GRAPH_EXEC_IS_DEAD(symbol_info->flags)) |
136 | 33 | { |
137 | 33 | symbol_info->cmd = subst((ccv_nnc_graph_exec_symbol_t){ |
138 | 33 | .d = i, |
139 | 33 | .graph = graph, |
140 | 33 | }, symbol_info->cmd); |
141 | 33 | if (symbol_info->cmd.cmd != CCV_NNC_GRAPH_FORWARD && symbol_info->cmd.cmd != CCV_NNC_GRAPH_BACKWARD) |
142 | 33 | { |
143 | 33 | symbol_info->graph_ref_size = 0; |
144 | 33 | if (symbol_info->_heap_graph_ref) |
145 | 2 | { |
146 | 2 | ccfree(symbol_info->_heap_graph_ref); |
147 | 2 | symbol_info->_heap_graph_ref = 0; |
148 | 2 | } |
149 | 33 | } |
150 | 33 | } |
151 | 35 | } |
152 | 13 | } |
153 | | // TODO: See how and if I need to dup sub-graphs. I also need to figure out what's the relationship between this graph |
154 | | // and its parent graph (or how can we use the symbol from the graph properly). |
155 | 13 | new_graph->sub_graphs = 0; |
156 | 13 | return new_graph; |
157 | 13 | } |
158 | | |
159 | | ccv_nnc_tensor_symbol_t ccv_nnc_tensor_symbol_new(ccv_nnc_symbolic_graph_t* const graph, const ccv_nnc_tensor_param_t info, const char* const name) |
160 | 100k | { |
161 | 100k | ccv_nnc_tensor_symbol_t symbol = { |
162 | 100k | .d = graph->tensor_symbol_info->rnum, |
163 | 100k | .graph = graph |
164 | 100k | }; |
165 | 100k | ccv_nnc_tensor_symbol_info_t symbol_info = { |
166 | 100k | .info = info, |
167 | 100k | }; |
168 | 100k | if (name) |
169 | 4.79k | { |
170 | 4.79k | const size_t len = strnlen(name, 63); |
171 | 4.79k | const size_t n = len + 1; |
172 | 4.79k | symbol_info.name = (char*)ccmalloc(n); |
173 | | // Don't use strndup because this way I can have custom allocator (for ccmalloc). |
174 | 4.79k | memcpy(symbol_info.name, name, n); |
175 | 4.79k | symbol_info.name[len] = 0; |
176 | 4.79k | } |
177 | 100k | if (graph->reuse.tensor >= 0) |
178 | 16.2k | { |
179 | 16.2k | const int reuse_tensor_d = graph->reuse.tensor; |
180 | 16.2k | assert(reuse_tensor_d < graph->tensor_symbol_info->rnum); |
181 | 16.2k | *(ccv_nnc_tensor_symbol_info_t*)ccv_array_get(graph->tensor_symbol_info, reuse_tensor_d) = symbol_info; |
182 | 16.2k | int i; |
183 | 16.2k | graph->reuse.tensor = -1; |
184 | 29.9k | for (i = reuse_tensor_d + 1; i < graph->tensor_symbol_info->rnum && graph->reuse.tensor < 022.5k ; i++13.6k ) |
185 | 13.6k | if (CCV_NNC_TENSOR_SYMBOL_IS_DEAD(((ccv_nnc_tensor_symbol_info_t*)ccv_array_get(graph->tensor_symbol_info, i))->flags)) |
186 | 11.4k | graph->reuse.tensor = i; |
187 | 16.2k | symbol.d = reuse_tensor_d; |
188 | 16.2k | } else |
189 | 84.4k | ccv_array_push(graph->tensor_symbol_info, &symbol_info); |
190 | 100k | if (graph->hooks.tensor_symbol_new.func) |
191 | 47.0k | graph->hooks.tensor_symbol_new.func(graph->hooks.tensor_symbol_new.context, symbol, info, name); |
192 | 100k | return symbol; |
193 | 100k | } |
194 | | |
195 | | void* ccv_nnc_tensor_symbol_new_hook(ccv_nnc_symbolic_graph_t* const graph, ccv_nnc_tensor_symbol_new_hook_f hook, void* context, ccv_nnc_tensor_symbol_new_hook_f* previous_hook) |
196 | 11.7k | { |
197 | 11.7k | if (previous_hook) |
198 | 4 | *previous_hook = graph->hooks.tensor_symbol_new.func; |
199 | 11.7k | void* const prev = graph->hooks.tensor_symbol_new.context; |
200 | 11.7k | graph->hooks.tensor_symbol_new.func = hook; |
201 | 11.7k | graph->hooks.tensor_symbol_new.context = context; |
202 | 11.7k | return prev; |
203 | 11.7k | } |
204 | | |
205 | | ccv_nnc_tensor_symbol_t ccv_nnc_tensor_symbol_alias_new(ccv_nnc_symbolic_graph_t* const graph, const ccv_nnc_tensor_symbol_t tensor_symbol, const int ofs[CCV_NNC_MAX_DIM_ALLOC], const int stride[CCV_NNC_MAX_DIM_ALLOC], const ccv_nnc_tensor_param_t info, const char* const name) |
206 | 4.77k | { |
207 | 4.77k | assert(tensor_symbol.graph == graph); |
208 | 4.77k | int d = tensor_symbol.d; |
209 | 4.77k | assert(d >= 0 && d < graph->tensor_symbol_info->rnum); |
210 | 4.77k | ccv_nnc_tensor_symbol_info_t* info_d = (ccv_nnc_tensor_symbol_info_t*)ccv_array_get(graph->tensor_symbol_info, d); |
211 | | // Find the root tensor that is not an alias. |
212 | 4.77k | while (info_d->alias_ref) |
213 | 1 | { |
214 | 1 | d = info_d->alias_ref - 1; |
215 | 1 | assert(d >= 0 && d < graph->tensor_symbol_info->rnum); |
216 | 1 | info_d = (ccv_nnc_tensor_symbol_info_t*)ccv_array_get(graph->tensor_symbol_info, d); |
217 | 1 | } |
218 | 4.77k | ccv_nnc_tensor_symbol_t alias = { |
219 | 4.77k | .d = graph->tensor_symbol_info->rnum, |
220 | 4.77k | .graph = graph |
221 | 4.77k | }; |
222 | | // Alias comes in two shapes: the total tensor count is strictly smaller or equal to. |
223 | | // If it is not auto, check dimensions. |
224 | 4.77k | if (!ccv_nnc_is_tensor_auto(info_d->info)) |
225 | 4.77k | { assert((size_t)stride[0] * info.dim[0] <= ccv_nnc_tensor_count(info_d->info)); } |
226 | 4.77k | ccv_nnc_tensor_symbol_info_t alias_info = { |
227 | 4.77k | .alias_ref = d + 1, |
228 | 4.77k | .info = info, |
229 | 4.77k | }; |
230 | 4.77k | if (name) |
231 | 77 | { |
232 | 77 | const size_t len = strnlen(name, 63); |
233 | 77 | const size_t n = len + 1; |
234 | 77 | alias_info.name = (char*)ccmalloc(n); |
235 | | // Don't use strndup because this way I can have custom allocator (for ccmalloc). |
236 | 77 | memcpy(alias_info.name, name, n); |
237 | 77 | alias_info.name[len] = 0; |
238 | 77 | } |
239 | 4.77k | memcpy(alias_info.ofs, ofs, sizeof(int) * CCV_NNC_MAX_DIM_ALLOC); |
240 | 4.77k | memcpy(alias_info.stride, stride, sizeof(int) * CCV_NNC_MAX_DIM_ALLOC); |
241 | 4.77k | if (graph->reuse.tensor >= 0) |
242 | 6 | { |
243 | 6 | const int reuse_tensor_d = graph->reuse.tensor; |
244 | 6 | *(ccv_nnc_tensor_symbol_info_t*)ccv_array_get(graph->tensor_symbol_info, reuse_tensor_d) = alias_info; |
245 | 6 | int i; |
246 | 6 | graph->reuse.tensor = -1; |
247 | 12 | for (i = reuse_tensor_d + 1; i < graph->tensor_symbol_info->rnum && graph->reuse.tensor < 0; i++6 ) |
248 | 6 | if (CCV_NNC_TENSOR_SYMBOL_IS_DEAD(((ccv_nnc_tensor_symbol_info_t*)ccv_array_get(graph->tensor_symbol_info, i))->flags)) |
249 | 6 | graph->reuse.tensor = i; |
250 | 6 | alias.d = reuse_tensor_d; |
251 | 6 | } else |
252 | 4.77k | ccv_array_push(graph->tensor_symbol_info, &alias_info); |
253 | 4.77k | if (graph->hooks.tensor_symbol_alias_new.func) |
254 | 1.49k | graph->hooks.tensor_symbol_alias_new.func(graph->hooks.tensor_symbol_alias_new.context, alias, tensor_symbol, ofs, stride, info, name); |
255 | 4.77k | return alias; |
256 | 4.77k | } |
257 | | |
258 | | void* ccv_nnc_tensor_symbol_alias_new_hook(ccv_nnc_symbolic_graph_t* const graph, ccv_nnc_tensor_symbol_alias_new_hook_f hook, void* context, ccv_nnc_tensor_symbol_alias_new_hook_f* previous_hook) |
259 | 11.7k | { |
260 | 11.7k | if (previous_hook) |
261 | 4 | *previous_hook = graph->hooks.tensor_symbol_alias_new.func; |
262 | 11.7k | void* const prev = graph->hooks.tensor_symbol_alias_new.context; |
263 | 11.7k | graph->hooks.tensor_symbol_alias_new.func = hook; |
264 | 11.7k | graph->hooks.tensor_symbol_alias_new.context = context; |
265 | 11.7k | return prev; |
266 | 11.7k | } |
267 | | |
268 | | ccv_nnc_tensor_symbol_t ccv_nnc_tensor_symbol_alias_to(const ccv_nnc_symbolic_graph_t* const graph, const ccv_nnc_tensor_symbol_t tensor_symbol) |
269 | 103k | { |
270 | 103k | assert(tensor_symbol.graph == graph); |
271 | 103k | int d = tensor_symbol.d; |
272 | 103k | assert(d >= 0 && d < graph->tensor_symbol_info->rnum); |
273 | 103k | ccv_nnc_tensor_symbol_info_t* info_d = (ccv_nnc_tensor_symbol_info_t*)ccv_array_get(graph->tensor_symbol_info, d); |
274 | | // Find the root tensor that is not an alias. |
275 | 105k | while (info_d->alias_ref) |
276 | 2.05k | { |
277 | 2.05k | d = info_d->alias_ref - 1; |
278 | 2.05k | assert(d >= 0 && d < graph->tensor_symbol_info->rnum); |
279 | 2.05k | info_d = (ccv_nnc_tensor_symbol_info_t*)ccv_array_get(graph->tensor_symbol_info, d); |
280 | 2.05k | } |
281 | 103k | if (d != tensor_symbol.d) |
282 | 2.05k | return (ccv_nnc_tensor_symbol_t){ |
283 | 2.05k | .d = d, |
284 | 2.05k | .graph = graph |
285 | 2.05k | }; |
286 | 100k | return (ccv_nnc_tensor_symbol_t){ |
287 | 100k | .d = CCV_NNC_NO_TENSOR_SYMBOL, |
288 | 100k | .graph = 0 |
289 | 100k | }; |
290 | 103k | } |
291 | | |
292 | | // Resolve this tensor symbol to the current graph. If cannot find, return no symbol. |
293 | | ccv_nnc_tensor_symbol_t ccv_nnc_tensor_symbol_resolve(const ccv_nnc_symbolic_graph_t* const graph, const ccv_nnc_tensor_symbol_t tensor_symbol) |
294 | 190k | { |
295 | 190k | if (graph == tensor_symbol.graph) |
296 | 190k | return tensor_symbol; |
297 | 38 | ccv_nnc_tensor_symbol_info_t* const symbol_info = (ccv_nnc_tensor_symbol_info_t*)ccv_array_get(tensor_symbol.graph->tensor_symbol_info, tensor_symbol.d); |
298 | 38 | assert(!symbol_info->alias_ref); |
299 | | // Find if the symbol is in the sub-graph. |
300 | 38 | const ccv_nnc_symbolic_graph_t* curr_graph = tensor_symbol.graph; |
301 | 38 | assert(tensor_symbol.d >= 0 && tensor_symbol.d < curr_graph->tensor_symbol_info->rnum); |
302 | 78 | while (38 curr_graph && curr_graph != graph57 ) |
303 | 40 | curr_graph = curr_graph->p; |
304 | 38 | if (curr_graph) |
305 | 17 | { |
306 | | // The graph is a parent of the symbol passed in. |
307 | 17 | curr_graph = tensor_symbol.graph; |
308 | 17 | ccv_nnc_tensor_symbol_info_t* curr_symbol_info = symbol_info; |
309 | 17 | ccv_nnc_tensor_symbol_t curr_symbol = tensor_symbol; |
310 | 22 | while (curr_graph != graph) |
311 | 17 | { |
312 | 17 | ccv_nnc_symbolic_graph_t* const p = curr_graph->p; |
313 | | // Cannot find the relevant one in the parent graph, return no symbol. |
314 | 17 | if (!curr_symbol_info->p_ref) |
315 | 12 | return (ccv_nnc_tensor_symbol_t){ |
316 | 12 | .d = CCV_NNC_NO_TENSOR_SYMBOL, |
317 | 12 | .graph = graph, |
318 | 12 | }; |
319 | 5 | curr_symbol.d = curr_symbol_info->p_ref - 1; |
320 | 5 | curr_symbol.graph = p; |
321 | 5 | assert(curr_symbol.d >= 0 && curr_symbol.d < p->tensor_symbol_info->rnum); |
322 | 5 | curr_symbol_info = (ccv_nnc_tensor_symbol_info_t*)ccv_array_get(p->tensor_symbol_info, curr_symbol.d); |
323 | | // Move on. |
324 | 5 | curr_graph = p; |
325 | 5 | } |
326 | 5 | return curr_symbol; |
327 | 17 | } |
328 | | // Otherwise, if the symbol is in the parent graph, this is a bit more expensive because I need to keep a trace stack. |
329 | 21 | curr_graph = graph; |
330 | 21 | int d; |
331 | 42 | for (d = 0; curr_graph && curr_graph != tensor_symbol.graph; d++21 ) |
332 | 21 | curr_graph = curr_graph->p; |
333 | 21 | curr_graph = graph; |
334 | 21 | assert(d > 0); |
335 | 21 | int trace[d]; |
336 | 42 | for (d = 0; curr_graph && curr_graph != tensor_symbol.graph; d++21 ) |
337 | 21 | { |
338 | 21 | const int p_idx = curr_graph->p_idx - 1; |
339 | 21 | trace[d] = p_idx; |
340 | 21 | curr_graph = curr_graph->p; |
341 | 21 | } |
342 | | // If it is not in both the parent graph and the sub-graph, the input is invalid. |
343 | 21 | assert(curr_graph); |
344 | 21 | curr_graph = tensor_symbol.graph; |
345 | 21 | ccv_nnc_tensor_symbol_info_t* curr_symbol_info = symbol_info; |
346 | 21 | ccv_nnc_tensor_symbol_t curr_symbol = tensor_symbol; |
347 | | // The graph is a sub graph of the symbol passed in. |
348 | 21 | int i; |
349 | 42 | for (i = d - 1; i >= 0; i--21 ) |
350 | 21 | { |
351 | 21 | const int p_idx = trace[i]; |
352 | 21 | assert(p_idx >= 0); |
353 | | // Cannot find the relevant one in the sub-graph, return no symbol. |
354 | 21 | if (!curr_graph->sub_graphs || !curr_symbol_info->s_ref || |
355 | 21 | curr_symbol_info->s_ref->rnum != curr_graph->sub_graphs->rnum) |
356 | 0 | return (ccv_nnc_tensor_symbol_t){ |
357 | 0 | .d = CCV_NNC_NO_TENSOR_SYMBOL, |
358 | 0 | .graph = graph, |
359 | 0 | }; |
360 | 21 | assert(p_idx >= 0 && p_idx < curr_symbol_info->s_ref->rnum); |
361 | 21 | const int s_idx = *(int*)ccv_array_get(curr_symbol_info->s_ref, p_idx); |
362 | 21 | ccv_nnc_symbolic_graph_t* const s = *(ccv_nnc_symbolic_graph_t**)ccv_array_get(curr_graph->sub_graphs, p_idx); |
363 | 21 | curr_symbol.d = s_idx - 1; |
364 | 21 | curr_symbol.graph = s; |
365 | 21 | assert(curr_symbol.d >= 0 && curr_symbol.d < s->tensor_symbol_info->rnum); |
366 | 21 | curr_symbol_info = (ccv_nnc_tensor_symbol_info_t*)ccv_array_get(s->tensor_symbol_info, curr_symbol.d); |
367 | | // Move on. |
368 | 21 | curr_graph = s; |
369 | 21 | } |
370 | 21 | return curr_symbol; |
371 | 21 | } |
372 | | |
373 | | // This method generate tensor symbols and their links along the way when traverse the graph. |
374 | | enum { |
375 | | MAP_TENSOR_USE_AS_INPUT, |
376 | | MAP_TENSOR_USE_AS_OUTPUT, |
377 | | }; |
378 | | |
379 | | static void _ccv_nnc_graph_exec_add_input_if_needed(ccv_nnc_graph_exec_symbol_info_t* const exec_symbol_info, const int d) |
380 | 52 | { |
381 | 52 | int i; |
382 | 83 | for (i = 0; i < exec_symbol_info->input_size; i++31 ) |
383 | 53 | if (exec_symbol_info->inputs[i] == d) |
384 | 22 | return; // No need to continue, this symbol already exists as input. |
385 | | // Expand the array. |
386 | 30 | if (!exec_symbol_info->input_size && !exec_symbol_info->output_size16 ) |
387 | 16 | { |
388 | 16 | exec_symbol_info->inputs = (int*)ccmalloc(sizeof(int)); |
389 | 16 | exec_symbol_info->inputs[0] = d; |
390 | 16 | exec_symbol_info->input_size = 1; |
391 | 16 | exec_symbol_info->outputs = exec_symbol_info->inputs + 1; |
392 | 16 | return; |
393 | 16 | } |
394 | 14 | exec_symbol_info->inputs = (int*)ccrealloc(exec_symbol_info->inputs, sizeof(int) * (exec_symbol_info->input_size + 1 + exec_symbol_info->output_size)); |
395 | 14 | exec_symbol_info->outputs = exec_symbol_info->inputs + exec_symbol_info->input_size; |
396 | 14 | if (exec_symbol_info->output_size) |
397 | 6 | memmove(exec_symbol_info->outputs + 1, exec_symbol_info->outputs, sizeof(int) * exec_symbol_info->output_size); |
398 | 14 | exec_symbol_info->inputs[exec_symbol_info->input_size] = d; |
399 | 14 | ++exec_symbol_info->input_size; |
400 | 14 | exec_symbol_info->outputs = exec_symbol_info->inputs + exec_symbol_info->input_size; |
401 | 14 | } |
402 | | |
403 | | static void _ccv_nnc_graph_exec_add_output_if_needed(ccv_nnc_graph_exec_symbol_info_t* const exec_symbol_info, const int d) |
404 | 52 | { |
405 | 52 | int i; |
406 | 69 | for (i = 0; i < exec_symbol_info->output_size; i++17 ) |
407 | 45 | if (exec_symbol_info->outputs[i] == d) |
408 | 28 | return; // No need to continue, this symbol already exists as output. |
409 | | // Expand the array. |
410 | 24 | if (!exec_symbol_info->input_size && !exec_symbol_info->output_size7 ) |
411 | 3 | { |
412 | 3 | exec_symbol_info->inputs = (int*)ccmalloc(sizeof(int)); |
413 | 3 | exec_symbol_info->outputs = exec_symbol_info->inputs; |
414 | 3 | exec_symbol_info->outputs[0] = d; |
415 | 3 | exec_symbol_info->output_size = 1; |
416 | 3 | return; |
417 | 3 | } |
418 | 21 | exec_symbol_info->inputs = (int*)ccrealloc(exec_symbol_info->inputs, sizeof(int) * (exec_symbol_info->input_size + exec_symbol_info->output_size + 1)); |
419 | 21 | exec_symbol_info->outputs = exec_symbol_info->inputs + exec_symbol_info->input_size; |
420 | 21 | exec_symbol_info->outputs[exec_symbol_info->output_size] = d; |
421 | 21 | ++exec_symbol_info->output_size; |
422 | 21 | } |
423 | | |
424 | | void ccv_nnc_tensor_symbol_pair_with(ccv_nnc_symbolic_graph_t* const graph, const ccv_nnc_tensor_symbol_t tensor_symbol, const ccv_nnc_tensor_symbol_t pair_tensor_symbol) |
425 | 11 | { |
426 | 11 | assert(tensor_symbol.graph == graph); |
427 | 11 | assert(tensor_symbol.d >= 0); |
428 | 11 | assert(tensor_symbol.d < graph->tensor_symbol_info->rnum); |
429 | 11 | assert(pair_tensor_symbol.graph == graph->pair); |
430 | 11 | assert(pair_tensor_symbol.d >= 0); |
431 | 11 | assert(pair_tensor_symbol.d < graph->pair->tensor_symbol_info->rnum); |
432 | 11 | ccv_nnc_tensor_symbol_info_t* const tensor_info = (ccv_nnc_tensor_symbol_info_t*)ccv_array_get(graph->tensor_symbol_info, tensor_symbol.d); |
433 | 11 | tensor_info->pair_ref = pair_tensor_symbol.d + 1; |
434 | 11 | } |
435 | | |
436 | | static int _ccv_nnc_symbolic_graph_map_tensor_symbol_no_alias(ccv_nnc_symbolic_graph_t* const graph, const ccv_nnc_tensor_symbol_t symbol, const int map_use) |
437 | 68 | { |
438 | 68 | assert(graph && symbol.graph); |
439 | 68 | assert(symbol.graph != graph); |
440 | 68 | ccv_nnc_tensor_symbol_info_t* const symbol_info = (ccv_nnc_tensor_symbol_info_t*)ccv_array_get(symbol.graph->tensor_symbol_info, symbol.d); |
441 | 68 | assert(!symbol_info->alias_ref); |
442 | | // Find if the symbol is in the sub-graph. |
443 | 68 | const ccv_nnc_symbolic_graph_t* curr_graph = symbol.graph; |
444 | 68 | assert(symbol.d >= 0 && symbol.d < curr_graph->tensor_symbol_info->rnum); |
445 | 139 | while (68 curr_graph && curr_graph != graph90 ) |
446 | 71 | curr_graph = curr_graph->p; |
447 | 68 | if (curr_graph) |
448 | 19 | { |
449 | | // The graph is a parent of the symbol passed in. For this case, if we are connecting this symbol to an exec as input, |
450 | | // that means it must be an output in these sub-graphs. Otherwise, if we are connecting this symbol to an exec as output, |
451 | | // it must be an input in these sub-graphs. |
452 | 19 | curr_graph = symbol.graph; |
453 | 19 | ccv_nnc_tensor_symbol_info_t* curr_symbol_info = symbol_info; |
454 | 19 | ccv_nnc_tensor_symbol_t curr_symbol = symbol; |
455 | 38 | while (curr_graph != graph) |
456 | 19 | { |
457 | 19 | ccv_nnc_symbolic_graph_t* const p = curr_graph->p; |
458 | | // I need to find the symbol whether it exists or not before creating new one. |
459 | 19 | ccv_nnc_tensor_symbol_t new_symbol; |
460 | 19 | ccv_nnc_tensor_symbol_info_t* new_symbol_info; |
461 | 19 | if (!curr_symbol_info->p_ref) |
462 | 18 | { |
463 | 18 | new_symbol = ccv_nnc_tensor_symbol_new(p, curr_symbol_info->info, curr_symbol_info->name); |
464 | 18 | new_symbol_info = (ccv_nnc_tensor_symbol_info_t*)ccv_array_get(p->tensor_symbol_info, new_symbol.d); |
465 | 18 | curr_symbol_info->p_ref = new_symbol.d + 1; |
466 | 18 | new_symbol_info->s_ref = ccv_array_new(sizeof(int), p->sub_graphs->rnum, 0); |
467 | 18 | new_symbol_info->s_ref->rnum = p->sub_graphs->rnum; |
468 | 18 | ccv_array_zero(new_symbol_info->s_ref); |
469 | 18 | *(int*)ccv_array_get(new_symbol_info->s_ref, curr_graph->p_idx - 1) = curr_symbol.d + 1; |
470 | 18 | } else { |
471 | 1 | new_symbol.d = curr_symbol_info->p_ref - 1; |
472 | 1 | new_symbol.graph = p; |
473 | 1 | assert(new_symbol.d >= 0 && new_symbol.d < p->tensor_symbol_info->rnum); |
474 | 1 | new_symbol_info = (ccv_nnc_tensor_symbol_info_t*)ccv_array_get(p->tensor_symbol_info, new_symbol.d); |
475 | 1 | } |
476 | 19 | if (curr_graph->exec_idx) |
477 | 19 | { |
478 | | // This is a sub-graph. |
479 | 19 | assert(p); |
480 | 19 | assert(curr_graph->exec_idx > 0 && curr_graph->exec_idx <= p->exec_symbol_info->rnum); |
481 | 19 | ccv_nnc_graph_exec_symbol_info_t* const exec_symbol_info = (ccv_nnc_graph_exec_symbol_info_t*)ccv_array_get(p->exec_symbol_info, curr_graph->exec_idx - 1); |
482 | 19 | switch (map_use) |
483 | 19 | { |
484 | 19 | case MAP_TENSOR_USE_AS_INPUT: |
485 | 19 | _ccv_nnc_graph_exec_add_output_if_needed(exec_symbol_info, new_symbol.d); |
486 | 19 | break; |
487 | 0 | case MAP_TENSOR_USE_AS_OUTPUT: |
488 | 0 | _ccv_nnc_graph_exec_add_input_if_needed(exec_symbol_info, new_symbol.d); |
489 | 0 | break; |
490 | 19 | } |
491 | 19 | } |
492 | | // Move on. |
493 | 19 | curr_symbol = new_symbol; |
494 | 19 | curr_symbol_info = new_symbol_info; |
495 | 19 | curr_graph = p; |
496 | 19 | } |
497 | 19 | return curr_symbol.d; |
498 | 19 | } |
499 | | // Otherwise, if the symbol is in the parent graph, this is a bit more expensive because I need to keep a trace stack. |
500 | 49 | curr_graph = graph; |
501 | 49 | int d; |
502 | 99 | for (d = 0; curr_graph && curr_graph != symbol.graph; d++50 ) |
503 | 50 | curr_graph = curr_graph->p; |
504 | 49 | curr_graph = graph; |
505 | 49 | assert(d > 0); |
506 | 49 | int trace[d]; |
507 | 99 | for (d = 0; curr_graph && curr_graph != symbol.graph; d++50 ) |
508 | 50 | { |
509 | 50 | const int p_idx = curr_graph->p_idx - 1; |
510 | 50 | trace[d] = p_idx; |
511 | 50 | curr_graph = curr_graph->p; |
512 | 50 | } |
513 | | // If it is not in both the parent graph and the sub-graph, the input is invalid. |
514 | 49 | assert(curr_graph); |
515 | 49 | curr_graph = symbol.graph; |
516 | 49 | ccv_nnc_tensor_symbol_info_t* curr_symbol_info = symbol_info; |
517 | 49 | ccv_nnc_tensor_symbol_t curr_symbol = symbol; |
518 | | // The graph is a sub graph of the symbol passed in. For this case, if we are connecting this symbol to an exec as input, |
519 | | // that means it must be an input in these parent graphs. Otherwise, if we are connecting this symbol to an exec as output, |
520 | | // it must be an output in these parent graphs. |
521 | 49 | int i; |
522 | 99 | for (i = d - 1; i >= 0; i--50 ) |
523 | 50 | { |
524 | 50 | const int p_idx = trace[i]; |
525 | 50 | assert(p_idx >= 0); |
526 | 50 | assert(curr_graph->sub_graphs); |
527 | 50 | if (!curr_symbol_info->s_ref) |
528 | 36 | { |
529 | 36 | curr_symbol_info->s_ref = ccv_array_new(sizeof(int), curr_graph->sub_graphs->rnum, 0); |
530 | 36 | curr_symbol_info->s_ref->rnum = curr_graph->sub_graphs->rnum; |
531 | 36 | ccv_array_zero(curr_symbol_info->s_ref); |
532 | 36 | } else if (14 curr_symbol_info->s_ref->rnum != curr_graph->sub_graphs->rnum14 ) |
533 | 8 | ccv_array_resize(curr_symbol_info->s_ref, curr_graph->sub_graphs->rnum); |
534 | 50 | assert(p_idx >= 0 && p_idx < curr_symbol_info->s_ref->rnum); |
535 | 50 | const int s_idx = *(int*)ccv_array_get(curr_symbol_info->s_ref, p_idx); |
536 | 50 | ccv_nnc_symbolic_graph_t* const s = *(ccv_nnc_symbolic_graph_t**)ccv_array_get(curr_graph->sub_graphs, p_idx); |
537 | 50 | ccv_nnc_tensor_symbol_t new_symbol; |
538 | 50 | ccv_nnc_tensor_symbol_info_t* new_symbol_info; |
539 | | // I need to find the symbol whether it exists or not before creating new one. |
540 | 50 | if (!s_idx) |
541 | 44 | { |
542 | 44 | new_symbol = ccv_nnc_tensor_symbol_new(s, symbol_info->info, symbol_info->name); |
543 | 44 | new_symbol_info = (ccv_nnc_tensor_symbol_info_t*)ccv_array_get(s->tensor_symbol_info, new_symbol.d); |
544 | 44 | new_symbol_info->p_ref = curr_symbol.d + 1; |
545 | 44 | *(int*)ccv_array_get(curr_symbol_info->s_ref, p_idx) = new_symbol.d + 1; |
546 | 44 | } else { |
547 | 6 | new_symbol.d = s_idx - 1; |
548 | 6 | new_symbol.graph = s; |
549 | 6 | assert(new_symbol.d >= 0 && new_symbol.d < s->tensor_symbol_info->rnum); |
550 | 6 | new_symbol_info = (ccv_nnc_tensor_symbol_info_t*)ccv_array_get(s->tensor_symbol_info, new_symbol.d); |
551 | 6 | } |
552 | 50 | if (s->exec_idx) |
553 | 50 | { |
554 | 50 | assert(s->p); // This is a sub-graph. |
555 | 50 | assert(s->exec_idx > 0 && s->exec_idx <= curr_graph->exec_symbol_info->rnum); |
556 | 50 | ccv_nnc_graph_exec_symbol_info_t* const exec_symbol_info = (ccv_nnc_graph_exec_symbol_info_t*)ccv_array_get(curr_graph->exec_symbol_info, s->exec_idx - 1); |
557 | 50 | switch (map_use) |
558 | 50 | { |
559 | 45 | case MAP_TENSOR_USE_AS_INPUT: |
560 | 45 | _ccv_nnc_graph_exec_add_input_if_needed(exec_symbol_info, curr_symbol.d); |
561 | 45 | break; |
562 | 5 | case MAP_TENSOR_USE_AS_OUTPUT: |
563 | 5 | _ccv_nnc_graph_exec_add_output_if_needed(exec_symbol_info, curr_symbol.d); |
564 | 5 | break; |
565 | 50 | } |
566 | 50 | } |
567 | | // Move on. |
568 | 50 | curr_symbol = new_symbol; |
569 | 50 | curr_symbol_info = new_symbol_info; |
570 | 50 | curr_graph = s; |
571 | 50 | } |
572 | 49 | return curr_symbol.d; |
573 | 49 | } |
574 | | |
575 | | static int _ccv_nnc_symbolic_graph_map_tensor_symbol(ccv_nnc_symbolic_graph_t* const graph, const ccv_nnc_tensor_symbol_t symbol, const int map_use) |
576 | 68 | { |
577 | 68 | assert(graph && symbol.graph); |
578 | 68 | assert(symbol.graph != graph); |
579 | 68 | ccv_nnc_tensor_symbol_info_t* const symbol_info = (ccv_nnc_tensor_symbol_info_t*)ccv_array_get(symbol.graph->tensor_symbol_info, symbol.d); |
580 | 68 | if (!symbol_info->alias_ref) |
581 | 62 | return _ccv_nnc_symbolic_graph_map_tensor_symbol_no_alias(graph, symbol, map_use); |
582 | 6 | const int d = symbol_info->alias_ref - 1; |
583 | 6 | assert(d >= 0 && d < symbol.graph->tensor_symbol_info->rnum); |
584 | 6 | const int map_d = _ccv_nnc_symbolic_graph_map_tensor_symbol_no_alias(graph, (ccv_nnc_tensor_symbol_t){ |
585 | 6 | .graph = symbol.graph, |
586 | 6 | .d = d |
587 | 6 | }, map_use); |
588 | 6 | const ccv_nnc_tensor_symbol_t alias = ccv_nnc_tensor_symbol_alias_new(graph, (ccv_nnc_tensor_symbol_t){ |
589 | 6 | .graph = graph, |
590 | 6 | .d = map_d |
591 | 6 | }, symbol_info->ofs, symbol_info->stride, symbol_info->info, symbol_info->name); |
592 | 6 | return alias.d; |
593 | 6 | } |
594 | | |
595 | | int ccv_nnc_tensor_symbol_map_raw(ccv_nnc_symbolic_graph_t* const graph, const ccv_nnc_tensor_symbol_t symbol) |
596 | 146k | { |
597 | 146k | if (symbol.d >= 0) |
598 | 115k | return symbol.graph != graph ? _ccv_nnc_symbolic_graph_map_tensor_symbol(graph, symbol, MAP_TENSOR_USE_AS_INPUT)61 : symbol.d115k ; |
599 | 31.6k | if (symbol.graph == graph || symbol.d == CCV_NNC_NO_TENSOR_SYMBOL31.5k ) |
600 | 31.6k | return symbol.d; |
601 | 1 | ccv_nnc_symbolic_graph_t* curr_graph = graph; |
602 | 1 | int d; |
603 | 2 | for (d = 0; curr_graph && curr_graph != symbol.graph; d++1 ) |
604 | 1 | curr_graph = curr_graph->p; |
605 | 1 | assert(curr_graph == symbol.graph); |
606 | 1 | return CCV_NNC_ENCODE_WHILE_COUNT_SYMBOL(d); |
607 | 1 | } |
608 | | |
609 | | void ccv_nnc_tensor_symbol_hookup(ccv_nnc_symbolic_graph_t* const src_graph, ccv_nnc_symbolic_graph_t* const dest_graph, const ccv_nnc_tensor_symbol_t src_tensor_symbol, const ccv_nnc_tensor_symbol_t dest_tensor_symbol) |
610 | 35 | { |
611 | 35 | assert(src_graph != dest_graph); |
612 | 35 | assert(src_graph->p == dest_graph || dest_graph->p == src_graph); |
613 | 35 | assert(src_tensor_symbol.d >= 0); |
614 | 35 | assert(dest_tensor_symbol.d >= 0); |
615 | 35 | ccv_nnc_tensor_symbol_t tensor_symbol = src_tensor_symbol; |
616 | 35 | if (tensor_symbol.graph != src_graph) |
617 | 2 | tensor_symbol = (ccv_nnc_tensor_symbol_t){ |
618 | 2 | .graph = src_graph, |
619 | 2 | .d = _ccv_nnc_symbolic_graph_map_tensor_symbol(src_graph, tensor_symbol, MAP_TENSOR_USE_AS_INPUT), |
620 | 2 | }; |
621 | 35 | ccv_nnc_tensor_symbol_t sub_tensor_symbol = dest_tensor_symbol; |
622 | 35 | if (sub_tensor_symbol.graph != dest_graph) |
623 | 0 | sub_tensor_symbol = (ccv_nnc_tensor_symbol_t){ |
624 | 0 | .graph = dest_graph, |
625 | 0 | .d = _ccv_nnc_symbolic_graph_map_tensor_symbol(dest_graph, sub_tensor_symbol, MAP_TENSOR_USE_AS_OUTPUT), |
626 | 0 | }; |
627 | 35 | ccv_nnc_symbolic_graph_t* curr_graph = src_graph; |
628 | 70 | while (curr_graph && curr_graph != dest_graph63 ) |
629 | 35 | curr_graph = curr_graph->p; |
630 | 35 | ccv_nnc_symbolic_graph_t* graph; |
631 | 35 | ccv_nnc_symbolic_graph_t* sub_graph; |
632 | 35 | int map_use; |
633 | 35 | if (curr_graph) |
634 | 28 | { |
635 | | // src_graph is the sub graph, dest_graph is the parent graph. |
636 | 28 | graph = dest_graph; |
637 | 28 | sub_graph = src_graph; |
638 | | // Swap tensor_symbol and sub_tensor_symbol |
639 | 28 | ccv_nnc_tensor_symbol_t x; |
640 | 28 | CCV_SWAP(tensor_symbol, sub_tensor_symbol, x); |
641 | 28 | map_use = MAP_TENSOR_USE_AS_OUTPUT; |
642 | 28 | } else { |
643 | 7 | graph = src_graph; |
644 | 7 | sub_graph = dest_graph; |
645 | 7 | map_use = MAP_TENSOR_USE_AS_INPUT; |
646 | 7 | } |
647 | 35 | ccv_nnc_symbolic_graph_t* p_graph = sub_graph; |
648 | 35 | while (p_graph && p_graph->p != graph) |
649 | 0 | p_graph = p_graph->p; |
650 | 35 | assert(p_graph); |
651 | 35 | if (p_graph != sub_graph) |
652 | 0 | { |
653 | 0 | sub_tensor_symbol.d = _ccv_nnc_symbolic_graph_map_tensor_symbol(p_graph, sub_tensor_symbol, map_use); |
654 | 0 | sub_tensor_symbol.graph = p_graph; |
655 | 0 | sub_graph = p_graph; |
656 | 0 | } |
657 | 35 | assert(tensor_symbol.d < graph->tensor_symbol_info->rnum); |
658 | 35 | assert(sub_tensor_symbol.d < sub_graph->tensor_symbol_info->rnum); |
659 | 35 | ccv_nnc_tensor_symbol_info_t* const sub_tensor_info = (ccv_nnc_tensor_symbol_info_t*)ccv_array_get(sub_graph->tensor_symbol_info, sub_tensor_symbol.d); |
660 | 35 | sub_tensor_info->p_ref = tensor_symbol.d + 1; |
661 | 35 | ccv_nnc_tensor_symbol_info_t* const tensor_info = (ccv_nnc_tensor_symbol_info_t*)ccv_array_get(graph->tensor_symbol_info, tensor_symbol.d); |
662 | 35 | if (!tensor_info->s_ref) |
663 | 15 | { |
664 | 15 | tensor_info->s_ref = ccv_array_new(sizeof(int), graph->sub_graphs->rnum, 0); |
665 | 15 | tensor_info->s_ref->rnum = graph->sub_graphs->rnum; |
666 | 15 | ccv_array_zero(tensor_info->s_ref); |
667 | 20 | } else if (tensor_info->s_ref->rnum != graph->sub_graphs->rnum) |
668 | 20 | ccv_array_resize(tensor_info->s_ref, graph->sub_graphs->rnum); |
669 | 35 | const int p_idx = sub_graph->p_idx - 1; |
670 | 35 | assert(p_idx >= 0 && p_idx < tensor_info->s_ref->rnum); |
671 | 35 | const int s_idx = *(int*)ccv_array_get(tensor_info->s_ref, p_idx); |
672 | 35 | assert(s_idx == 0); // Otherwise it is assigned before |
673 | 35 | *(int*)ccv_array_get(tensor_info->s_ref, p_idx) = sub_tensor_symbol.d + 1; |
674 | 35 | ccv_nnc_graph_exec_symbol_info_t* const exec_symbol_info = (ccv_nnc_graph_exec_symbol_info_t*)ccv_array_get(graph->exec_symbol_info, sub_graph->exec_idx - 1); |
675 | 35 | switch (map_use) |
676 | 35 | { |
677 | 7 | case MAP_TENSOR_USE_AS_INPUT: |
678 | 7 | _ccv_nnc_graph_exec_add_input_if_needed(exec_symbol_info, tensor_symbol.d); |
679 | 7 | break; |
680 | 28 | case MAP_TENSOR_USE_AS_OUTPUT: |
681 | 28 | _ccv_nnc_graph_exec_add_output_if_needed(exec_symbol_info, tensor_symbol.d); |
682 | 28 | break; |
683 | 35 | } |
684 | 35 | } |
685 | | |
686 | | void ccv_nnc_tensor_symbol_set_bypasses(ccv_nnc_symbolic_graph_t* const graph, const ccv_nnc_tensor_symbol_map_t* const symbol_map, const int symbol_map_size) |
687 | 11 | { |
688 | 11 | int i; |
689 | 22 | for (i = 0; i < symbol_map_size; i++11 ) |
690 | 11 | { |
691 | 11 | const ccv_nnc_tensor_symbol_t source = ccv_nnc_tensor_symbol_resolve(graph, symbol_map[i].source); |
692 | 11 | const ccv_nnc_tensor_symbol_t destination = ccv_nnc_tensor_symbol_resolve(graph, symbol_map[i].destination); |
693 | 11 | assert(source.graph == graph); |
694 | 11 | assert(destination.graph == graph); |
695 | 11 | assert(source.d < graph->tensor_symbol_info->rnum); |
696 | 11 | assert(destination.d < graph->tensor_symbol_info->rnum); |
697 | 11 | ccv_nnc_tensor_symbol_info_t* source_tensor_symbol_info = (ccv_nnc_tensor_symbol_info_t*)ccv_array_get(graph->tensor_symbol_info, source.d); |
698 | 11 | ccv_nnc_tensor_symbol_info_t* destination_tensor_symbol_info = (ccv_nnc_tensor_symbol_info_t*)ccv_array_get(graph->tensor_symbol_info, destination.d); |
699 | | // Don't support parameterize with alias. The reason is that to support parameterized loop (for SSA), I choose |
700 | | // to simply reuse the piece of memory (allocating the same memory region to both, therefore to enable parameter |
701 | | // passing). For alias, it is not possible because alias can pointing to the tensors with different sizes, thus, |
702 | | // these pointed tensors cannot share the same memory region. The best way for alias to be parameterized is to |
703 | | // create a new tensor of the same size, transfer value over, and parameterized on that tensor instead. |
704 | 11 | assert(!destination_tensor_symbol_info->alias_ref); |
705 | 11 | assert(!source_tensor_symbol_info->alias_ref); |
706 | 11 | destination_tensor_symbol_info->bypass_ref = source.d + 1; |
707 | 11 | source_tensor_symbol_info->r_bypass_ref = destination.d + 1; |
708 | 11 | } |
709 | 11 | } |
710 | | |
711 | | int ccv_nnc_tensor_symbol_set(ccv_nnc_symbolic_graph_t* const graph, const ccv_nnc_tensor_symbol_t tensor, const ccv_nnc_tensor_param_t info) |
712 | 38.5k | { |
713 | 38.5k | assert(graph == tensor.graph); |
714 | 38.5k | assert(tensor.d < graph->tensor_symbol_info->rnum); |
715 | 38.5k | ccv_nnc_tensor_symbol_info_t* const symbol_info = (ccv_nnc_tensor_symbol_info_t*)ccv_array_get(graph->tensor_symbol_info, tensor.d); |
716 | 38.5k | symbol_info->info = info; |
717 | | // It also need to propagate to assign_ref if needed. |
718 | 38.5k | if (symbol_info->assign_ref) |
719 | 0 | { |
720 | 0 | ccv_nnc_tensor_symbol_info_t* const assign_info = (ccv_nnc_tensor_symbol_info_t*)ccv_array_get(graph->tensor_symbol_info, symbol_info->assign_ref - 1); |
721 | 0 | assign_info->info = info; |
722 | 0 | } |
723 | 38.5k | return 0; |
724 | 38.5k | } |
725 | | |
726 | | ccv_nnc_tensor_param_t ccv_nnc_tensor_symbol_params(const ccv_nnc_symbolic_graph_t* const graph, const ccv_nnc_tensor_symbol_t tensor) |
727 | 72.6k | { |
728 | 72.6k | assert(graph == tensor.graph); |
729 | 72.6k | assert(tensor.d < graph->tensor_symbol_info->rnum); |
730 | 72.6k | ccv_nnc_tensor_symbol_info_t* const symbol_info = (ccv_nnc_tensor_symbol_info_t*)ccv_array_get(graph->tensor_symbol_info, tensor.d); |
731 | 72.6k | return symbol_info->info; |
732 | 72.6k | } |
733 | | |
734 | | const char* ccv_nnc_tensor_symbol_name(const ccv_nnc_symbolic_graph_t* const graph, const ccv_nnc_tensor_symbol_t tensor) |
735 | 0 | { |
736 | 0 | assert(graph == tensor.graph); |
737 | 0 | assert(tensor.d < graph->tensor_symbol_info->rnum); |
738 | 0 | ccv_nnc_tensor_symbol_info_t* const symbol_info = (ccv_nnc_tensor_symbol_info_t*)ccv_array_get(graph->tensor_symbol_info, tensor.d); |
739 | 0 | return symbol_info->name; |
740 | 0 | } |
741 | | |
742 | | int ccv_nnc_tensor_symbol_alias_set(ccv_nnc_symbolic_graph_t* const graph, const ccv_nnc_tensor_symbol_t tensor, const int ofs[CCV_NNC_MAX_DIM_ALLOC], const int stride[CCV_NNC_MAX_DIM_ALLOC]) |
743 | 2.00k | { |
744 | 2.00k | assert(graph == tensor.graph); |
745 | 2.00k | assert(tensor.d < graph->tensor_symbol_info->rnum); |
746 | 2.00k | ccv_nnc_tensor_symbol_info_t* const symbol_info = (ccv_nnc_tensor_symbol_info_t*)ccv_array_get(graph->tensor_symbol_info, tensor.d); |
747 | 2.00k | if (!symbol_info->alias_ref) |
748 | 0 | return -1; |
749 | 2.00k | memcpy(symbol_info->ofs, ofs, sizeof(symbol_info->ofs)); |
750 | 2.00k | memcpy(symbol_info->stride, stride, sizeof(symbol_info->stride)); |
751 | | // We don't need to propagate to assign_ref because alias cannot be loop carry-overs. |
752 | 2.00k | assert(!symbol_info->assign_ref); |
753 | 2.00k | return 0; |
754 | 2.00k | } |
755 | | |
756 | | int ccv_nnc_tensor_symbol_alias_params(const ccv_nnc_symbolic_graph_t* const graph, const ccv_nnc_tensor_symbol_t tensor, int ofs[CCV_NNC_MAX_DIM_ALLOC], int stride[CCV_NNC_MAX_DIM_ALLOC]) |
757 | 38.5k | { |
758 | 38.5k | assert(graph == tensor.graph); |
759 | 38.5k | assert(tensor.d < graph->tensor_symbol_info->rnum); |
760 | 38.5k | ccv_nnc_tensor_symbol_info_t* const symbol_info = (ccv_nnc_tensor_symbol_info_t*)ccv_array_get(graph->tensor_symbol_info, tensor.d); |
761 | 38.5k | if (!symbol_info->alias_ref) |
762 | 36.5k | return -1; |
763 | 2.00k | if (ofs) |
764 | 2.00k | memcpy(ofs, symbol_info->ofs, sizeof(symbol_info->ofs)); |
765 | 2.00k | if (stride) |
766 | 2.00k | memcpy(stride, symbol_info->stride, sizeof(symbol_info->stride)); |
767 | 2.00k | return 0; |
768 | 38.5k | } |
769 | | |
770 | | int ccv_nnc_tensor_symbol_set_flags(ccv_nnc_symbolic_graph_t* const graph, const ccv_nnc_tensor_symbol_t tensor, const int flags) |
771 | 43.4k | { |
772 | 43.4k | assert(graph == tensor.graph); |
773 | 43.4k | assert(tensor.d < graph->tensor_symbol_info->rnum); |
774 | 43.4k | ccv_nnc_tensor_symbol_info_t* const symbol_info = (ccv_nnc_tensor_symbol_info_t*)ccv_array_get(graph->tensor_symbol_info, tensor.d); |
775 | 43.4k | symbol_info->flags = flags; |
776 | | // It also need to propagate to assign_ref if needed. |
777 | 43.4k | if (symbol_info->assign_ref) |
778 | 1 | { |
779 | 1 | ccv_nnc_tensor_symbol_info_t* const assign_info = (ccv_nnc_tensor_symbol_info_t*)ccv_array_get(graph->tensor_symbol_info, symbol_info->assign_ref - 1); |
780 | 1 | assign_info->flags = flags; |
781 | 1 | } |
782 | 43.4k | return 0; |
783 | 43.4k | } |
784 | | |
785 | | int ccv_nnc_tensor_symbol_flags(ccv_nnc_symbolic_graph_t* const graph, const ccv_nnc_tensor_symbol_t tensor) |
786 | 11 | { |
787 | 11 | assert(graph == tensor.graph); |
788 | 11 | assert(tensor.d < graph->tensor_symbol_info->rnum); |
789 | 11 | ccv_nnc_tensor_symbol_info_t* const symbol_info = (ccv_nnc_tensor_symbol_info_t*)ccv_array_get(graph->tensor_symbol_info, tensor.d); |
790 | 11 | return symbol_info->flags; |
791 | 11 | } |
792 | | |
793 | | void ccv_nnc_tensor_symbol_free(ccv_nnc_symbolic_graph_t* const graph, ccv_nnc_tensor_symbol_t tensor) |
794 | 69.4k | { |
795 | 69.4k | assert(graph == tensor.graph); |
796 | 69.4k | assert(tensor.d < graph->tensor_symbol_info->rnum); |
797 | 69.4k | ccv_nnc_tensor_symbol_info_t* const symbol_info = (ccv_nnc_tensor_symbol_info_t*)ccv_array_get(graph->tensor_symbol_info, tensor.d); |
798 | 69.4k | if (symbol_info->s_ref) |
799 | 0 | { |
800 | 0 | ccv_array_free(symbol_info->s_ref); |
801 | 0 | symbol_info->s_ref = 0; |
802 | 0 | } |
803 | 69.4k | if (symbol_info->name) |
804 | 8 | { |
805 | 8 | ccfree(symbol_info->name); |
806 | 8 | symbol_info->name = 0; |
807 | 8 | } |
808 | 69.4k | symbol_info->flags |= CCV_NNC_TENSOR_SYMBOL_DEAD; |
809 | 69.4k | int i; |
810 | 136k | for (i = graph->tensor_symbol_info->rnum - 1; i >= 0; i--66.7k ) |
811 | 133k | if (!CCV_NNC_TENSOR_SYMBOL_IS_DEAD(((ccv_nnc_tensor_symbol_info_t*)ccv_array_get(graph->tensor_symbol_info, i))->flags)) |
812 | 66.8k | { |
813 | 66.8k | graph->tensor_symbol_info->rnum = i + 1; |
814 | 66.8k | break; |
815 | 66.8k | } |
816 | 69.4k | if (tensor.d < graph->tensor_symbol_info->rnum && |
817 | 69.4k | (55.3k tensor.d < graph->reuse.tensor55.3k || graph->reuse.tensor < 055.3k )) |
818 | 13.2k | graph->reuse.tensor = tensor.d; |
819 | 56.1k | else if (graph->reuse.tensor >= graph->tensor_symbol_info->rnum) |
820 | 8.43k | graph->reuse.tensor = -1; |
821 | 69.4k | } |
822 | | |
823 | | static void _ccv_nnc_graph_exec_symbol_set_io(ccv_nnc_symbolic_graph_t* const graph, ccv_nnc_graph_exec_symbol_info_t* const exec_info, const ccv_nnc_tensor_symbol_t* const inputs, const int input_size, const ccv_nnc_tensor_symbol_t* const outputs, const int output_size) |
824 | 57.2k | { |
825 | 57.2k | exec_info->input_size = input_size; |
826 | 57.2k | exec_info->output_size = output_size; |
827 | 57.2k | if (input_size > 0 || output_size > 04.40k ) |
828 | 57.2k | { |
829 | 57.2k | if (!exec_info->inputs) |
830 | 56.5k | exec_info->inputs = ccmalloc(sizeof(int) * (input_size + output_size)); |
831 | 636 | else |
832 | 636 | exec_info->inputs = ccrealloc(exec_info->inputs, sizeof(int) * (input_size + output_size)); |
833 | 57.2k | exec_info->outputs = exec_info->inputs + input_size; |
834 | 57.2k | } |
835 | 57.2k | int i; |
836 | 57.2k | int tensor_memory = 0, tensor_formats = 0, tensor_datatypes = 0, tensor_auto = 0; |
837 | 204k | for (i = 0; i < input_size; i++146k ) |
838 | 146k | { |
839 | 146k | const int d = ccv_nnc_tensor_symbol_map_raw(graph, inputs[i]); |
840 | 146k | exec_info->inputs[i] = d; |
841 | 146k | if (d >= 0) |
842 | 115k | { |
843 | 115k | const ccv_nnc_tensor_symbol_info_t* const tensor_info = (ccv_nnc_tensor_symbol_info_t*)ccv_array_get(graph->tensor_symbol_info, d); |
844 | 115k | tensor_auto = tensor_auto || ccv_nnc_is_tensor_auto(tensor_info->info)115k ; |
845 | 115k | tensor_memory |= CCV_TENSOR_GET_MEMORY(tensor_info->info.type), tensor_formats |= tensor_info->info.format, tensor_datatypes |= CCV_GET_DATA_TYPE(tensor_info->info.datatype); |
846 | 115k | } |
847 | 146k | } |
848 | 138k | for (i = 0; i < output_size; i++80.8k ) |
849 | 80.8k | { |
850 | 80.8k | const int d = (outputs[i].graph != graph && outputs[i].d >= 09.34k ) ? _ccv_nnc_symbolic_graph_map_tensor_symbol(graph, outputs[i], MAP_TENSOR_USE_AS_OUTPUT)5 : outputs[i].d80.8k ; |
851 | 80.8k | exec_info->outputs[i] = d; |
852 | 80.8k | if (d >= 0) |
853 | 71.5k | { |
854 | 71.5k | const ccv_nnc_tensor_symbol_info_t* const tensor_info = (ccv_nnc_tensor_symbol_info_t*)ccv_array_get(graph->tensor_symbol_info, d); |
855 | 71.5k | tensor_auto = tensor_auto || ccv_nnc_is_tensor_auto(tensor_info->info)71.5k ; |
856 | 71.5k | tensor_memory |= CCV_TENSOR_GET_MEMORY(tensor_info->info.type), tensor_formats |= tensor_info->info.format, tensor_datatypes |= CCV_GET_DATA_TYPE(tensor_info->info.datatype); |
857 | 71.5k | } |
858 | 80.8k | } |
859 | | // If there is no auto tensor, we try to find backend (we don't know which backend if the tensor is auto). |
860 | 57.2k | if (!tensor_auto) |
861 | 57.2k | exec_info->cmd.backend = ccv_nnc_cmd_find_backend(exec_info->cmd, tensor_memory, tensor_formats, tensor_datatypes); |
862 | 57.2k | } |
863 | | |
864 | | ccv_nnc_graph_exec_symbol_t ccv_nnc_graph_exec_symbol_new(ccv_nnc_symbolic_graph_t* const graph, const ccv_nnc_cmd_t cmd, const ccv_nnc_tensor_symbol_t* const inputs, const int input_size, const ccv_nnc_tensor_symbol_t* const outputs, const int output_size, const char* const name) |
865 | 56.6k | { |
866 | 56.6k | ccv_nnc_graph_exec_symbol_t symbol = { |
867 | 56.6k | .d = graph->exec_symbol_info->rnum, |
868 | 56.6k | .graph = graph |
869 | 56.6k | }; |
870 | 56.6k | ccv_nnc_graph_exec_symbol_info_t symbol_info = { |
871 | 56.6k | .cmd = cmd, |
872 | 56.6k | .hint = ccv_nnc_no_hint, |
873 | 56.6k | }; |
874 | 56.6k | if (name) |
875 | 4.49k | { |
876 | 4.49k | const size_t len = strnlen(name, 63); |
877 | 4.49k | const size_t n = len + 1; |
878 | 4.49k | symbol_info.name = (char*)ccmalloc(n); |
879 | | // Don't use strndup because this way I can have custom allocator (for ccmalloc). |
880 | 4.49k | memcpy(symbol_info.name, name, n); |
881 | 4.49k | symbol_info.name[len] = 0; |
882 | 4.49k | } |
883 | 56.6k | _ccv_nnc_graph_exec_symbol_set_io(graph, &symbol_info, inputs, input_size, outputs, output_size); |
884 | 56.6k | if (graph->reuse.exec >= 0) |
885 | 9.61k | { |
886 | 9.61k | const int reuse_exec_d = graph->reuse.exec; |
887 | 9.61k | assert(reuse_exec_d < graph->exec_symbol_info->rnum); |
888 | 9.61k | *(ccv_nnc_graph_exec_symbol_info_t*)ccv_array_get(graph->exec_symbol_info, reuse_exec_d) = symbol_info; |
889 | 9.61k | int i; |
890 | 9.61k | graph->reuse.exec = -1; |
891 | 14.4k | for (i = reuse_exec_d + 1; i < graph->exec_symbol_info->rnum && graph->reuse.exec < 07.22k ; i++4.81k ) |
892 | 4.81k | if (CCV_NNC_GRAPH_EXEC_IS_DEAD(((ccv_nnc_graph_exec_symbol_info_t*)ccv_array_get(graph->exec_symbol_info, i))->flags)) |
893 | 4.80k | graph->reuse.exec = i; |
894 | 9.61k | symbol.d = reuse_exec_d; |
895 | 9.61k | } else |
896 | 47.0k | ccv_array_push(graph->exec_symbol_info, &symbol_info); |
897 | 56.6k | if (graph->hooks.graph_exec_symbol_new.func) |
898 | 35.1k | graph->hooks.graph_exec_symbol_new.func(graph->hooks.graph_exec_symbol_new.context, symbol, cmd, inputs, input_size, outputs, output_size, name); |
899 | 56.6k | return symbol; |
900 | 56.6k | } |
901 | | |
902 | | void* ccv_nnc_graph_exec_symbol_new_hook(ccv_nnc_symbolic_graph_t* const graph, ccv_nnc_graph_exec_symbol_new_hook_f hook, void* context, ccv_nnc_graph_exec_symbol_new_hook_f* previous_hook) |
903 | 16.1k | { |
904 | 16.1k | if (previous_hook) |
905 | 7 | *previous_hook = graph->hooks.graph_exec_symbol_new.func; |
906 | 16.1k | void* const prev = graph->hooks.graph_exec_symbol_new.context; |
907 | 16.1k | graph->hooks.graph_exec_symbol_new.func = hook; |
908 | 16.1k | graph->hooks.graph_exec_symbol_new.context = context; |
909 | 16.1k | return prev; |
910 | 16.1k | } |
911 | | |
912 | | void ccv_nnc_graph_exec_symbol_set_io(ccv_nnc_symbolic_graph_t* const graph, const ccv_nnc_graph_exec_symbol_t exec, const ccv_nnc_tensor_symbol_t* const inputs, const int input_size, const ccv_nnc_tensor_symbol_t* const outputs, const int output_size) |
913 | 636 | { |
914 | 636 | assert(exec.graph == graph); |
915 | 636 | assert(exec.d >= 0); |
916 | 636 | assert(exec.d < graph->exec_symbol_info->rnum); |
917 | 636 | ccv_nnc_graph_exec_symbol_info_t* const exec_info = (ccv_nnc_graph_exec_symbol_info_t*)ccv_array_get(graph->exec_symbol_info, exec.d); |
918 | 636 | _ccv_nnc_graph_exec_symbol_set_io(graph, exec_info, inputs, input_size, outputs, output_size); |
919 | 636 | } |
920 | | |
921 | | void ccv_nnc_graph_exec_symbol_pair_with(ccv_nnc_symbolic_graph_t* const graph, const ccv_nnc_graph_exec_symbol_t exec_symbol, const ccv_nnc_graph_exec_symbol_t pair_exec_symbol) |
922 | 19.1k | { |
923 | 19.1k | assert(exec_symbol.graph == graph); |
924 | 19.1k | assert(exec_symbol.d >= 0); |
925 | 19.1k | assert(exec_symbol.d < graph->exec_symbol_info->rnum); |
926 | 19.1k | assert(pair_exec_symbol.graph == graph || pair_exec_symbol.graph == graph->pair); |
927 | 19.1k | assert(pair_exec_symbol.d >= 0); |
928 | 19.1k | if (pair_exec_symbol.graph == graph) |
929 | 19.1k | { assert(pair_exec_symbol.d < graph->exec_symbol_info->rnum); } |
930 | 4 | else |
931 | 4 | { assert(pair_exec_symbol.d < graph->pair->exec_symbol_info->rnum); } |
932 | 19.1k | ccv_nnc_graph_exec_symbol_info_t* const exec_info = (ccv_nnc_graph_exec_symbol_info_t*)ccv_array_get(graph->exec_symbol_info, exec_symbol.d); |
933 | 19.1k | exec_info->pair_ref = pair_exec_symbol.d + 1; |
934 | 19.1k | } |
935 | | |
936 | | void ccv_nnc_graph_exec_symbol_set(ccv_nnc_symbolic_graph_t* const graph, const ccv_nnc_graph_exec_symbol_t exec, const ccv_nnc_cmd_t cmd) |
937 | 43.7k | { |
938 | 43.7k | assert(graph == exec.graph); |
939 | 43.7k | assert(exec.d < graph->exec_symbol_info->rnum); |
940 | 43.7k | ccv_nnc_graph_exec_symbol_info_t* symbol_info = (ccv_nnc_graph_exec_symbol_info_t*)ccv_array_get(graph->exec_symbol_info, exec.d); |
941 | 43.7k | symbol_info->cmd = cmd; |
942 | 43.7k | } |
943 | | |
944 | | void ccv_nnc_graph_exec_symbol_set_flags(ccv_nnc_symbolic_graph_t* const graph, const ccv_nnc_graph_exec_symbol_t exec, const int flags) |
945 | 1.02k | { |
946 | 1.02k | assert(graph == exec.graph); |
947 | 1.02k | assert(exec.d < graph->exec_symbol_info->rnum); |
948 | 1.02k | ccv_nnc_graph_exec_symbol_info_t* symbol_info = (ccv_nnc_graph_exec_symbol_info_t*)ccv_array_get(graph->exec_symbol_info, exec.d); |
949 | 1.02k | assert(!(flags & 0xffff)); // the pass-in flag shouldn't set the lower 16-bit. |
950 | 1.02k | symbol_info->flags = flags | (symbol_info->flags & 0xffff); |
951 | 1.02k | } |
952 | | |
953 | | int ccv_nnc_graph_exec_symbol_flags(ccv_nnc_symbolic_graph_t* const graph, const ccv_nnc_graph_exec_symbol_t exec) |
954 | 0 | { |
955 | 0 | assert(graph == exec.graph); |
956 | 0 | assert(exec.d < graph->exec_symbol_info->rnum); |
957 | 0 | ccv_nnc_graph_exec_symbol_info_t* symbol_info = (ccv_nnc_graph_exec_symbol_info_t*)ccv_array_get(graph->exec_symbol_info, exec.d); |
958 | 0 | return (symbol_info->flags & 0xffff0000); |
959 | 0 | } |
960 | | |
961 | | ccv_nnc_cmd_t ccv_nnc_graph_exec_symbol_cmd(const ccv_nnc_symbolic_graph_t* const graph, const ccv_nnc_graph_exec_symbol_t exec) |
962 | 51.6k | { |
963 | 51.6k | assert(graph == exec.graph); |
964 | 51.6k | assert(exec.d < graph->exec_symbol_info->rnum); |
965 | 51.6k | ccv_nnc_graph_exec_symbol_info_t* symbol_info = (ccv_nnc_graph_exec_symbol_info_t*)ccv_array_get(graph->exec_symbol_info, exec.d); |
966 | 51.6k | return symbol_info->cmd; |
967 | 51.6k | } |
968 | | |
969 | | const char* ccv_nnc_graph_exec_symbol_name(const ccv_nnc_symbolic_graph_t* const graph, const ccv_nnc_graph_exec_symbol_t exec) |
970 | 0 | { |
971 | 0 | assert(graph == exec.graph); |
972 | 0 | assert(exec.d < graph->exec_symbol_info->rnum); |
973 | 0 | ccv_nnc_graph_exec_symbol_info_t* symbol_info = (ccv_nnc_graph_exec_symbol_info_t*)ccv_array_get(graph->exec_symbol_info, exec.d); |
974 | 0 | return symbol_info->name; |
975 | 0 | } |
976 | | |
977 | | void ccv_nnc_graph_exec_symbol_set_hint(ccv_nnc_symbolic_graph_t* const graph, const ccv_nnc_graph_exec_symbol_t exec, const ccv_nnc_hint_t hint) |
978 | 20.8k | { |
979 | 20.8k | assert(graph == exec.graph); |
980 | 20.8k | assert(exec.d < graph->exec_symbol_info->rnum); |
981 | 20.8k | ccv_nnc_graph_exec_symbol_info_t* symbol_info = (ccv_nnc_graph_exec_symbol_info_t*)ccv_array_get(graph->exec_symbol_info, exec.d); |
982 | 20.8k | symbol_info->hint = hint; |
983 | 20.8k | } |
984 | | |
985 | | int ccv_nnc_graph_exec_symbol_concat(ccv_nnc_symbolic_graph_t* const graph, const ccv_nnc_graph_exec_symbol_t source, const ccv_nnc_graph_exec_symbol_t destination) |
986 | 74.4k | { |
987 | 74.4k | assert(graph == source.graph); |
988 | 74.4k | assert(graph == destination.graph); |
989 | 74.4k | assert(source.d < graph->exec_symbol_info->rnum); |
990 | 74.4k | assert(destination.d < graph->exec_symbol_info->rnum); |
991 | 74.4k | ccv_nnc_graph_exec_symbol_info_t* src_symbol_info = (ccv_nnc_graph_exec_symbol_info_t*)ccv_array_get(graph->exec_symbol_info, source.d); |
992 | 74.4k | if (!src_symbol_info->outgoings) |
993 | 47.2k | src_symbol_info->outgoings = ccv_array_new(sizeof(int32_t), 1, 0); |
994 | 27.2k | else { |
995 | 27.2k | int i; |
996 | | // Check if this is already connected, if so, skip. |
997 | 46.1k | for (i = 0; i < src_symbol_info->outgoings->rnum; i++18.8k ) |
998 | 30.6k | if (*(int*)ccv_array_get(src_symbol_info->outgoings, i) == destination.d) |
999 | 11.7k | return -1; |
1000 | 27.2k | } |
1001 | 62.7k | ccv_array_push(src_symbol_info->outgoings, &destination.d); |
1002 | 62.7k | return 0; |
1003 | 74.4k | } |
1004 | | |
1005 | | void ccv_nnc_graph_exec_symbol_io(const ccv_nnc_symbolic_graph_t* const graph, const ccv_nnc_graph_exec_symbol_t symbol, const int** const inputs, int* const input_size, const int** const outputs, int* const output_size) |
1006 | 87.0k | { |
1007 | 87.0k | assert(graph == symbol.graph); |
1008 | 87.0k | assert(symbol.d < graph->exec_symbol_info->rnum); |
1009 | 87.0k | const ccv_nnc_graph_exec_symbol_info_t* const symbol_info = (ccv_nnc_graph_exec_symbol_info_t*)ccv_array_get(graph->exec_symbol_info, symbol.d); |
1010 | 87.0k | if (inputs) |
1011 | 61.6k | *inputs = symbol_info->inputs; |
1012 | 87.0k | if (input_size) |
1013 | 69.5k | *input_size = symbol_info->input_size; |
1014 | 87.0k | if (outputs) |
1015 | 74.4k | *outputs = symbol_info->outputs; |
1016 | 87.0k | if (output_size) |
1017 | 82.3k | *output_size = symbol_info->output_size; |
1018 | 87.0k | } |
1019 | | |
1020 | | void ccv_nnc_graph_exec_symbol_replace_io(const ccv_nnc_symbolic_graph_t* const graph, const ccv_nnc_graph_exec_symbol_t symbol, const ccv_nnc_tensor_symbol_t old_symbol, const ccv_nnc_tensor_symbol_t new_symbol) |
1021 | 4 | { |
1022 | 4 | assert(graph == symbol.graph); |
1023 | 4 | assert(symbol.d < graph->exec_symbol_info->rnum); |
1024 | 4 | assert(graph == old_symbol.graph); |
1025 | 4 | assert(old_symbol.d < graph->tensor_symbol_info->rnum); |
1026 | 4 | assert(graph == new_symbol.graph); |
1027 | 4 | assert(new_symbol.d < graph->tensor_symbol_info->rnum); |
1028 | 4 | const ccv_nnc_tensor_symbol_info_t* const old_tensor_info = (ccv_nnc_tensor_symbol_info_t*)ccv_array_get(graph->tensor_symbol_info, old_symbol.d); |
1029 | 4 | const ccv_nnc_tensor_symbol_info_t* const new_tensor_info = (ccv_nnc_tensor_symbol_info_t*)ccv_array_get(graph->tensor_symbol_info, new_symbol.d); |
1030 | 4 | if (old_tensor_info != new_tensor_info) |
1031 | 4 | { |
1032 | | // These need to be the same, otherwise we need to find the backend again for this exec. See _ccv_nnc_graph_exec_symbol_set_io |
1033 | 4 | assert(ccv_nnc_is_tensor_auto(old_tensor_info->info) == ccv_nnc_is_tensor_auto(new_tensor_info->info)); |
1034 | 4 | assert(old_tensor_info->info.type == new_tensor_info->info.type); |
1035 | 4 | assert(old_tensor_info->info.format == new_tensor_info->info.format); |
1036 | 4 | assert(old_tensor_info->info.datatype == new_tensor_info->info.datatype); |
1037 | 4 | } |
1038 | 4 | const ccv_nnc_graph_exec_symbol_info_t* const symbol_info = (ccv_nnc_graph_exec_symbol_info_t*)ccv_array_get(graph->exec_symbol_info, symbol.d); |
1039 | 4 | int i; |
1040 | 12 | for (i = 0; i < symbol_info->input_size; i++8 ) |
1041 | 8 | if (symbol_info->inputs[i] == old_symbol.d) |
1042 | 4 | symbol_info->inputs[i] = new_symbol.d; |
1043 | 8 | for (i = 0; i < symbol_info->output_size; i++4 ) |
1044 | 4 | if (symbol_info->outputs[i] == old_symbol.d) |
1045 | 0 | symbol_info->outputs[i] = new_symbol.d; |
1046 | 4 | } |
1047 | | |
1048 | | void ccv_nnc_graph_exec_symbol_to(const ccv_nnc_symbolic_graph_t* const graph, const ccv_nnc_graph_exec_symbol_t symbol, const int** const tos, int* const to_size) |
1049 | 37.6k | { |
1050 | 37.6k | assert(graph == symbol.graph); |
1051 | 37.6k | assert(symbol.d < graph->exec_symbol_info->rnum); |
1052 | 37.6k | assert(tos); |
1053 | 37.6k | assert(to_size); |
1054 | 37.6k | const ccv_nnc_graph_exec_symbol_info_t* const symbol_info = (ccv_nnc_graph_exec_symbol_info_t*)ccv_array_get(graph->exec_symbol_info, symbol.d); |
1055 | 37.6k | if (!symbol_info->outgoings) |
1056 | 4.51k | { |
1057 | 4.51k | *tos = 0; |
1058 | 4.51k | *to_size = 0; |
1059 | 4.51k | return; |
1060 | 4.51k | } |
1061 | 33.1k | *to_size = symbol_info->outgoings->rnum; |
1062 | 33.1k | *tos = (int*)ccv_array_get(symbol_info->outgoings, 0); |
1063 | 33.1k | } |
1064 | | |
1065 | | int ccv_nnc_graph_exec_symbol_count(const ccv_nnc_symbolic_graph_t* const graph) |
1066 | 22.0k | { |
1067 | 22.0k | return graph->exec_symbol_info->rnum; |
1068 | 22.0k | } |
1069 | | |
1070 | | int ccv_nnc_symbolic_graph_active_symbol_count(const ccv_nnc_symbolic_graph_t* const graph, const int type) |
1071 | 451 | { |
1072 | 451 | assert(type == CCV_NNC_SYMBOL_TENSOR || type == CCV_NNC_SYMBOL_GRAPH_EXEC); |
1073 | 451 | if (type == CCV_NNC_SYMBOL_GRAPH_EXEC) |
1074 | 422 | { |
1075 | 422 | int i, count = graph->exec_symbol_info->rnum; |
1076 | 864 | for (i = 0; i < graph->exec_symbol_info->rnum; i++442 ) |
1077 | 442 | if (CCV_NNC_GRAPH_EXEC_IS_DEAD(((ccv_nnc_graph_exec_symbol_info_t*)ccv_array_get(graph->exec_symbol_info, i))->flags)) |
1078 | 416 | --count; |
1079 | 422 | return count; |
1080 | 422 | } else if (29 type == CCV_NNC_SYMBOL_TENSOR29 ) { |
1081 | 29 | int i, count = graph->tensor_symbol_info->rnum; |
1082 | 134 | for (i = 0; i < graph->tensor_symbol_info->rnum; i++105 ) |
1083 | 105 | if (CCV_NNC_TENSOR_SYMBOL_IS_DEAD(((ccv_nnc_tensor_symbol_info_t*)ccv_array_get(graph->tensor_symbol_info, i))->flags)) |
1084 | 54 | --count; |
1085 | 29 | return count; |
1086 | 29 | } |
1087 | 0 | return 0; |
1088 | 451 | } |
1089 | | |
1090 | | int ccv_nnc_tensor_symbol_count(const ccv_nnc_symbolic_graph_t* const graph) |
1091 | 84 | { |
1092 | 84 | return graph->tensor_symbol_info->rnum; |
1093 | 84 | } |
1094 | | |
1095 | | static inline void _ccv_nnc_graph_exec_symbol_free(ccv_nnc_graph_exec_symbol_info_t* const symbol_info, const int zeroing) |
1096 | 56.7k | { |
1097 | 56.7k | if (symbol_info->name) |
1098 | 4.52k | ccfree(symbol_info->name); |
1099 | 56.7k | if (symbol_info->_heap_graph_ref) |
1100 | 7 | ccfree(symbol_info->_heap_graph_ref); |
1101 | 56.7k | ccv_array_t* outgoings = symbol_info->outgoings; |
1102 | 56.7k | if (outgoings) |
1103 | 47.2k | ccv_array_free(outgoings); |
1104 | | // We allocate inputs & outputs in continuous fashion, therefore, only need to free the input array. |
1105 | 56.7k | if (symbol_info->inputs) |
1106 | 56.6k | ccfree(symbol_info->inputs); |
1107 | 56.7k | if (symbol_info->flags & CCV_NNC_GRAPH_EXEC_P_WHILE) |
1108 | 24 | if (symbol_info->p_while.inputs) |
1109 | 19 | ccfree(symbol_info->p_while.inputs); |
1110 | 56.7k | if (zeroing) |
1111 | 43.5k | { |
1112 | 43.5k | symbol_info->name = 0; |
1113 | 43.5k | symbol_info->_heap_graph_ref = 0; |
1114 | 43.5k | symbol_info->outgoings = 0; |
1115 | 43.5k | symbol_info->inputs = 0; |
1116 | 43.5k | symbol_info->input_size = 0; |
1117 | 43.5k | symbol_info->outputs = 0; |
1118 | 43.5k | symbol_info->output_size = 0; |
1119 | 43.5k | } |
1120 | 56.7k | } |
1121 | | |
1122 | | void ccv_nnc_graph_exec_symbol_free(ccv_nnc_symbolic_graph_t* const graph, const ccv_nnc_graph_exec_symbol_t symbol) |
1123 | 43.5k | { |
1124 | 43.5k | assert(graph == symbol.graph); |
1125 | 43.5k | assert(symbol.d < graph->exec_symbol_info->rnum); |
1126 | | // If any of the exec symbol have reference to it, has to remove that. |
1127 | 43.5k | int i, j, k; |
1128 | 43.5k | ccv_nnc_graph_exec_symbol_info_t* const free_symbol_info = (ccv_nnc_graph_exec_symbol_info_t*)ccv_array_get(graph->exec_symbol_info, symbol.d); |
1129 | 334k | for (i = 0; i < graph->exec_symbol_info->rnum; i++290k ) |
1130 | 290k | if (i != symbol.d) |
1131 | 247k | { |
1132 | 247k | ccv_nnc_graph_exec_symbol_info_t* const symbol_info = (ccv_nnc_graph_exec_symbol_info_t*)ccv_array_get(graph->exec_symbol_info, i); |
1133 | 247k | if (symbol_info->outgoings) |
1134 | 349k | for (j = 0; 171k j < symbol_info->outgoings->rnum; j++177k ) |
1135 | 227k | if (*(int*)ccv_array_get(symbol_info->outgoings, j) == symbol.d) |
1136 | 49.4k | { |
1137 | 49.4k | if (j < symbol_info->outgoings->rnum - 1) |
1138 | 65 | *(int*)ccv_array_get(symbol_info->outgoings, j) = *(int*)ccv_array_get(symbol_info->outgoings, symbol_info->outgoings->rnum - 1); |
1139 | 49.4k | --symbol_info->outgoings->rnum; |
1140 | 49.4k | if (free_symbol_info->outgoings) |
1141 | 61.8k | for (k = 0; 35.5k k < free_symbol_info->outgoings->rnum; k++26.3k ) |
1142 | 26.3k | ccv_array_add_unique_int(symbol_info->outgoings, *(int*)ccv_array_get(free_symbol_info->outgoings, k)); |
1143 | 49.4k | break; |
1144 | 49.4k | } |
1145 | 247k | } |
1146 | | // Deallocate any memory for exec symbol. |
1147 | 43.5k | _ccv_nnc_graph_exec_symbol_free(free_symbol_info, 1); |
1148 | 43.5k | free_symbol_info->flags = CCV_NNC_GRAPH_EXEC_DEAD; // Mark this as dead. |
1149 | | // If everything from symbol.d to the end of the graph is dead, we can reclaim this memory. |
1150 | 87.0k | for (i = graph->exec_symbol_info->rnum - 1; i >= 0; i--43.5k ) |
1151 | 82.2k | if (!CCV_NNC_GRAPH_EXEC_IS_DEAD(((ccv_nnc_graph_exec_symbol_info_t*)ccv_array_get(graph->exec_symbol_info, i))->flags)) |
1152 | 38.6k | { |
1153 | 38.6k | graph->exec_symbol_info->rnum = i + 1; |
1154 | 38.6k | break; |
1155 | 38.6k | } |
1156 | | // Loop over sources and destinations to remove this. |
1157 | 43.5k | if (graph->sources) |
1158 | 114 | for (i = 0; 42 i < graph->sources->rnum; i++72 ) |
1159 | 80 | if (*(int*)ccv_array_get(graph->sources, i) == symbol.d) |
1160 | 8 | { |
1161 | 8 | if (i < graph->sources->rnum - 1) |
1162 | 1 | *(int*)ccv_array_get(graph->sources, i) = *(int*)ccv_array_get(graph->sources, graph->sources->rnum - 1); |
1163 | 8 | --graph->sources->rnum; |
1164 | 8 | break; |
1165 | 8 | } |
1166 | 43.5k | if (graph->destinations) |
1167 | 85 | for (i = 0; 42 i < graph->destinations->rnum; i++43 ) |
1168 | 54 | if (*(int*)ccv_array_get(graph->destinations, i) == symbol.d) |
1169 | 11 | { |
1170 | 11 | if (i < graph->destinations->rnum - 1) |
1171 | 4 | *(int*)ccv_array_get(graph->destinations, i) = *(int*)ccv_array_get(graph->destinations, graph->destinations->rnum - 1); |
1172 | 11 | --graph->destinations->rnum; |
1173 | 11 | break; |
1174 | 11 | } |
1175 | 43.5k | if (symbol.d < graph->exec_symbol_info->rnum && |
1176 | 43.5k | (27.8k symbol.d < graph->reuse.exec27.8k || graph->reuse.exec < 027.8k )) |
1177 | 9.28k | graph->reuse.exec = symbol.d; |
1178 | 34.2k | else if (graph->reuse.exec >= graph->exec_symbol_info->rnum) |
1179 | 4.43k | graph->reuse.exec = -1; |
1180 | 43.5k | } |
1181 | | |
1182 | | int ccv_nnc_graph_exec_symbol_disjoin(ccv_nnc_symbolic_graph_t* const graph, const ccv_nnc_graph_exec_symbol_t source, const ccv_nnc_graph_exec_symbol_t destination) |
1183 | 10 | { |
1184 | 10 | assert(graph == source.graph); |
1185 | 10 | assert(graph == destination.graph); |
1186 | 10 | assert(source.d < graph->exec_symbol_info->rnum); |
1187 | 10 | assert(destination.d < graph->exec_symbol_info->rnum); |
1188 | 10 | ccv_nnc_graph_exec_symbol_info_t* src_symbol_info = (ccv_nnc_graph_exec_symbol_info_t*)ccv_array_get(graph->exec_symbol_info, source.d); |
1189 | 10 | if (!src_symbol_info->outgoings) |
1190 | 0 | return -1; |
1191 | 10 | int i; |
1192 | | // Check if this is already disjoined, if so, skip. |
1193 | 10 | for (i = 0; i < src_symbol_info->outgoings->rnum; i++0 ) |
1194 | 10 | if (*(int*)ccv_array_get(src_symbol_info->outgoings, i) == destination.d) |
1195 | 10 | { |
1196 | 10 | if (i < src_symbol_info->outgoings->rnum - 1) |
1197 | 1 | *(int*)ccv_array_get(src_symbol_info->outgoings, i) = *(int*)ccv_array_get(src_symbol_info->outgoings, src_symbol_info->outgoings->rnum - 1); |
1198 | 10 | --src_symbol_info->outgoings->rnum; |
1199 | 10 | return 0; |
1200 | 10 | } |
1201 | 0 | return -1; |
1202 | 10 | } |
1203 | | |
1204 | 465k | #define CCV_NNC_IS_AUTOGEN_ALL_EXECS(x) ((x) & CCV_NNC_AUTOGEN_ALL_EXECS) |
1205 | 11.8k | #define CCV_NNC_IS_AUTOGEN_SOURCES_AND_DESTINATIONS(x) ((x) & CCV_NNC_AUTOGEN_SOURCES_AND_DESTINATIONS) |
1206 | | |
1207 | | int ccv_nnc_over_tensor_symbol_aliases(const ccv_nnc_tensor_symbol_info_t* const tensor_a, const ccv_nnc_tensor_symbol_info_t* const tensor_b) |
1208 | 84 | { |
1209 | 84 | int i; |
1210 | 84 | const int* stride = tensor_a->stride; |
1211 | | // Only can compare if the stride is the same, otherwise, we can only assume it overlaps. |
1212 | 84 | if (memcmp(stride, tensor_b->stride, sizeof(int) * CCV_NNC_MAX_DIM_ALLOC) != 0) |
1213 | 0 | return 1; |
1214 | 84 | const int* ofs = tensor_a->ofs; |
1215 | 84 | const int* dim = tensor_a->info.dim; |
1216 | 200 | for (i = 0; i < CCV_NNC_MAX_DIM_ALLOC && dim[i] && tensor_b->info.dim[i]147 ; i++116 ) |
1217 | 147 | if (ccv_min(ofs[i] + dim[i], tensor_b->ofs[i] + tensor_b->info.dim[i]) <= ccv_max(ofs[i], tensor_b->ofs[i])) |
1218 | 31 | return 0; // Cannot overlap. |
1219 | 53 | return 1; |
1220 | 84 | } |
1221 | | |
1222 | | int ccv_nnc_graph_exec_symbol_autogen(ccv_nnc_symbolic_graph_t* const graph, const ccv_nnc_graph_exec_symbol_t* const execs, const int exec_size, const int flags) |
1223 | 11.8k | { |
1224 | 11.8k | int i, j, x, y; |
1225 | 11.9k | for (i = 0; i < exec_size; i++83 ) |
1226 | 83 | if (execs[i].graph == graph) |
1227 | 83 | { |
1228 | 83 | assert(execs[i].d >= 0); |
1229 | 83 | assert(execs[i].d < graph->exec_symbol_info->rnum); |
1230 | 83 | } |
1231 | 11.8k | if (!CCV_NNC_IS_AUTOGEN_ALL_EXECS(flags) && exec_size4.63k ) |
1232 | 16 | { assert(execs); } |
1233 | 11.8k | const int exec_total_size = CCV_NNC_IS_AUTOGEN_ALL_EXECS(flags) ? graph->exec_symbol_info->rnum7.18k : exec_size4.63k ; |
1234 | 30.6k | for (i = 0; i < exec_total_size; i++18.8k ) |
1235 | 18.8k | { |
1236 | 18.8k | if (!CCV_NNC_IS_AUTOGEN_ALL_EXECS(flags) && execs[i].graph != graph83 ) |
1237 | 0 | continue; |
1238 | 18.8k | int idx = CCV_NNC_IS_AUTOGEN_ALL_EXECS(flags) ? i18.7k : execs[i].d83 ; |
1239 | 18.8k | ccv_nnc_graph_exec_symbol_info_t* symbol_info = (ccv_nnc_graph_exec_symbol_info_t*)ccv_array_get(graph->exec_symbol_info, idx); |
1240 | 18.8k | if (CCV_NNC_GRAPH_EXEC_IS_DEAD(symbol_info->flags)) |
1241 | 5 | continue; |
1242 | | // Autogen for sub-graphs. |
1243 | 18.8k | if (CCV_NNC_GRAPH_REF(symbol_info)[0]) |
1244 | 26 | ccv_nnc_graph_exec_symbol_autogen(*(ccv_nnc_symbolic_graph_t**)ccv_array_get(graph->sub_graphs, CCV_NNC_GRAPH_REF(symbol_info)[0] - 1), execs, exec_size, flags); |
1245 | 18.8k | } |
1246 | 30.6k | for (i = 0; i < exec_total_size; i++18.8k ) |
1247 | 18.8k | { |
1248 | 18.8k | if (!CCV_NNC_IS_AUTOGEN_ALL_EXECS(flags) && execs[i].graph != graph83 ) |
1249 | 0 | continue; |
1250 | 18.8k | int a_idx = CCV_NNC_IS_AUTOGEN_ALL_EXECS(flags) ? i18.7k : execs[i].d83 ; |
1251 | 18.8k | ccv_nnc_graph_exec_symbol_info_t* a_symbol_info = (ccv_nnc_graph_exec_symbol_info_t*)ccv_array_get(graph->exec_symbol_info, a_idx); |
1252 | 18.8k | if (CCV_NNC_GRAPH_EXEC_IS_DEAD(a_symbol_info->flags)) |
1253 | 5 | continue; |
1254 | 124k | for (j = i + 1; 18.8k j < exec_total_size; j++105k ) |
1255 | 105k | { |
1256 | 105k | if (!CCV_NNC_IS_AUTOGEN_ALL_EXECS(flags) && execs[j].graph != graph269 ) |
1257 | 0 | continue; |
1258 | 105k | int b_idx = CCV_NNC_IS_AUTOGEN_ALL_EXECS(flags) ? j105k : execs[j].d269 ; |
1259 | | // Skip if they are the same. |
1260 | 105k | if (a_idx == b_idx) |
1261 | 0 | continue; |
1262 | 105k | ccv_nnc_graph_exec_symbol_info_t* b_symbol_info = (ccv_nnc_graph_exec_symbol_info_t*)ccv_array_get(graph->exec_symbol_info, b_idx); |
1263 | 105k | if (CCV_NNC_GRAPH_EXEC_IS_DEAD(b_symbol_info->flags)) |
1264 | 9 | continue; |
1265 | 105k | int b_to_a = 0; |
1266 | 393k | for (x = 0; x < a_symbol_info->input_size && !b_to_a289k ; x++287k ) |
1267 | 287k | { |
1268 | 287k | int a = a_symbol_info->inputs[x]; |
1269 | 287k | if (a < 0) |
1270 | 37.5k | continue; |
1271 | | // Handle alias as well. |
1272 | 250k | ccv_nnc_tensor_symbol_info_t* a_tensor_info = (ccv_nnc_tensor_symbol_info_t*)ccv_array_get(graph->tensor_symbol_info, a); |
1273 | 250k | if (a_tensor_info->alias_ref) |
1274 | 17.0k | a = a_tensor_info->alias_ref - 1; |
1275 | 700k | for (y = 0; y < b_symbol_info->output_size && !b_to_a450k ; y++450k ) |
1276 | 450k | { |
1277 | 450k | int b = b_symbol_info->outputs[y]; |
1278 | 450k | if (b < 0) |
1279 | 7.91k | continue; |
1280 | 442k | ccv_nnc_tensor_symbol_info_t* b_tensor_info = (ccv_nnc_tensor_symbol_info_t*)ccv_array_get(graph->tensor_symbol_info, b); |
1281 | 442k | if (b_tensor_info->alias_ref) |
1282 | 14.1k | b = b_tensor_info->alias_ref - 1; |
1283 | 442k | if (a == b && // This two have matching inputs and outputs. |
1284 | 442k | (1.45k !a_tensor_info->alias_ref1.45k || |
1285 | 1.45k | !b_tensor_info->alias_ref7 || // If any of them are not alias, the must overlap, you can concatenate. |
1286 | 1.45k | ccv_nnc_over_tensor_symbol_aliases(a_tensor_info, b_tensor_info)7 )) // Otherwise, we explicitly check whether it overlaps, if it does, concatenate. |
1287 | 1.44k | b_to_a = 1; |
1288 | 442k | } |
1289 | 250k | } |
1290 | 105k | if (b_to_a) |
1291 | 1.44k | { |
1292 | 1.44k | if (execs) |
1293 | 0 | ccv_nnc_graph_exec_symbol_concat(graph, execs[j], execs[i]); |
1294 | 1.44k | else |
1295 | 1.44k | ccv_nnc_graph_exec_symbol_concat(graph, |
1296 | 1.44k | (ccv_nnc_graph_exec_symbol_t) { |
1297 | 1.44k | .d = j, |
1298 | 1.44k | .graph = graph |
1299 | 1.44k | }, (ccv_nnc_graph_exec_symbol_t) { |
1300 | 1.44k | .d = i, |
1301 | 1.44k | .graph = graph |
1302 | 1.44k | } |
1303 | 1.44k | ); |
1304 | 1.44k | } |
1305 | 105k | int a_to_b = 0; |
1306 | 290k | for (x = 0; x < a_symbol_info->output_size && !a_to_b185k ; x++184k ) |
1307 | 184k | { |
1308 | 184k | int a = a_symbol_info->outputs[x]; |
1309 | 184k | if (a < 0) |
1310 | 3.28k | continue; |
1311 | | // Handle alias as well. |
1312 | 181k | ccv_nnc_tensor_symbol_info_t* a_tensor_info = (ccv_nnc_tensor_symbol_info_t*)ccv_array_get(graph->tensor_symbol_info, a); |
1313 | 181k | if (a_tensor_info->alias_ref) |
1314 | 5.27k | a = a_tensor_info->alias_ref - 1; |
1315 | 709k | for (y = 0; y < b_symbol_info->input_size && !a_to_b532k ; y++528k ) |
1316 | 528k | { |
1317 | 528k | int b = b_symbol_info->inputs[y]; |
1318 | 528k | if (b < 0) |
1319 | 73.2k | continue; |
1320 | 455k | ccv_nnc_tensor_symbol_info_t* b_tensor_info = (ccv_nnc_tensor_symbol_info_t*)ccv_array_get(graph->tensor_symbol_info, b); |
1321 | 455k | if (b_tensor_info->alias_ref) |
1322 | 13.5k | b = b_tensor_info->alias_ref - 1; |
1323 | 455k | if (a == b && // This two have matching inputs and outputs. |
1324 | 455k | (8.47k !a_tensor_info->alias_ref8.47k || |
1325 | 8.47k | !b_tensor_info->alias_ref59 || // If any of them are not alias, the must overlap, you can concatenate. |
1326 | 8.47k | ccv_nnc_over_tensor_symbol_aliases(a_tensor_info, b_tensor_info)23 )) // Otherwise, we explicitly check whether it overlaps, if it does, concatenate. |
1327 | 8.46k | a_to_b = 1; |
1328 | 455k | } |
1329 | 181k | } |
1330 | 105k | if (a_to_b) |
1331 | 8.46k | { |
1332 | 8.46k | if (execs) |
1333 | 79 | ccv_nnc_graph_exec_symbol_concat(graph, execs[i], execs[j]); |
1334 | 8.38k | else |
1335 | 8.38k | ccv_nnc_graph_exec_symbol_concat(graph, |
1336 | 8.38k | (ccv_nnc_graph_exec_symbol_t) { |
1337 | 8.38k | .d = i, |
1338 | 8.38k | .graph = graph |
1339 | 8.38k | }, (ccv_nnc_graph_exec_symbol_t) { |
1340 | 8.38k | .d = j, |
1341 | 8.38k | .graph = graph |
1342 | 8.38k | } |
1343 | 8.38k | ); |
1344 | 8.46k | } |
1345 | 105k | } |
1346 | 18.8k | } |
1347 | | // If flag says so, loop over to find sources / destinations too. |
1348 | 11.8k | if (CCV_NNC_IS_AUTOGEN_SOURCES_AND_DESTINATIONS(flags)) |
1349 | 9.56k | { |
1350 | 9.56k | uint8_t* flags = (uint8_t*)cccalloc(sizeof(uint8_t), graph->exec_symbol_info->rnum); |
1351 | 28.7k | for (i = 0; i < graph->exec_symbol_info->rnum; i++19.1k ) |
1352 | 19.1k | { |
1353 | 19.1k | ccv_nnc_graph_exec_symbol_info_t* symbol_info = (ccv_nnc_graph_exec_symbol_info_t*)ccv_array_get(graph->exec_symbol_info, i); |
1354 | 19.1k | if (CCV_NNC_GRAPH_EXEC_IS_DEAD(symbol_info->flags)) |
1355 | 19 | { |
1356 | 19 | flags[i] = 3; // Skip. |
1357 | 19 | continue; |
1358 | 19 | } |
1359 | 19.1k | if (symbol_info->outgoings && symbol_info->outgoings->rnum8.94k ) |
1360 | 8.93k | { |
1361 | 8.93k | flags[i] |= 2; |
1362 | 20.4k | for (j = 0; j < symbol_info->outgoings->rnum; j++11.5k ) |
1363 | 11.5k | flags[*(int*)ccv_array_get(symbol_info->outgoings, j)] |= 1; |
1364 | 8.93k | } |
1365 | 19.1k | } |
1366 | 9.56k | if (!graph->sources) |
1367 | 2.55k | graph->sources = ccv_array_new(sizeof(ccv_nnc_graph_exec_symbol_t), 0, 0); |
1368 | 7.01k | else |
1369 | 7.01k | ccv_array_clear(graph->sources); |
1370 | 9.56k | if (!graph->destinations) |
1371 | 2.55k | graph->destinations = ccv_array_new(sizeof(ccv_nnc_graph_exec_symbol_t), 0, 0); |
1372 | 7.01k | else |
1373 | 7.01k | ccv_array_clear(graph->destinations); |
1374 | 28.7k | for (i = 0; i < graph->exec_symbol_info->rnum; i++19.1k ) |
1375 | 19.1k | { |
1376 | 19.1k | if (flags[i] == 3) |
1377 | 4.20k | continue; |
1378 | 14.9k | ccv_nnc_graph_exec_symbol_t exec = { |
1379 | 14.9k | .d = i, |
1380 | 14.9k | .graph = graph, |
1381 | 14.9k | }; |
1382 | 14.9k | if (!(flags[i] & 1)) |
1383 | 9.84k | ccv_array_push(graph->sources, &exec); |
1384 | 14.9k | if (!(flags[i] & 2)) |
1385 | 10.1k | ccv_array_push(graph->destinations, &exec); |
1386 | 14.9k | } |
1387 | 9.56k | ccfree(flags); |
1388 | 9.56k | } |
1389 | 11.8k | return 0; |
1390 | 11.8k | } |
1391 | | |
1392 | | ccv_nnc_graph_exec_symbol_t* ccv_nnc_symbolic_graph_sources(const ccv_nnc_symbolic_graph_t* const graph) |
1393 | 7.32k | { |
1394 | 7.32k | return graph->sources ? (ccv_nnc_graph_exec_symbol_t*)ccv_array_get(graph->sources, 0) : 00 ; |
1395 | 7.32k | } |
1396 | | |
1397 | | void ccv_nnc_symbolic_graph_add_source(ccv_nnc_symbolic_graph_t* const graph, const ccv_nnc_graph_exec_symbol_t source) |
1398 | 17 | { |
1399 | 17 | if (!graph->sources) |
1400 | 0 | graph->sources = ccv_array_new(sizeof(ccv_nnc_graph_exec_symbol_t), 0, 0); |
1401 | 17 | assert(source.graph == graph); |
1402 | 17 | ccv_array_push(graph->sources, &source); |
1403 | 17 | } |
1404 | | |
1405 | | void ccv_nnc_symbolic_graph_set_sources(ccv_nnc_symbolic_graph_t* const graph, const ccv_nnc_graph_exec_symbol_t* const sources, const int source_size) |
1406 | 15 | { |
1407 | 15 | if (!graph->sources) |
1408 | 11 | graph->sources = ccv_array_new(sizeof(ccv_nnc_graph_exec_symbol_t), 0, 0); |
1409 | 4 | else |
1410 | 4 | ccv_array_clear(graph->sources); |
1411 | 15 | int i; |
1412 | 30 | for (i = 0; i < source_size; i++15 ) |
1413 | 15 | ccv_nnc_symbolic_graph_add_source(graph, sources[i]); |
1414 | 15 | } |
1415 | | |
1416 | | int ccv_nnc_symbolic_graph_source_size(const ccv_nnc_symbolic_graph_t* const graph) |
1417 | 7.32k | { |
1418 | 7.32k | return graph->sources ? graph->sources->rnum : 00 ; |
1419 | 7.32k | } |
1420 | | |
1421 | | ccv_nnc_graph_exec_symbol_t* ccv_nnc_symbolic_graph_destinations(const ccv_nnc_symbolic_graph_t* const graph) |
1422 | 9.54k | { |
1423 | 9.54k | return graph->destinations ? (ccv_nnc_graph_exec_symbol_t*)ccv_array_get(graph->destinations, 0) : 00 ; |
1424 | 9.54k | } |
1425 | | |
1426 | | void ccv_nnc_symbolic_graph_add_destination(ccv_nnc_symbolic_graph_t* const graph, const ccv_nnc_graph_exec_symbol_t destination) |
1427 | 2.69k | { |
1428 | 2.69k | if (!graph->destinations) |
1429 | 0 | graph->destinations = ccv_array_new(sizeof(ccv_nnc_graph_exec_symbol_t), 0, 0); |
1430 | 2.69k | assert(destination.graph == graph); |
1431 | 2.69k | ccv_array_push(graph->destinations, &destination); |
1432 | 2.69k | } |
1433 | | |
1434 | | void ccv_nnc_symbolic_graph_set_destinations(ccv_nnc_symbolic_graph_t* const graph, const ccv_nnc_graph_exec_symbol_t* const destinations, const int destination_size) |
1435 | 2.25k | { |
1436 | 2.25k | if (!graph->destinations) |
1437 | 11 | graph->destinations = ccv_array_new(sizeof(ccv_nnc_graph_exec_symbol_t), 0, 0); |
1438 | 2.24k | else |
1439 | 2.24k | ccv_array_clear(graph->destinations); |
1440 | 2.25k | int i; |
1441 | 4.93k | for (i = 0; i < destination_size; i++2.67k ) |
1442 | 2.67k | if (destinations[i].d >= 0) |
1443 | 2.66k | ccv_nnc_symbolic_graph_add_destination(graph, destinations[i]); |
1444 | 2.25k | } |
1445 | | |
1446 | | int ccv_nnc_symbolic_graph_destination_size(const ccv_nnc_symbolic_graph_t* const graph) |
1447 | 9.55k | { |
1448 | 9.55k | return graph->destinations ? graph->destinations->rnum : 00 ; |
1449 | 9.55k | } |
1450 | | |
1451 | | static void _ccv_nnc_symbolic_graph_dot_exec_symbol(const int index, const ccv_nnc_graph_exec_symbol_info_t* const symbol_info, const int flags, FILE* out) |
1452 | 1.28k | { |
1453 | 1.28k | if (flags == CCV_NNC_LONG_DOT_GRAPH) |
1454 | 1.25k | fputc('{', out); |
1455 | 1.28k | if (symbol_info->name) |
1456 | 689 | fputs(symbol_info->name, out); |
1457 | 594 | else |
1458 | 594 | fprintf(out, "node%d", index); |
1459 | 1.28k | if (flags == CCV_NNC_LONG_DOT_GRAPH) |
1460 | 1.25k | { |
1461 | 1.25k | fputs("|Command: ", out); |
1462 | 1.25k | fputs(ccv_nnc_cmd_name(symbol_info->cmd.cmd), out); |
1463 | 1.25k | fputc('}', out); |
1464 | 1.25k | } |
1465 | 1.28k | } |
1466 | | |
1467 | | static void _ccv_nnc_symbolic_graph_dot_tensor_symbol(const int index, const ccv_nnc_tensor_symbol_info_t* const symbol_info, const ccv_nnc_tensor_symbol_info_t* const alias_info, const int html_like, const int flags, FILE* out) |
1468 | 4.00k | { |
1469 | | // if it has an alias pointer, or, it is a long form. |
1470 | 4.00k | if ((flags == CCV_NNC_LONG_DOT_GRAPH || alias_info79 ) && !html_like3.95k ) |
1471 | 3.87k | fputc('{', out); |
1472 | 4.00k | if (symbol_info->name) |
1473 | 1.95k | fputs(symbol_info->name, out); |
1474 | 2.04k | else |
1475 | 2.04k | fprintf(out, "tensor%d", index); |
1476 | 4.00k | if (flags == CCV_NNC_LONG_DOT_GRAPH) |
1477 | 3.92k | { |
1478 | 3.92k | int flag = -1; |
1479 | 3.92k | if (symbol_info->flags & CCV_NNC_TENSOR_SYMBOL_INIT_ZEROS) |
1480 | 16 | flag = fputs(" (0", out); // Output if it is zero init'ed. |
1481 | 3.91k | else if (symbol_info->flags & CCV_NNC_TENSOR_SYMBOL_INIT_ONES) |
1482 | 8 | flag = fputs(" (1", out); // Output if it is one init'ed. |
1483 | 3.92k | if (symbol_info->flags & CCV_NNC_TENSOR_SYMBOL_TAPE_VAR) |
1484 | 16 | flag = (flag >= 0) ? fputs(",t", out)0 : fputs(" (t", out); // Output is a tape variable |
1485 | 3.92k | if (CCV_TENSOR_GET_MEMORY(symbol_info->info.type) == CCV_TENSOR_GPU_MEMORY && |
1486 | 3.92k | CCV_TENSOR_GET_DEVICE1.06k (symbol_info->info.type) != CCV_COMPUTE_DEVICE_ANY1.06k ) |
1487 | 1.06k | flag = (flag >= 0) ? fprintf(out, ",d%d", 8 CCV_TENSOR_GET_DEVICE_ID8 (symbol_info->info.type)) : fprintf(out, " (d%d", 1.06k CCV_TENSOR_GET_DEVICE_ID1.06k (symbol_info->info.type)); |
1488 | 3.92k | if (flag >= 0) |
1489 | 1.10k | fputs(")", out); |
1490 | 3.92k | } |
1491 | 4.00k | if (flags == CCV_NNC_LONG_DOT_GRAPH) |
1492 | 3.92k | { |
1493 | 3.92k | int i; |
1494 | 3.92k | if (html_like) |
1495 | 86 | fprintf(out, "</td><td>%d", symbol_info->info.dim[0]); |
1496 | 3.84k | else |
1497 | 3.84k | fprintf(out, "|%d", symbol_info->info.dim[0]); |
1498 | 9.07k | for (i = 1; i < CCV_NNC_MAX_DIM_ALLOC && symbol_info->info.dim[i]; i++5.14k ) |
1499 | 5.14k | fprintf(out, "x%d", symbol_info->info.dim[i]); |
1500 | 3.92k | } |
1501 | 4.00k | if (alias_info) |
1502 | 198 | { |
1503 | 198 | if (html_like) |
1504 | 0 | fputs("</td><td border=\"0\">as. ", out); |
1505 | 198 | else |
1506 | 198 | fputs("|as. ", out); |
1507 | 198 | if (alias_info->name) |
1508 | 107 | fputs(alias_info->name, out); |
1509 | 91 | else |
1510 | 91 | fprintf(out, "tensor%d", symbol_info->alias_ref - 1); |
1511 | 198 | if (flags == CCV_NNC_LONG_DOT_GRAPH) |
1512 | 166 | { |
1513 | 166 | int flag = -1; |
1514 | 166 | if (alias_info->flags & CCV_NNC_TENSOR_SYMBOL_INIT_ZEROS) |
1515 | 7 | flag = fputs(" (0", out); // Output if it is zero init'ed. |
1516 | 159 | else if (alias_info->flags & CCV_NNC_TENSOR_SYMBOL_INIT_ONES) |
1517 | 0 | flag = fputs(" (1", out); // Output if it is one init'ed. |
1518 | 166 | if (alias_info->flags & CCV_NNC_TENSOR_SYMBOL_TAPE_VAR) |
1519 | 0 | flag = (flag >= 0) ? fputs(",t", out) : fputs(" (t", out); // Output is a tape variable |
1520 | 166 | if (CCV_TENSOR_GET_MEMORY(alias_info->info.type) == CCV_TENSOR_GPU_MEMORY && |
1521 | 166 | CCV_TENSOR_GET_DEVICE12 (alias_info->info.type) != CCV_COMPUTE_DEVICE_ANY12 ) |
1522 | 12 | flag = (flag >= 0) ? fprintf(out, ",d%d", 0 CCV_TENSOR_GET_DEVICE_ID0 (alias_info->info.type)) : fprintf(out, " (d%d", CCV_TENSOR_GET_DEVICE_ID(alias_info->info.type)); |
1523 | 166 | if (flag >= 0) |
1524 | 19 | fputs(")", out); |
1525 | 166 | } |
1526 | 198 | } |
1527 | 4.00k | if ((flags == CCV_NNC_LONG_DOT_GRAPH || alias_info79 ) && !html_like3.95k ) |
1528 | 3.87k | fputc('}', out); |
1529 | 4.00k | } |
1530 | | |
1531 | | static void _ccv_nnc_symbolic_graph_dot_node(const ccv_nnc_graph_exec_symbol_info_t* const exec_symbol_info, const int index, const ccv_array_t* const tensor_symbol_info, const int flags, FILE* out) |
1532 | 1.28k | { |
1533 | 1.28k | fprintf(out, "node%d [shape=record,label=\"", index); |
1534 | 1.28k | _ccv_nnc_symbolic_graph_dot_exec_symbol(index, exec_symbol_info, flags, out); |
1535 | 1.28k | int i; |
1536 | 1.28k | if (exec_symbol_info->input_size > 0) |
1537 | 1.20k | { |
1538 | 1.20k | fputs("|{Input", out); |
1539 | 4.12k | for (i = 0; i < exec_symbol_info->input_size; i++2.91k ) |
1540 | 2.91k | { |
1541 | 2.91k | if (exec_symbol_info->inputs[i] >= 0) |
1542 | 2.36k | { |
1543 | 2.36k | fputc('|', out); |
1544 | 2.36k | const ccv_nnc_tensor_symbol_info_t* const tensor_symbol = (ccv_nnc_tensor_symbol_info_t*)ccv_array_get(tensor_symbol_info, exec_symbol_info->inputs[i]); |
1545 | 2.36k | const ccv_nnc_tensor_symbol_info_t* const alias_symbol = tensor_symbol->alias_ref ? (ccv_nnc_tensor_symbol_info_t*)121 ccv_array_get121 (tensor_symbol_info, tensor_symbol->alias_ref - 1) : 02.23k ; |
1546 | 2.36k | _ccv_nnc_symbolic_graph_dot_tensor_symbol(exec_symbol_info->inputs[i], tensor_symbol, alias_symbol, 0, flags, out); |
1547 | 2.36k | } else |
1548 | 552 | fputs("|-", out); |
1549 | 2.91k | } |
1550 | 1.20k | fputc('}', out); |
1551 | 1.20k | } |
1552 | 1.28k | if (exec_symbol_info->output_size > 0) |
1553 | 1.26k | { |
1554 | 1.26k | fputs("|{Output", out); |
1555 | 2.90k | for (i = 0; i < exec_symbol_info->output_size; i++1.64k ) |
1556 | 1.64k | { |
1557 | 1.64k | if (exec_symbol_info->outputs[i] >= 0) |
1558 | 1.55k | { |
1559 | 1.55k | fputc('|', out); |
1560 | 1.55k | const ccv_nnc_tensor_symbol_info_t* const tensor_symbol = (ccv_nnc_tensor_symbol_info_t*)ccv_array_get(tensor_symbol_info, exec_symbol_info->outputs[i]); |
1561 | 1.55k | const ccv_nnc_tensor_symbol_info_t* const alias_symbol = tensor_symbol->alias_ref ? (ccv_nnc_tensor_symbol_info_t*)77 ccv_array_get77 (tensor_symbol_info, tensor_symbol->alias_ref - 1) : 01.48k ; |
1562 | 1.55k | _ccv_nnc_symbolic_graph_dot_tensor_symbol(exec_symbol_info->outputs[i], tensor_symbol, alias_symbol, 0, flags, out); |
1563 | 1.55k | } else |
1564 | 82 | fputs("|-", out); |
1565 | 1.64k | } |
1566 | 1.26k | fputc('}', out); |
1567 | 1.26k | } |
1568 | 1.28k | fputs("\"];\n", out); |
1569 | 1.28k | } |
1570 | | |
1571 | | static void _ccv_nnc_symbolic_graph_dot_while_label(const ccv_nnc_graph_exec_symbol_info_t* const exec_symbol_info, const int index, const ccv_array_t* const tensor_symbol_info, const ccv_nnc_symbolic_graph_t* const while_graph, const int flags, FILE* out) |
1572 | 21 | { |
1573 | 21 | int i; |
1574 | 21 | if (flags == CCV_NNC_LONG_DOT_GRAPH) |
1575 | 21 | fputs("<table border=\"0\" cellborder=\"1\" cellspacing=\"0\"><tr><td colspan=\"3\" border=\"0\"><b>", out); |
1576 | 0 | else |
1577 | 0 | fputs("<table border=\"0\" cellborder=\"1\" cellspacing=\"0\"><tr><td colspan=\"2\" border=\"0\"><b>", out); |
1578 | 21 | if (exec_symbol_info->name) |
1579 | 21 | fputs(exec_symbol_info->name, out); |
1580 | 0 | else |
1581 | 0 | fprintf(out, "while%d", index); |
1582 | 21 | fputs(" </b>Command: ", out); |
1583 | 21 | fputs(ccv_nnc_cmd_name(exec_symbol_info->cmd.cmd), out); |
1584 | 21 | fputs("</td></tr>", out); |
1585 | 21 | const int p_idx = while_graph->p_idx - 1; |
1586 | 21 | assert(p_idx >= 0); |
1587 | 21 | if (exec_symbol_info->input_size > 0) |
1588 | 16 | { |
1589 | 16 | fprintf(out, "<tr><td rowspan=\"%d\">Input</td>", exec_symbol_info->input_size); |
1590 | 39 | for (i = 0; i < exec_symbol_info->input_size; i++23 ) |
1591 | 23 | { |
1592 | 23 | if (i > 0) |
1593 | 7 | fputs("<tr>", out); |
1594 | 23 | if (exec_symbol_info->inputs[i] >= 0) |
1595 | 23 | { |
1596 | 23 | fputs("<td>", out); |
1597 | 23 | const ccv_nnc_tensor_symbol_info_t* const tensor_symbol = (ccv_nnc_tensor_symbol_info_t*)ccv_array_get(tensor_symbol_info, exec_symbol_info->inputs[i]); |
1598 | 23 | const ccv_nnc_tensor_symbol_info_t* const alias_symbol = tensor_symbol->alias_ref ? (ccv_nnc_tensor_symbol_info_t*)0 ccv_array_get0 (tensor_symbol_info, tensor_symbol->alias_ref - 1) : 0; |
1599 | 23 | _ccv_nnc_symbolic_graph_dot_tensor_symbol(exec_symbol_info->inputs[i], tensor_symbol, alias_symbol, 1, flags, out); |
1600 | 23 | fputs("</td><td border=\"0\">=> ", out); |
1601 | 23 | const int s_idx = *(int*)ccv_array_get(tensor_symbol->s_ref, p_idx) - 1; |
1602 | 23 | assert(s_idx >= 0); |
1603 | 23 | const ccv_nnc_tensor_symbol_info_t* const sub_tensor_symbol = (ccv_nnc_tensor_symbol_info_t*)ccv_array_get(while_graph->tensor_symbol_info, s_idx); |
1604 | 23 | if (sub_tensor_symbol->name) |
1605 | 21 | fputs(sub_tensor_symbol->name, out); |
1606 | 2 | else |
1607 | 2 | fprintf(out, "tensor%d", s_idx); |
1608 | 23 | fputs("</td></tr>", out); |
1609 | 23 | } else { |
1610 | 0 | if (flags == CCV_NNC_LONG_DOT_GRAPH) |
1611 | 0 | fputs("<td colspan=\"3\">-</td></tr>", out); |
1612 | 0 | else |
1613 | 0 | fputs("<td colspan=\"2\">-</td></tr>", out); |
1614 | 0 | } |
1615 | 23 | } |
1616 | 16 | } |
1617 | 21 | if (exec_symbol_info->output_size > 0) |
1618 | 15 | { |
1619 | 15 | fprintf(out, "<tr><td rowspan=\"%d\">Output</td>", exec_symbol_info->output_size); |
1620 | 38 | for (i = 0; i < exec_symbol_info->output_size; i++23 ) |
1621 | 23 | { |
1622 | 23 | if (i > 0) |
1623 | 8 | fputs("<tr>", out); |
1624 | 23 | if (exec_symbol_info->outputs[i] >= 0) |
1625 | 23 | { |
1626 | 23 | fputs("<td>", out); |
1627 | 23 | ccv_nnc_tensor_symbol_info_t* tensor_symbol = (ccv_nnc_tensor_symbol_info_t*)ccv_array_get(tensor_symbol_info, exec_symbol_info->outputs[i]); |
1628 | 23 | ccv_nnc_tensor_symbol_info_t* alias_symbol = tensor_symbol->alias_ref ? (ccv_nnc_tensor_symbol_info_t*)0 ccv_array_get0 (tensor_symbol_info, tensor_symbol->alias_ref - 1) : 0; |
1629 | 23 | _ccv_nnc_symbolic_graph_dot_tensor_symbol(exec_symbol_info->outputs[i], tensor_symbol, alias_symbol, 1, flags, out); |
1630 | 23 | fputs("</td><td border=\"0\">=> ", out); |
1631 | 23 | const int s_idx = *(int*)ccv_array_get(tensor_symbol->s_ref, p_idx) - 1; |
1632 | 23 | assert(s_idx >= 0); |
1633 | 23 | const ccv_nnc_tensor_symbol_info_t* const sub_tensor_symbol = (ccv_nnc_tensor_symbol_info_t*)ccv_array_get(while_graph->tensor_symbol_info, s_idx); |
1634 | 23 | if (sub_tensor_symbol->name) |
1635 | 22 | fputs(sub_tensor_symbol->name, out); |
1636 | 1 | else |
1637 | 1 | fprintf(out, "tensor%d", s_idx); |
1638 | 23 | fputs("</td></tr>", out); |
1639 | 23 | } else { |
1640 | 0 | if (flags == CCV_NNC_LONG_DOT_GRAPH) |
1641 | 0 | fputs("<td colspan=\"3\">-</td></tr>", out); |
1642 | 0 | else |
1643 | 0 | fputs("<td colspan=\"2\">-</td></tr>", out); |
1644 | 0 | } |
1645 | 23 | } |
1646 | 15 | } |
1647 | 127 | for (i = 0; 21 i < while_graph->tensor_symbol_info->rnum; i++106 ) |
1648 | 106 | { |
1649 | 106 | const ccv_nnc_tensor_symbol_info_t* const tensor_symbol_info = (ccv_nnc_tensor_symbol_info_t*)ccv_array_get(while_graph->tensor_symbol_info, i); |
1650 | 106 | if (tensor_symbol_info->assign_ref) |
1651 | 24 | { |
1652 | 24 | if (flags == CCV_NNC_LONG_DOT_GRAPH) |
1653 | 24 | fputs("<tr><td colspan=\"3\" border=\"0\">", out); |
1654 | 0 | else |
1655 | 0 | fputs("<tr><td colspan=\"2\" border=\"0\">", out); |
1656 | 24 | const ccv_nnc_tensor_symbol_info_t* const assign_symbol_info = (ccv_nnc_tensor_symbol_info_t*)ccv_array_get(while_graph->tensor_symbol_info, tensor_symbol_info->assign_ref - 1); |
1657 | 24 | if (assign_symbol_info->name) |
1658 | 22 | fputs(assign_symbol_info->name, out); |
1659 | 2 | else |
1660 | 2 | fprintf(out, "tensor%d", tensor_symbol_info->assign_ref - 1); |
1661 | 24 | fputs(" -> ", out); |
1662 | 24 | if (tensor_symbol_info->name) |
1663 | 22 | fputs(tensor_symbol_info->name, out); |
1664 | 2 | else |
1665 | 2 | fprintf(out, "tensor%d", i); |
1666 | 24 | fputs("</td></tr>", out); |
1667 | 24 | } |
1668 | 106 | } |
1669 | 21 | fputs("</table>", out); |
1670 | 21 | } |
1671 | | |
1672 | | static void _ccv_nnc_symbolic_graph_dot_case_of_label(const ccv_nnc_graph_exec_symbol_info_t* const exec_symbol_info, const int index, const ccv_array_t* const tensor_symbol_info, const int flags, FILE* out) |
1673 | 11 | { |
1674 | 11 | int i; |
1675 | 11 | if (flags == CCV_NNC_LONG_DOT_GRAPH) |
1676 | 11 | fputs("<table border=\"0\" cellborder=\"1\" cellspacing=\"0\"><tr><td colspan=\"3\" border=\"0\"><b>", out); |
1677 | 0 | else |
1678 | 0 | fputs("<table border=\"0\" cellborder=\"1\" cellspacing=\"0\"><tr><td colspan=\"2\" border=\"0\"><b>", out); |
1679 | 11 | if (exec_symbol_info->name) |
1680 | 11 | fputs(exec_symbol_info->name, out); |
1681 | 0 | else |
1682 | 0 | fprintf(out, "caseof%d", index); |
1683 | 11 | fputs(" </b>Command: ", out); |
1684 | 11 | fputs(ccv_nnc_cmd_name(exec_symbol_info->cmd.cmd), out); |
1685 | 11 | fputs("</td></tr>", out); |
1686 | 11 | if (exec_symbol_info->input_size > 0) |
1687 | 11 | { |
1688 | 11 | fprintf(out, "<tr><td rowspan=\"%d\">Input</td>", exec_symbol_info->input_size); |
1689 | 38 | for (i = 0; i < exec_symbol_info->input_size; i++27 ) |
1690 | 27 | { |
1691 | 27 | if (i > 0) |
1692 | 16 | fputs("<tr>", out); |
1693 | 27 | if (exec_symbol_info->inputs[i] >= 0) |
1694 | 27 | { |
1695 | 27 | fputs("<td>", out); |
1696 | 27 | const ccv_nnc_tensor_symbol_info_t* const tensor_symbol = (ccv_nnc_tensor_symbol_info_t*)ccv_array_get(tensor_symbol_info, exec_symbol_info->inputs[i]); |
1697 | 27 | const ccv_nnc_tensor_symbol_info_t* const alias_symbol = tensor_symbol->alias_ref ? (ccv_nnc_tensor_symbol_info_t*)0 ccv_array_get0 (tensor_symbol_info, tensor_symbol->alias_ref - 1) : 0; |
1698 | 27 | _ccv_nnc_symbolic_graph_dot_tensor_symbol(exec_symbol_info->inputs[i], tensor_symbol, alias_symbol, 1, flags, out); |
1699 | 27 | fputs("</td></tr>", out); |
1700 | 27 | } else { |
1701 | 0 | if (flags == CCV_NNC_LONG_DOT_GRAPH) |
1702 | 0 | fputs("<td colspan=\"2\">-</td></tr>", out); |
1703 | 0 | else |
1704 | 0 | fputs("<td colspan=\"1\">-</td></tr>", out); |
1705 | 0 | } |
1706 | 27 | } |
1707 | 11 | } |
1708 | 11 | if (exec_symbol_info->output_size > 0) |
1709 | 11 | { |
1710 | 11 | fprintf(out, "<tr><td rowspan=\"%d\">Output</td>", exec_symbol_info->output_size); |
1711 | 24 | for (i = 0; i < exec_symbol_info->output_size; i++13 ) |
1712 | 13 | { |
1713 | 13 | if (i > 0) |
1714 | 2 | fputs("<tr>", out); |
1715 | 13 | if (exec_symbol_info->outputs[i] >= 0) |
1716 | 13 | { |
1717 | 13 | fputs("<td>", out); |
1718 | 13 | ccv_nnc_tensor_symbol_info_t* tensor_symbol = (ccv_nnc_tensor_symbol_info_t*)ccv_array_get(tensor_symbol_info, exec_symbol_info->outputs[i]); |
1719 | 13 | ccv_nnc_tensor_symbol_info_t* alias_symbol = tensor_symbol->alias_ref ? (ccv_nnc_tensor_symbol_info_t*)0 ccv_array_get0 (tensor_symbol_info, tensor_symbol->alias_ref - 1) : 0; |
1720 | 13 | _ccv_nnc_symbolic_graph_dot_tensor_symbol(exec_symbol_info->outputs[i], tensor_symbol, alias_symbol, 1, flags, out); |
1721 | 13 | fputs("</td></tr>", out); |
1722 | 13 | } else { |
1723 | 0 | if (flags == CCV_NNC_LONG_DOT_GRAPH) |
1724 | 0 | fputs("<td colspan=\"2\">-</td></tr>", out); |
1725 | 0 | else |
1726 | 0 | fputs("<td colspan=\"1\">-</td></tr>", out); |
1727 | 0 | } |
1728 | 13 | } |
1729 | 11 | } |
1730 | 11 | fputs("</table>", out); |
1731 | 11 | } |
1732 | | |
1733 | | static void _ccv_nnc_symbolic_graph_dot_sub_graphs(const ccv_nnc_graph_exec_symbol_info_t* const exec_symbol_info, const ccv_array_t* const tensor_symbol_info, const ccv_array_t* const sub_graphs, const int flags, FILE* out, int* c) |
1734 | 32 | { |
1735 | 32 | int i, j, k; |
1736 | | // Output this node info within this subgraph. |
1737 | 32 | if (exec_symbol_info->flags & CCV_NNC_GRAPH_EXEC_P_WHILE) |
1738 | 21 | { |
1739 | 21 | fprintf(out, "subgraph cluster%d {\nstyle=\"rounded\";\nnode%d [style=invisible];\nlabel=<", *c, *c); |
1740 | 21 | const ccv_nnc_symbolic_graph_t* const while_graph = *(ccv_nnc_symbolic_graph_t**)ccv_array_get(sub_graphs, CCV_NNC_GRAPH_REF(exec_symbol_info)[0] - 1); |
1741 | 21 | _ccv_nnc_symbolic_graph_dot_while_label(exec_symbol_info, *c, tensor_symbol_info, while_graph, flags, out); |
1742 | 21 | } else if (11 exec_symbol_info->flags & CCV_NNC_GRAPH_EXEC_CASE_OF11 ) { |
1743 | 11 | fprintf(out, "subgraph cluster%d {\nstyle=\"rounded\";\nnode%d [style=invisible];\nlabel=<", *c, *c); |
1744 | 11 | _ccv_nnc_symbolic_graph_dot_case_of_label(exec_symbol_info, *c, tensor_symbol_info, flags, out); |
1745 | 11 | } |
1746 | 32 | fputs(">;\n", out); |
1747 | 32 | ++(*c); |
1748 | 81 | for (k = 0; k < exec_symbol_info->graph_ref_size; k++49 ) |
1749 | 49 | { |
1750 | 49 | if (exec_symbol_info->flags & CCV_NNC_GRAPH_EXEC_CASE_OF) |
1751 | 28 | { |
1752 | 28 | fprintf(out, "subgraph cluster%d {\nstyle=\"rounded\";\nnode%d [style=invisible];\nlabel=\"\"\n", *c, *c); |
1753 | 28 | ++(*c); |
1754 | 28 | } |
1755 | 49 | const ccv_nnc_symbolic_graph_t* const graph = *(ccv_nnc_symbolic_graph_t**)ccv_array_get(sub_graphs, CCV_NNC_GRAPH_REF(exec_symbol_info)[k] - 1); |
1756 | 49 | int* node_id = (int*)ccmalloc(sizeof(int) * graph->exec_symbol_info->rnum); |
1757 | 144 | for (i = 0; i < graph->exec_symbol_info->rnum; i++95 ) |
1758 | 95 | { |
1759 | 95 | node_id[i] = *c; |
1760 | 95 | const ccv_nnc_graph_exec_symbol_info_t* const exec_symbol_info = (ccv_nnc_graph_exec_symbol_info_t*)ccv_array_get(graph->exec_symbol_info, i); |
1761 | | // Skip the dead one. |
1762 | 95 | if (CCV_NNC_GRAPH_EXEC_IS_DEAD(exec_symbol_info->flags)) |
1763 | 2 | continue; |
1764 | 93 | if (exec_symbol_info->graph_ref_size) |
1765 | 3 | _ccv_nnc_symbolic_graph_dot_sub_graphs(exec_symbol_info, graph->tensor_symbol_info, graph->sub_graphs, flags, out, c); |
1766 | 90 | else { |
1767 | 90 | _ccv_nnc_symbolic_graph_dot_node(exec_symbol_info, *c, graph->tensor_symbol_info, flags, out); |
1768 | 90 | ++(*c); |
1769 | 90 | } |
1770 | 93 | } |
1771 | | // Output connections. |
1772 | 144 | for (i = 0; i < graph->exec_symbol_info->rnum; i++95 ) |
1773 | 95 | { |
1774 | 95 | const ccv_nnc_graph_exec_symbol_info_t* exec_symbol_info = (ccv_nnc_graph_exec_symbol_info_t*)ccv_array_get(graph->exec_symbol_info, i); |
1775 | | // Skip the dead one. |
1776 | 95 | if (CCV_NNC_GRAPH_EXEC_IS_DEAD(exec_symbol_info->flags)) |
1777 | 2 | continue; |
1778 | 93 | if (exec_symbol_info->outgoings) |
1779 | 90 | for (j = 0; 45 j < exec_symbol_info->outgoings->rnum; j++45 ) |
1780 | 45 | { |
1781 | 45 | const int outgoing_idx = *(int*)ccv_array_get(exec_symbol_info->outgoings, j); |
1782 | 45 | const ccv_nnc_graph_exec_symbol_info_t* const outgoing_symbol_info = (ccv_nnc_graph_exec_symbol_info_t*)ccv_array_get(graph->exec_symbol_info, outgoing_idx); |
1783 | | // If both are sub-graphs, have both tail and head specified. |
1784 | 45 | if (CCV_NNC_GRAPH_REF(exec_symbol_info)[0] && CCV_NNC_GRAPH_REF1 (outgoing_symbol_info)[0]1 ) |
1785 | 0 | fprintf(out, "node%d -> node%d [ltail=cluster%d,lhead=cluster%d];\n", node_id[i], node_id[outgoing_idx], node_id[i], node_id[outgoing_idx]); |
1786 | 45 | else if (CCV_NNC_GRAPH_REF(exec_symbol_info)[0] && !1 CCV_NNC_GRAPH_REF1 (outgoing_symbol_info)[0]) |
1787 | 1 | fprintf(out, "node%d -> node%d [ltail=cluster%d];\n", node_id[i], node_id[outgoing_idx], node_id[i]); |
1788 | 44 | else if (!CCV_NNC_GRAPH_REF(exec_symbol_info)[0] && CCV_NNC_GRAPH_REF(outgoing_symbol_info)[0]) |
1789 | 3 | fprintf(out, "node%d -> node%d [lhead=cluster%d];\n", node_id[i], node_id[outgoing_idx], node_id[outgoing_idx]); |
1790 | 41 | else |
1791 | 41 | fprintf(out, "node%d -> node%d;\n", node_id[i], node_id[outgoing_idx]); |
1792 | 45 | } |
1793 | 93 | } |
1794 | 49 | fputs("}\n", out); |
1795 | 49 | ccfree(node_id); |
1796 | 49 | } |
1797 | | // Extra subgraph cluster. |
1798 | 32 | if (exec_symbol_info->flags & CCV_NNC_GRAPH_EXEC_CASE_OF) |
1799 | 11 | fputs("}\n", out); |
1800 | 32 | } |
1801 | | |
1802 | | void ccv_nnc_symbolic_graph_dot(const ccv_nnc_symbolic_graph_t* const graph, const int flags, FILE* out) |
1803 | 644 | { |
1804 | 644 | fputs("digraph G {\ncompound=true;\n", out); |
1805 | 644 | int i, j; |
1806 | 644 | int c = 0; |
1807 | 644 | int* node_id = (int*)ccmalloc(sizeof(int) * graph->exec_symbol_info->rnum); |
1808 | | // Output styles. |
1809 | 2.29k | for (i = 0; i < graph->exec_symbol_info->rnum; i++1.65k ) |
1810 | 1.65k | { |
1811 | 1.65k | node_id[i] = c; |
1812 | 1.65k | const ccv_nnc_graph_exec_symbol_info_t* const exec_symbol_info = (ccv_nnc_graph_exec_symbol_info_t*)ccv_array_get(graph->exec_symbol_info, i); |
1813 | | // Skip the dead one. |
1814 | 1.65k | if (CCV_NNC_GRAPH_EXEC_IS_DEAD(exec_symbol_info->flags)) |
1815 | 433 | continue; |
1816 | 1.22k | if (exec_symbol_info->graph_ref_size) |
1817 | 29 | _ccv_nnc_symbolic_graph_dot_sub_graphs(exec_symbol_info, graph->tensor_symbol_info, graph->sub_graphs, flags, out, &c); |
1818 | 1.19k | else { |
1819 | 1.19k | _ccv_nnc_symbolic_graph_dot_node(exec_symbol_info, c, graph->tensor_symbol_info, flags, out); |
1820 | 1.19k | ++c; |
1821 | 1.19k | } |
1822 | 1.22k | } |
1823 | | // Output connections. |
1824 | 2.29k | for (i = 0; i < graph->exec_symbol_info->rnum; i++1.65k ) |
1825 | 1.65k | { |
1826 | 1.65k | const ccv_nnc_graph_exec_symbol_info_t* exec_symbol_info = (ccv_nnc_graph_exec_symbol_info_t*)ccv_array_get(graph->exec_symbol_info, i); |
1827 | | // Skip the dead one. |
1828 | 1.65k | if (CCV_NNC_GRAPH_EXEC_IS_DEAD(exec_symbol_info->flags)) |
1829 | 433 | continue; |
1830 | 1.22k | if (exec_symbol_info->outgoings) |
1831 | 2.30k | for (j = 0; 928 j < exec_symbol_info->outgoings->rnum; j++1.37k ) |
1832 | 1.37k | { |
1833 | 1.37k | const int outgoing_idx = *(int*)ccv_array_get(exec_symbol_info->outgoings, j); |
1834 | 1.37k | const ccv_nnc_graph_exec_symbol_info_t* const outgoing_symbol_info = (ccv_nnc_graph_exec_symbol_info_t*)ccv_array_get(graph->exec_symbol_info, outgoing_idx); |
1835 | | // If both are sub-graphs, have both tail and head specified. |
1836 | 1.37k | if (exec_symbol_info->graph_ref_size && outgoing_symbol_info->graph_ref_size16 ) |
1837 | 2 | fprintf(out, "node%d -> node%d [ltail=cluster%d,lhead=cluster%d];\n", node_id[i], node_id[outgoing_idx], node_id[i], node_id[outgoing_idx]); |
1838 | 1.37k | else if (exec_symbol_info->graph_ref_size && !outgoing_symbol_info->graph_ref_size14 ) |
1839 | 14 | fprintf(out, "node%d -> node%d [ltail=cluster%d];\n", node_id[i], node_id[outgoing_idx], node_id[i]); |
1840 | 1.36k | else if (!exec_symbol_info->graph_ref_size && outgoing_symbol_info->graph_ref_size) |
1841 | 4 | fprintf(out, "node%d -> node%d [lhead=cluster%d];\n", node_id[i], node_id[outgoing_idx], node_id[outgoing_idx]); |
1842 | 1.35k | else |
1843 | 1.35k | fprintf(out, "node%d -> node%d;\n", node_id[i], node_id[outgoing_idx]); |
1844 | 1.37k | } |
1845 | 1.22k | } |
1846 | 644 | fputs("}\n", out); |
1847 | 644 | ccfree(node_id); |
1848 | 644 | } |
1849 | | |
1850 | | void ccv_nnc_symbolic_graph_format(const ccv_nnc_symbolic_graph_t* const graph, const ccv_nnc_graph_exec_symbol_t* const sources, const int source_size, const ccv_nnc_graph_exec_symbol_t* const destinations, const int destination_size, const ccv_nnc_symbolic_graph_format_f format_fn, void* const context) |
1851 | 2 | { |
1852 | 2 | assert((sources && source_size) || (!sources && !source_size)); |
1853 | 2 | const ccv_nnc_graph_exec_symbol_t* const graph_sources = sources ? sources1 : (ccv_nnc_graph_exec_symbol_t*)1 ccv_array_get1 (graph->sources, 0); |
1854 | 2 | const int graph_source_size = source_size ? source_size1 : graph->sources->rnum1 ; |
1855 | 2 | assert((destinations && destination_size) || (!destinations && !destination_size)); |
1856 | 2 | const ccv_nnc_graph_exec_symbol_t* const graph_destinations = destinations ? destinations1 : (ccv_nnc_graph_exec_symbol_t*)1 ccv_array_get1 (graph->destinations, 0); |
1857 | 2 | const ccv_nnc_graph_exec_symbol_info_t* const exec_symbol_info = (ccv_nnc_graph_exec_symbol_info_t*)ccv_array_get(graph->exec_symbol_info, 0); |
1858 | 2 | const int graph_destination_size = destination_size ? destination_size1 : graph->destinations->rnum1 ; |
1859 | 4 | ccv_nnc_graph_visit_t* const visit = ccv_nnc_graph_visit_new2 (graph, exec_symbol_info, graph->exec_symbol_info->rnum, graph_sources, graph_source_size, graph_destinations, graph_destination_size, 0); |
1860 | 0 | int outgoing_edge_count = 0; |
1861 | 11 | ccv_nnc_graph_visit_for(visit, exec_symbol_info, node) { |
1862 | 11 | outgoing_edge_count += node->outgoings ? node->outgoings->rnum8 : 03 ; |
1863 | 11 | } ccv_nnc_graph_visit_endfor |
1864 | 4 | int* const incoming_counts = (int*)ccmalloc2 (sizeof(int) * (graph->exec_symbol_info->rnum * 2 + outgoing_edge_count)); |
1865 | 4 | memset(incoming_counts, 0, sizeof(int) * graph->exec_symbol_info->rnum); |
1866 | 4 | int i; |
1867 | 11 | ccv_nnc_graph_visit_for(visit, exec_symbol_info, node) { |
1868 | 11 | if (CCV_NNC_GRAPH_EXEC_IS_DEAD(node->flags)) |
1869 | 0 | continue; |
1870 | 11 | if (node->outgoings && node->outgoings->rnum8 ) { |
1871 | 25 | for (i = 0; i < node->outgoings->rnum; i++17 ) |
1872 | 17 | ++incoming_counts[*(int*)ccv_array_get(node->outgoings, i)]; |
1873 | 8 | } |
1874 | 11 | } ccv_nnc_graph_visit_endfor |
1875 | 4 | int* const incoming_offsets = incoming_counts + graph->exec_symbol_info->rnum; |
1876 | 4 | int incoming_edge_count = 0; |
1877 | 11 | ccv_nnc_graph_visit_for(visit, exec_symbol_info, node, idx) { |
1878 | 11 | if (CCV_NNC_GRAPH_EXEC_IS_DEAD(node->flags)) |
1879 | 0 | continue; |
1880 | 11 | incoming_offsets[idx] = incoming_edge_count; |
1881 | 11 | incoming_edge_count += incoming_counts[idx]; |
1882 | 11 | } ccv_nnc_graph_visit_endfor |
1883 | 4 | assert(incoming_edge_count <= outgoing_edge_count); |
1884 | 2 | memset(incoming_counts, 0, sizeof(int) * graph->exec_symbol_info->rnum); |
1885 | 2 | int* const incoming_edges = incoming_offsets + graph->exec_symbol_info->rnum; |
1886 | 11 | ccv_nnc_graph_visit_for(visit, exec_symbol_info, node, idx) { |
1887 | 11 | if (CCV_NNC_GRAPH_EXEC_IS_DEAD(node->flags)) |
1888 | 0 | continue; |
1889 | 11 | if (node->outgoings && node->outgoings->rnum8 ) { |
1890 | 25 | for (i = 0; i < node->outgoings->rnum; i++17 ) |
1891 | 17 | { |
1892 | 17 | const int d = *(int*)ccv_array_get(node->outgoings, i); |
1893 | 17 | incoming_edges[incoming_offsets[d] + incoming_counts[d]] = idx; |
1894 | 17 | ++incoming_counts[d]; |
1895 | 17 | } |
1896 | 8 | } |
1897 | 11 | } ccv_nnc_graph_visit_endfor |
1898 | 11 | ccv_nnc_graph_visit_for(visit, exec_symbol_info, node, idx) { |
1899 | 11 | if (CCV_NNC_GRAPH_EXEC_IS_DEAD(node->flags)) |
1900 | 0 | continue; |
1901 | 11 | format_fn(graph, idx, node->name, node->cmd, node->flags, incoming_edges + incoming_offsets[idx], incoming_counts[idx], node->outgoings ? (int*)8 ccv_array_get8 (node->outgoings, 0) : 03 , node->outgoings ? node->outgoings->rnum8 : 03 , node->inputs, node->input_size, node->outputs, node->output_size, context); |
1902 | 11 | } ccv_nnc_graph_visit_endfor |
1903 | 2 | ccv_nnc_graph_visit_free(visit); |
1904 | 2 | ccfree(incoming_counts); |
1905 | 2 | } |
1906 | | |
1907 | | void ccv_nnc_symbolic_graph_free(ccv_nnc_symbolic_graph_t* const graph) |
1908 | 2.65k | { |
1909 | 2.65k | int i; |
1910 | 15.9k | for (i = 0; i < graph->exec_symbol_info->rnum; i++13.2k ) |
1911 | 13.2k | _ccv_nnc_graph_exec_symbol_free((ccv_nnc_graph_exec_symbol_info_t*)ccv_array_get(graph->exec_symbol_info, i), 0); |
1912 | 38.8k | for (i = 0; i < graph->tensor_symbol_info->rnum; i++36.1k ) |
1913 | 36.1k | { |
1914 | 36.1k | ccv_nnc_tensor_symbol_info_t* const symbol_info = (ccv_nnc_tensor_symbol_info_t*)ccv_array_get(graph->tensor_symbol_info, i); |
1915 | 36.1k | if (symbol_info->name) |
1916 | 4.93k | ccfree(symbol_info->name); |
1917 | 36.1k | if (symbol_info->s_ref) |
1918 | 74 | ccv_array_free(symbol_info->s_ref); |
1919 | 36.1k | } |
1920 | 2.65k | if (graph->sub_graphs) |
1921 | 29 | { |
1922 | 80 | for (i = 0; i < graph->sub_graphs->rnum; i++51 ) |
1923 | 51 | ccv_nnc_symbolic_graph_free(*(ccv_nnc_symbolic_graph_t**)ccv_array_get(graph->sub_graphs, i)); |
1924 | 29 | ccv_array_free(graph->sub_graphs); |
1925 | 29 | } |
1926 | 2.65k | if (graph->sources) |
1927 | 2.57k | ccv_array_free(graph->sources); |
1928 | 2.65k | if (graph->destinations) |
1929 | 2.57k | ccv_array_free(graph->destinations); |
1930 | 2.65k | if (graph->breakpoints) |
1931 | 33 | ccfree(graph->breakpoints); |
1932 | 2.65k | ccv_array_free(graph->tensor_symbol_info); |
1933 | 2.65k | ccv_array_free(graph->exec_symbol_info); |
1934 | 2.65k | if (graph->backward.tensor_symbol_idx) |
1935 | 2.37k | ccfree(graph->backward.tensor_symbol_idx); |
1936 | 2.65k | if (graph->data_parallel.tensor_symbol_idx) |
1937 | 17 | ccfree(graph->data_parallel.tensor_symbol_idx); |
1938 | 2.65k | if (graph->data_parallel.exec_symbol_idx) |
1939 | 17 | ccfree(graph->data_parallel.exec_symbol_idx); |
1940 | 2.65k | ccfree(graph); |
1941 | 2.65k | } |
1942 | | |
1943 | | void ccv_nnc_symbolic_graph_symbol_infer(const ccv_nnc_symbolic_graph_t* const symbolic_graph, const ccv_nnc_graph_visit_t* const visit, const ccv_nnc_graph_exec_symbol_t* const sources, const int source_size, const ccv_nnc_graph_exec_symbol_t* const destinations, const int destination_size, const ccv_nnc_tensor_symbol_info_t* const p_tensor_symbol_info, const int p_tensor_symbol_info_size, ccv_nnc_tensor_symbol_info_t* const tensor_symbol_info, ccv_nnc_graph_exec_symbol_info_t* const exec_symbol_info) |
1944 | 26.4k | { |
1945 | 26.4k | if (ccv_array_get(symbolic_graph->tensor_symbol_info, 0) != tensor_symbol_info) |
1946 | 17.6k | memcpy(tensor_symbol_info, ccv_array_get(symbolic_graph->tensor_symbol_info, 0), sizeof(ccv_nnc_tensor_symbol_info_t) * symbolic_graph->tensor_symbol_info->rnum); |
1947 | 26.4k | if (ccv_array_get(symbolic_graph->exec_symbol_info, 0) != exec_symbol_info) |
1948 | 17.6k | memcpy(exec_symbol_info, ccv_array_get(symbolic_graph->exec_symbol_info, 0), sizeof(ccv_nnc_graph_exec_symbol_info_t) * symbolic_graph->exec_symbol_info->rnum); |
1949 | 26.4k | int i; |
1950 | 26.4k | if (p_tensor_symbol_info) |
1951 | 417 | for (i = 0; 64 i < symbolic_graph->tensor_symbol_info->rnum; i++353 ) |
1952 | 353 | if (tensor_symbol_info[i].p_ref) |
1953 | 132 | { |
1954 | 132 | const int p_ref = tensor_symbol_info[i].p_ref - 1; |
1955 | 132 | assert(p_ref < p_tensor_symbol_info_size); |
1956 | 132 | tensor_symbol_info[i].info = p_tensor_symbol_info[p_ref].info; |
1957 | | // I don't need to copy over stride and ofs for alias. |
1958 | 132 | } |
1959 | 26.4k | int max_input_size = 0, max_output_size = 0; |
1960 | | // Materialize auto hints. |
1961 | 169k | for (i = 0; i < symbolic_graph->exec_symbol_info->rnum; i++143k ) |
1962 | 143k | { |
1963 | 143k | if (CCV_NNC_GRAPH_EXEC_IS_DEAD(exec_symbol_info[i].flags)) |
1964 | 24 | continue; |
1965 | 143k | max_input_size = ccv_max(max_input_size, exec_symbol_info[i].input_size); |
1966 | 143k | max_output_size = ccv_max(max_output_size, exec_symbol_info[i].output_size); |
1967 | | // If there is no hint and we have input and output tensor specified. |
1968 | 143k | if (ccv_nnc_is_no_hint(exec_symbol_info[i].hint) && |
1969 | 143k | exec_symbol_info[i].input_size > 0121k && exec_symbol_info[i].inputs[0] >= 0116k && !ccv_nnc_is_tensor_auto(tensor_symbol_info[exec_symbol_info[i].inputs[0]].info)116k && |
1970 | 143k | exec_symbol_info[i].output_size > 0116k && exec_symbol_info[i].outputs[0] >= 0116k && !ccv_nnc_is_tensor_auto(tensor_symbol_info[exec_symbol_info[i].outputs[0]].info)112k ) |
1971 | 111k | exec_symbol_info[i].hint = ccv_nnc_hint_auto(exec_symbol_info[i].cmd.info, tensor_symbol_info[exec_symbol_info[i].inputs[0]].info, tensor_symbol_info[exec_symbol_info[i].outputs[0]].info); |
1972 | 143k | } |
1973 | | |
1974 | 26.4k | ccv_nnc_tensor_param_t input_params[ccv_max(1, max_input_size)]; |
1975 | 26.4k | ccv_nnc_tensor_param_t output_params[ccv_max(1, max_output_size)]; |
1976 | | |
1977 | | // Materialize auto tensors. This need to go with the topological order. |
1978 | | // TODO: Need to proper handle sub-graphs (thus, run sub-graph to figure out the tensor properties). |
1979 | 123k | ccv_nnc_graph_visit_for(visit, exec_symbol_info, node) { |
1980 | 123k | if (node->input_size > 0 && node->output_size > 0118k ) |
1981 | 118k | { |
1982 | 464k | for (i = 0; i < node->input_size; i++345k ) |
1983 | 345k | input_params[i] = node->inputs[i] >= 0 ? tensor_symbol_info[node->inputs[i]].info272k : ccv_nnc_tensor_auto73.0k ; |
1984 | | // output_params will be initialized to tensor_auto inside the ccv_nnc_hint_tensor_auto method. |
1985 | 118k | ccv_nnc_hint_tensor_auto(node->cmd, input_params, node->input_size, node->hint, output_params, node->output_size); |
1986 | 305k | for (i = 0; i < node->output_size; i++186k ) |
1987 | | /* Only assign the output parameters if the symbol itself is auto. */ |
1988 | 186k | if (node->outputs[i] >= 0 && ccv_nnc_is_tensor_auto(tensor_symbol_info[node->outputs[i]].info)174k ) |
1989 | 100 | tensor_symbol_info[node->outputs[i]].info = output_params[i]; |
1990 | 118k | } |
1991 | 123k | } ccv_nnc_graph_visit_endfor |
1992 | | // If still point to any device, assign default device 00 to it. |
1993 | 358k | for (i = 0; i < symbolic_graph->tensor_symbol_info->rnum; i++332k ) |
1994 | 332k | if (CCV_TENSOR_GET_DEVICE(tensor_symbol_info[i].info.type) == CCV_COMPUTE_DEVICE_ANY) |
1995 | 128k | tensor_symbol_info[i].info.type = (~CCV_COMPUTE_DEVICE_ANY & tensor_symbol_info[i].info.type) | CCV_COMPUTE_DEVICE_000; |
1996 | 26.4k | } |
1997 | | |
1998 | | void ccv_nnc_symbolic_graph_tensor_auto(ccv_nnc_symbolic_graph_t* const graph, const ccv_nnc_graph_exec_symbol_t* const sources, const int source_size, const ccv_nnc_graph_exec_symbol_t* const destinations, const int destination_size) |
1999 | 8.84k | { |
2000 | 8.84k | assert((sources && source_size) || (!sources && !source_size)); |
2001 | 8.84k | const ccv_nnc_graph_exec_symbol_t* const graph_sources = sources ? sources0 : (ccv_nnc_graph_exec_symbol_t*)ccv_array_get(graph->sources, 0); |
2002 | 8.84k | const int graph_source_size = source_size ? source_size0 : graph->sources->rnum; |
2003 | 8.84k | assert((destinations && destination_size) || (!destinations && !destination_size)); |
2004 | 8.84k | const ccv_nnc_graph_exec_symbol_t* const graph_destinations = destinations ? destinations0 : (ccv_nnc_graph_exec_symbol_t*)ccv_array_get(graph->destinations, 0); |
2005 | 8.84k | const int graph_destination_size = destination_size ? destination_size0 : graph->destinations->rnum; |
2006 | 17.6k | ccv_nnc_graph_visit_t* const visit = ccv_nnc_graph_visit_new8.84k (graph, (ccv_nnc_graph_exec_symbol_info_t*)ccv_array_get(graph->exec_symbol_info, 0), graph->exec_symbol_info->rnum, graph_sources, graph_source_size, graph_destinations, graph_destination_size, 0); |
2007 | 8.84k | ccv_nnc_tensor_symbol_info_t* const tensor_symbol_info = (ccv_nnc_tensor_symbol_info_t*)ccv_array_get(graph->tensor_symbol_info, 0); |
2008 | | // Some more clever things we can do here: |
2009 | | // 1. If there is a backward symbol for it, copy over the parameters. |
2010 | 17.6k | int i; |
2011 | 67.9k | for (i = 0; i < graph->backward.tensor_symbol_size; i++59.0k ) |
2012 | 59.0k | { |
2013 | 59.0k | const int d = graph->backward.tensor_symbol_idx[i]; |
2014 | 59.0k | if (d >= 0) |
2015 | 34.0k | { |
2016 | 34.0k | tensor_symbol_info[d].info = tensor_symbol_info[i].info; |
2017 | 34.0k | memcpy(tensor_symbol_info[d].stride, tensor_symbol_info[i].stride, sizeof(tensor_symbol_info[i].stride)); |
2018 | 34.0k | memcpy(tensor_symbol_info[d].ofs, tensor_symbol_info[i].ofs, sizeof(tensor_symbol_info[i].ofs)); |
2019 | 34.0k | } |
2020 | 59.0k | } |
2021 | 17.6k | ccv_nnc_graph_exec_symbol_info_t* const exec_symbol_info = (ccv_nnc_graph_exec_symbol_info_t*)ccv_array_get8.84k (graph->exec_symbol_info, 0); |
2022 | | // 2. If there is a copy (because the data parallel setting), copy over the info. |
2023 | 17.6k | const int parallel_count = graph->data_parallel.count; |
2024 | 17.6k | if (parallel_count > 18.84k ) |
2025 | 16 | { |
2026 | 16 | int j; |
2027 | 15.8k | for (i = 0; i < graph->data_parallel.tensor_symbol_size; i++15.8k ) |
2028 | 15.8k | { |
2029 | 15.8k | const int device_id = CCV_TENSOR_GET_DEVICE_ID(tensor_symbol_info[i].info.type); |
2030 | 63.2k | for (j = 0; j < parallel_count - 1; j++47.4k ) |
2031 | 47.4k | { |
2032 | 47.4k | const int d = graph->data_parallel.tensor_symbol_idx[i * (parallel_count - 1) + j]; |
2033 | 47.4k | if (d >= 0) |
2034 | 17.4k | { |
2035 | 17.4k | tensor_symbol_info[d].info = tensor_symbol_info[i].info; |
2036 | 17.4k | if (j + 1 != device_id) |
2037 | 17.4k | CCV_TENSOR_SET_DEVICE_ID(tensor_symbol_info[d].info.type, j + 1); // Set the device id. |
2038 | 0 | else |
2039 | 0 | CCV_TENSOR_SET_DEVICE_ID(tensor_symbol_info[d].info.type, 0); |
2040 | 17.4k | memcpy(tensor_symbol_info[d].stride, tensor_symbol_info[i].stride, sizeof(tensor_symbol_info[i].stride)); |
2041 | 17.4k | memcpy(tensor_symbol_info[d].ofs, tensor_symbol_info[i].ofs, sizeof(tensor_symbol_info[i].ofs)); |
2042 | 17.4k | } |
2043 | 47.4k | } |
2044 | 15.8k | } |
2045 | 2.16k | for (i = 0; i < graph->data_parallel.exec_symbol_size; i++2.14k ) |
2046 | 8.57k | for (j = 0; 2.14k j < parallel_count - 1; j++6.43k ) |
2047 | 6.43k | { |
2048 | 6.43k | const int d = graph->data_parallel.exec_symbol_idx[i * (parallel_count - 1) + j]; |
2049 | 6.43k | if (d >= 0) |
2050 | 6.43k | exec_symbol_info[d].cmd = exec_symbol_info[i].cmd; |
2051 | 6.43k | } |
2052 | 16 | } |
2053 | 17.6k | ccv_nnc_symbolic_graph_symbol_infer(graph, visit, graph_sources, graph_source_size, graph_destinations, graph_destination_size, 0, 0, tensor_symbol_info, exec_symbol_info); |
2054 | 17.6k | ccv_nnc_graph_visit_free(visit); |
2055 | 17.6k | } |
2056 | | |
2057 | | void ccv_nnc_symbolic_graph_sources_to_destinations(const ccv_nnc_symbolic_graph_t* const graph, const ccv_nnc_graph_exec_symbol_t* const sources, const int source_size, const ccv_nnc_graph_exec_symbol_t* const destinations, const int destination_size, uint64_t* const bitmask) |
2058 | 13 | { |
2059 | 13 | assert(sources && source_size); |
2060 | 13 | assert(destinations && destination_size); |
2061 | 13 | int i; |
2062 | 45 | for (i = 0; i < source_size; i++32 ) |
2063 | 32 | { |
2064 | 32 | assert(sources[i].graph == graph); |
2065 | 32 | assert(sources[i].d >= 0 && sources[i].d < graph->exec_symbol_info->rnum); |
2066 | 32 | } |
2067 | 26 | for (i = 0; 13 i < destination_size; i++13 ) |
2068 | 13 | { |
2069 | 13 | assert(destinations[i].graph == graph); |
2070 | 13 | assert(destinations[i].d >= 0 && destinations[i].d < graph->exec_symbol_info->rnum); |
2071 | 13 | } |
2072 | 13 | ccv_sparse_matrix_t* const exec_dep = ccv_sparse_matrix_new(graph->exec_symbol_info->rnum, graph->exec_symbol_info->rnum, CCV_8U | CCV_C1, CCV_SPARSE_ROW_MAJOR, 0); |
2073 | 13 | ccv_array_t* const ws = ccv_array_new(sizeof(int), source_size, 0); |
2074 | 45 | for (i = 0; i < source_size; i++32 ) |
2075 | 32 | ccv_array_push(ws, &sources[i].d); |
2076 | 13 | int* buf = (int*)ccmalloc(sizeof(int) * graph->exec_symbol_info->rnum); |
2077 | 13 | int buf_size; |
2078 | 13 | #define for_block(x, val) \ |
2079 | 15 | do { \ |
2080 | 15 | if (((uint8_t*)val)[0] != 0) \ |
2081 | 15 | buf[buf_size++] = x; \ |
2082 | 15 | } while (0) |
2083 | 13 | const uint8_t one = 1; |
2084 | 55 | for (i = 0; i < ws->rnum; i++42 ) |
2085 | 42 | { |
2086 | 42 | int j; |
2087 | 42 | const int d = *(int*)ccv_array_get(ws, i); |
2088 | 42 | int flag = 0; |
2089 | 84 | for (j = 0; !flag && j < destination_size79 ; j++42 ) |
2090 | 42 | flag = (d == destinations[j].d); |
2091 | 42 | if (flag) |
2092 | 5 | continue; |
2093 | 37 | buf_size = 0; /* save all its parent deps to this buffer */ |
2094 | 37 | ccv_sparse_matrix_vector_t* vector = ccv_get_sparse_matrix_vector(exec_dep, d); |
2095 | 37 | if (vector) |
2096 | 15 | CCV_SPARSE_VECTOR_FOREACH11 (exec_dep, vector, for_block); |
2097 | 37 | ccv_nnc_graph_exec_symbol_info_t* const info = ccv_array_get(graph->exec_symbol_info, d); |
2098 | 37 | if (info->outgoings && info->outgoings->rnum > 028 ) |
2099 | 25 | { |
2100 | 25 | ccv_array_t* const outgoings = info->outgoings; |
2101 | 50 | for (j = 0; j < outgoings->rnum; j++25 ) |
2102 | 25 | { |
2103 | 25 | const int outgoing_d = *(int*)ccv_array_get(outgoings, j); |
2104 | 25 | int k; |
2105 | 25 | int flag = 0; |
2106 | 50 | for (k = 0; !flag && k < destination_size35 ; k++25 ) |
2107 | 25 | flag = (outgoing_d == destinations[k].d); |
2108 | | // We cannot avoid the ones that visited, because these may not contain all the deps. |
2109 | 25 | if (!flag) |
2110 | 10 | ccv_array_push(ws, &outgoing_d); |
2111 | 25 | ccv_set_sparse_matrix_cell(exec_dep, outgoing_d, d, &one); |
2112 | 31 | for (k = 0; k < buf_size; k++6 ) |
2113 | 6 | ccv_set_sparse_matrix_cell(exec_dep, outgoing_d, buf[k], &one); |
2114 | 25 | } |
2115 | 25 | } |
2116 | 37 | } |
2117 | 13 | ccfree(buf); |
2118 | 13 | ccv_array_free(ws); |
2119 | | // Use exec_dep to fill the bitmask |
2120 | 45 | for (i = 0; i < source_size; i++32 ) |
2121 | 32 | { |
2122 | 32 | const int d = sources[i].d; |
2123 | 32 | int j; |
2124 | 32 | int flag = 0; |
2125 | 64 | for (j = 0; !flag && j < destination_size44 ; j++32 ) |
2126 | 32 | if (d == destinations[j].d) { |
2127 | 5 | flag = 1; |
2128 | 27 | } else { |
2129 | 27 | ccv_numeric_data_t cell = ccv_get_sparse_matrix_cell(exec_dep, destinations[j].d, d); |
2130 | 27 | flag = (cell.u8 && cell.u8[0] != 015 ); |
2131 | 27 | } |
2132 | 32 | if (flag) |
2133 | 20 | bitmask[i >> 6] |= ((uint64_t)1 << (i & 63)); |
2134 | 12 | else |
2135 | 12 | bitmask[i >> 6] &= ~((uint64_t)1 << (i & 63)); |
2136 | 32 | } |
2137 | 13 | ccv_matrix_free(exec_dep); |
2138 | 13 | } |