/home/liu/actions-runner/_work/ccv/ccv/lib/nnc/ccv_nnc_graph_case_of.c
Line | Count | Source |
1 | | #include "ccv_nnc.h" |
2 | | #include "ccv_nnc_easy.h" |
3 | | #include "ccv_nnc_internal.h" |
4 | | #include "ccv_internal.h" |
5 | | #include "_ccv_nnc_graph.h" |
6 | | |
7 | | // MARK - Level-3.5 API |
8 | | |
9 | | ccv_nnc_graph_exec_t ccv_nnc_graph_case_of_new(ccv_nnc_graph_t* const graph, const uint32_t cmd, ccv_nnc_tensor_t* const* const inputs, const int input_size, ccv_nnc_tensor_t* const* const outputs, const int output_size) |
10 | 12 | { |
11 | 12 | assert(cmd == CCV_NNC_GRAPH_FORWARD || cmd == CCV_NNC_GRAPH_BACKWARD); |
12 | 12 | ccv_nnc_graph_exec_t exec = ccv_nnc_graph_exec_new(graph, ccv_nnc_cmd(cmd, 0, CMD_GENERIC(), 0), ccv_nnc_no_hint, inputs, input_size, outputs, output_size); |
13 | 12 | ccv_nnc_graph_exec_info_t* const exec_info = (ccv_nnc_graph_exec_info_t*)ccv_array_get(graph->exec_info, exec.d); |
14 | 12 | exec_info->flags |= CCV_NNC_GRAPH_EXEC_CASE_OF; |
15 | 12 | int i, j; |
16 | 25 | for (i = 0; i < output_size; i++13 ) |
17 | 13 | if (outputs[i] && ((ccv_nnc_tensor_multiview_t*)outputs[i])->anchor == CCV_NNC_MULTIVIEW_PHI) |
18 | 30 | for (j = 0; 10 j < ((ccv_nnc_tensor_multiview_t*)outputs[i])->kind + ((ccv_nnc_tensor_multiview_t*)outputs[i])->repeat; j++20 ) |
19 | 20 | { |
20 | 20 | ccv_nnc_tensor_t* const mv = (ccv_nnc_tensor_t*)CCV_NNC_MULTIVIEW_DATA((ccv_nnc_tensor_multiview_t*)outputs[i])[j]->alias_ref; |
21 | 20 | if (mv && CCV_IS_TENSOR_MULTIVIEW2 (mv)) |
22 | 2 | ccv_nnc_graph_exec_add_as_affected(graph, exec, mv); |
23 | 20 | } |
24 | 12 | return exec; |
25 | 12 | } |
26 | | |
27 | | void ccv_nnc_graph_set_case_of_expr(ccv_nnc_graph_t* const graph, const ccv_nnc_graph_exec_t exec, ccv_nnc_graph_case_of_f case_of, const void* case_of_data, const int offset) |
28 | 12 | { |
29 | 12 | assert(exec.graph == graph); |
30 | 12 | assert(exec.d >= 0); |
31 | 12 | assert(exec.d < graph->exec_info->rnum); |
32 | 12 | ccv_nnc_graph_exec_info_t* const exec_info = (ccv_nnc_graph_exec_info_t*)ccv_array_get(graph->exec_info, exec.d); |
33 | 12 | assert(exec_info->flags & CCV_NNC_GRAPH_EXEC_CASE_OF); |
34 | 12 | exec_info->case_of.data = case_of_data; |
35 | 12 | exec_info->case_of.expr = case_of; |
36 | 12 | exec_info->case_of.offset = offset; |
37 | 12 | } |
38 | | |
39 | | void ccv_nnc_graph_set_case_of(ccv_nnc_graph_t* const graph, const ccv_nnc_graph_exec_t exec, ccv_nnc_graph_t* const case_graph, const int case_of) |
40 | 32 | { |
41 | 32 | assert(exec.graph == graph); |
42 | 32 | assert(exec.d >= 0); |
43 | 32 | assert(exec.d < graph->exec_info->rnum); |
44 | 32 | ccv_nnc_graph_exec_info_t* const exec_info = (ccv_nnc_graph_exec_info_t*)ccv_array_get(graph->exec_info, exec.d); |
45 | 32 | assert(exec_info->flags & CCV_NNC_GRAPH_EXEC_CASE_OF); |
46 | 32 | if (!graph->sub_graphs) |
47 | 10 | graph->sub_graphs = ccv_array_new(sizeof(ccv_nnc_graph_t*), 1, 0); |
48 | 32 | int i; |
49 | 32 | if (case_graph->tensor_wraps_refs) |
50 | 7 | { |
51 | | // Copy wraps from sub graph to parent graph. |
52 | 7 | if (!graph->tensor_wraps_refs) |
53 | 0 | graph->tensor_wraps_refs = ccv_array_new(sizeof(ccv_nnc_graph_tensor_wraps_ref_t), case_graph->tensor_wraps_refs->rnum, 0); |
54 | 16 | for (i = 0; i < case_graph->tensor_wraps_refs->rnum; i++9 ) |
55 | 9 | ccv_array_push(graph->tensor_wraps_refs, ccv_array_get(case_graph->tensor_wraps_refs, i)); |
56 | 7 | } |
57 | 32 | ccv_array_push(graph->sub_graphs, &case_graph); |
58 | 32 | case_graph->p = graph; |
59 | 32 | case_graph->p_idx = graph->sub_graphs->rnum; |
60 | 32 | case_graph->exec_idx = exec.d + 1; |
61 | | // If case_of is larger than the inline graph_ref, we need to allocate. |
62 | 32 | if (case_of >= sizeof(exec_info->_inline_graph_ref) / sizeof(exec_info->_inline_graph_ref[0])) |
63 | 10 | { |
64 | 10 | if (!exec_info->_heap_graph_ref) |
65 | 8 | { |
66 | 8 | exec_info->_heap_graph_ref = cccalloc(case_of + 1, sizeof(int)); |
67 | | // Copy from inline data. |
68 | 8 | memcpy(exec_info->_heap_graph_ref, exec_info->_inline_graph_ref, sizeof(exec_info->_inline_graph_ref)); |
69 | 8 | exec_info->graph_ref_size = case_of + 1; |
70 | 8 | } else if (2 exec_info->graph_ref_size <= case_of2 ) { |
71 | 2 | exec_info->_heap_graph_ref = ccrealloc(exec_info->_heap_graph_ref, sizeof(int) * (case_of + 1)); |
72 | | // Reset the newly allocated ones to 0. |
73 | 2 | memset(exec_info->_heap_graph_ref + exec_info->graph_ref_size, 0, sizeof(int) * (case_of + 1 - exec_info->graph_ref_size)); |
74 | 2 | exec_info->graph_ref_size = case_of + 1; |
75 | 2 | } |
76 | 10 | } else |
77 | 22 | exec_info->graph_ref_size = ccv_max(exec_info->graph_ref_size, case_of + 1); |
78 | | // Set the branch with the graph. |
79 | 32 | CCV_NNC_GRAPH_REF(exec_info)[case_of] = graph->sub_graphs->rnum; |
80 | 32 | } |