/home/liu/actions-runner/_work/ccv/ccv/lib/nnc/ccv_nnc_symbolic_graph_io.c
Line | Count | Source |
1 | | #include "ccv_nnc.h" |
2 | | #include "ccv_nnc_easy.h" |
3 | | #include "ccv_nnc_internal.h" |
4 | | #include "ccv_internal.h" |
5 | | #include "_ccv_nnc_symbolic_graph.h" |
6 | | #include "3rdparty/sqlite3/sqlite3.h" |
7 | | #ifdef HAVE_CUDA |
8 | | #include "gpu/ccv_nnc_compat.h" |
9 | | #endif |
10 | | |
11 | | // MARK - Level-3 API |
12 | | |
13 | | #ifdef NDEBUG |
14 | | #define SQLITE_ENFORCE(stmt) (void)(stmt) |
15 | | #else |
16 | 25 | #define SQLITE_ENFORCE assert |
17 | | #endif |
18 | | |
19 | | static int _ccv_nnc_symbolic_graph_index_in_repo(const ccv_nnc_symbolic_graph_t* const graph, const ccv_array_t* const repo) |
20 | 5 | { |
21 | 5 | if (!graph) |
22 | 2 | return -1; |
23 | 3 | int i; |
24 | 3 | for (i = 0; i < repo->rnum; i++0 ) |
25 | 3 | if (*(ccv_nnc_symbolic_graph_t**)ccv_array_get(repo, i) == graph) |
26 | 3 | return i; |
27 | 0 | return -1; |
28 | 3 | } |
29 | | |
30 | | static void _ccv_nnc_symbolic_graph_write(const ccv_nnc_symbolic_graph_t* const graph, const ccv_array_t* const repo, const int graph_idx, sqlite3_stmt* const tensor_symbol_insert_stmt, sqlite3_stmt* const exec_symbol_insert_stmt, sqlite3_stmt* const graph_insert_stmt, ccv_array_t* const ws) |
31 | 1 | { |
32 | 1 | int i; |
33 | 6 | for (i = 0; i < graph->tensor_symbol_info->rnum; i++5 ) |
34 | 5 | { |
35 | 5 | const ccv_nnc_tensor_symbol_info_t* const symbol_info = (ccv_nnc_tensor_symbol_info_t*)ccv_array_get(graph->tensor_symbol_info, i); |
36 | 5 | sqlite3_bind_int(tensor_symbol_insert_stmt, 1, i); |
37 | 5 | sqlite3_bind_int(tensor_symbol_insert_stmt, 2, graph_idx); |
38 | 5 | sqlite3_bind_int(tensor_symbol_insert_stmt, 3, symbol_info->assign_ref); |
39 | 5 | sqlite3_bind_int(tensor_symbol_insert_stmt, 4, symbol_info->r_assign_ref); |
40 | 5 | sqlite3_bind_int(tensor_symbol_insert_stmt, 5, symbol_info->bypass_ref); |
41 | 5 | sqlite3_bind_int(tensor_symbol_insert_stmt, 6, symbol_info->r_bypass_ref); |
42 | 5 | sqlite3_bind_int(tensor_symbol_insert_stmt, 7, symbol_info->p_ref); |
43 | 5 | sqlite3_bind_int(tensor_symbol_insert_stmt, 8, symbol_info->alias_ref); |
44 | 5 | sqlite3_bind_int(tensor_symbol_insert_stmt, 9, symbol_info->pair_ref); |
45 | 5 | sqlite3_bind_int(tensor_symbol_insert_stmt, 10, symbol_info->flags); |
46 | 5 | sqlite3_bind_blob(tensor_symbol_insert_stmt, 11, symbol_info->ofs, sizeof(symbol_info->ofs), 0); |
47 | 5 | sqlite3_bind_blob(tensor_symbol_insert_stmt, 12, symbol_info->stride, sizeof(symbol_info->stride), 0); |
48 | 5 | if (symbol_info->s_ref) |
49 | 0 | sqlite3_bind_blob(tensor_symbol_insert_stmt, 13, ccv_array_get(symbol_info->s_ref, 0), sizeof(int) * symbol_info->s_ref->rnum, 0); |
50 | 5 | else |
51 | 5 | sqlite3_bind_null(tensor_symbol_insert_stmt, 13); |
52 | 5 | if (symbol_info->name) |
53 | 5 | sqlite3_bind_text(tensor_symbol_insert_stmt, 14, symbol_info->name, -1, 0); |
54 | 0 | else |
55 | 0 | sqlite3_bind_null(tensor_symbol_insert_stmt, 14); |
56 | 5 | sqlite3_bind_int(tensor_symbol_insert_stmt, 15, symbol_info->info.type); |
57 | 5 | sqlite3_bind_int(tensor_symbol_insert_stmt, 16, symbol_info->info.format); |
58 | 5 | sqlite3_bind_int(tensor_symbol_insert_stmt, 17, symbol_info->info.datatype); |
59 | 5 | sqlite3_bind_blob(tensor_symbol_insert_stmt, 18, symbol_info->info.dim, sizeof(symbol_info->info.dim), 0); |
60 | 5 | SQLITE_ENFORCE(SQLITE_DONE == sqlite3_step(tensor_symbol_insert_stmt)); |
61 | 5 | sqlite3_reset(tensor_symbol_insert_stmt); |
62 | 5 | sqlite3_clear_bindings(tensor_symbol_insert_stmt); |
63 | 5 | } |
64 | 4 | for (i = 0; 1 i < graph->exec_symbol_info->rnum; i++3 ) |
65 | 3 | { |
66 | 3 | const ccv_nnc_graph_exec_symbol_info_t* const symbol_info = (ccv_nnc_graph_exec_symbol_info_t*)ccv_array_get(graph->exec_symbol_info, i); |
67 | 3 | sqlite3_bind_int(exec_symbol_insert_stmt, 1, i); |
68 | 3 | sqlite3_bind_int(exec_symbol_insert_stmt, 2, graph_idx); |
69 | 3 | sqlite3_bind_int(exec_symbol_insert_stmt, 3, symbol_info->input_size); |
70 | 3 | sqlite3_bind_int(exec_symbol_insert_stmt, 4, symbol_info->output_size); |
71 | 3 | sqlite3_bind_int(exec_symbol_insert_stmt, 5, symbol_info->graph_ref_size); |
72 | 3 | sqlite3_bind_int(exec_symbol_insert_stmt, 6, symbol_info->flags); |
73 | 3 | sqlite3_bind_int(exec_symbol_insert_stmt, 7, symbol_info->pair_ref); |
74 | 3 | if (symbol_info->input_size) |
75 | 3 | sqlite3_bind_blob(exec_symbol_insert_stmt, 8, symbol_info->inputs, sizeof(int) * symbol_info->input_size, 0); |
76 | 3 | if (symbol_info->output_size) |
77 | 3 | sqlite3_bind_blob(exec_symbol_insert_stmt, 9, symbol_info->outputs, sizeof(int) * symbol_info->output_size, 0); |
78 | 3 | if (symbol_info->outgoings && symbol_info->outgoings->rnum2 ) |
79 | 2 | sqlite3_bind_blob(exec_symbol_insert_stmt, 10, ccv_array_get(symbol_info->outgoings, 0), sizeof(int) * symbol_info->outgoings->rnum, 0); |
80 | 3 | if (symbol_info->name) |
81 | 3 | sqlite3_bind_text(exec_symbol_insert_stmt, 11, symbol_info->name, -1, 0); |
82 | 3 | sqlite3_bind_int(exec_symbol_insert_stmt, 12, symbol_info->cmd.cmd); |
83 | 3 | sqlite3_bind_int(exec_symbol_insert_stmt, 13, symbol_info->cmd.backend); |
84 | 3 | sqlite3_bind_int(exec_symbol_insert_stmt, 14, symbol_info->cmd.algorithm); |
85 | 3 | sqlite3_bind_blob(exec_symbol_insert_stmt, 15, &symbol_info->cmd.info, sizeof(symbol_info->cmd.info), 0); |
86 | 3 | sqlite3_bind_blob(exec_symbol_insert_stmt, 16, &symbol_info->hint, sizeof(symbol_info->hint), 0); |
87 | 3 | if (symbol_info->graph_ref_size) |
88 | 0 | sqlite3_bind_blob(exec_symbol_insert_stmt, 17, CCV_NNC_GRAPH_REF(symbol_info), sizeof(int) * symbol_info->graph_ref_size, 0); |
89 | 3 | if (symbol_info->flags & CCV_NNC_GRAPH_EXEC_CASE_OF) |
90 | 0 | { |
91 | 0 | sqlite3_bind_int(exec_symbol_insert_stmt, 18, 0); |
92 | 0 | sqlite3_bind_int(exec_symbol_insert_stmt, 19, symbol_info->case_of.flags); |
93 | 0 | sqlite3_bind_int(exec_symbol_insert_stmt, 20, symbol_info->case_of.argument.offset); |
94 | 0 | sqlite3_bind_int(exec_symbol_insert_stmt, 21, symbol_info->case_of.argument.size); |
95 | 0 | } |
96 | 3 | if (symbol_info->flags & CCV_NNC_GRAPH_EXEC_P_WHILE) |
97 | 0 | { |
98 | 0 | sqlite3_bind_int(exec_symbol_insert_stmt, 22, 0); |
99 | 0 | sqlite3_bind_int(exec_symbol_insert_stmt, 23, symbol_info->p_while.input_size); |
100 | 0 | if (symbol_info->p_while.input_size) |
101 | 0 | sqlite3_bind_blob(exec_symbol_insert_stmt, 24, symbol_info->p_while.inputs, sizeof(int) * symbol_info->p_while.input_size, 0); |
102 | 0 | } |
103 | 3 | SQLITE_ENFORCE(SQLITE_DONE == sqlite3_step(exec_symbol_insert_stmt)); |
104 | 3 | sqlite3_reset(exec_symbol_insert_stmt); |
105 | 3 | sqlite3_clear_bindings(exec_symbol_insert_stmt); |
106 | 3 | } |
107 | 1 | ccv_array_clear(ws); |
108 | 1 | sqlite3_bind_int(graph_insert_stmt, 1, graph_idx); |
109 | 1 | sqlite3_bind_int(graph_insert_stmt, 2, graph->tensor_symbol_info->rnum); |
110 | 1 | sqlite3_bind_int(graph_insert_stmt, 3, graph->exec_symbol_info->rnum); |
111 | 1 | if (graph->sources && graph->sources->rnum) |
112 | 3 | for (i = 0; 1 i < graph->sources->rnum; i++2 ) |
113 | 2 | ccv_array_push(ws, &((ccv_nnc_graph_exec_symbol_t*)ccv_array_get(graph->sources, i))->d); |
114 | 1 | if (graph->destinations && graph->destinations->rnum) |
115 | 2 | for (i = 0; 1 i < graph->destinations->rnum; i++1 ) |
116 | 1 | ccv_array_push(ws, &((ccv_nnc_graph_exec_symbol_t*)ccv_array_get(graph->destinations, i))->d); |
117 | 1 | if (graph->sub_graphs && graph->sub_graphs->rnum0 ) |
118 | 0 | for (i = 0; i < graph->sub_graphs->rnum; i++) |
119 | 0 | { |
120 | 0 | const int sub_graph_idx = _ccv_nnc_symbolic_graph_index_in_repo(*(ccv_nnc_symbolic_graph_t**)ccv_array_get(graph->sub_graphs, i), repo); |
121 | 0 | ccv_array_push(ws, &sub_graph_idx); |
122 | 0 | } |
123 | 1 | if (graph->breakpoint_size && graph->breakpoints0 ) |
124 | 0 | for (i = 0; i < graph->breakpoint_size; i++) |
125 | 0 | ccv_array_push(ws, &graph->breakpoints[i].d); |
126 | 1 | const int* pos = (int*)ccv_array_get(ws, 0); |
127 | 1 | if (graph->sources && graph->sources->rnum) |
128 | 1 | { |
129 | 1 | sqlite3_bind_blob(graph_insert_stmt, 4, pos, sizeof(int) * graph->sources->rnum, 0); |
130 | 1 | pos += graph->sources->rnum; |
131 | 1 | } |
132 | 1 | if (graph->destinations && graph->destinations->rnum) |
133 | 1 | { |
134 | 1 | sqlite3_bind_blob(graph_insert_stmt, 5, pos, sizeof(int) * graph->destinations->rnum, 0); |
135 | 1 | pos += graph->destinations->rnum; |
136 | 1 | } |
137 | 1 | if (graph->sub_graphs && graph->sub_graphs->rnum0 ) |
138 | 0 | { |
139 | 0 | sqlite3_bind_blob(graph_insert_stmt, 6, pos, sizeof(int) * graph->sub_graphs->rnum, 0); |
140 | 0 | pos += graph->sub_graphs->rnum; |
141 | 0 | } |
142 | 1 | sqlite3_bind_int(graph_insert_stmt, 7, _ccv_nnc_symbolic_graph_index_in_repo(graph->pair, repo)); |
143 | 1 | sqlite3_bind_int(graph_insert_stmt, 8, _ccv_nnc_symbolic_graph_index_in_repo(graph->p, repo)); |
144 | 1 | sqlite3_bind_int(graph_insert_stmt, 9, graph->p_idx); |
145 | 1 | sqlite3_bind_int(graph_insert_stmt, 10, graph->exec_idx); |
146 | 1 | sqlite3_bind_int(graph_insert_stmt, 11, graph->breakpoint_size); |
147 | 1 | if (graph->breakpoint_size && graph->breakpoints0 ) |
148 | 0 | sqlite3_bind_blob(graph_insert_stmt, 12, pos, sizeof(int) * graph->breakpoint_size, 0); |
149 | 1 | sqlite3_bind_int(graph_insert_stmt, 13, graph->backward.tensor_symbol_size); |
150 | 1 | if (graph->backward.tensor_symbol_size) |
151 | 0 | sqlite3_bind_blob(graph_insert_stmt, 14, graph->backward.tensor_symbol_idx, sizeof(int) * graph->backward.tensor_symbol_size, 0); |
152 | 1 | sqlite3_bind_int(graph_insert_stmt, 15, graph->backward.exec_symbol_size); |
153 | 1 | if (graph->backward.exec_symbol_size) |
154 | 0 | sqlite3_bind_blob(graph_insert_stmt, 16, graph->backward.exec_symbol_idx, sizeof(int) * graph->backward.exec_symbol_size, 0); |
155 | 1 | sqlite3_bind_int(graph_insert_stmt, 17, graph->data_parallel.count); |
156 | 1 | sqlite3_bind_int(graph_insert_stmt, 18, graph->data_parallel.tensor_symbol_size); |
157 | 1 | if (graph->data_parallel.tensor_symbol_idx) |
158 | 0 | sqlite3_bind_blob(graph_insert_stmt, 19, graph->data_parallel.tensor_symbol_idx, sizeof(int) * graph->data_parallel.tensor_symbol_size, 0); |
159 | 1 | sqlite3_bind_int(graph_insert_stmt, 20, graph->data_parallel.exec_symbol_size); |
160 | 1 | if (graph->data_parallel.exec_symbol_idx) |
161 | 0 | sqlite3_bind_blob(graph_insert_stmt, 21, graph->data_parallel.exec_symbol_idx, sizeof(int) * graph->data_parallel.exec_symbol_size, 0); |
162 | 1 | SQLITE_ENFORCE(SQLITE_DONE == sqlite3_step(graph_insert_stmt)); |
163 | 1 | sqlite3_reset(graph_insert_stmt); |
164 | 1 | sqlite3_clear_bindings(graph_insert_stmt); |
165 | 1 | } |
166 | | |
167 | | static void _ccv_nnc_symbolic_graph_push_repo(const ccv_nnc_symbolic_graph_t* const graph, ccv_array_t* const repo) |
168 | 1 | { |
169 | 1 | ccv_array_push(repo, &graph); |
170 | 1 | int i; |
171 | 1 | if (graph->sub_graphs && graph->sub_graphs->rnum0 ) |
172 | 0 | for (i = 0; i < graph->sub_graphs->rnum; i++) |
173 | 0 | { |
174 | 0 | const ccv_nnc_symbolic_graph_t* const sub_graph = *(ccv_nnc_symbolic_graph_t**)ccv_array_get(graph->sub_graphs, i); |
175 | 0 | if (sub_graph) |
176 | 0 | _ccv_nnc_symbolic_graph_push_repo(sub_graph, repo); |
177 | 0 | } |
178 | 1 | } |
179 | | |
180 | | void ccv_nnc_symbolic_graph_write(const ccv_nnc_symbolic_graph_t* const graph, const ccv_nnc_tensor_bind_t* const tensor_binds, const int tensor_bind_size, const char* const fn) |
181 | 1 | { |
182 | 1 | sqlite3* conn = 0; |
183 | 1 | if (SQLITE_OK != sqlite3_open(fn, &conn)) |
184 | 0 | return; |
185 | 1 | SQLITE_ENFORCE(SQLITE_OK == sqlite3_exec(conn, "BEGIN", 0, 0, 0)); |
186 | 1 | const char tensor_symbol_create_table_qs[] = "CREATE TABLE IF NOT EXISTS tensor_symbol " |
187 | 1 | "(id INTEGER, graph INTEGER, assign_ref INTEGER, r_assign_ref INTEGER, " |
188 | 1 | "bypass_ref INTEGER, r_bypass_ref INTEGER, p_ref INTEGER, alias_ref INTEGER, pair_ref INTEGER, " |
189 | 1 | "flags INTEGER, ofs BLOB, stride BLOB, s_ref BLOB, name TEXT, type INTEGER, format INTEGER, " |
190 | 1 | "datatype INTEGER, dim BLOB, PRIMARY KEY (id, graph))"; |
191 | 1 | SQLITE_ENFORCE(SQLITE_OK == sqlite3_exec(conn, tensor_symbol_create_table_qs, 0, 0, 0)); |
192 | 1 | const char tensor_symbol_insert_qs[] = |
193 | 1 | "REPLACE INTO tensor_symbol " |
194 | 1 | "(id, graph, assign_ref, r_assign_ref, bypass_ref, r_bypass_ref, p_ref, alias_ref, pair_ref, flags, " |
195 | 1 | "ofs, stride, s_ref, name, type, format, datatype, dim) VALUES " |
196 | 1 | "($id, $graph, $assign_ref, $r_assign_ref, $bypass_ref, $r_bypass_ref, $p_ref, $alias_ref, $pair_ref, " |
197 | 1 | "$flags, $ofs, $stride, $s_ref, $name, $type, $format, $datatype, $dim)"; |
198 | 1 | sqlite3_stmt* tensor_symbol_insert_stmt = 0; |
199 | 1 | SQLITE_ENFORCE(SQLITE_OK == sqlite3_prepare_v2(conn, tensor_symbol_insert_qs, sizeof(tensor_symbol_insert_qs), &tensor_symbol_insert_stmt, 0)); |
200 | | |
201 | 1 | const char exec_symbol_create_table_qs[] = "CREATE TABLE IF NOT EXISTS graph_exec_symbol " |
202 | 1 | "(id INTEGER, graph INTEGER, input_size INTEGER, output_size INTEGER, graph_ref_size INTEGER, " |
203 | 1 | "flags INTEGER, pair_ref INTEGER, inputs BLOB, outputs BLOB, outgoings BLOB, name TEXT, " |
204 | 1 | "cmd_cmd INTEGER, cmd_backend INTEGER, cmd_algorithm INTEGER, cmd_info BLOB, hint BLOB, graph_ref BLOB, " |
205 | 1 | "case_of_expr INTEGER, case_of_flags INTEGER, case_of_argument_offset INTEGER, case_of_argument_size INTEGER, " |
206 | 1 | "p_while_expr INTEGER, p_while_input_size INTEGER, p_while_inputs BLOB, PRIMARY KEY (id, graph))"; |
207 | 1 | SQLITE_ENFORCE(SQLITE_OK == sqlite3_exec(conn, exec_symbol_create_table_qs, 0, 0, 0)); |
208 | 1 | const char exec_symbol_insert_qs[] = |
209 | 1 | "REPLACE INTO graph_exec_symbol " |
210 | 1 | "(id, graph, input_size, output_size, graph_ref_size, flags, pair_ref, inputs, outputs, outgoings, " |
211 | 1 | "name, cmd_cmd, cmd_backend, cmd_algorithm, cmd_info, hint, graph_ref, case_of_expr, case_of_flags, " |
212 | 1 | "case_of_argument_offset, case_of_argument_size, p_while_expr, p_while_input_size, p_while_inputs) " |
213 | 1 | "VALUES ($id, $graph, $input_size, $output_size, $graph_ref_size, $flags, $pair_ref, $inputs, $outputs, " |
214 | 1 | "$outgoings, $name, $cmd_cmd, $cmd_backend, $cmd_algorithm, $cmd_info, $hint, $graph_ref, $case_of_expr, " |
215 | 1 | "$case_of_flags, $case_of_argument_offset, $case_of_argument_size, $p_while_expr, $p_while_input_size, " |
216 | 1 | "$p_while_inputs)"; |
217 | 1 | sqlite3_stmt* exec_symbol_insert_stmt = 0; |
218 | 1 | SQLITE_ENFORCE(SQLITE_OK == sqlite3_prepare_v2(conn, exec_symbol_insert_qs, sizeof(exec_symbol_insert_qs), &exec_symbol_insert_stmt, 0)); |
219 | | |
220 | 1 | const char graph_create_table_qs[] = "CREATE TABLE IF NOT EXISTS graph " |
221 | 1 | "(graph INTEGER PRIMARY KEY, tensor_symbol_size INTEGER, exec_symbol_size INTEGER, sources BLOB, " |
222 | 1 | "destinations BLOB, sub_graphs BLOB, pair INTEGER, p INTEGER, p_idx INTEGER, exec_idx INTEGER, " |
223 | 1 | "breakpoint_size INTEGER, breakpoints BLOB, backward_tensor_symbol_size INTEGER, " |
224 | 1 | "backward_tensor_symbol_idx BLOB, backward_exec_symbol_size INTEGER, backward_exec_symbol_idx BLOB, " |
225 | 1 | "parallel_count INTEGER, parallel_tensor_symbol_size INTEGER, parallel_tensor_symbol_idx BLOB, " |
226 | 1 | "parallel_exec_symbol_size INTEGER, parallel_exec_symbol_idx BLOB)"; |
227 | 1 | SQLITE_ENFORCE(SQLITE_OK == sqlite3_exec(conn, graph_create_table_qs, 0, 0, 0)); |
228 | 1 | const char graph_insert_qs[] = |
229 | 1 | "REPLACE INTO graph " |
230 | 1 | "(graph, tensor_symbol_size, exec_symbol_size, sources, destinations, sub_graphs, pair, p, p_idx, " |
231 | 1 | "exec_idx, breakpoint_size, breakpoints, backward_tensor_symbol_size, " |
232 | 1 | "backward_tensor_symbol_idx, backward_exec_symbol_size, backward_exec_symbol_idx, " |
233 | 1 | "parallel_count, parallel_tensor_symbol_size, parallel_tensor_symbol_idx, " |
234 | 1 | "parallel_exec_symbol_size, parallel_exec_symbol_idx) VALUES " |
235 | 1 | "($graph, $tensor_symbol_size, $exec_symbol_size, $sources, $destinations, $sub_graphs, $pair, $p, $p_idx, " |
236 | 1 | "$exec_idx, $breakpoint_size, $breakpoints, $backward_tensor_symbol_size, " |
237 | 1 | "$backward_tensor_symbol_idx, $backward_exec_symbol_size, $backward_exec_symbol_idx, " |
238 | 1 | "$parallel_count, $parallel_tensor_symbol_size, $parallel_tensor_symbol_idx, " |
239 | 1 | "$parallel_exec_symbol_size, $parallel_exec_symbol_idx)"; |
240 | 1 | sqlite3_stmt* graph_insert_stmt = 0; |
241 | 1 | SQLITE_ENFORCE(SQLITE_OK == sqlite3_prepare_v2(conn, graph_insert_qs, sizeof(graph_insert_qs), &graph_insert_stmt, 0)); |
242 | 1 | ccv_array_t* const repo = ccv_array_new(sizeof(ccv_nnc_symbolic_graph_t*), 1, 0); |
243 | 1 | _ccv_nnc_symbolic_graph_push_repo(graph, repo); |
244 | 1 | ccv_array_t* const ws = ccv_array_new(sizeof(int), 1, 0); |
245 | 1 | int i; |
246 | 2 | for (i = 0; i < repo->rnum; i++1 ) |
247 | 1 | _ccv_nnc_symbolic_graph_write(*(ccv_nnc_symbolic_graph_t**)ccv_array_get(repo, i), repo, i, |
248 | 1 | tensor_symbol_insert_stmt, exec_symbol_insert_stmt, graph_insert_stmt, ws); |
249 | 1 | ccv_array_free(ws); |
250 | 1 | sqlite3_finalize(tensor_symbol_insert_stmt); |
251 | 1 | sqlite3_finalize(exec_symbol_insert_stmt); |
252 | 1 | sqlite3_finalize(graph_insert_stmt); |
253 | | // Write tensor binds. |
254 | 1 | const char tensor_bind_create_table_qs[] = "CREATE TABLE IF NOT EXISTS tensor_bind " |
255 | 1 | "(id INTEGER, graph INTEGER, type INTEGER, format INTEGER, datatype INTEGER, " |
256 | 1 | "dim BLOB, data BLOB, PRIMARY KEY (id, graph))"; |
257 | 1 | SQLITE_ENFORCE(SQLITE_OK == sqlite3_exec(conn, tensor_bind_create_table_qs, 0, 0, 0)); |
258 | | // Remove everything in that table. |
259 | 1 | SQLITE_ENFORCE(SQLITE_OK == sqlite3_exec(conn, "DELETE FROM tensor_bind", 0, 0, 0)); |
260 | 1 | const char tensor_bind_insert_qs[] = |
261 | 1 | "REPLACE INTO tensor_bind " |
262 | 1 | "(id, graph, type, format, datatype, dim, data) VALUES (" |
263 | 1 | "$id, $graph, $type, $format, $datatype, $dim, $data)"; |
264 | 1 | sqlite3_stmt* tensor_bind_insert_stmt = 0; |
265 | 1 | SQLITE_ENFORCE(SQLITE_OK == sqlite3_prepare_v2(conn, tensor_bind_insert_qs, sizeof(tensor_bind_insert_qs), &tensor_bind_insert_stmt, 0)); |
266 | 1 | #ifdef HAVE_CUDA |
267 | 1 | size_t workspace_size = 0; |
268 | 1 | void* workspace = 0; |
269 | 1 | #endif |
270 | 4 | for (i = 0; i < tensor_bind_size; i++3 ) |
271 | 3 | { |
272 | 3 | const int graph_idx = _ccv_nnc_symbolic_graph_index_in_repo(tensor_binds[i].symbol.graph, repo); |
273 | 3 | if (graph_idx < 0) |
274 | 0 | continue; |
275 | 3 | sqlite3_bind_int(tensor_bind_insert_stmt, 1, tensor_binds[i].symbol.d); |
276 | 3 | sqlite3_bind_int(tensor_bind_insert_stmt, 2, graph_idx); |
277 | 3 | if (tensor_binds[i].tensor) |
278 | 0 | { |
279 | 0 | const ccv_nnc_tensor_t* const tensor = tensor_binds[i].tensor; |
280 | 0 | assert(!CCV_IS_TENSOR_VIEW(tensor)); |
281 | 0 | sqlite3_bind_int(tensor_bind_insert_stmt, 3, tensor->info.type); |
282 | 0 | sqlite3_bind_int(tensor_bind_insert_stmt, 4, tensor->info.format); |
283 | 0 | sqlite3_bind_int(tensor_bind_insert_stmt, 5, tensor->info.datatype); |
284 | 0 | sqlite3_bind_blob(tensor_bind_insert_stmt, 6, tensor->info.dim, sizeof(tensor->info.dim), 0); |
285 | 0 | const size_t data_size = ccv_nnc_tensor_data_size(tensor->info); |
286 | 0 | #ifdef HAVE_CUDA |
287 | 0 | if (CCV_TENSOR_GET_MEMORY(tensor->info.type) == CCV_TENSOR_GPU_MEMORY) |
288 | 0 | { |
289 | 0 | if (!workspace) |
290 | 0 | { |
291 | 0 | workspace = ccmalloc(data_size); |
292 | 0 | workspace_size = data_size; |
293 | 0 | } else if (data_size > workspace_size) { |
294 | 0 | workspace = ccrealloc(workspace, data_size); |
295 | 0 | workspace_size = data_size; |
296 | 0 | } |
297 | 0 | cumemcpy(workspace, CCV_TENSOR_CPU_MEMORY, tensor->data.u8, tensor->info.type, data_size); |
298 | 0 | sqlite3_bind_blob(tensor_bind_insert_stmt, 7, workspace, data_size, 0); |
299 | 0 | } else |
300 | 0 | sqlite3_bind_blob(tensor_bind_insert_stmt, 7, tensor->data.u8, data_size, 0); |
301 | | #else |
302 | | sqlite3_bind_blob(tensor_bind_insert_stmt, 7, tensor->data.u8, data_size, 0); |
303 | | #endif |
304 | 3 | } else { |
305 | 3 | assert(tensor_binds[i].symbol.d >= 0); |
306 | 3 | const ccv_nnc_tensor_symbol_info_t* const symbol_info = (ccv_nnc_tensor_symbol_info_t*)ccv_array_get(graph->tensor_symbol_info, tensor_binds[i].symbol.d); |
307 | 3 | sqlite3_bind_int(tensor_bind_insert_stmt, 3, symbol_info->info.type); |
308 | 3 | sqlite3_bind_int(tensor_bind_insert_stmt, 4, symbol_info->info.format); |
309 | 3 | sqlite3_bind_int(tensor_bind_insert_stmt, 5, symbol_info->info.datatype); |
310 | 3 | sqlite3_bind_blob(tensor_bind_insert_stmt, 6, symbol_info->info.dim, sizeof(symbol_info->info.dim), 0); |
311 | 3 | } |
312 | 3 | sqlite3_step(tensor_bind_insert_stmt); |
313 | 3 | sqlite3_reset(tensor_bind_insert_stmt); |
314 | 3 | sqlite3_clear_bindings(tensor_bind_insert_stmt); |
315 | 3 | } |
316 | 1 | sqlite3_finalize(tensor_bind_insert_stmt); |
317 | 1 | #ifdef HAVE_CUDA |
318 | 1 | if (workspace) |
319 | 0 | ccfree(workspace); |
320 | 1 | #endif |
321 | 1 | ccv_array_free(repo); |
322 | 1 | SQLITE_ENFORCE(SQLITE_OK == sqlite3_exec(conn, "COMMIT", 0, 0, 0)); |
323 | 1 | sqlite3_close(conn); |
324 | 1 | } |
325 | | |
326 | | static ccv_nnc_symbolic_graph_t* _ccv_nnc_symbolic_graph_get(const ccv_array_t* const repo, const ccv_nnc_symbolic_graph_t* const pos) |
327 | 0 | { |
328 | 0 | const int idx = (uintptr_t)pos >> 1; |
329 | 0 | assert(idx < repo->rnum); |
330 | 0 | return *(ccv_nnc_symbolic_graph_t**)ccv_array_get(repo, idx); |
331 | 0 | } |
332 | | |
333 | 0 | #define CCV_NNC_IS_SYMBOLIC_GRAPH_POS(ptr) ((uintptr_t)(ptr) & 1) |
334 | | |
335 | | static ccv_nnc_symbolic_graph_t* _ccv_nnc_symbolic_graph_pos(const int idx) |
336 | 2 | { |
337 | 2 | if (idx < 0) |
338 | 2 | return 0; // This is nil. |
339 | 0 | return (ccv_nnc_symbolic_graph_t*)(((uintptr_t)idx << 1) + 1); |
340 | 2 | } |
341 | | |
342 | | static void _ccv_nnc_symbolic_graph_read(const int graph_idx, sqlite3_stmt* const graph_select_stmt, sqlite3_stmt* const tensor_symbol_select_stmt, sqlite3_stmt* const exec_symbol_select_stmt, ccv_nnc_symbolic_graph_t* const graph) |
343 | 1 | { |
344 | 1 | int i, j; |
345 | 1 | ccv_array_resize(graph->tensor_symbol_info, sqlite3_column_int(graph_select_stmt, 1)); |
346 | 1 | ccv_array_resize(graph->exec_symbol_info, sqlite3_column_int(graph_select_stmt, 2)); |
347 | 1 | if (sqlite3_column_blob(graph_select_stmt, 3)) |
348 | 1 | { |
349 | 1 | const int* const sources = sqlite3_column_blob(graph_select_stmt, 3); |
350 | 1 | const int count = sqlite3_column_bytes(graph_select_stmt, 3) / sizeof(int); |
351 | 1 | graph->sources = ccv_array_new(sizeof(ccv_nnc_graph_exec_symbol_t), count, 0); |
352 | 3 | for (i = 0; i < count; i++2 ) |
353 | 2 | { |
354 | 2 | const ccv_nnc_graph_exec_symbol_t symbol = { |
355 | 2 | .graph = graph, |
356 | 2 | .d = sources[i] |
357 | 2 | }; |
358 | 2 | ccv_array_push(graph->sources, &symbol); |
359 | 2 | } |
360 | 1 | } |
361 | 1 | if (sqlite3_column_blob(graph_select_stmt, 4)) |
362 | 1 | { |
363 | 1 | const int* const destinations = sqlite3_column_blob(graph_select_stmt, 4); |
364 | 1 | const int count = sqlite3_column_bytes(graph_select_stmt, 4) / sizeof(int); |
365 | 1 | graph->destinations = ccv_array_new(sizeof(ccv_nnc_graph_exec_symbol_t), count, 0); |
366 | 2 | for (i = 0; i < count; i++1 ) |
367 | 1 | { |
368 | 1 | const ccv_nnc_graph_exec_symbol_t symbol = { |
369 | 1 | .graph = graph, |
370 | 1 | .d = destinations[i] |
371 | 1 | }; |
372 | 1 | ccv_array_push(graph->destinations, &symbol); |
373 | 1 | } |
374 | 1 | } |
375 | 1 | if (sqlite3_column_blob(graph_select_stmt, 5)) |
376 | 0 | { |
377 | 0 | const int* const sub_graphs = sqlite3_column_blob(graph_select_stmt, 5); |
378 | 0 | const int count = sqlite3_column_bytes(graph_select_stmt, 5) / sizeof(int); |
379 | 0 | graph->sub_graphs = ccv_array_new(sizeof(ccv_nnc_symbolic_graph_t*), count, 0); |
380 | 0 | for (i = 0; i < count; i++) |
381 | 0 | { |
382 | 0 | const ccv_nnc_symbolic_graph_t* const sub_graph = _ccv_nnc_symbolic_graph_pos(sub_graphs[i]); |
383 | 0 | ccv_array_push(graph->sub_graphs, &sub_graph); |
384 | 0 | } |
385 | 0 | } |
386 | 1 | graph->pair = _ccv_nnc_symbolic_graph_pos(sqlite3_column_int(graph_select_stmt, 6)); |
387 | 1 | graph->p = _ccv_nnc_symbolic_graph_pos(sqlite3_column_int(graph_select_stmt, 7)); |
388 | 1 | graph->p_idx = sqlite3_column_int(graph_select_stmt, 8); |
389 | 1 | graph->exec_idx = sqlite3_column_int(graph_select_stmt, 9); |
390 | 1 | graph->breakpoint_size = sqlite3_column_int(graph_select_stmt, 10); |
391 | 1 | if (graph->breakpoint_size) |
392 | 0 | { |
393 | 0 | graph->breakpoints = (ccv_nnc_graph_exec_symbol_t*)ccmalloc(sizeof(ccv_nnc_graph_exec_symbol_t) * graph->breakpoint_size); |
394 | 0 | assert(sizeof(int) * graph->breakpoint_size == sqlite3_column_bytes(graph_select_stmt, 11)); |
395 | 0 | const int* const breakpoints = sqlite3_column_blob(graph_select_stmt, 11); |
396 | 0 | for (i = 0; i < graph->breakpoint_size; i++) |
397 | 0 | graph->breakpoints[i] = (ccv_nnc_graph_exec_symbol_t){ |
398 | 0 | .d = breakpoints[i], |
399 | 0 | .graph = graph |
400 | 0 | }; |
401 | 0 | } |
402 | 1 | graph->backward.tensor_symbol_size = sqlite3_column_int(graph_select_stmt, 12); |
403 | 1 | graph->backward.exec_symbol_size = sqlite3_column_int(graph_select_stmt, 14); |
404 | 1 | if (graph->backward.tensor_symbol_size || graph->backward.exec_symbol_size) |
405 | 0 | graph->backward.tensor_symbol_idx = (int*)ccmalloc(sizeof(int) * (graph->backward.tensor_symbol_size + graph->backward.exec_symbol_size)); |
406 | 1 | if (graph->backward.tensor_symbol_size) |
407 | 0 | { |
408 | 0 | assert(sizeof(int) * graph->backward.tensor_symbol_size == sqlite3_column_bytes(graph_select_stmt, 13)); |
409 | 0 | const int* const backward_tensor_symbol_idx = sqlite3_column_blob(graph_select_stmt, 13); |
410 | 0 | memcpy(graph->backward.tensor_symbol_idx, backward_tensor_symbol_idx, sizeof(int) * graph->backward.tensor_symbol_size); |
411 | 0 | } |
412 | 1 | if (graph->backward.exec_symbol_size) |
413 | 0 | { |
414 | 0 | graph->backward.exec_symbol_idx = graph->backward.tensor_symbol_idx + graph->backward.tensor_symbol_size; |
415 | 0 | assert(sizeof(int) * graph->backward.exec_symbol_size == sqlite3_column_bytes(graph_select_stmt, 15)); |
416 | 0 | const int* const backward_exec_symbol_idx = sqlite3_column_blob(graph_select_stmt, 15); |
417 | 0 | memcpy(graph->backward.exec_symbol_idx, backward_exec_symbol_idx, sizeof(int) * graph->backward.exec_symbol_size); |
418 | 0 | } |
419 | 1 | graph->data_parallel.count = sqlite3_column_int(graph_select_stmt, 16); |
420 | 1 | graph->data_parallel.tensor_symbol_size = sqlite3_column_int(graph_select_stmt, 17); |
421 | 1 | if (graph->data_parallel.tensor_symbol_size) |
422 | 0 | { |
423 | 0 | graph->data_parallel.tensor_symbol_idx = (int*)ccmalloc(sizeof(int) * graph->data_parallel.tensor_symbol_size); |
424 | 0 | assert(sizeof(int) * graph->data_parallel.tensor_symbol_size == sqlite3_column_bytes(graph_select_stmt, 18)); |
425 | 0 | const int* const parallel_tensor_symbol_idx = sqlite3_column_blob(graph_select_stmt, 18); |
426 | 0 | memcpy(graph->data_parallel.tensor_symbol_idx, parallel_tensor_symbol_idx, sizeof(int) * graph->data_parallel.tensor_symbol_size); |
427 | 0 | } |
428 | 1 | graph->data_parallel.exec_symbol_size = sqlite3_column_int(graph_select_stmt, 19); |
429 | 1 | if (graph->data_parallel.exec_symbol_size) |
430 | 0 | { |
431 | 0 | graph->data_parallel.exec_symbol_idx = (int*)ccmalloc(sizeof(int) * graph->data_parallel.exec_symbol_size); |
432 | 0 | assert(sizeof(int) * graph->data_parallel.exec_symbol_size == sqlite3_column_bytes(graph_select_stmt, 20)); |
433 | 0 | const int* const parallel_exec_symbol_idx = sqlite3_column_blob(graph_select_stmt, 20); |
434 | 0 | memcpy(graph->data_parallel.exec_symbol_idx, parallel_exec_symbol_idx, sizeof(int) * graph->data_parallel.exec_symbol_size); |
435 | 0 | } |
436 | 1 | sqlite3_bind_int(tensor_symbol_select_stmt, 1, graph_idx); |
437 | 6 | for (i = 0; SQLITE_ROW == sqlite3_step(tensor_symbol_select_stmt); i++5 ) |
438 | 5 | { |
439 | 5 | assert(sqlite3_column_int(tensor_symbol_select_stmt, 0) == i); // id should match. |
440 | 5 | assert(i < graph->tensor_symbol_info->rnum); |
441 | 5 | assert(sqlite3_column_int(tensor_symbol_select_stmt, 0) == i); |
442 | 5 | ccv_nnc_tensor_symbol_info_t* const symbol_info = (ccv_nnc_tensor_symbol_info_t*)ccv_array_get(graph->tensor_symbol_info, i); |
443 | 5 | symbol_info->assign_ref = sqlite3_column_int(tensor_symbol_select_stmt, 1); |
444 | 5 | symbol_info->r_assign_ref = sqlite3_column_int(tensor_symbol_select_stmt, 2); |
445 | 5 | symbol_info->bypass_ref = sqlite3_column_int(tensor_symbol_select_stmt, 3); |
446 | 5 | symbol_info->r_bypass_ref = sqlite3_column_int(tensor_symbol_select_stmt, 4); |
447 | 5 | symbol_info->p_ref = sqlite3_column_int(tensor_symbol_select_stmt, 5); |
448 | 5 | symbol_info->alias_ref = sqlite3_column_int(tensor_symbol_select_stmt, 6); |
449 | 5 | symbol_info->pair_ref = sqlite3_column_int(tensor_symbol_select_stmt, 7); |
450 | 5 | symbol_info->flags = sqlite3_column_int(tensor_symbol_select_stmt, 8); |
451 | 5 | memset(symbol_info->ofs, 0, sizeof(symbol_info->ofs)); |
452 | 5 | const int* const ofs = sqlite3_column_blob(tensor_symbol_select_stmt, 9); |
453 | 5 | if (ofs) |
454 | 5 | memcpy(symbol_info->ofs, ofs, ccv_min(sqlite3_column_bytes(tensor_symbol_select_stmt, 8), sizeof(symbol_info->ofs))); |
455 | 5 | memset(symbol_info->stride, 0, sizeof(symbol_info->stride)); |
456 | 5 | const int* const stride = sqlite3_column_blob(tensor_symbol_select_stmt, 10); |
457 | 5 | if (stride) |
458 | 5 | memcpy(symbol_info->stride, stride, ccv_min(sqlite3_column_bytes(tensor_symbol_select_stmt, 9), sizeof(symbol_info->stride))); |
459 | 5 | const int* const s_ref = sqlite3_column_blob(tensor_symbol_select_stmt, 11); |
460 | 5 | if (s_ref) |
461 | 0 | { |
462 | 0 | const int count = sqlite3_column_bytes(tensor_symbol_select_stmt, 11) / sizeof(int); |
463 | 0 | symbol_info->s_ref = ccv_array_new(sizeof(int), count, 0); |
464 | 0 | ccv_array_resize(symbol_info->s_ref, count); |
465 | 0 | memcpy(ccv_array_get(symbol_info->s_ref, 0), s_ref, sizeof(int) * count); |
466 | 0 | } else |
467 | 5 | symbol_info->s_ref = 0; |
468 | 5 | const char* const name = (char*)sqlite3_column_text(tensor_symbol_select_stmt, 12); |
469 | 5 | if (name) |
470 | 5 | { |
471 | 5 | const int count = sqlite3_column_bytes(tensor_symbol_select_stmt, 12); |
472 | 5 | symbol_info->name = (char*)ccmalloc(sizeof(char) * (count + 1)); |
473 | 5 | memcpy(symbol_info->name, name, count); |
474 | 5 | symbol_info->name[count] = 0; // null terminator |
475 | 5 | } else |
476 | 0 | symbol_info->name = 0; |
477 | 5 | symbol_info->info.type = sqlite3_column_int(tensor_symbol_select_stmt, 13); |
478 | 5 | symbol_info->info.format = sqlite3_column_int(tensor_symbol_select_stmt, 14); |
479 | 5 | symbol_info->info.datatype = sqlite3_column_int(tensor_symbol_select_stmt, 15); |
480 | 5 | memset(symbol_info->info.dim, 0, sizeof(symbol_info->info.dim)); |
481 | 5 | const int* const dim = sqlite3_column_blob(tensor_symbol_select_stmt, 16); |
482 | 5 | if (dim) |
483 | 5 | memcpy(symbol_info->info.dim, dim, ccv_min(sqlite3_column_bytes(tensor_symbol_select_stmt, 16), sizeof(symbol_info->info.dim))); |
484 | 5 | if (CCV_NNC_TENSOR_SYMBOL_IS_DEAD(symbol_info->flags) && graph->reuse.tensor < 00 ) |
485 | 0 | graph->reuse.tensor = i; |
486 | 5 | } |
487 | 1 | sqlite3_reset(tensor_symbol_select_stmt); |
488 | 1 | sqlite3_clear_bindings(tensor_symbol_select_stmt); |
489 | 1 | sqlite3_bind_int(exec_symbol_select_stmt, 1, graph_idx); |
490 | 4 | for (i = 0; SQLITE_ROW == sqlite3_step(exec_symbol_select_stmt); i++3 ) |
491 | 3 | { |
492 | 3 | assert(sqlite3_column_int(exec_symbol_select_stmt, 0) == i); // id should match. |
493 | 3 | assert(i < graph->exec_symbol_info->rnum); |
494 | 3 | assert(sqlite3_column_int(exec_symbol_select_stmt, 0) == i); |
495 | 3 | ccv_nnc_graph_exec_symbol_info_t* const symbol_info = (ccv_nnc_graph_exec_symbol_info_t*)ccv_array_get(graph->exec_symbol_info, i); |
496 | 3 | memset(symbol_info, 0, sizeof(ccv_nnc_graph_exec_symbol_info_t)); |
497 | 3 | symbol_info->input_size = sqlite3_column_int(exec_symbol_select_stmt, 1); |
498 | 3 | symbol_info->output_size = sqlite3_column_int(exec_symbol_select_stmt, 2); |
499 | 3 | symbol_info->graph_ref_size = sqlite3_column_int(exec_symbol_select_stmt, 3); |
500 | 3 | symbol_info->flags = sqlite3_column_int(exec_symbol_select_stmt, 4); |
501 | 3 | symbol_info->pair_ref = sqlite3_column_int(exec_symbol_select_stmt, 5); |
502 | 3 | if (symbol_info->input_size > 0 || symbol_info->output_size > 00 ) |
503 | 3 | { |
504 | 3 | symbol_info->inputs = (int*)ccmalloc(sizeof(int) * (symbol_info->input_size + symbol_info->output_size)); |
505 | 9 | for (j = 0; j < symbol_info->input_size; j++6 ) |
506 | 6 | symbol_info->inputs[j] = CCV_NNC_NO_TENSOR_SYMBOL; |
507 | 3 | symbol_info->outputs = symbol_info->inputs + symbol_info->input_size; |
508 | 6 | for (j = 0; j < symbol_info->output_size; j++3 ) |
509 | 3 | symbol_info->outputs[j] = CCV_NNC_NO_TENSOR_SYMBOL; |
510 | 3 | } |
511 | 3 | if (symbol_info->input_size) |
512 | 3 | { |
513 | 3 | const int* const inputs = sqlite3_column_blob(exec_symbol_select_stmt, 6); |
514 | 3 | if (inputs) |
515 | 3 | memcpy(symbol_info->inputs, inputs, ccv_min(sizeof(int) * symbol_info->input_size, sqlite3_column_bytes(exec_symbol_select_stmt, 6))); |
516 | 3 | } |
517 | 3 | if (symbol_info->output_size) |
518 | 3 | { |
519 | 3 | const int* const outputs = sqlite3_column_blob(exec_symbol_select_stmt, 7); |
520 | 3 | if (outputs) |
521 | 3 | memcpy(symbol_info->outputs, outputs, ccv_min(sizeof(int) * symbol_info->output_size, sqlite3_column_bytes(exec_symbol_select_stmt, 7))); |
522 | 3 | } |
523 | 3 | const int* const outgoings = sqlite3_column_blob(exec_symbol_select_stmt, 8); |
524 | 3 | if (outgoings) |
525 | 2 | { |
526 | 2 | const int count = sqlite3_column_bytes(exec_symbol_select_stmt, 8) / sizeof(int); |
527 | 2 | symbol_info->outgoings = ccv_array_new(sizeof(int), count, 0); |
528 | 2 | ccv_array_resize(symbol_info->outgoings, count); |
529 | 2 | memcpy(ccv_array_get(symbol_info->outgoings, 0), outgoings, sizeof(int) * count); |
530 | 2 | } |
531 | 3 | const char* const name = (char*)sqlite3_column_text(exec_symbol_select_stmt, 9); |
532 | 3 | if (name) |
533 | 3 | { |
534 | 3 | const int count = sqlite3_column_bytes(exec_symbol_select_stmt, 9); |
535 | 3 | symbol_info->name = (char*)ccmalloc(sizeof(char) * (count + 1)); |
536 | 3 | memcpy(symbol_info->name, name, count); |
537 | 3 | symbol_info->name[count] = 0; // null terminator |
538 | 3 | } |
539 | 3 | symbol_info->cmd.cmd = sqlite3_column_int(exec_symbol_select_stmt, 10); |
540 | 3 | symbol_info->cmd.backend = sqlite3_column_int(exec_symbol_select_stmt, 11); |
541 | 3 | symbol_info->cmd.algorithm = sqlite3_column_int(exec_symbol_select_stmt, 12); |
542 | 3 | const void* const cmd_info = sqlite3_column_blob(exec_symbol_select_stmt, 13); |
543 | 3 | if (cmd_info) |
544 | 3 | memcpy(&symbol_info->cmd.info, cmd_info, ccv_min(sizeof(symbol_info->cmd.info), sqlite3_column_bytes(exec_symbol_select_stmt, 13))); |
545 | 3 | const void* const hint = sqlite3_column_blob(exec_symbol_select_stmt, 14); |
546 | 3 | if (hint) |
547 | 3 | memcpy(&symbol_info->hint, hint, ccv_min(sizeof(symbol_info->hint), sqlite3_column_bytes(exec_symbol_select_stmt, 14))); |
548 | 3 | if (symbol_info->graph_ref_size) |
549 | 0 | { |
550 | 0 | const int* const graph_ref = sqlite3_column_blob(exec_symbol_select_stmt, 15); |
551 | 0 | if (symbol_info->graph_ref_size > sizeof(symbol_info->_inline_graph_ref) / sizeof(symbol_info->_inline_graph_ref[0])) |
552 | 0 | symbol_info->_heap_graph_ref = (int*)cccalloc(symbol_info->graph_ref_size, sizeof(int)); |
553 | 0 | if (graph_ref) |
554 | 0 | memcpy(CCV_NNC_GRAPH_REF(symbol_info), graph_ref, ccv_min(sizeof(int) * symbol_info->graph_ref_size, sqlite3_column_bytes(exec_symbol_select_stmt, 15))); |
555 | 0 | } |
556 | 3 | if (symbol_info->flags & CCV_NNC_GRAPH_EXEC_CASE_OF) |
557 | 0 | { |
558 | 0 | symbol_info->case_of.flags = sqlite3_column_int(exec_symbol_select_stmt, 17); |
559 | 0 | symbol_info->case_of.argument.offset = sqlite3_column_int(exec_symbol_select_stmt, 18); |
560 | 0 | symbol_info->case_of.argument.size = sqlite3_column_int(exec_symbol_select_stmt, 19); |
561 | 3 | } else if (symbol_info->flags & CCV_NNC_GRAPH_EXEC_P_WHILE) { |
562 | 0 | symbol_info->p_while.input_size = sqlite3_column_int(exec_symbol_select_stmt, 21); |
563 | 0 | if (symbol_info->p_while.input_size) |
564 | 0 | { |
565 | 0 | symbol_info->p_while.inputs = (int*)cccalloc(symbol_info->p_while.input_size, sizeof(int)); |
566 | 0 | const int* const inputs = sqlite3_column_blob(exec_symbol_select_stmt, 22); |
567 | 0 | if (inputs) |
568 | 0 | memcpy(symbol_info->p_while.inputs, inputs, ccv_min(sizeof(int) * symbol_info->p_while.input_size, sqlite3_column_bytes(exec_symbol_select_stmt, 22))); |
569 | 0 | } |
570 | 0 | } |
571 | 3 | if (CCV_NNC_GRAPH_EXEC_IS_DEAD(symbol_info->flags) && graph->reuse.exec < 00 ) |
572 | 0 | graph->reuse.exec = i; |
573 | 3 | } |
574 | 1 | sqlite3_reset(exec_symbol_select_stmt); |
575 | 1 | sqlite3_clear_bindings(exec_symbol_select_stmt); |
576 | 1 | } |
577 | | |
578 | | static void _ccv_nnc_symbolic_graph_rewire(const ccv_array_t* const repo, ccv_nnc_symbolic_graph_t* const graph) |
579 | 1 | { |
580 | 1 | if (graph->p && CCV_NNC_IS_SYMBOLIC_GRAPH_POS0 (graph->p)) |
581 | 0 | graph->p = _ccv_nnc_symbolic_graph_get(repo, graph->p); |
582 | 1 | if (graph->pair && CCV_NNC_IS_SYMBOLIC_GRAPH_POS0 (graph->pair)) |
583 | 0 | graph->pair = _ccv_nnc_symbolic_graph_get(repo, graph->pair); |
584 | 1 | int i; |
585 | 1 | if (graph->sub_graphs) |
586 | 0 | for (i = 0; i < graph->sub_graphs->rnum; i++) |
587 | 0 | { |
588 | 0 | ccv_nnc_symbolic_graph_t* const sub_graph = *(ccv_nnc_symbolic_graph_t**)ccv_array_get(graph->sub_graphs, i); |
589 | 0 | if (sub_graph && CCV_NNC_IS_SYMBOLIC_GRAPH_POS(sub_graph)) |
590 | 0 | *(ccv_nnc_symbolic_graph_t**)ccv_array_get(graph->sub_graphs, i) = _ccv_nnc_symbolic_graph_get(repo, sub_graph); |
591 | 0 | } |
592 | 1 | } |
593 | | |
594 | | void ccv_nnc_symbolic_graph_read(const char* const fn, ccv_nnc_symbolic_graph_t** const graph_ref, ccv_nnc_tensor_bind_t** const tensor_binds_ref, int* const tensor_bind_size_ref) |
595 | 1 | { |
596 | 1 | sqlite3* conn = 0; |
597 | 1 | if (SQLITE_OK != sqlite3_open(fn, &conn)) |
598 | 0 | return; |
599 | 1 | ccv_array_t* const repo = ccv_array_new(sizeof(ccv_nnc_symbolic_graph_t*), 1, 0); |
600 | 1 | const char graph_select_qs[] = |
601 | 1 | "SELECT graph, tensor_symbol_size, exec_symbol_size, sources, destinations, sub_graphs, pair, p, p_idx, " |
602 | 1 | "exec_idx, breakpoint_size, breakpoints, backward_tensor_symbol_size, " |
603 | 1 | "backward_tensor_symbol_idx, backward_exec_symbol_size, backward_exec_symbol_idx, " |
604 | 1 | "parallel_count, parallel_tensor_symbol_size, parallel_tensor_symbol_idx, " |
605 | 1 | "parallel_exec_symbol_size, parallel_exec_symbol_idx FROM graph ORDER BY graph"; |
606 | 1 | sqlite3_stmt* graph_select_stmt = 0; |
607 | 1 | SQLITE_ENFORCE(SQLITE_OK == sqlite3_prepare_v2(conn, graph_select_qs, sizeof(graph_select_qs), &graph_select_stmt, 0)); |
608 | 1 | sqlite3_stmt* tensor_symbol_select_stmt = 0; |
609 | 1 | const char tensor_symbol_select_qs[] = |
610 | 1 | "SELECT id, assign_ref, r_assign_ref, bypass_ref, r_bypass_ref, p_ref, alias_ref, pair_ref, flags, ofs, stride, " |
611 | 1 | "s_ref, name, type, format, datatype, dim FROM tensor_symbol WHERE graph=$graph ORDER BY id"; |
612 | 1 | SQLITE_ENFORCE(SQLITE_OK == sqlite3_prepare_v2(conn, tensor_symbol_select_qs, sizeof(tensor_symbol_select_qs), &tensor_symbol_select_stmt, 0)); |
613 | 1 | const char exec_symbol_select_qs[] = |
614 | 1 | "SELECT id, input_size, output_size, graph_ref_size, flags, pair_ref, inputs, outputs, outgoings, " |
615 | 1 | "name, cmd_cmd, cmd_backend, cmd_algorithm, cmd_info, hint, graph_ref, case_of_expr, case_of_flags, " |
616 | 1 | "case_of_argument_offset, case_of_argument_size, p_while_expr, p_while_input_size, p_while_inputs " |
617 | 1 | "FROM graph_exec_symbol WHERE graph=$graph ORDER BY id"; |
618 | 1 | sqlite3_stmt* exec_symbol_select_stmt = 0; |
619 | 1 | SQLITE_ENFORCE(SQLITE_OK == sqlite3_prepare_v2(conn, exec_symbol_select_qs, sizeof(exec_symbol_select_qs), &exec_symbol_select_stmt, 0)); |
620 | 2 | while (1 SQLITE_ROW == sqlite3_step(graph_select_stmt)) |
621 | 1 | { |
622 | 1 | ccv_nnc_symbolic_graph_t* const graph = ccv_nnc_symbolic_graph_new(); |
623 | 1 | const int graph_idx = sqlite3_column_int(graph_select_stmt, 0); |
624 | 1 | assert(graph_idx == repo->rnum); |
625 | 1 | ccv_array_push(repo, &graph); |
626 | 1 | _ccv_nnc_symbolic_graph_read(graph_idx, graph_select_stmt, tensor_symbol_select_stmt, exec_symbol_select_stmt, graph); |
627 | 1 | } |
628 | 1 | int i; |
629 | 2 | for (i = 0; i < repo->rnum; i++1 ) |
630 | 1 | _ccv_nnc_symbolic_graph_rewire(repo, *(ccv_nnc_symbolic_graph_t**)ccv_array_get(repo, i)); |
631 | 1 | *graph_ref = (repo->rnum > 0) ? *(ccv_nnc_symbolic_graph_t**)ccv_array_get(repo, 0) : 00 ; |
632 | 1 | assert((tensor_bind_size_ref && tensor_binds_ref) || (!tensor_bind_size_ref && !tensor_binds_ref)); |
633 | 1 | if (tensor_bind_size_ref && tensor_binds_ref) |
634 | 1 | { |
635 | 1 | const char tensor_bind_count_qs[] = |
636 | 1 | "SELECT COUNT(*) FROM tensor_bind"; |
637 | 1 | sqlite3_stmt* tensor_bind_count_stmt = 0; |
638 | 1 | SQLITE_ENFORCE(SQLITE_OK == sqlite3_prepare_v2(conn, tensor_bind_count_qs, sizeof(tensor_bind_count_qs), &tensor_bind_count_stmt, 0)); |
639 | 1 | sqlite3_step(tensor_bind_count_stmt); |
640 | 1 | const int tensor_bind_size = *tensor_bind_size_ref = sqlite3_column_int(tensor_bind_count_stmt, 0); |
641 | 1 | sqlite3_finalize(tensor_bind_count_stmt); |
642 | | // Respect the insert order (rowid). |
643 | 1 | if (!tensor_bind_size) |
644 | 0 | *tensor_binds_ref = 0; |
645 | 1 | else { |
646 | 1 | const char tensor_bind_select_qs[] = |
647 | 1 | "SELECT id, graph, type, format, datatype, dim, data FROM tensor_bind"; |
648 | 1 | sqlite3_stmt* tensor_bind_select_stmt = 0; |
649 | 1 | ccv_nnc_tensor_bind_t* const tensor_binds = *tensor_binds_ref = (ccv_nnc_tensor_bind_t*)ccmalloc(sizeof(ccv_nnc_tensor_bind_t) * tensor_bind_size); |
650 | 1 | SQLITE_ENFORCE(SQLITE_OK == sqlite3_prepare_v2(conn, tensor_bind_select_qs, sizeof(tensor_bind_select_qs), &tensor_bind_select_stmt, 0)); |
651 | 4 | for (i = 0; 1 SQLITE_ROW == sqlite3_step(tensor_bind_select_stmt); i++3 ) |
652 | 3 | { |
653 | 3 | assert(i < tensor_bind_size); |
654 | 3 | tensor_binds[i].symbol.d = sqlite3_column_int(tensor_bind_select_stmt, 0); |
655 | 3 | const int graph_idx = sqlite3_column_int(tensor_bind_select_stmt, 1); |
656 | 3 | assert(graph_idx < repo->rnum); |
657 | 3 | tensor_binds[i].symbol.graph = (graph_idx >= 0) ? *(ccv_nnc_symbolic_graph_t**)ccv_array_get(repo, graph_idx) : 00 ; |
658 | 3 | ccv_nnc_tensor_param_t info; |
659 | 3 | info.type = sqlite3_column_int(tensor_bind_select_stmt, 2); |
660 | 3 | info.format = sqlite3_column_int(tensor_bind_select_stmt, 3); |
661 | 3 | info.datatype = sqlite3_column_int(tensor_bind_select_stmt, 4); |
662 | 3 | const int* const dim = sqlite3_column_blob(tensor_bind_select_stmt, 5); |
663 | 3 | memset(info.dim, 0, sizeof(info.dim)); |
664 | 3 | if (dim) |
665 | 3 | memcpy(info.dim, dim, ccv_min(sizeof(info.dim), sqlite3_column_bytes(tensor_bind_select_stmt, 5))); |
666 | 3 | const void* const data = sqlite3_column_blob(tensor_bind_select_stmt, 6); |
667 | 3 | if (!data) |
668 | 3 | tensor_binds[i].tensor = 0; |
669 | 0 | else { |
670 | 0 | tensor_binds[i].tensor = ccv_nnc_tensor_new(0, info, 0); |
671 | 0 | size_t data_size = ccv_nnc_tensor_data_size(info); |
672 | 0 | #ifdef HAVE_CUDA |
673 | 0 | if (CCV_TENSOR_GET_MEMORY(info.type) == CCV_TENSOR_GPU_MEMORY) |
674 | 0 | cumemcpy(tensor_binds[i].tensor->data.u8, info.type, data, CCV_TENSOR_CPU_MEMORY, ccv_min(data_size, sqlite3_column_bytes(tensor_bind_select_stmt, 6))); |
675 | 0 | else |
676 | 0 | memcpy(tensor_binds[i].tensor->data.u8, data, ccv_min(data_size, sqlite3_column_bytes(tensor_bind_select_stmt, 6))); |
677 | | #else |
678 | | memcpy(tensor_binds[i].tensor->data.u8, data, ccv_min(data_size, sqlite3_column_bytes(tensor_bind_select_stmt, 6))); |
679 | | #endif |
680 | 0 | } |
681 | 3 | } |
682 | 1 | for (; i < tensor_bind_size; i++0 ) |
683 | 0 | { |
684 | 0 | tensor_binds[i].symbol.d = CCV_NNC_NO_TENSOR_SYMBOL; |
685 | 0 | tensor_binds[i].symbol.graph = 0; |
686 | 0 | tensor_binds[i].tensor = 0; |
687 | 0 | } |
688 | 1 | sqlite3_finalize(tensor_bind_select_stmt); |
689 | 1 | } |
690 | 1 | } |
691 | 1 | ccv_array_free(repo); |
692 | 1 | sqlite3_finalize(graph_select_stmt); |
693 | 1 | sqlite3_finalize(tensor_symbol_select_stmt); |
694 | 1 | sqlite3_finalize(exec_symbol_select_stmt); |
695 | 1 | sqlite3_close(conn); |
696 | 1 | } |