/home/liu/actions-runner/_work/ccv/ccv/test/unit/nnc/complex.tests.c
Line | Count | Source |
1 | | #include "case.h" |
2 | | #include "ccv_case.h" |
3 | | #include "ccv_nnc_case.h" |
4 | | #include <ccv.h> |
5 | | #include <nnc/ccv_nnc.h> |
6 | | #include <nnc/ccv_nnc_easy.h> |
7 | | #include "3rdparty/dsfmt/dSFMT.h" |
8 | | |
9 | | TEST_SETUP() |
10 | | { |
11 | | ccv_nnc_init(); |
12 | | } |
13 | | |
14 | | TEST_CASE("compare cmul with gemm computed result") |
15 | 1 | { |
16 | 1 | ccv_nnc_symbolic_graph_t* const symbolic_graph = ccv_nnc_symbolic_graph_new(); |
17 | 1 | const ccv_nnc_tensor_symbol_t x = ccv_nnc_tensor_symbol_new(symbolic_graph, CPU_TENSOR_NHWC(32F, 10), "x"); |
18 | 1 | const ccv_nnc_tensor_symbol_t y = ccv_nnc_tensor_symbol_new(symbolic_graph, CPU_TENSOR_NHWC(32F, 10), "y"); |
19 | 1 | const ccv_nnc_tensor_symbol_t z = ccv_nnc_tensor_symbol_new(symbolic_graph, ccv_nnc_tensor_auto, "z"); |
20 | 1 | ccv_nnc_graph_exec_symbol_new(symbolic_graph, CMD_CMUL_FORWARD(), TENSOR_SYMBOL_LIST(x, y), TENSOR_SYMBOL_LIST(z), "cmul"); |
21 | 1 | ccv_nnc_graph_exec_symbol_autogen(symbolic_graph, 0, 0, CCV_NNC_AUTOGEN_ALL_EXECS | CCV_NNC_AUTOGEN_SOURCES_AND_DESTINATIONS); |
22 | 1 | SYMBOLIC_GRAPH_GEN(symbolic_graph, CCV_NNC_LONG_DOT_GRAPH); |
23 | 1 | ccv_nnc_graph_t* graph = 0; |
24 | 1 | ccv_nnc_tensor_arena_t* tensor_arena = 0; |
25 | 1 | ccv_nnc_graph_exec_arena_t* graph_exec_arena = 0; |
26 | 1 | ccv_nnc_symbolic_graph_compile(symbolic_graph, ccv_nnc_default_compile_params, |
27 | 1 | 0, 0, TENSOR_SYMBOL_LIST(z), |
28 | 1 | SYMBOLIC_GRAPH_SOURCES(symbolic_graph), SYMBOLIC_GRAPH_DESTINATIONS(symbolic_graph), |
29 | 1 | &graph, &tensor_arena, &graph_exec_arena); |
30 | 1 | GRAPH_GEN(graph, CCV_NNC_LONG_DOT_GRAPH); |
31 | 1 | ccv_nnc_tensor_t* const x_tensor = ccv_nnc_tensor_from_symbol(tensor_arena, x); |
32 | 1 | dsfmt_t dsfmt; |
33 | 1 | int i; |
34 | 1 | dsfmt_init_gen_rand(&dsfmt, 1); |
35 | 11 | for (i = 0; i < 10; i++10 ) |
36 | 10 | x_tensor->data.f32[i] = dsfmt_genrand_open_close(&dsfmt); |
37 | 1 | ccv_nnc_tensor_t* gemm_x = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, 5, 1, 2), 0); |
38 | 1 | memcpy(gemm_x->data.f32, x_tensor->data.f32, sizeof(float) * 10); |
39 | 1 | ccv_nnc_tensor_t* const y_tensor = ccv_nnc_tensor_from_symbol(tensor_arena, y); |
40 | 11 | for (i = 0; i < 10; i++10 ) |
41 | 10 | y_tensor->data.f32[i] = dsfmt_genrand_open_close(&dsfmt); |
42 | 1 | ccv_nnc_tensor_t* gemm_y = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, 5, 2, 2), 0); |
43 | 6 | for (i = 0; i < 5; i++5 ) |
44 | 5 | { |
45 | 5 | gemm_y->data.f32[i * 4] = y_tensor->data.f32[i * 2]; |
46 | 5 | gemm_y->data.f32[i * 4 + 1] = -y_tensor->data.f32[i * 2 + 1]; |
47 | 5 | gemm_y->data.f32[i * 4 + 2] = y_tensor->data.f32[i * 2 + 1]; |
48 | 5 | gemm_y->data.f32[i * 4 + 3] = y_tensor->data.f32[i * 2]; |
49 | 5 | } |
50 | 1 | ccv_nnc_graph_run(graph, 0, TRAVERSE_FULL, 0, 0); |
51 | 1 | ccv_nnc_tensor_t* const z_tensor = ccv_nnc_tensor_from_symbol(tensor_arena, z); |
52 | 1 | ccv_nnc_tensor_t* gemm_z = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, 5, 1, 2), 0); |
53 | 1 | ccv_nnc_cmd_exec(CMD_GEMM_FORWARD(NO_TRANSPOSE, TRANSPOSE(1, 2)), ccv_nnc_no_hint, 0, TENSOR_LIST(gemm_x, gemm_y), TENSOR_LIST(gemm_z), 0); |
54 | 1 | REQUIRE_ARRAY_EQ_WITH_TOLERANCE(float, z_tensor->data.f32, gemm_z->data.f32, 10, 1e-5, "should match as if GEMM"); |
55 | 1 | ccv_nnc_symbolic_graph_free(symbolic_graph); |
56 | 1 | ccv_nnc_tensor_arena_free(tensor_arena); |
57 | 1 | ccv_nnc_graph_exec_arena_free(graph_exec_arena); |
58 | 1 | ccv_nnc_graph_free(graph); |
59 | 1 | ccv_nnc_tensor_free(gemm_x); |
60 | 1 | ccv_nnc_tensor_free(gemm_y); |
61 | 1 | ccv_nnc_tensor_free(gemm_z); |
62 | 1 | } |
63 | | |
64 | | TEST_CASE("compare cmul gradient with gemm computed result") |
65 | 1 | { |
66 | 1 | ccv_nnc_symbolic_graph_t* const symbolic_graph = ccv_nnc_symbolic_graph_new(); |
67 | 1 | const ccv_nnc_tensor_symbol_t x = ccv_nnc_tensor_symbol_new(symbolic_graph, CPU_TENSOR_NHWC(32F, 10), "x"); |
68 | 1 | const ccv_nnc_tensor_symbol_t y = ccv_nnc_tensor_symbol_new(symbolic_graph, CPU_TENSOR_NHWC(32F, 10), "y"); |
69 | 1 | const ccv_nnc_tensor_symbol_t z = ccv_nnc_tensor_symbol_new(symbolic_graph, ccv_nnc_tensor_auto, "z"); |
70 | 1 | ccv_nnc_graph_exec_symbol_new(symbolic_graph, CMD_CMUL_FORWARD(), TENSOR_SYMBOL_LIST(x, y), TENSOR_SYMBOL_LIST(z), "cmul"); |
71 | 1 | ccv_nnc_graph_exec_symbol_autogen(symbolic_graph, 0, 0, CCV_NNC_AUTOGEN_ALL_EXECS | CCV_NNC_AUTOGEN_SOURCES_AND_DESTINATIONS); |
72 | 1 | ccv_nnc_symbolic_graph_backward(symbolic_graph, TENSOR_SYMBOL_LIST(z), TENSOR_SYMBOL_LIST(x), SYMBOLIC_GRAPH_SOURCES(symbolic_graph), SYMBOLIC_GRAPH_DESTINATIONS(symbolic_graph)); |
73 | 1 | ccv_nnc_graph_exec_symbol_autogen(symbolic_graph, 0, 0, CCV_NNC_AUTOGEN_ALL_EXECS | CCV_NNC_AUTOGEN_SOURCES_AND_DESTINATIONS); |
74 | 1 | const ccv_nnc_tensor_symbol_t dx = ccv_nnc_tensor_symbol_for_backward(symbolic_graph, x); |
75 | 1 | SYMBOLIC_GRAPH_GEN(symbolic_graph, CCV_NNC_LONG_DOT_GRAPH); |
76 | 1 | ccv_nnc_graph_t* graph = 0; |
77 | 1 | ccv_nnc_tensor_arena_t* tensor_arena = 0; |
78 | 1 | ccv_nnc_graph_exec_arena_t* graph_exec_arena = 0; |
79 | 1 | ccv_nnc_symbolic_graph_compile(symbolic_graph, ccv_nnc_default_compile_params, |
80 | 1 | 0, 0, TENSOR_SYMBOL_LIST(z, dx), |
81 | 1 | SYMBOLIC_GRAPH_SOURCES(symbolic_graph), SYMBOLIC_GRAPH_DESTINATIONS(symbolic_graph), |
82 | 1 | &graph, &tensor_arena, &graph_exec_arena); |
83 | 1 | GRAPH_GEN(graph, CCV_NNC_LONG_DOT_GRAPH); |
84 | 1 | ccv_nnc_tensor_t* const x_tensor = ccv_nnc_tensor_from_symbol(tensor_arena, x); |
85 | 1 | dsfmt_t dsfmt; |
86 | 1 | int i; |
87 | 1 | dsfmt_init_gen_rand(&dsfmt, 1); |
88 | 11 | for (i = 0; i < 10; i++10 ) |
89 | 10 | x_tensor->data.f32[i] = dsfmt_genrand_open_close(&dsfmt); |
90 | 1 | const ccv_nnc_tensor_symbol_t dz = ccv_nnc_tensor_symbol_for_backward(symbolic_graph, z); |
91 | 1 | ccv_nnc_tensor_t* const dz_tensor = ccv_nnc_tensor_from_symbol(tensor_arena, dz); |
92 | 11 | for (i = 0; i < 10; i++10 ) |
93 | 10 | dz_tensor->data.f32[i] = dsfmt_genrand_open_close(&dsfmt); |
94 | 1 | ccv_nnc_symbolic_graph_t* const gemm_symbolic_graph = ccv_nnc_symbolic_graph_new(); |
95 | 1 | const ccv_nnc_tensor_symbol_t gemm_x = ccv_nnc_tensor_symbol_new(gemm_symbolic_graph, CPU_TENSOR_NHWC(32F, 5, 1, 2), "x"); |
96 | 1 | const ccv_nnc_tensor_symbol_t gemm_y = ccv_nnc_tensor_symbol_new(gemm_symbolic_graph, CPU_TENSOR_NHWC(32F, 5, 2, 2), "y"); |
97 | 1 | const ccv_nnc_tensor_symbol_t gemm_z = ccv_nnc_tensor_symbol_new(gemm_symbolic_graph, ccv_nnc_tensor_auto, "z"); |
98 | 1 | ccv_nnc_graph_exec_symbol_new(gemm_symbolic_graph, CMD_GEMM_FORWARD(NO_TRANSPOSE, TRANSPOSE(1, 2)), TENSOR_SYMBOL_LIST(gemm_x, gemm_y), TENSOR_SYMBOL_LIST(gemm_z), "gemm"); |
99 | 1 | ccv_nnc_graph_exec_symbol_autogen(gemm_symbolic_graph, 0, 0, CCV_NNC_AUTOGEN_ALL_EXECS | CCV_NNC_AUTOGEN_SOURCES_AND_DESTINATIONS); |
100 | 1 | ccv_nnc_symbolic_graph_backward(gemm_symbolic_graph, TENSOR_SYMBOL_LIST(gemm_z), TENSOR_SYMBOL_LIST(gemm_x), SYMBOLIC_GRAPH_SOURCES(gemm_symbolic_graph), SYMBOLIC_GRAPH_DESTINATIONS(gemm_symbolic_graph)); |
101 | 1 | ccv_nnc_graph_exec_symbol_autogen(gemm_symbolic_graph, 0, 0, CCV_NNC_AUTOGEN_ALL_EXECS | CCV_NNC_AUTOGEN_SOURCES_AND_DESTINATIONS); |
102 | 1 | const ccv_nnc_tensor_symbol_t dgemmx = ccv_nnc_tensor_symbol_for_backward(gemm_symbolic_graph, gemm_x); |
103 | 1 | ccv_nnc_graph_t* gemm_graph = 0; |
104 | 1 | ccv_nnc_tensor_arena_t* gemm_tensor_arena = 0; |
105 | 1 | ccv_nnc_graph_exec_arena_t* gemm_graph_exec_arena = 0; |
106 | 1 | ccv_nnc_symbolic_graph_compile(gemm_symbolic_graph, ccv_nnc_default_compile_params, |
107 | 1 | 0, 0, TENSOR_SYMBOL_LIST(gemm_z, dgemmx), |
108 | 1 | SYMBOLIC_GRAPH_SOURCES(gemm_symbolic_graph), SYMBOLIC_GRAPH_DESTINATIONS(gemm_symbolic_graph), |
109 | 1 | &gemm_graph, &gemm_tensor_arena, &gemm_graph_exec_arena); |
110 | 1 | ccv_nnc_tensor_t* gemm_x_tensor = ccv_nnc_tensor_from_symbol(gemm_tensor_arena, gemm_x); |
111 | 1 | memcpy(gemm_x_tensor->data.f32, x_tensor->data.f32, sizeof(float) * 10); |
112 | 1 | const ccv_nnc_tensor_symbol_t dgemmz = ccv_nnc_tensor_symbol_for_backward(gemm_symbolic_graph, gemm_z); |
113 | 1 | ccv_nnc_tensor_t* const dgemmz_tensor = ccv_nnc_tensor_from_symbol(gemm_tensor_arena, dgemmz); |
114 | 1 | memcpy(dgemmz_tensor->data.f32, dz_tensor->data.f32, sizeof(float) * 10); |
115 | 1 | ccv_nnc_tensor_t* const y_tensor = ccv_nnc_tensor_from_symbol(tensor_arena, y); |
116 | 11 | for (i = 0; i < 10; i++10 ) |
117 | 10 | y_tensor->data.f32[i] = dsfmt_genrand_open_close(&dsfmt); |
118 | 1 | ccv_nnc_tensor_t* gemm_y_tensor = ccv_nnc_tensor_from_symbol(gemm_tensor_arena, gemm_y); |
119 | 6 | for (i = 0; i < 5; i++5 ) |
120 | 5 | { |
121 | 5 | gemm_y_tensor->data.f32[i * 4] = y_tensor->data.f32[i * 2]; |
122 | 5 | gemm_y_tensor->data.f32[i * 4 + 1] = -y_tensor->data.f32[i * 2 + 1]; |
123 | 5 | gemm_y_tensor->data.f32[i * 4 + 2] = y_tensor->data.f32[i * 2 + 1]; |
124 | 5 | gemm_y_tensor->data.f32[i * 4 + 3] = y_tensor->data.f32[i * 2]; |
125 | 5 | } |
126 | 1 | ccv_nnc_graph_run(graph, 0, TRAVERSE_FULL, 0, 0); |
127 | 1 | ccv_nnc_graph_run(gemm_graph, 0, TRAVERSE_FULL, 0, 0); |
128 | 1 | ccv_nnc_tensor_t* const z_tensor = ccv_nnc_tensor_from_symbol(tensor_arena, z); |
129 | 1 | ccv_nnc_tensor_t* gemm_z_tensor = ccv_nnc_tensor_from_symbol(gemm_tensor_arena, gemm_z); |
130 | 1 | REQUIRE_ARRAY_EQ_WITH_TOLERANCE(float, z_tensor->data.f32, gemm_z_tensor->data.f32, 10, 1e-5, "should match as if GEMM"); |
131 | 1 | ccv_nnc_tensor_t* const dx_tensor = ccv_nnc_tensor_from_symbol(tensor_arena, dx); |
132 | 1 | ccv_nnc_tensor_t* dgemmx_tensor = ccv_nnc_tensor_from_symbol(gemm_tensor_arena, dgemmx); |
133 | 1 | REQUIRE_ARRAY_EQ_WITH_TOLERANCE(float, dx_tensor->data.f32, dgemmx_tensor->data.f32, 10, 1e-5, "should match as if GEMM"); |
134 | 1 | ccv_nnc_symbolic_graph_free(symbolic_graph); |
135 | 1 | ccv_nnc_tensor_arena_free(tensor_arena); |
136 | 1 | ccv_nnc_graph_exec_arena_free(graph_exec_arena); |
137 | 1 | ccv_nnc_graph_free(graph); |
138 | 1 | ccv_nnc_symbolic_graph_free(gemm_symbolic_graph); |
139 | 1 | ccv_nnc_tensor_arena_free(gemm_tensor_arena); |
140 | 1 | ccv_nnc_graph_exec_arena_free(gemm_graph_exec_arena); |
141 | 1 | ccv_nnc_graph_free(gemm_graph); |
142 | 1 | } |
143 | | |
144 | | #include "case_main.h" |