/home/liu/actions-runner/_work/ccv/ccv/lib/nnc/ccv_nnc_symbolic_graph_while.c
Line | Count | Source |
1 | | #include "ccv_nnc.h" |
2 | | #include "ccv_nnc_easy.h" |
3 | | #include "ccv_nnc_internal.h" |
4 | | #include "ccv_internal.h" |
5 | | #include "_ccv_nnc_symbolic_graph.h" |
6 | | |
7 | | // MARK - Level-3.5 API |
8 | | |
9 | | ccv_nnc_graph_exec_symbol_t ccv_nnc_symbolic_graph_while(ccv_nnc_symbolic_graph_t* const graph, const uint32_t cmd, ccv_nnc_symbolic_graph_t* const while_graph, const char* const name) |
10 | 23 | { |
11 | 23 | assert(cmd == CCV_NNC_GRAPH_FORWARD || cmd == CCV_NNC_GRAPH_BACKWARD); |
12 | 23 | assert(while_graph->p == 0); |
13 | 23 | assert(while_graph->p_idx == 0); |
14 | | // Added one more symbol. |
15 | 23 | ccv_nnc_graph_exec_symbol_t symbol = ccv_nnc_graph_exec_symbol_new(graph, ccv_nnc_cmd(cmd, 0, CMD_GENERIC(), 0), 0, 0, 0, 0, name); |
16 | | // Assigning graph_ref to it. |
17 | 23 | if (!graph->sub_graphs) |
18 | 20 | graph->sub_graphs = ccv_array_new(sizeof(ccv_nnc_symbolic_graph_t*), 1, 0); |
19 | 23 | ccv_array_push(graph->sub_graphs, &while_graph); |
20 | 23 | ccv_nnc_graph_exec_symbol_info_t* symbol_info = (ccv_nnc_graph_exec_symbol_info_t*)ccv_array_get(graph->exec_symbol_info, symbol.d); |
21 | 23 | symbol_info->flags |= CCV_NNC_GRAPH_EXEC_P_WHILE; |
22 | 23 | symbol_info->graph_ref_size = 1; |
23 | | // Note the extra allocation (the ccv_array_t only holds a pointer to ccv_nnc_symbolic_graph_t*). |
24 | | // In this way, we can get the while graph and don't have to worry about it will be an invalid pointer once |
25 | | // the array expands (another while graph allocated). |
26 | 23 | CCV_NNC_GRAPH_REF(symbol_info)[0] = graph->sub_graphs->rnum; |
27 | 23 | while_graph->p_idx = graph->sub_graphs->rnum; |
28 | 23 | while_graph->exec_idx = symbol.d + 1; |
29 | 23 | while_graph->p = graph; |
30 | 23 | return symbol; |
31 | 23 | } |
32 | | |
33 | | void ccv_nnc_symbolic_graph_set_while_expr(ccv_nnc_symbolic_graph_t* const while_graph, const ccv_nnc_graph_while_f while_expr, const void* const while_data, const ccv_nnc_tensor_symbol_t* const inputs, const int input_size, const ccv_nnc_graph_exec_symbol_t* const breakpoints, const int breakpoint_size) |
34 | 20 | { |
35 | 20 | const int exec_idx = while_graph->exec_idx - 1; |
36 | 20 | assert(exec_idx >= 0 && exec_idx < while_graph->p->exec_symbol_info->rnum); |
37 | 20 | ccv_nnc_graph_exec_symbol_info_t* const exec_info = (ccv_nnc_graph_exec_symbol_info_t*)ccv_array_get(while_graph->p->exec_symbol_info, exec_idx); |
38 | 20 | exec_info->p_while.data = while_data; |
39 | 20 | exec_info->p_while.expr = while_expr; |
40 | 20 | int i; |
41 | 20 | if (input_size > 0) |
42 | 18 | { |
43 | 18 | assert(inputs); |
44 | 18 | exec_info->p_while.inputs = (int*)ccmalloc(sizeof(int) * input_size); |
45 | 39 | for (i = 0; i < input_size; i++21 ) |
46 | 21 | exec_info->p_while.inputs[i] = ccv_nnc_tensor_symbol_map_raw(while_graph, inputs[i]); |
47 | 18 | exec_info->p_while.input_size = input_size; |
48 | 18 | } |
49 | 20 | if (breakpoint_size > 0) |
50 | 20 | { |
51 | 20 | assert(breakpoints); |
52 | 20 | while_graph->breakpoint_size = breakpoint_size; |
53 | 20 | while_graph->breakpoints = (ccv_nnc_graph_exec_symbol_t*)ccmalloc(sizeof(ccv_nnc_graph_exec_symbol_t) * breakpoint_size); |
54 | 20 | memcpy(while_graph->breakpoints, breakpoints, sizeof(ccv_nnc_graph_exec_symbol_t) * breakpoint_size); |
55 | 20 | } |
56 | 20 | } |
57 | | |
58 | | void ccv_nnc_symbolic_graph_set_carry_overs(ccv_nnc_symbolic_graph_t* const while_graph, const ccv_nnc_tensor_symbol_map_t* const symbol_map, const int symbol_map_size) |
59 | 21 | { |
60 | 21 | int i; |
61 | 45 | for (i = 0; i < symbol_map_size; i++24 ) |
62 | 24 | { |
63 | 24 | const ccv_nnc_tensor_symbol_t source = ccv_nnc_tensor_symbol_resolve(while_graph, symbol_map[i].source); |
64 | 24 | const ccv_nnc_tensor_symbol_t destination = ccv_nnc_tensor_symbol_resolve(while_graph, symbol_map[i].destination); |
65 | 24 | assert(source.graph == while_graph); |
66 | 24 | assert(destination.graph == while_graph); |
67 | 24 | assert(source.d < while_graph->tensor_symbol_info->rnum); |
68 | 24 | assert(destination.d < while_graph->tensor_symbol_info->rnum); |
69 | 24 | ccv_nnc_tensor_symbol_info_t* source_tensor_symbol_info = (ccv_nnc_tensor_symbol_info_t*)ccv_array_get(while_graph->tensor_symbol_info, source.d); |
70 | 24 | ccv_nnc_tensor_symbol_info_t* destination_tensor_symbol_info = (ccv_nnc_tensor_symbol_info_t*)ccv_array_get(while_graph->tensor_symbol_info, destination.d); |
71 | | // Don't support parameterize with alias. The reason is that to support parameterized loop (for SSA), I choose |
72 | | // to simply reuse the piece of memory (allocating the same memory region to both, therefore to enable parameter |
73 | | // passing). For alias, it is not possible because alias can pointing to the tensors with different sizes, thus, |
74 | | // these pointed tensors cannot share the same memory region. The best way for alias to be parameterized is to |
75 | | // create a new tensor of the same size, transfer value over, and parameterized on that tensor instead. |
76 | 24 | assert(!destination_tensor_symbol_info->alias_ref); |
77 | 24 | assert(!source_tensor_symbol_info->alias_ref); |
78 | 24 | destination_tensor_symbol_info->assign_ref = source.d + 1; |
79 | 24 | source_tensor_symbol_info->r_assign_ref = destination.d + 1; |
80 | 24 | } |
81 | 21 | } |
82 | | |
83 | | ccv_nnc_tensor_symbol_t ccv_nnc_tensor_symbol_for_while_count(const ccv_nnc_symbolic_graph_t* const while_graph) |
84 | 19 | { |
85 | 19 | return (ccv_nnc_tensor_symbol_t){ |
86 | 19 | .d = CCV_NNC_WHILE_COUNT_TENSOR_SYMBOL, |
87 | 19 | .graph = while_graph |
88 | 19 | }; |
89 | 19 | } |
90 | | |
91 | | ccv_nnc_symbolic_graph_t* ccv_nnc_symbolic_graph_from_while_symbol(const ccv_nnc_symbolic_graph_t* const graph, const ccv_nnc_graph_exec_symbol_t while_symbol) |
92 | 2 | { |
93 | 2 | assert(graph->sub_graphs); |
94 | 2 | assert(while_symbol.graph == graph); |
95 | 2 | assert(while_symbol.d < graph->exec_symbol_info->rnum); |
96 | 2 | ccv_nnc_graph_exec_symbol_info_t* symbol_info = (ccv_nnc_graph_exec_symbol_info_t*)ccv_array_get(graph->exec_symbol_info, while_symbol.d); |
97 | 2 | assert(CCV_NNC_GRAPH_REF(symbol_info)[0] <= graph->sub_graphs->rnum); |
98 | 2 | return *(ccv_nnc_symbolic_graph_t**)ccv_array_get(graph->sub_graphs, CCV_NNC_GRAPH_REF(symbol_info)[0] - 1); |
99 | 2 | } |