/home/liu/actions-runner/_work/ccv/ccv/lib/nnc/ccv_nnc_graph_while.c
Line | Count | Source (jump to first uncovered line) |
1 | | #include "ccv_nnc.h" |
2 | | #include "ccv_nnc_easy.h" |
3 | | #include "ccv_nnc_internal.h" |
4 | | #include "ccv_internal.h" |
5 | | #include "_ccv_nnc_graph.h" |
6 | | |
7 | | // MARK - Level-3.5 API |
8 | | |
9 | | void ccv_nnc_tensor_multiview(ccv_nnc_tensor_t* data[], const uint8_t kind, const uint16_t repeat, const ccv_nnc_graph_t* const graph, ccv_nnc_tensor_multiview_t* const tensor_multiview) |
10 | 76 | { |
11 | 76 | assert(kind == CCV_NNC_MULTIVIEW_K0N || kind == CCV_NNC_MULTIVIEW_K1N); |
12 | 76 | assert(repeat > 0); |
13 | 76 | tensor_multiview->type = CCV_TENSOR_MULTIVIEW; |
14 | 76 | tensor_multiview->kind = kind; |
15 | 76 | tensor_multiview->repeat = repeat; |
16 | 76 | tensor_multiview->anchor = (intptr_t)graph; |
17 | 76 | tensor_multiview->it = 0; |
18 | 76 | tensor_multiview->p = 0; |
19 | 76 | tensor_multiview->offset = 0; |
20 | 76 | tensor_multiview->sp = 0; |
21 | 76 | tensor_multiview->_heap_data = (repeat + kind <= sizeof(tensor_multiview->_inline_data) / sizeof(tensor_multiview->_inline_data[0])) ? 0 : ccmalloc0 (sizeof(ccv_nnc_tensor_t*) * (repeat + kind))0 ; |
22 | 76 | int i; |
23 | | // Currently, only CCV_NNC_MULTIVIEW_K12 uses 3 tensors. |
24 | 244 | for (i = 0; i < repeat + kind; i++168 ) |
25 | 168 | { |
26 | 168 | CCV_NNC_MULTIVIEW_DATA(tensor_multiview)[i] = data[i]; |
27 | 168 | ccv_nnc_tensor_multiview_t* const mv = (ccv_nnc_tensor_multiview_t*)data[i]; |
28 | 168 | if (data[i] != CCV_NNC_TENSOR_PLACEHOLDER && CCV_IS_TENSOR_MULTIVIEW160 (mv)) |
29 | 9 | mv->p = tensor_multiview; |
30 | 168 | } |
31 | 76 | } |
32 | | |
33 | | void ccv_nnc_tensor_multiview_free(const ccv_nnc_tensor_multiview_t tensor_multiview) |
34 | 65 | { |
35 | 65 | if (tensor_multiview.sp) |
36 | 36 | ccv_array_free(tensor_multiview.sp); |
37 | 65 | if (tensor_multiview._heap_data) |
38 | 0 | ccfree(tensor_multiview._heap_data); |
39 | 65 | } |
40 | | |
41 | | void ccv_nnc_tensor_synchronize_to_multiview(ccv_nnc_tensor_multiview_t* const tensor_multiview, ccv_nnc_tensor_t* const tensor) |
42 | 56 | { |
43 | 56 | if (!tensor_multiview->sp) |
44 | 36 | tensor_multiview->sp = ccv_array_new(sizeof(ccv_nnc_tensor_t*), 0, 0); |
45 | 56 | ccv_array_push(tensor_multiview->sp, &tensor); |
46 | 56 | } |
47 | | |
48 | | void ccv_nnc_tensor_multiview_synchronize(ccv_nnc_tensor_multiview_t* const tensor_multiview) |
49 | 521 | { |
50 | 521 | assert(tensor_multiview->it && !CCV_IS_TENSOR_MULTIVIEW(tensor_multiview->it)); |
51 | | // Update the pointer on tv only if it is not a single tensor pointer. |
52 | | // TODO: This will not work with fat pointers (MPS). |
53 | 521 | ccv_numeric_data_t data = tensor_multiview->it->data; |
54 | 521 | off_t dataof = tensor_multiview->it->dataof; |
55 | 521 | ccv_nnc_tensor_data_add(tensor_multiview->it->info, -tensor_multiview->offset, &data, &dataof); |
56 | 521 | const ccv_nnc_tensor_multiview_t* c = tensor_multiview; |
57 | 521 | int i; |
58 | 673 | do { |
59 | 673 | if (c->sp) |
60 | 722 | for (i = 0; 264 i < c->sp->rnum; i++458 ) |
61 | 458 | { |
62 | 458 | ccv_nnc_tensor_t* const tensor = *(ccv_nnc_tensor_t**)ccv_array_get(c->sp, i); |
63 | 458 | if (CCV_IS_TENSOR_VIEW(tensor)) |
64 | 10 | { |
65 | 10 | ccv_nnc_tensor_view_t* const tensor_view = (ccv_nnc_tensor_view_t*)tensor; |
66 | 10 | ccv_nnc_tensor_data(tensor_view->info, data.u8, tensor_view->off + dataof, &tensor_view->data, &tensor_view->dataof); |
67 | 448 | } else { |
68 | 448 | tensor->data = data; |
69 | 448 | tensor->dataof = dataof; |
70 | 448 | } |
71 | 458 | } |
72 | 673 | c = c->p; |
73 | 673 | } while (c); |
74 | 521 | } |
75 | | |
76 | | ccv_nnc_graph_exec_t ccv_nnc_graph_while(ccv_nnc_graph_t* const graph, const uint32_t cmd, ccv_nnc_graph_t* const while_graph) |
77 | 27 | { |
78 | 27 | assert(cmd == CCV_NNC_GRAPH_FORWARD || cmd == CCV_NNC_GRAPH_BACKWARD); |
79 | 27 | ccv_nnc_graph_exec_t while_exec = ccv_nnc_graph_exec_new(graph, ccv_nnc_cmd(cmd, 0, CMD_GENERIC(), 0), ccv_nnc_no_hint, 0, 0, 0, 0); |
80 | 27 | ccv_nnc_graph_exec_info_t* while_exec_info = (ccv_nnc_graph_exec_info_t*)ccv_array_get(graph->exec_info, while_exec.d); |
81 | 27 | while_exec_info->flags |= CCV_NNC_GRAPH_EXEC_P_WHILE; |
82 | 27 | if (!graph->sub_graphs) |
83 | 25 | graph->sub_graphs = ccv_array_new(sizeof(ccv_nnc_graph_t*), 1, 0); |
84 | 27 | int i; |
85 | 27 | if (while_graph->tensor_wraps_refs) |
86 | 0 | { |
87 | | // Copy wraps from sub graph to parent graph. |
88 | 0 | if (!graph->tensor_wraps_refs) |
89 | 0 | graph->tensor_wraps_refs = ccv_array_new(sizeof(ccv_nnc_graph_tensor_wraps_ref_t), while_graph->tensor_wraps_refs->rnum, 0); |
90 | 0 | for (i = 0; i < while_graph->tensor_wraps_refs->rnum; i++) |
91 | 0 | ccv_array_push(graph->tensor_wraps_refs, ccv_array_get(while_graph->tensor_wraps_refs, i)); |
92 | 0 | } |
93 | 27 | ccv_array_push(graph->sub_graphs, &while_graph); |
94 | 27 | while_graph->p = graph; |
95 | 27 | while_graph->p_idx = graph->sub_graphs->rnum; |
96 | 27 | while_graph->exec_idx = while_exec.d + 1; |
97 | 27 | while_exec_info->graph_ref_size = 1; |
98 | 27 | CCV_NNC_GRAPH_REF(while_exec_info)[0] = graph->sub_graphs->rnum; |
99 | 27 | return while_exec; |
100 | 27 | } |
101 | | |
102 | | ccv_nnc_graph_t* ccv_nnc_graph_from_while_exec(const ccv_nnc_graph_t* const graph, ccv_nnc_graph_exec_t exec) |
103 | 1 | { |
104 | 1 | assert(exec.graph == graph); |
105 | 1 | assert(exec.d < graph->exec_info->rnum); |
106 | 1 | assert(graph->sub_graphs); |
107 | 1 | ccv_nnc_graph_exec_info_t* exec_info = (ccv_nnc_graph_exec_info_t*)ccv_array_get(graph->exec_info, exec.d); |
108 | 1 | assert(CCV_NNC_GRAPH_REF(exec_info)[0]); |
109 | 1 | const int graph_ref = CCV_NNC_GRAPH_REF(exec_info)[0] - 1; |
110 | 1 | assert(graph_ref < graph->sub_graphs->rnum); |
111 | 1 | ccv_nnc_graph_t* sub_graph = *(ccv_nnc_graph_t**)ccv_array_get(graph->sub_graphs, graph_ref); |
112 | 1 | return sub_graph; |
113 | 1 | } |
114 | | |
115 | | ccv_nnc_tensor_t ccv_nnc_tensor_for_while_count(const ccv_nnc_graph_t* const while_graph) |
116 | 23 | { |
117 | 23 | return ccv_nnc_tensor(&while_graph->while_count, CPU_TENSOR_NHWC(64S, 1), 0); |
118 | 23 | } |
119 | | |
120 | | void ccv_nnc_graph_set_while_expr(ccv_nnc_graph_t* const while_graph, const ccv_nnc_graph_while_f while_expr, const void* const while_data, ccv_nnc_tensor_t* const* const inputs, const int input_size, const ccv_nnc_graph_exec_t* const breakpoints, const int breakpoint_size) |
121 | 26 | { |
122 | 26 | assert(while_graph->p); |
123 | 26 | const int exec_idx = while_graph->exec_idx - 1; |
124 | 26 | assert(exec_idx >= 0 && exec_idx < while_graph->p->exec_info->rnum); |
125 | 26 | ccv_nnc_graph_exec_info_t* const exec_info = (ccv_nnc_graph_exec_info_t*)ccv_array_get(while_graph->p->exec_info, exec_idx); |
126 | 26 | assert(!exec_info->p_while.expr); |
127 | 26 | exec_info->p_while.expr = while_expr; |
128 | 26 | exec_info->p_while.data = while_data; |
129 | 26 | if (input_size > 0) |
130 | 23 | { |
131 | 23 | exec_info->p_while.input_size = input_size; |
132 | 23 | exec_info->p_while.inputs = (ccv_nnc_tensor_t**)ccmalloc(sizeof(ccv_nnc_tensor_t*) * input_size); |
133 | 23 | memcpy(exec_info->p_while.inputs, inputs, sizeof(ccv_nnc_tensor_t*) * input_size); |
134 | | // Register for unwrapping. |
135 | 23 | if (ccv_nnc_tensors_have_wraps(exec_info->p_while.inputs, input_size)) |
136 | 1 | { |
137 | 1 | ccv_nnc_graph_tensor_wrap_array_t* const tensor_wrap_array = ccv_nnc_get_tensor_wrap_array(while_graph, input_size, &exec_info->p_while.tensor_wraps_ref); |
138 | 1 | ccv_nnc_set_tensor_wraps(tensor_wrap_array->tensor_wraps, exec_info->p_while.inputs, input_size); |
139 | 1 | assert(exec_info->p_while.tensor_wraps_ref); |
140 | 1 | ccv_nnc_graph_register_tensor_wraps(while_graph, exec_info->p_while.tensor_wraps_ref - 1); |
141 | 1 | } |
142 | 23 | } |
143 | 26 | assert(breakpoint_size > 0); |
144 | 26 | while_graph->breakpoint_size = breakpoint_size; |
145 | 26 | while_graph->breakpoints = (ccv_nnc_graph_exec_t*)((while_graph->breakpoints) ? ccrealloc0 (while_graph->breakpoints, sizeof(ccv_nnc_graph_exec_t) * breakpoint_size)0 : ccmalloc(sizeof(ccv_nnc_graph_exec_t) * breakpoint_size)); |
146 | 26 | memcpy(while_graph->breakpoints, breakpoints, sizeof(ccv_nnc_graph_exec_t) * breakpoint_size); |
147 | 26 | } |