/home/liu/actions-runner/_work/ccv/ccv/lib/nnc/ccv_nnc_symbolic_graph_case_of.c
Line | Count | Source |
1 | | #include "ccv_nnc.h" |
2 | | #include "ccv_nnc_easy.h" |
3 | | #include "ccv_nnc_internal.h" |
4 | | #include "ccv_internal.h" |
5 | | #include "_ccv_nnc_symbolic_graph.h" |
6 | | |
7 | | // MARK - Level-3.5 API |
8 | | |
9 | | ccv_nnc_graph_exec_symbol_t ccv_nnc_symbolic_graph_case_of_new(ccv_nnc_symbolic_graph_t* const graph, const uint32_t cmd, const ccv_nnc_tensor_symbol_t* const inputs, const int input_size, const ccv_nnc_tensor_symbol_map_t* const symbol_map, const int symbol_map_size, const char* const name) |
10 | 11 | { |
11 | 11 | assert(cmd == CCV_NNC_GRAPH_FORWARD || cmd == CCV_NNC_GRAPH_BACKWARD); |
12 | | // A case_if statement must have meaningful outputs / inputs. |
13 | 11 | assert(symbol_map_size > 0); |
14 | 11 | ccv_nnc_tensor_symbol_t all_inputs[symbol_map_size * 2 + input_size]; |
15 | 11 | ccv_nnc_tensor_symbol_t* const outputs = all_inputs + (symbol_map_size + input_size); |
16 | 11 | int i; |
17 | 22 | for (i = 0; i < symbol_map_size; i++11 ) |
18 | 11 | all_inputs[i] = symbol_map[i].source, outputs[i] = symbol_map[i].destination; |
19 | 21 | for (i = symbol_map_size; i < symbol_map_size + input_size; i++10 ) |
20 | 10 | all_inputs[i] = inputs[i - symbol_map_size]; |
21 | | // Added one more symbol. |
22 | 11 | const ccv_nnc_graph_exec_symbol_t symbol = ccv_nnc_graph_exec_symbol_new(graph, ccv_nnc_cmd(cmd, 0, CMD_GENERIC(), 0), all_inputs, symbol_map_size + input_size, outputs, symbol_map_size, name); |
23 | 11 | ccv_nnc_tensor_symbol_set_bypasses(graph, symbol_map, symbol_map_size); |
24 | 11 | ccv_nnc_graph_exec_symbol_info_t* const symbol_info = (ccv_nnc_graph_exec_symbol_info_t*)ccv_array_get(graph->exec_symbol_info, symbol.d); |
25 | 11 | symbol_info->flags |= CCV_NNC_GRAPH_EXEC_CASE_OF; |
26 | | // We are still free to add more inputs to this graph, it is OK, we are covered by the argument.offset / size. |
27 | 11 | symbol_info->case_of.argument.offset = symbol_map_size; |
28 | 11 | symbol_info->case_of.argument.size = input_size; |
29 | 11 | return symbol; |
30 | 11 | } |
31 | | |
32 | | void ccv_nnc_symbolic_graph_set_case_of_expr(ccv_nnc_symbolic_graph_t* const graph, const ccv_nnc_graph_exec_symbol_t exec, ccv_nnc_graph_case_of_f case_of, const void* case_of_data) |
33 | 11 | { |
34 | 11 | assert(exec.graph == graph); |
35 | 11 | ccv_nnc_graph_exec_symbol_info_t* const symbol_info = (ccv_nnc_graph_exec_symbol_info_t*)ccv_array_get(graph->exec_symbol_info, exec.d); |
36 | 11 | symbol_info->case_of.expr = case_of; |
37 | 11 | symbol_info->case_of.data = case_of_data; |
38 | 11 | } |
39 | | |
40 | | void ccv_nnc_symbolic_graph_set_case_of(ccv_nnc_symbolic_graph_t* const graph, const ccv_nnc_graph_exec_symbol_t symbol, ccv_nnc_symbolic_graph_t* const case_graph, const int case_of, const ccv_nnc_tensor_symbol_map_t* const symbol_map, const int symbol_map_size) |
41 | 28 | { |
42 | 28 | assert(symbol.graph == graph); |
43 | 28 | assert(symbol.d >= 0); |
44 | 28 | assert(symbol.d < graph->exec_symbol_info->rnum); |
45 | 28 | ccv_nnc_graph_exec_symbol_info_t* const symbol_info = (ccv_nnc_graph_exec_symbol_info_t*)ccv_array_get(graph->exec_symbol_info, symbol.d); |
46 | 28 | assert(symbol_map_size <= symbol_info->output_size); |
47 | 28 | assert(symbol_info->flags == CCV_NNC_GRAPH_EXEC_CASE_OF); |
48 | 28 | if (!graph->sub_graphs) |
49 | 9 | graph->sub_graphs = ccv_array_new(sizeof(ccv_nnc_symbolic_graph_t*), 1, 0); |
50 | 28 | ccv_array_push(graph->sub_graphs, &case_graph); |
51 | 28 | case_graph->p_idx = graph->sub_graphs->rnum; |
52 | 28 | case_graph->exec_idx = symbol.d + 1; |
53 | 28 | case_graph->p = graph; |
54 | | // If case_of is larger than the inline graph_ref, we need to allocate. |
55 | 28 | if (case_of >= sizeof(symbol_info->_inline_graph_ref) / sizeof(symbol_info->_inline_graph_ref[0])) |
56 | 8 | { |
57 | 8 | if (!symbol_info->_heap_graph_ref) |
58 | 7 | { |
59 | 7 | symbol_info->_heap_graph_ref = cccalloc(case_of + 1, sizeof(int)); |
60 | | // Copy from inline data. |
61 | 7 | memcpy(symbol_info->_heap_graph_ref, symbol_info->_inline_graph_ref, sizeof(symbol_info->_inline_graph_ref)); |
62 | 7 | symbol_info->graph_ref_size = case_of + 1; |
63 | 7 | } else if (1 symbol_info->graph_ref_size <= case_of1 ) { |
64 | 1 | symbol_info->_heap_graph_ref = ccrealloc(symbol_info->_heap_graph_ref, sizeof(int) * (case_of + 1)); |
65 | | // Reset the newly allocated ones to 0. |
66 | 1 | memset(symbol_info->_heap_graph_ref + symbol_info->graph_ref_size, 0, sizeof(int) * (case_of + 1 - symbol_info->graph_ref_size)); |
67 | 1 | symbol_info->graph_ref_size = case_of + 1; |
68 | 1 | } |
69 | 8 | } else |
70 | 20 | symbol_info->graph_ref_size = ccv_max(symbol_info->graph_ref_size, case_of + 1); |
71 | | // Set the branch with the graph. |
72 | 28 | CCV_NNC_GRAPH_REF(symbol_info)[case_of] = graph->sub_graphs->rnum; |
73 | 28 | int i; |
74 | 56 | for (i = 0; i < symbol_map_size; i++28 ) |
75 | 28 | ccv_nnc_tensor_symbol_hookup(case_graph, graph, symbol_map[i].source, symbol_map[i].destination); |
76 | 28 | } |