/home/liu/actions-runner/_work/ccv/ccv/test/unit/nnc/gemm.tests.c
Line | Count | Source |
1 | | #include "case.h" |
2 | | #include "ccv_case.h" |
3 | | #include "ccv_nnc_case.h" |
4 | | #include <ccv.h> |
5 | | #include <nnc/ccv_nnc.h> |
6 | | #include <nnc/ccv_nnc_easy.h> |
7 | | |
8 | | TEST_SETUP() |
9 | | { |
10 | | ccv_nnc_init(); |
11 | | } |
12 | | |
13 | | TEST_CASE("[[1, 2], [3, 4], [5, 6], [7, 8]] * [[7, 8, 9], [10, 11, 12]]") |
14 | 1 | { |
15 | 1 | float ap[] = { |
16 | 1 | 1, 2, |
17 | 1 | 3, 4, |
18 | 1 | 5, 6, |
19 | 1 | 7, 8, |
20 | 1 | }; |
21 | 1 | ccv_nnc_tensor_t* const a = ccv_nnc_tensor_new(ap, CPU_TENSOR_NHWC(32F, 4, 2), 0); |
22 | 1 | float bp[] = { |
23 | 1 | 7, 8, 9, |
24 | 1 | 10, 11, 12, |
25 | 1 | }; |
26 | 1 | ccv_nnc_tensor_t* const b = ccv_nnc_tensor_new(bp, CPU_TENSOR_NHWC(32F, 2, 3), 0); |
27 | 1 | ccv_nnc_tensor_t* const c = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, 4, 3), 0); |
28 | 1 | ccv_nnc_cmd_exec(CMD_GEMM_FORWARD(), ccv_nnc_no_hint, 0, TENSOR_LIST(a, b), TENSOR_LIST(c), 0); |
29 | 1 | float ctp[] = { |
30 | 1 | 1 * 7 + 2 * 10, 1 * 8 + 2 * 11, 1 * 9 + 2 * 12, |
31 | 1 | 3 * 7 + 4 * 10, 3 * 8 + 4 * 11, 3 * 9 + 4 * 12, |
32 | 1 | 5 * 7 + 6 * 10, 5 * 8 + 6 * 11, 5 * 9 + 6 * 12, |
33 | 1 | 7 * 7 + 8 * 10, 7 * 8 + 8 * 11, 7 * 9 + 8 * 12, |
34 | 1 | }; |
35 | 1 | ccv_nnc_tensor_t ct = ccv_nnc_tensor(ctp, CPU_TENSOR_NHWC(32F, 4, 3), 0); |
36 | 1 | REQUIRE_TENSOR_EQ(c, &ct, "result should be equal"); |
37 | 1 | ccv_nnc_tensor_free(a); |
38 | 1 | ccv_nnc_tensor_free(b); |
39 | 1 | ccv_nnc_tensor_free(c); |
40 | 1 | } |
41 | | |
42 | | TEST_CASE("[1, 2] * [[7, 8, 9], [10, 11, 12]]") |
43 | 1 | { |
44 | 1 | float ap[] = { |
45 | 1 | 1, 2, |
46 | 1 | }; |
47 | 1 | ccv_nnc_tensor_t* const a = ccv_nnc_tensor_new(ap, CPU_TENSOR_NHWC(32F, 2), 0); |
48 | 1 | float bp[] = { |
49 | 1 | 7, 8, 9, |
50 | 1 | 10, 11, 12, |
51 | 1 | }; |
52 | 1 | ccv_nnc_tensor_t* const b = ccv_nnc_tensor_new(bp, CPU_TENSOR_NHWC(32F, 2, 3), 0); |
53 | 1 | ccv_nnc_tensor_t* const c = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, 3), 0); |
54 | 1 | ccv_nnc_cmd_exec(CMD_GEMM_FORWARD(), ccv_nnc_no_hint, 0, TENSOR_LIST(a, b), TENSOR_LIST(c), 0); |
55 | 1 | float ctp[] = { |
56 | 1 | 1 * 7 + 2 * 10, 1 * 8 + 2 * 11, 1 * 9 + 2 * 12, |
57 | 1 | }; |
58 | 1 | ccv_nnc_tensor_t ct = ccv_nnc_tensor(ctp, CPU_TENSOR_NHWC(32F, 3), 0); |
59 | 1 | REQUIRE_TENSOR_EQ(c, &ct, "result should be equal"); |
60 | 1 | ccv_nnc_tensor_free(a); |
61 | 1 | ccv_nnc_tensor_free(b); |
62 | 1 | ccv_nnc_tensor_free(c); |
63 | 1 | } |
64 | | |
65 | | TEST_CASE("[[1, 2], [3, 4], [5, 6], [7, 8]] * [[7, 10], [8, 11], [9, 12]]^T") |
66 | 1 | { |
67 | 1 | float ap[] = { |
68 | 1 | 1, 2, |
69 | 1 | 3, 4, |
70 | 1 | 5, 6, |
71 | 1 | 7, 8, |
72 | 1 | }; |
73 | 1 | ccv_nnc_tensor_t* const a = ccv_nnc_tensor_new(ap, CPU_TENSOR_NHWC(32F, 4, 2), 0); |
74 | 1 | float bp[] = { |
75 | 1 | 7, 10, |
76 | 1 | 8, 11, |
77 | 1 | 9, 12, |
78 | 1 | }; |
79 | 1 | ccv_nnc_tensor_t* const b = ccv_nnc_tensor_new(bp, CPU_TENSOR_NHWC(32F, 3, 2), 0); |
80 | 1 | ccv_nnc_tensor_t* const c = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, 4, 3), 0); |
81 | 1 | ccv_nnc_cmd_exec(CMD_GEMM_FORWARD(NO_TRANSPOSE, TRANSPOSE(0, 1)), ccv_nnc_no_hint, 0, TENSOR_LIST(a, b), TENSOR_LIST(c), 0); |
82 | 1 | float ctp[] = { |
83 | 1 | 1 * 7 + 2 * 10, 1 * 8 + 2 * 11, 1 * 9 + 2 * 12, |
84 | 1 | 3 * 7 + 4 * 10, 3 * 8 + 4 * 11, 3 * 9 + 4 * 12, |
85 | 1 | 5 * 7 + 6 * 10, 5 * 8 + 6 * 11, 5 * 9 + 6 * 12, |
86 | 1 | 7 * 7 + 8 * 10, 7 * 8 + 8 * 11, 7 * 9 + 8 * 12, |
87 | 1 | }; |
88 | 1 | ccv_nnc_tensor_t ct = ccv_nnc_tensor(ctp, CPU_TENSOR_NHWC(32F, 4, 3), 0); |
89 | 1 | REQUIRE_TENSOR_EQ(c, &ct, "result should be equal"); |
90 | 1 | ccv_nnc_tensor_free(a); |
91 | 1 | ccv_nnc_tensor_free(b); |
92 | 1 | ccv_nnc_tensor_free(c); |
93 | 1 | } |
94 | | |
95 | | TEST_CASE("[[1, 3, 5, 7], [2, 4, 6, 8]]^T * [[7, 10], [8, 11], [9, 12]]^T") |
96 | 1 | { |
97 | 1 | float ap[] = { |
98 | 1 | 1, 3, 5, 7, |
99 | 1 | 2, 4, 6, 8, |
100 | 1 | }; |
101 | 1 | ccv_nnc_tensor_t* const a = ccv_nnc_tensor_new(ap, CPU_TENSOR_NHWC(32F, 1, 2, 4), 0); |
102 | 1 | float bp[] = { |
103 | 1 | 7, 10, |
104 | 1 | 8, 11, |
105 | 1 | 9, 12, |
106 | 1 | }; |
107 | 1 | ccv_nnc_tensor_t* const b = ccv_nnc_tensor_new(bp, CPU_TENSOR_NHWC(32F, 3, 2), 0); |
108 | 1 | ccv_nnc_tensor_t* const c = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, 1, 4, 3), 0); |
109 | 1 | ccv_nnc_cmd_exec(CMD_GEMM_FORWARD(TRANSPOSE(1, 2), TRANSPOSE(0, 1)), ccv_nnc_no_hint, 0, TENSOR_LIST(a, b), TENSOR_LIST(c), 0); |
110 | 1 | float ctp[] = { |
111 | 1 | 1 * 7 + 2 * 10, 1 * 8 + 2 * 11, 1 * 9 + 2 * 12, |
112 | 1 | 3 * 7 + 4 * 10, 3 * 8 + 4 * 11, 3 * 9 + 4 * 12, |
113 | 1 | 5 * 7 + 6 * 10, 5 * 8 + 6 * 11, 5 * 9 + 6 * 12, |
114 | 1 | 7 * 7 + 8 * 10, 7 * 8 + 8 * 11, 7 * 9 + 8 * 12, |
115 | 1 | }; |
116 | 1 | ccv_nnc_tensor_t ct = ccv_nnc_tensor(ctp, CPU_TENSOR_NHWC(32F, 1, 4, 3), 0); |
117 | 1 | REQUIRE_TENSOR_EQ(c, &ct, "result should be equal"); |
118 | 1 | ccv_nnc_tensor_free(a); |
119 | 1 | ccv_nnc_tensor_free(b); |
120 | 1 | ccv_nnc_tensor_free(c); |
121 | 1 | } |
122 | | |
123 | | TEST_CASE("[[1, 2], [3, 4], [5, 6], [7, 8]] * [[7, 8, 9], [10, 11, 12]] + [-1, 0, 1]") |
124 | 1 | { |
125 | 1 | float ap[] = { |
126 | 1 | 1, 2, |
127 | 1 | 3, 4, |
128 | 1 | 5, 6, |
129 | 1 | 7, 8, |
130 | 1 | }; |
131 | 1 | ccv_nnc_tensor_t* const a = ccv_nnc_tensor_new(ap, CPU_TENSOR_NHWC(32F, 4, 2), 0); |
132 | 1 | float bp[] = { |
133 | 1 | 7, 8, 9, |
134 | 1 | 10, 11, 12, |
135 | 1 | }; |
136 | 1 | ccv_nnc_tensor_t* const b = ccv_nnc_tensor_new(bp, CPU_TENSOR_NHWC(32F, 2, 3), 0); |
137 | 1 | float biasp[] = { |
138 | 1 | -1, 0, 1, |
139 | 1 | }; |
140 | 1 | ccv_nnc_tensor_t* const bias = ccv_nnc_tensor_new(biasp, CPU_TENSOR_NHWC(32F, 3), 0); |
141 | 1 | ccv_nnc_tensor_t* const c = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, 4, 3), 0); |
142 | 1 | ccv_nnc_cmd_exec(CMD_GEMM_FORWARD(), ccv_nnc_no_hint, 0, TENSOR_LIST(a, b, bias), TENSOR_LIST(c), 0); |
143 | 1 | float ctp[] = { |
144 | 1 | 1 * 7 + 2 * 10 - 1, 1 * 8 + 2 * 11, 1 * 9 + 2 * 12 + 1, |
145 | 1 | 3 * 7 + 4 * 10 - 1, 3 * 8 + 4 * 11, 3 * 9 + 4 * 12 + 1, |
146 | 1 | 5 * 7 + 6 * 10 - 1, 5 * 8 + 6 * 11, 5 * 9 + 6 * 12 + 1, |
147 | 1 | 7 * 7 + 8 * 10 - 1, 7 * 8 + 8 * 11, 7 * 9 + 8 * 12 + 1, |
148 | 1 | }; |
149 | 1 | ccv_nnc_tensor_t ct = ccv_nnc_tensor(ctp, CPU_TENSOR_NHWC(32F, 4, 3), 0); |
150 | 1 | REQUIRE_TENSOR_EQ(c, &ct, "result should be equal"); |
151 | 1 | ccv_nnc_tensor_free(a); |
152 | 1 | ccv_nnc_tensor_free(b); |
153 | 1 | ccv_nnc_tensor_free(bias); |
154 | 1 | ccv_nnc_tensor_free(c); |
155 | 1 | } |
156 | | |
157 | | TEST_CASE("backward gemm with no transpose") |
158 | 1 | { |
159 | 1 | float gp[] = { |
160 | 1 | 1, 2, 3, |
161 | 1 | 4, 5, 6, |
162 | 1 | 7, 8, 9, |
163 | 1 | 10, 11, 12, |
164 | 1 | }; |
165 | 1 | ccv_nnc_tensor_t* const g = ccv_nnc_tensor_new(gp, CPU_TENSOR_NHWC(32F, 4, 3), 0); |
166 | 1 | float ap[] = { |
167 | 1 | 13, 14, |
168 | 1 | 15, 16, |
169 | 1 | 17, 18, |
170 | 1 | 19, 20, |
171 | 1 | }; |
172 | 1 | ccv_nnc_tensor_t* const a = ccv_nnc_tensor_new(ap, CPU_TENSOR_NHWC(32F, 4, 2), 0); |
173 | 1 | float bp[] = { |
174 | 1 | 21, 22, 23, |
175 | 1 | 24, 25, 26, |
176 | 1 | }; |
177 | 1 | ccv_nnc_tensor_t* const b = ccv_nnc_tensor_new(bp, CPU_TENSOR_NHWC(32F, 2, 3), 0); |
178 | 1 | ccv_nnc_tensor_t* const h = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, 4, 2), 0); |
179 | 1 | ccv_nnc_tensor_t* const db = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, 2, 3), 0); |
180 | 1 | ccv_nnc_tensor_t* const dbias = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, 3), 0); |
181 | 1 | ccv_nnc_cmd_exec(CMD_GEMM_BACKWARD(), ccv_nnc_no_hint, 0, TENSOR_LIST(g, a, b), TENSOR_LIST(h, db, dbias), 0); |
182 | 1 | float dbiastp[] = { |
183 | 1 | 22, 26, 30, |
184 | 1 | }; |
185 | 1 | ccv_nnc_tensor_t dbiast = ccv_nnc_tensor(dbiastp, CPU_TENSOR_NHWC(32F, 3), 0); |
186 | 1 | REQUIRE_TENSOR_EQ(dbias, &dbiast, "bias should be equal"); |
187 | 1 | float htp[] = { |
188 | 1 | 1 * 21 + 2 * 22 + 3 * 23, 1 * 24 + 2 * 25 + 3 * 26, |
189 | 1 | 4 * 21 + 5 * 22 + 6 * 23, 4 * 24 + 5 * 25 + 6 * 26, |
190 | 1 | 7 * 21 + 8 * 22 + 9 * 23, 7 * 24 + 8 * 25 + 9 * 26, |
191 | 1 | 10 * 21 + 11 * 22 + 12 * 23, 10 * 24 + 11 * 25 + 12 * 26, |
192 | 1 | }; |
193 | 1 | ccv_nnc_tensor_t ht = ccv_nnc_tensor(htp, CPU_TENSOR_NHWC(32F, 4, 2), 0); |
194 | 1 | REQUIRE_TENSOR_EQ(h, &ht, "h should be equal"); |
195 | 1 | float dbtp[] = { |
196 | 1 | 1 * 13 + 4 * 15 + 7 * 17 + 10 * 19, 2 * 13 + 5 * 15 + 8 * 17 + 11 * 19, 3 * 13 + 6 * 15 + 9 * 17 + 12 * 19, |
197 | 1 | 1 * 14 + 4 * 16 + 7 * 18 + 10 * 20, 2 * 14 + 5 * 16 + 8 * 18 + 11 * 20, 3 * 14 + 6 * 16 + 9 * 18 + 12 * 20, |
198 | 1 | }; |
199 | 1 | ccv_nnc_tensor_t dbt = ccv_nnc_tensor(dbtp, CPU_TENSOR_NHWC(32F, 2, 3), 0); |
200 | 1 | REQUIRE_TENSOR_EQ(db, &dbt, "db should be equal"); |
201 | 1 | ccv_nnc_tensor_free(g); |
202 | 1 | ccv_nnc_tensor_free(a); |
203 | 1 | ccv_nnc_tensor_free(b); |
204 | 1 | ccv_nnc_tensor_free(h); |
205 | 1 | ccv_nnc_tensor_free(db); |
206 | 1 | ccv_nnc_tensor_free(dbias); |
207 | 1 | } |
208 | | |
209 | | TEST_CASE("backward gemm with transpose a") |
210 | 1 | { |
211 | 1 | float gp[] = { |
212 | 1 | 1, 2, 3, |
213 | 1 | 4, 5, 6, |
214 | 1 | 7, 8, 9, |
215 | 1 | 10, 11, 12, |
216 | 1 | }; |
217 | 1 | ccv_nnc_tensor_t* const g = ccv_nnc_tensor_new(gp, CPU_TENSOR_NHWC(32F, 4, 3), 0); |
218 | 1 | float ap[] = { |
219 | 1 | 13, 15, 17, 19, |
220 | 1 | 14, 16, 18, 20, |
221 | 1 | }; |
222 | 1 | ccv_nnc_tensor_t* const a = ccv_nnc_tensor_new(ap, CPU_TENSOR_NHWC(32F, 2, 4), 0); |
223 | 1 | float bp[] = { |
224 | 1 | 21, 22, 23, |
225 | 1 | 24, 25, 26, |
226 | 1 | }; |
227 | 1 | ccv_nnc_tensor_t* const b = ccv_nnc_tensor_new(bp, CPU_TENSOR_NHWC(32F, 2, 3), 0); |
228 | 1 | ccv_nnc_tensor_t* const h = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, 2, 4), 0); |
229 | 1 | ccv_nnc_tensor_t* const db = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, 2, 3), 0); |
230 | 1 | ccv_nnc_tensor_t* const dbias = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, 3), 0); |
231 | 1 | ccv_nnc_cmd_exec(CMD_GEMM_BACKWARD(TRANSPOSE(0, 1)), ccv_nnc_no_hint, 0, TENSOR_LIST(g, a, b), TENSOR_LIST(h, db, dbias), 0); |
232 | 1 | float dbiastp[] = { |
233 | 1 | 22, 26, 30, |
234 | 1 | }; |
235 | 1 | ccv_nnc_tensor_t dbiast = ccv_nnc_tensor(dbiastp, CPU_TENSOR_NHWC(32F, 3), 0); |
236 | 1 | REQUIRE_TENSOR_EQ(dbias, &dbiast, "bias should be equal"); |
237 | 1 | float htp[] = { |
238 | 1 | 1 * 21 + 2 * 22 + 3 * 23, 4 * 21 + 5 * 22 + 6 * 23, 7 * 21 + 8 * 22 + 9 * 23, 10 * 21 + 11 * 22 + 12 * 23, |
239 | 1 | 1 * 24 + 2 * 25 + 3 * 26, 4 * 24 + 5 * 25 + 6 * 26, 7 * 24 + 8 * 25 + 9 * 26, 10 * 24 + 11 * 25 + 12 * 26, |
240 | 1 | }; |
241 | 1 | ccv_nnc_tensor_t ht = ccv_nnc_tensor(htp, CPU_TENSOR_NHWC(32F, 2, 4), 0); |
242 | 1 | REQUIRE_TENSOR_EQ(h, &ht, "h should be equal"); |
243 | 1 | float dbtp[] = { |
244 | 1 | 1 * 13 + 4 * 15 + 7 * 17 + 10 * 19, 2 * 13 + 5 * 15 + 8 * 17 + 11 * 19, 3 * 13 + 6 * 15 + 9 * 17 + 12 * 19, |
245 | 1 | 1 * 14 + 4 * 16 + 7 * 18 + 10 * 20, 2 * 14 + 5 * 16 + 8 * 18 + 11 * 20, 3 * 14 + 6 * 16 + 9 * 18 + 12 * 20, |
246 | 1 | }; |
247 | 1 | ccv_nnc_tensor_t dbt = ccv_nnc_tensor(dbtp, CPU_TENSOR_NHWC(32F, 2, 3), 0); |
248 | 1 | REQUIRE_TENSOR_EQ(db, &dbt, "db should be equal"); |
249 | 1 | ccv_nnc_tensor_free(g); |
250 | 1 | ccv_nnc_tensor_free(a); |
251 | 1 | ccv_nnc_tensor_free(b); |
252 | 1 | ccv_nnc_tensor_free(h); |
253 | 1 | ccv_nnc_tensor_free(db); |
254 | 1 | ccv_nnc_tensor_free(dbias); |
255 | 1 | } |
256 | | |
257 | | TEST_CASE("backward gemm with transpose b") |
258 | 1 | { |
259 | 1 | float gp[] = { |
260 | 1 | 1, 2, 3, |
261 | 1 | 4, 5, 6, |
262 | 1 | 7, 8, 9, |
263 | 1 | 10, 11, 12, |
264 | 1 | }; |
265 | 1 | ccv_nnc_tensor_t* const g = ccv_nnc_tensor_new(gp, CPU_TENSOR_NHWC(32F, 4, 3), 0); |
266 | 1 | float ap[] = { |
267 | 1 | 13, 14, |
268 | 1 | 15, 16, |
269 | 1 | 17, 18, |
270 | 1 | 19, 20, |
271 | 1 | }; |
272 | 1 | ccv_nnc_tensor_t* const a = ccv_nnc_tensor_new(ap, CPU_TENSOR_NHWC(32F, 4, 2), 0); |
273 | 1 | float bp[] = { |
274 | 1 | 21, 24, |
275 | 1 | 22, 25, |
276 | 1 | 23, 26, |
277 | 1 | }; |
278 | 1 | ccv_nnc_tensor_t* const b = ccv_nnc_tensor_new(bp, CPU_TENSOR_NHWC(32F, 3, 2), 0); |
279 | 1 | ccv_nnc_tensor_t* const h = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, 4, 2), 0); |
280 | 1 | ccv_nnc_tensor_t* const db = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, 3, 2), 0); |
281 | 1 | ccv_nnc_tensor_t* const dbias = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, 3), 0); |
282 | 1 | ccv_nnc_cmd_exec(CMD_GEMM_BACKWARD(NO_TRANSPOSE, TRANSPOSE(0, 1)), ccv_nnc_no_hint, 0, TENSOR_LIST(g, a, b), TENSOR_LIST(h, db, dbias), 0); |
283 | 1 | float dbiastp[] = { |
284 | 1 | 22, 26, 30, |
285 | 1 | }; |
286 | 1 | ccv_nnc_tensor_t dbiast = ccv_nnc_tensor(dbiastp, CPU_TENSOR_NHWC(32F, 3), 0); |
287 | 1 | REQUIRE_TENSOR_EQ(dbias, &dbiast, "bias should be equal"); |
288 | 1 | float htp[] = { |
289 | 1 | 1 * 21 + 2 * 22 + 3 * 23, 1 * 24 + 2 * 25 + 3 * 26, |
290 | 1 | 4 * 21 + 5 * 22 + 6 * 23, 4 * 24 + 5 * 25 + 6 * 26, |
291 | 1 | 7 * 21 + 8 * 22 + 9 * 23, 7 * 24 + 8 * 25 + 9 * 26, |
292 | 1 | 10 * 21 + 11 * 22 + 12 * 23, 10 * 24 + 11 * 25 + 12 * 26, |
293 | 1 | }; |
294 | 1 | ccv_nnc_tensor_t ht = ccv_nnc_tensor(htp, CPU_TENSOR_NHWC(32F, 4, 2), 0); |
295 | 1 | REQUIRE_TENSOR_EQ(h, &ht, "h should be equal"); |
296 | 1 | float dbtp[] = { |
297 | 1 | 1 * 13 + 4 * 15 + 7 * 17 + 10 * 19, 1 * 14 + 4 * 16 + 7 * 18 + 10 * 20, |
298 | 1 | 2 * 13 + 5 * 15 + 8 * 17 + 11 * 19, 2 * 14 + 5 * 16 + 8 * 18 + 11 * 20, |
299 | 1 | 3 * 13 + 6 * 15 + 9 * 17 + 12 * 19, 3 * 14 + 6 * 16 + 9 * 18 + 12 * 20, |
300 | 1 | }; |
301 | 1 | ccv_nnc_tensor_t dbt = ccv_nnc_tensor(dbtp, CPU_TENSOR_NHWC(32F, 3, 2), 0); |
302 | 1 | REQUIRE_TENSOR_EQ(db, &dbt, "db should be equal"); |
303 | 1 | ccv_nnc_tensor_free(g); |
304 | 1 | ccv_nnc_tensor_free(a); |
305 | 1 | ccv_nnc_tensor_free(b); |
306 | 1 | ccv_nnc_tensor_free(h); |
307 | 1 | ccv_nnc_tensor_free(db); |
308 | 1 | ccv_nnc_tensor_free(dbias); |
309 | 1 | } |
310 | | |
311 | | TEST_CASE("backward gemm with transpose a and b") |
312 | 1 | { |
313 | 1 | float gp[] = { |
314 | 1 | 1, 2, 3, |
315 | 1 | 4, 5, 6, |
316 | 1 | 7, 8, 9, |
317 | 1 | 10, 11, 12, |
318 | 1 | }; |
319 | 1 | ccv_nnc_tensor_t* const g = ccv_nnc_tensor_new(gp, CPU_TENSOR_NHWC(32F, 4, 3), 0); |
320 | 1 | float ap[] = { |
321 | 1 | 13, 15, 17, 19, |
322 | 1 | 14, 16, 18, 20, |
323 | 1 | }; |
324 | 1 | ccv_nnc_tensor_t* const a = ccv_nnc_tensor_new(ap, CPU_TENSOR_NHWC(32F, 2, 4), 0); |
325 | 1 | float bp[] = { |
326 | 1 | 21, 24, |
327 | 1 | 22, 25, |
328 | 1 | 23, 26, |
329 | 1 | }; |
330 | 1 | ccv_nnc_tensor_t* const b = ccv_nnc_tensor_new(bp, CPU_TENSOR_NHWC(32F, 3, 2), 0); |
331 | 1 | ccv_nnc_tensor_t* const h = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, 2, 4), 0); |
332 | 1 | ccv_nnc_tensor_t* const db = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, 3, 2), 0); |
333 | 1 | ccv_nnc_tensor_t* const dbias = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, 3), 0); |
334 | 1 | ccv_nnc_cmd_exec(CMD_GEMM_BACKWARD(TRANSPOSE(0, 1), TRANSPOSE(0, 1)), ccv_nnc_no_hint, 0, TENSOR_LIST(g, a, b), TENSOR_LIST(h, db, dbias), 0); |
335 | 1 | float dbiastp[] = { |
336 | 1 | 22, 26, 30, |
337 | 1 | }; |
338 | 1 | ccv_nnc_tensor_t dbiast = ccv_nnc_tensor(dbiastp, CPU_TENSOR_NHWC(32F, 3), 0); |
339 | 1 | REQUIRE_TENSOR_EQ(dbias, &dbiast, "bias should be equal"); |
340 | 1 | float htp[] = { |
341 | 1 | 1 * 21 + 2 * 22 + 3 * 23, 4 * 21 + 5 * 22 + 6 * 23, 7 * 21 + 8 * 22 + 9 * 23, 10 * 21 + 11 * 22 + 12 * 23, |
342 | 1 | 1 * 24 + 2 * 25 + 3 * 26, 4 * 24 + 5 * 25 + 6 * 26, 7 * 24 + 8 * 25 + 9 * 26, 10 * 24 + 11 * 25 + 12 * 26, |
343 | 1 | }; |
344 | 1 | ccv_nnc_tensor_t ht = ccv_nnc_tensor(htp, CPU_TENSOR_NHWC(32F, 2, 4), 0); |
345 | 1 | REQUIRE_TENSOR_EQ(h, &ht, "h should be equal"); |
346 | 1 | float dbtp[] = { |
347 | 1 | 1 * 13 + 4 * 15 + 7 * 17 + 10 * 19, 1 * 14 + 4 * 16 + 7 * 18 + 10 * 20, |
348 | 1 | 2 * 13 + 5 * 15 + 8 * 17 + 11 * 19, 2 * 14 + 5 * 16 + 8 * 18 + 11 * 20, |
349 | 1 | 3 * 13 + 6 * 15 + 9 * 17 + 12 * 19, 3 * 14 + 6 * 16 + 9 * 18 + 12 * 20, |
350 | 1 | }; |
351 | 1 | ccv_nnc_tensor_t dbt = ccv_nnc_tensor(dbtp, CPU_TENSOR_NHWC(32F, 3, 2), 0); |
352 | 1 | REQUIRE_TENSOR_EQ(db, &dbt, "db should be equal"); |
353 | 1 | ccv_nnc_tensor_free(g); |
354 | 1 | ccv_nnc_tensor_free(a); |
355 | 1 | ccv_nnc_tensor_free(b); |
356 | 1 | ccv_nnc_tensor_free(h); |
357 | 1 | ccv_nnc_tensor_free(db); |
358 | 1 | ccv_nnc_tensor_free(dbias); |
359 | 1 | } |
360 | | |
361 | | TEST_CASE("[[1, 2], [3, 4], [5, 6], [7, 8]], [[2, 3], [4, 5], [6, 7], [8, 9]] * [[7, 8, 9], [10, 11, 12]]") |
362 | 1 | { |
363 | 1 | float ap[] = { |
364 | 1 | 1, 2, |
365 | 1 | 3, 4, |
366 | 1 | 5, 6, |
367 | 1 | 7, 8, |
368 | 1 | 2, 3, |
369 | 1 | 4, 5, |
370 | 1 | 6, 7, |
371 | 1 | 8, 9 |
372 | 1 | }; |
373 | 1 | ccv_nnc_tensor_t* const a = ccv_nnc_tensor_new(ap, CPU_TENSOR_NHWC(32F, 2, 4, 2), 0); |
374 | 1 | float bp[] = { |
375 | 1 | 7, 8, 9, |
376 | 1 | 10, 11, 12, |
377 | 1 | }; |
378 | 1 | ccv_nnc_tensor_t* const b = ccv_nnc_tensor_new(bp, CPU_TENSOR_NHWC(32F, 2, 3), 0); |
379 | 1 | ccv_nnc_tensor_t* const c = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, 2, 4, 3), 0); |
380 | 1 | ccv_nnc_cmd_exec(CMD_GEMM_FORWARD(), ccv_nnc_no_hint, 0, TENSOR_LIST(a, b), TENSOR_LIST(c), 0); |
381 | 1 | float ctp[] = { |
382 | 1 | 1 * 7 + 2 * 10, 1 * 8 + 2 * 11, 1 * 9 + 2 * 12, |
383 | 1 | 3 * 7 + 4 * 10, 3 * 8 + 4 * 11, 3 * 9 + 4 * 12, |
384 | 1 | 5 * 7 + 6 * 10, 5 * 8 + 6 * 11, 5 * 9 + 6 * 12, |
385 | 1 | 7 * 7 + 8 * 10, 7 * 8 + 8 * 11, 7 * 9 + 8 * 12, |
386 | 1 | 2 * 7 + 3 * 10, 2 * 8 + 3 * 11, 2 * 9 + 3 * 12, |
387 | 1 | 4 * 7 + 5 * 10, 4 * 8 + 5 * 11, 4 * 9 + 5 * 12, |
388 | 1 | 6 * 7 + 7 * 10, 6 * 8 + 7 * 11, 6 * 9 + 7 * 12, |
389 | 1 | 8 * 7 + 9 * 10, 8 * 8 + 9 * 11, 8 * 9 + 9 * 12, |
390 | 1 | }; |
391 | 1 | ccv_nnc_tensor_t ct = ccv_nnc_tensor(ctp, CPU_TENSOR_NHWC(32F, 2, 4, 3), 0); |
392 | 1 | REQUIRE_TENSOR_EQ(c, &ct, "result should be equal"); |
393 | 1 | ccv_nnc_tensor_free(a); |
394 | 1 | ccv_nnc_tensor_free(b); |
395 | 1 | ccv_nnc_tensor_free(c); |
396 | 1 | } |
397 | | |
398 | | TEST_CASE("[[1, 2], [3, 4], [5, 6], [7, 8]], [[2, 3], [4, 5], [6, 7], [8, 9]] * [[7, 8, 9], [10, 11, 12]], [[80, 90, 10], [110, 120, 13]]") |
399 | 1 | { |
400 | 1 | float ap[] = { |
401 | 1 | 1, 2, |
402 | 1 | 3, 4, |
403 | 1 | 5, 6, |
404 | 1 | 7, 8, |
405 | 1 | 2, 3, |
406 | 1 | 4, 5, |
407 | 1 | 6, 7, |
408 | 1 | 8, 9 |
409 | 1 | }; |
410 | 1 | ccv_nnc_tensor_t* const a = ccv_nnc_tensor_new(ap, CPU_TENSOR_NHWC(32F, 2, 4, 2), 0); |
411 | 1 | float bp[] = { |
412 | 1 | 7, 8, 9, |
413 | 1 | 10, 11, 12, |
414 | 1 | 80, 90, 10, |
415 | 1 | 110, 120, 13, |
416 | 1 | }; |
417 | 1 | ccv_nnc_tensor_t* const b = ccv_nnc_tensor_new(bp, CPU_TENSOR_NHWC(32F, 2, 2, 3), 0); |
418 | 1 | ccv_nnc_tensor_t* const c = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, 2, 4, 3), 0); |
419 | 1 | ccv_nnc_cmd_exec(CMD_GEMM_FORWARD(), ccv_nnc_no_hint, 0, TENSOR_LIST(a, b), TENSOR_LIST(c), 0); |
420 | 1 | float ctp[] = { |
421 | 1 | 1 * 7 + 2 * 10, 1 * 8 + 2 * 11, 1 * 9 + 2 * 12, |
422 | 1 | 3 * 7 + 4 * 10, 3 * 8 + 4 * 11, 3 * 9 + 4 * 12, |
423 | 1 | 5 * 7 + 6 * 10, 5 * 8 + 6 * 11, 5 * 9 + 6 * 12, |
424 | 1 | 7 * 7 + 8 * 10, 7 * 8 + 8 * 11, 7 * 9 + 8 * 12, |
425 | 1 | 2 * 80 + 3 * 110, 2 * 90 + 3 * 120, 2 * 10 + 3 * 13, |
426 | 1 | 4 * 80 + 5 * 110, 4 * 90 + 5 * 120, 4 * 10 + 5 * 13, |
427 | 1 | 6 * 80 + 7 * 110, 6 * 90 + 7 * 120, 6 * 10 + 7 * 13, |
428 | 1 | 8 * 80 + 9 * 110, 8 * 90 + 9 * 120, 8 * 10 + 9 * 13, |
429 | 1 | }; |
430 | 1 | ccv_nnc_tensor_t ct = ccv_nnc_tensor(ctp, CPU_TENSOR_NHWC(32F, 2, 4, 3), 0); |
431 | 1 | REQUIRE_TENSOR_EQ(c, &ct, "result should be equal"); |
432 | 1 | ccv_nnc_tensor_free(a); |
433 | 1 | ccv_nnc_tensor_free(b); |
434 | 1 | ccv_nnc_tensor_free(c); |
435 | 1 | } |
436 | | |
437 | | TEST_CASE("[[1, 3, 5, 7], [2, 4, 6, 8]], [[2, 4, 6, 8], [3, 5, 7, 9]]^T * [[7, 8, 9], [10, 11, 12]] + [-1, 0, 1]") |
438 | 1 | { |
439 | 1 | float ap[] = { |
440 | 1 | 1, 3, 5, 7, |
441 | 1 | 2, 4, 6, 8, |
442 | 1 | 2, 4, 6, 8, |
443 | 1 | 3, 5, 7, 9, |
444 | 1 | }; |
445 | 1 | ccv_nnc_tensor_t* const a = ccv_nnc_tensor_new(ap, CPU_TENSOR_NHWC(32F, 2, 2, 4), 0); |
446 | 1 | float bp[] = { |
447 | 1 | 7, 8, 9, |
448 | 1 | 10, 11, 12, |
449 | 1 | }; |
450 | 1 | ccv_nnc_tensor_t* const b = ccv_nnc_tensor_new(bp, CPU_TENSOR_NHWC(32F, 2, 3), 0); |
451 | 1 | float dp[] = { |
452 | 1 | -1, 0, 1, |
453 | 1 | }; |
454 | 1 | ccv_nnc_tensor_t* const d = ccv_nnc_tensor_new(dp, CPU_TENSOR_NHWC(32F, 3), 0); |
455 | 1 | ccv_nnc_tensor_t* const c = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, 2, 4, 3), 0); |
456 | 1 | ccv_nnc_cmd_exec(CMD_GEMM_FORWARD(TRANSPOSE(1, 2)), ccv_nnc_no_hint, 0, TENSOR_LIST(a, b, d), TENSOR_LIST(c), 0); |
457 | 1 | float ctp[] = { |
458 | 1 | 1 * 7 + 2 * 10 - 1, 1 * 8 + 2 * 11, 1 * 9 + 2 * 12 + 1, |
459 | 1 | 3 * 7 + 4 * 10 - 1, 3 * 8 + 4 * 11, 3 * 9 + 4 * 12 + 1, |
460 | 1 | 5 * 7 + 6 * 10 - 1, 5 * 8 + 6 * 11, 5 * 9 + 6 * 12 + 1, |
461 | 1 | 7 * 7 + 8 * 10 - 1, 7 * 8 + 8 * 11, 7 * 9 + 8 * 12 + 1, |
462 | 1 | 2 * 7 + 3 * 10 - 1, 2 * 8 + 3 * 11, 2 * 9 + 3 * 12 + 1, |
463 | 1 | 4 * 7 + 5 * 10 - 1, 4 * 8 + 5 * 11, 4 * 9 + 5 * 12 + 1, |
464 | 1 | 6 * 7 + 7 * 10 - 1, 6 * 8 + 7 * 11, 6 * 9 + 7 * 12 + 1, |
465 | 1 | 8 * 7 + 9 * 10 - 1, 8 * 8 + 9 * 11, 8 * 9 + 9 * 12 + 1, |
466 | 1 | }; |
467 | 1 | ccv_nnc_tensor_t ct = ccv_nnc_tensor(ctp, CPU_TENSOR_NHWC(32F, 2, 4, 3), 0); |
468 | 1 | REQUIRE_TENSOR_EQ(c, &ct, "result should be equal"); |
469 | 1 | ccv_nnc_tensor_free(a); |
470 | 1 | ccv_nnc_tensor_free(b); |
471 | 1 | ccv_nnc_tensor_free(c); |
472 | 1 | ccv_nnc_tensor_free(d); |
473 | 1 | } |
474 | | |
475 | | TEST_CASE("[[1, 2], [3, 4], [5, 6], [7, 8]], [[2, 3], [4, 5], [6, 7], [8, 9]] * [[7, 10], [8, 11], [9, 12]], [[80, 110], [90, 120], [10, 13]]^T + [-1, 0, 1], [2, 3, -4]") |
476 | 1 | { |
477 | 1 | float ap[] = { |
478 | 1 | 1, 2, |
479 | 1 | 3, 4, |
480 | 1 | 5, 6, |
481 | 1 | 7, 8, |
482 | 1 | 2, 3, |
483 | 1 | 4, 5, |
484 | 1 | 6, 7, |
485 | 1 | 8, 9 |
486 | 1 | }; |
487 | 1 | ccv_nnc_tensor_t* const a = ccv_nnc_tensor_new(ap, CPU_TENSOR_NHWC(32F, 2, 4, 2), 0); |
488 | 1 | float bp[] = { |
489 | 1 | 7, 10, |
490 | 1 | 8, 11, |
491 | 1 | 9, 12, |
492 | 1 | 80, 110, |
493 | 1 | 90, 120, |
494 | 1 | 10, 13, |
495 | 1 | }; |
496 | 1 | ccv_nnc_tensor_t* const b = ccv_nnc_tensor_new(bp, CPU_TENSOR_NHWC(32F, 2, 3, 2), 0); |
497 | 1 | float dp[] = { |
498 | 1 | -1, 0, 1, |
499 | 1 | 2, 3, -4, |
500 | 1 | }; |
501 | 1 | ccv_nnc_tensor_t* const d = ccv_nnc_tensor_new(dp, CPU_TENSOR_NHWC(32F, 2, 1, 3), 0); |
502 | 1 | ccv_nnc_tensor_t* const c = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, 2, 4, 3), 0); |
503 | 1 | ccv_nnc_cmd_exec(CMD_GEMM_FORWARD(NO_TRANSPOSE, TRANSPOSE(1, 2)), ccv_nnc_no_hint, 0, TENSOR_LIST(a, b, d), TENSOR_LIST(c), 0); |
504 | 1 | float ctp[] = { |
505 | 1 | 1 * 7 + 2 * 10 - 1, 1 * 8 + 2 * 11, 1 * 9 + 2 * 12 + 1, |
506 | 1 | 3 * 7 + 4 * 10 - 1, 3 * 8 + 4 * 11, 3 * 9 + 4 * 12 + 1, |
507 | 1 | 5 * 7 + 6 * 10 - 1, 5 * 8 + 6 * 11, 5 * 9 + 6 * 12 + 1, |
508 | 1 | 7 * 7 + 8 * 10 - 1, 7 * 8 + 8 * 11, 7 * 9 + 8 * 12 + 1, |
509 | 1 | 2 * 80 + 3 * 110 + 2, 2 * 90 + 3 * 120 + 3, 2 * 10 + 3 * 13 - 4, |
510 | 1 | 4 * 80 + 5 * 110 + 2, 4 * 90 + 5 * 120 + 3, 4 * 10 + 5 * 13 - 4, |
511 | 1 | 6 * 80 + 7 * 110 + 2, 6 * 90 + 7 * 120 + 3, 6 * 10 + 7 * 13 - 4, |
512 | 1 | 8 * 80 + 9 * 110 + 2, 8 * 90 + 9 * 120 + 3, 8 * 10 + 9 * 13 - 4, |
513 | 1 | }; |
514 | 1 | ccv_nnc_tensor_t ct = ccv_nnc_tensor(ctp, CPU_TENSOR_NHWC(32F, 2, 4, 3), 0); |
515 | 1 | REQUIRE_TENSOR_EQ(c, &ct, "result should be equal"); |
516 | 1 | ccv_nnc_tensor_free(a); |
517 | 1 | ccv_nnc_tensor_free(b); |
518 | 1 | ccv_nnc_tensor_free(c); |
519 | 1 | ccv_nnc_tensor_free(d); |
520 | 1 | } |
521 | | |
522 | | TEST_CASE("backward gemm with no transpose batch 2, same b") |
523 | 1 | { |
524 | 1 | float gp[] = { |
525 | 1 | 1, 2, 3, |
526 | 1 | 4, 5, 6, |
527 | 1 | 7, 8, 9, |
528 | 1 | 10, 11, 12, |
529 | 1 | 10, 20, 30, |
530 | 1 | 40, 50, 60, |
531 | 1 | 70, 80, 90, |
532 | 1 | 100, 110, 120, |
533 | 1 | }; |
534 | 1 | ccv_nnc_tensor_t* const g = ccv_nnc_tensor_new(gp, CPU_TENSOR_NHWC(32F, 2, 4, 3), 0); |
535 | 1 | float ap[] = { |
536 | 1 | 13, 14, |
537 | 1 | 15, 16, |
538 | 1 | 17, 18, |
539 | 1 | 19, 20, |
540 | 1 | 131, 141, |
541 | 1 | 151, 161, |
542 | 1 | 171, 181, |
543 | 1 | 191, 201, |
544 | 1 | }; |
545 | 1 | ccv_nnc_tensor_t* const a = ccv_nnc_tensor_new(ap, CPU_TENSOR_NHWC(32F, 2, 4, 2), 0); |
546 | 1 | float bp[] = { |
547 | 1 | 21, 22, 23, |
548 | 1 | 24, 25, 26, |
549 | 1 | }; |
550 | 1 | ccv_nnc_tensor_t* const b = ccv_nnc_tensor_new(bp, CPU_TENSOR_NHWC(32F, 2, 3), 0); |
551 | 1 | ccv_nnc_tensor_t* const h = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, 2, 4, 2), 0); |
552 | 1 | ccv_nnc_tensor_t* const db = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, 2, 3), 0); |
553 | 1 | ccv_nnc_tensor_t* const dbias = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, 3), 0); |
554 | 1 | ccv_nnc_cmd_exec(CMD_GEMM_BACKWARD(), ccv_nnc_no_hint, 0, TENSOR_LIST(g, a, b), TENSOR_LIST(h, db, dbias), 0); |
555 | 1 | float dbiastp[] = { |
556 | 1 | 22 + 220, 26 + 260, 30 + 300, |
557 | 1 | }; |
558 | 1 | ccv_nnc_tensor_t dbiast = ccv_nnc_tensor(dbiastp, CPU_TENSOR_NHWC(32F, 3), 0); |
559 | 1 | REQUIRE_TENSOR_EQ(dbias, &dbiast, "bias should be equal"); |
560 | 1 | float htp[] = { |
561 | 1 | 1 * 21 + 2 * 22 + 3 * 23, 1 * 24 + 2 * 25 + 3 * 26, |
562 | 1 | 4 * 21 + 5 * 22 + 6 * 23, 4 * 24 + 5 * 25 + 6 * 26, |
563 | 1 | 7 * 21 + 8 * 22 + 9 * 23, 7 * 24 + 8 * 25 + 9 * 26, |
564 | 1 | 10 * 21 + 11 * 22 + 12 * 23, 10 * 24 + 11 * 25 + 12 * 26, |
565 | 1 | 10 * 21 + 20 * 22 + 30 * 23, 10 * 24 + 20 * 25 + 30 * 26, |
566 | 1 | 40 * 21 + 50 * 22 + 60 * 23, 40 * 24 + 50 * 25 + 60 * 26, |
567 | 1 | 70 * 21 + 80 * 22 + 90 * 23, 70 * 24 + 80 * 25 + 90 * 26, |
568 | 1 | 100 * 21 + 110 * 22 + 120 * 23, 100 * 24 + 110 * 25 + 120 * 26, |
569 | 1 | }; |
570 | 1 | ccv_nnc_tensor_t ht = ccv_nnc_tensor(htp, CPU_TENSOR_NHWC(32F, 2, 4, 2), 0); |
571 | 1 | REQUIRE_TENSOR_EQ(h, &ht, "h should be equal"); |
572 | 1 | float dbtp[] = { |
573 | 1 | 1 * 13 + 4 * 15 + 7 * 17 + 10 * 19 + 10 * 131 + 40 * 151 + 70 * 171 + 100 * 191, 2 * 13 + 5 * 15 + 8 * 17 + 11 * 19 + 20 * 131 + 50 * 151 + 80 * 171 + 110 * 191, 3 * 13 + 6 * 15 + 9 * 17 + 12 * 19 + 30 * 131 + 60 * 151 + 90 * 171 + 120 * 191, |
574 | 1 | 1 * 14 + 4 * 16 + 7 * 18 + 10 * 20 + 10 * 141 + 40 * 161 + 70 * 181 + 100 * 201, 2 * 14 + 5 * 16 + 8 * 18 + 11 * 20 + 20 * 141 + 50 * 161 + 80 * 181 + 110 * 201, 3 * 14 + 6 * 16 + 9 * 18 + 12 * 20 + 30 * 141 + 60 * 161 + 90 * 181 + 120 * 201, |
575 | 1 | }; |
576 | 1 | ccv_nnc_tensor_t dbt = ccv_nnc_tensor(dbtp, CPU_TENSOR_NHWC(32F, 2, 3), 0); |
577 | 1 | REQUIRE_TENSOR_EQ(db, &dbt, "db should be equal"); |
578 | 1 | ccv_nnc_tensor_free(g); |
579 | 1 | ccv_nnc_tensor_free(a); |
580 | 1 | ccv_nnc_tensor_free(b); |
581 | 1 | ccv_nnc_tensor_free(h); |
582 | 1 | ccv_nnc_tensor_free(db); |
583 | 1 | ccv_nnc_tensor_free(dbias); |
584 | 1 | } |
585 | | |
586 | | TEST_CASE("backward gemm with no transpose batch 2, batched b") |
587 | 1 | { |
588 | 1 | float gp[] = { |
589 | 1 | 1, 2, 3, |
590 | 1 | 4, 5, 6, |
591 | 1 | 7, 8, 9, |
592 | 1 | 10, 11, 12, |
593 | 1 | 10, 20, 30, |
594 | 1 | 40, 50, 60, |
595 | 1 | 70, 80, 90, |
596 | 1 | 100, 110, 120, |
597 | 1 | }; |
598 | 1 | ccv_nnc_tensor_t* const g = ccv_nnc_tensor_new(gp, CPU_TENSOR_NHWC(32F, 2, 4, 3), 0); |
599 | 1 | float ap[] = { |
600 | 1 | 13, 14, |
601 | 1 | 15, 16, |
602 | 1 | 17, 18, |
603 | 1 | 19, 20, |
604 | 1 | 131, 141, |
605 | 1 | 151, 161, |
606 | 1 | 171, 181, |
607 | 1 | 191, 201, |
608 | 1 | }; |
609 | 1 | ccv_nnc_tensor_t* const a = ccv_nnc_tensor_new(ap, CPU_TENSOR_NHWC(32F, 2, 4, 2), 0); |
610 | 1 | float bp[] = { |
611 | 1 | 21, 22, 23, |
612 | 1 | 24, 25, 26, |
613 | 1 | 212, 222, 232, |
614 | 1 | 242, 252, 262, |
615 | 1 | }; |
616 | 1 | ccv_nnc_tensor_t* const b = ccv_nnc_tensor_new(bp, CPU_TENSOR_NHWC(32F, 2, 2, 3), 0); |
617 | 1 | ccv_nnc_tensor_t* const h = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, 2, 4, 2), 0); |
618 | 1 | ccv_nnc_tensor_t* const db = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, 2, 2, 3), 0); |
619 | 1 | ccv_nnc_tensor_t* const dbias = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, 2, 1, 3), 0); |
620 | 1 | ccv_nnc_cmd_exec(CMD_GEMM_BACKWARD(), ccv_nnc_no_hint, 0, TENSOR_LIST(g, a, b), TENSOR_LIST(h, db, dbias), 0); |
621 | 1 | float dbiastp[] = { |
622 | 1 | 22, 26, 30, |
623 | 1 | 220, 260, 300, |
624 | 1 | }; |
625 | 1 | ccv_nnc_tensor_t dbiast = ccv_nnc_tensor(dbiastp, CPU_TENSOR_NHWC(32F, 2, 1, 3), 0); |
626 | 1 | REQUIRE_TENSOR_EQ(dbias, &dbiast, "bias should be equal"); |
627 | 1 | float htp[] = { |
628 | 1 | 1 * 21 + 2 * 22 + 3 * 23, 1 * 24 + 2 * 25 + 3 * 26, |
629 | 1 | 4 * 21 + 5 * 22 + 6 * 23, 4 * 24 + 5 * 25 + 6 * 26, |
630 | 1 | 7 * 21 + 8 * 22 + 9 * 23, 7 * 24 + 8 * 25 + 9 * 26, |
631 | 1 | 10 * 21 + 11 * 22 + 12 * 23, 10 * 24 + 11 * 25 + 12 * 26, |
632 | 1 | 10 * 212 + 20 * 222 + 30 * 232, 10 * 242 + 20 * 252 + 30 * 262, |
633 | 1 | 40 * 212 + 50 * 222 + 60 * 232, 40 * 242 + 50 * 252 + 60 * 262, |
634 | 1 | 70 * 212 + 80 * 222 + 90 * 232, 70 * 242 + 80 * 252 + 90 * 262, |
635 | 1 | 100 * 212 + 110 * 222 + 120 * 232, 100 * 242 + 110 * 252 + 120 * 262, |
636 | 1 | }; |
637 | 1 | ccv_nnc_tensor_t ht = ccv_nnc_tensor(htp, CPU_TENSOR_NHWC(32F, 2, 4, 2), 0); |
638 | 1 | REQUIRE_TENSOR_EQ(h, &ht, "h should be equal"); |
639 | 1 | float dbtp[] = { |
640 | 1 | 1 * 13 + 4 * 15 + 7 * 17 + 10 * 19, 2 * 13 + 5 * 15 + 8 * 17 + 11 * 19, 3 * 13 + 6 * 15 + 9 * 17 + 12 * 19, |
641 | 1 | 1 * 14 + 4 * 16 + 7 * 18 + 10 * 20, 2 * 14 + 5 * 16 + 8 * 18 + 11 * 20, 3 * 14 + 6 * 16 + 9 * 18 + 12 * 20, |
642 | 1 | 10 * 131 + 40 * 151 + 70 * 171 + 100 * 191, 20 * 131 + 50 * 151 + 80 * 171 + 110 * 191, 30 * 131 + 60 * 151 + 90 * 171 + 120 * 191, |
643 | 1 | 10 * 141 + 40 * 161 + 70 * 181 + 100 * 201, 20 * 141 + 50 * 161 + 80 * 181 + 110 * 201, 30 * 141 + 60 * 161 + 90 * 181 + 120 * 201, |
644 | 1 | }; |
645 | 1 | ccv_nnc_tensor_t dbt = ccv_nnc_tensor(dbtp, CPU_TENSOR_NHWC(32F, 2, 2, 3), 0); |
646 | 1 | REQUIRE_TENSOR_EQ(db, &dbt, "db should be equal"); |
647 | 1 | ccv_nnc_tensor_free(g); |
648 | 1 | ccv_nnc_tensor_free(a); |
649 | 1 | ccv_nnc_tensor_free(b); |
650 | 1 | ccv_nnc_tensor_free(h); |
651 | 1 | ccv_nnc_tensor_free(db); |
652 | 1 | ccv_nnc_tensor_free(dbias); |
653 | 1 | } |
654 | | |
655 | | TEST_CASE("backward gemm with transpose a batch 2, same b") |
656 | 1 | { |
657 | 1 | float gp[] = { |
658 | 1 | 1, 2, 3, |
659 | 1 | 4, 5, 6, |
660 | 1 | 7, 8, 9, |
661 | 1 | 10, 11, 12, |
662 | 1 | 10, 20, 30, |
663 | 1 | 40, 50, 60, |
664 | 1 | 70, 80, 90, |
665 | 1 | 100, 110, 120, |
666 | 1 | }; |
667 | 1 | ccv_nnc_tensor_t* const g = ccv_nnc_tensor_new(gp, CPU_TENSOR_NHWC(32F, 2, 4, 3), 0); |
668 | 1 | float ap[] = { |
669 | 1 | 13, 15, 17, 19, |
670 | 1 | 14, 16, 18, 20, |
671 | 1 | 131, 151, 171, 191, |
672 | 1 | 141, 161, 181, 201, |
673 | 1 | }; |
674 | 1 | ccv_nnc_tensor_t* const a = ccv_nnc_tensor_new(ap, CPU_TENSOR_NHWC(32F, 2, 2, 4), 0); |
675 | 1 | float bp[] = { |
676 | 1 | 21, 22, 23, |
677 | 1 | 24, 25, 26, |
678 | 1 | }; |
679 | 1 | ccv_nnc_tensor_t* const b = ccv_nnc_tensor_new(bp, CPU_TENSOR_NHWC(32F, 2, 3), 0); |
680 | 1 | ccv_nnc_tensor_t* const h = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, 2, 2, 4), 0); |
681 | 1 | ccv_nnc_tensor_t* const db = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, 2, 3), 0); |
682 | 1 | ccv_nnc_tensor_t* const dbias = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, 3), 0); |
683 | 1 | ccv_nnc_cmd_exec(CMD_GEMM_BACKWARD(TRANSPOSE(1, 2)), ccv_nnc_no_hint, 0, TENSOR_LIST(g, a, b), TENSOR_LIST(h, db, dbias), 0); |
684 | 1 | float dbiastp[] = { |
685 | 1 | 22 + 220, 26 + 260, 30 + 300, |
686 | 1 | }; |
687 | 1 | ccv_nnc_tensor_t dbiast = ccv_nnc_tensor(dbiastp, CPU_TENSOR_NHWC(32F, 3), 0); |
688 | 1 | REQUIRE_TENSOR_EQ(dbias, &dbiast, "bias should be equal"); |
689 | 1 | float htp[] = { |
690 | 1 | 1 * 21 + 2 * 22 + 3 * 23, 4 * 21 + 5 * 22 + 6 * 23, 7 * 21 + 8 * 22 + 9 * 23, 10 * 21 + 11 * 22 + 12 * 23, |
691 | 1 | 1 * 24 + 2 * 25 + 3 * 26, 4 * 24 + 5 * 25 + 6 * 26, 7 * 24 + 8 * 25 + 9 * 26, 10 * 24 + 11 * 25 + 12 * 26, |
692 | 1 | 10 * 21 + 20 * 22 + 30 * 23, 40 * 21 + 50 * 22 + 60 * 23, 70 * 21 + 80 * 22 + 90 * 23, 100 * 21 + 110 * 22 + 120 * 23, |
693 | 1 | 10 * 24 + 20 * 25 + 30 * 26, 40 * 24 + 50 * 25 + 60 * 26, 70 * 24 + 80 * 25 + 90 * 26, 100 * 24 + 110 * 25 + 120 * 26, |
694 | 1 | }; |
695 | 1 | ccv_nnc_tensor_t ht = ccv_nnc_tensor(htp, CPU_TENSOR_NHWC(32F, 2, 2, 4), 0); |
696 | 1 | REQUIRE_TENSOR_EQ(h, &ht, "h should be equal"); |
697 | 1 | float dbtp[] = { |
698 | 1 | 1 * 13 + 4 * 15 + 7 * 17 + 10 * 19 + 10 * 131 + 40 * 151 + 70 * 171 + 100 * 191, 2 * 13 + 5 * 15 + 8 * 17 + 11 * 19 + 20 * 131 + 50 * 151 + 80 * 171 + 110 * 191, 3 * 13 + 6 * 15 + 9 * 17 + 12 * 19 + 30 * 131 + 60 * 151 + 90 * 171 + 120 * 191, |
699 | 1 | 1 * 14 + 4 * 16 + 7 * 18 + 10 * 20 + 10 * 141 + 40 * 161 + 70 * 181 + 100 * 201, 2 * 14 + 5 * 16 + 8 * 18 + 11 * 20 + 20 * 141 + 50 * 161 + 80 * 181 + 110 * 201, 3 * 14 + 6 * 16 + 9 * 18 + 12 * 20 + 30 * 141 + 60 * 161 + 90 * 181 + 120 * 201, |
700 | 1 | }; |
701 | 1 | ccv_nnc_tensor_t dbt = ccv_nnc_tensor(dbtp, CPU_TENSOR_NHWC(32F, 2, 3), 0); |
702 | 1 | REQUIRE_TENSOR_EQ(db, &dbt, "db should be equal"); |
703 | 1 | ccv_nnc_tensor_free(g); |
704 | 1 | ccv_nnc_tensor_free(a); |
705 | 1 | ccv_nnc_tensor_free(b); |
706 | 1 | ccv_nnc_tensor_free(h); |
707 | 1 | ccv_nnc_tensor_free(db); |
708 | 1 | ccv_nnc_tensor_free(dbias); |
709 | 1 | } |
710 | | |
711 | | TEST_CASE("backward gemm with transpose b batch 2, batched b") |
712 | 1 | { |
713 | 1 | float gp[] = { |
714 | 1 | 1, 2, 3, |
715 | 1 | 4, 5, 6, |
716 | 1 | 7, 8, 9, |
717 | 1 | 10, 11, 12, |
718 | 1 | 10, 20, 30, |
719 | 1 | 40, 50, 60, |
720 | 1 | 70, 80, 90, |
721 | 1 | 100, 110, 120, |
722 | 1 | }; |
723 | 1 | ccv_nnc_tensor_t* const g = ccv_nnc_tensor_new(gp, CPU_TENSOR_NHWC(32F, 2, 4, 3), 0); |
724 | 1 | float ap[] = { |
725 | 1 | 13, 14, |
726 | 1 | 15, 16, |
727 | 1 | 17, 18, |
728 | 1 | 19, 20, |
729 | 1 | 131, 141, |
730 | 1 | 151, 161, |
731 | 1 | 171, 181, |
732 | 1 | 191, 201, |
733 | 1 | }; |
734 | 1 | ccv_nnc_tensor_t* const a = ccv_nnc_tensor_new(ap, CPU_TENSOR_NHWC(32F, 2, 4, 2), 0); |
735 | 1 | float bp[] = { |
736 | 1 | 21, 24, |
737 | 1 | 22, 25, |
738 | 1 | 23, 26, |
739 | 1 | 212, 242, |
740 | 1 | 222, 252, |
741 | 1 | 232, 262, |
742 | 1 | }; |
743 | 1 | ccv_nnc_tensor_t* const b = ccv_nnc_tensor_new(bp, CPU_TENSOR_NHWC(32F, 2, 3, 2), 0); |
744 | 1 | ccv_nnc_tensor_t* const h = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, 2, 4, 2), 0); |
745 | 1 | ccv_nnc_tensor_t* const db = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, 2, 3, 2), 0); |
746 | 1 | ccv_nnc_tensor_t* const dbias = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, 2, 1, 3), 0); |
747 | 1 | ccv_nnc_cmd_exec(CMD_GEMM_BACKWARD(NO_TRANSPOSE, TRANSPOSE(1, 2)), ccv_nnc_no_hint, 0, TENSOR_LIST(g, a, b), TENSOR_LIST(h, db, dbias), 0); |
748 | 1 | float dbiastp[] = { |
749 | 1 | 22, 26, 30, |
750 | 1 | 220, 260, 300, |
751 | 1 | }; |
752 | 1 | ccv_nnc_tensor_t dbiast = ccv_nnc_tensor(dbiastp, CPU_TENSOR_NHWC(32F, 2, 1, 3), 0); |
753 | 1 | REQUIRE_TENSOR_EQ(dbias, &dbiast, "bias should be equal"); |
754 | 1 | float htp[] = { |
755 | 1 | 1 * 21 + 2 * 22 + 3 * 23, 1 * 24 + 2 * 25 + 3 * 26, |
756 | 1 | 4 * 21 + 5 * 22 + 6 * 23, 4 * 24 + 5 * 25 + 6 * 26, |
757 | 1 | 7 * 21 + 8 * 22 + 9 * 23, 7 * 24 + 8 * 25 + 9 * 26, |
758 | 1 | 10 * 21 + 11 * 22 + 12 * 23, 10 * 24 + 11 * 25 + 12 * 26, |
759 | 1 | 10 * 212 + 20 * 222 + 30 * 232, 10 * 242 + 20 * 252 + 30 * 262, |
760 | 1 | 40 * 212 + 50 * 222 + 60 * 232, 40 * 242 + 50 * 252 + 60 * 262, |
761 | 1 | 70 * 212 + 80 * 222 + 90 * 232, 70 * 242 + 80 * 252 + 90 * 262, |
762 | 1 | 100 * 212 + 110 * 222 + 120 * 232, 100 * 242 + 110 * 252 + 120 * 262, |
763 | 1 | }; |
764 | 1 | ccv_nnc_tensor_t ht = ccv_nnc_tensor(htp, CPU_TENSOR_NHWC(32F, 2, 4, 2), 0); |
765 | 1 | REQUIRE_TENSOR_EQ(h, &ht, "h should be equal"); |
766 | 1 | float dbtp[] = { |
767 | 1 | 1 * 13 + 4 * 15 + 7 * 17 + 10 * 19, 1 * 14 + 4 * 16 + 7 * 18 + 10 * 20, |
768 | 1 | 2 * 13 + 5 * 15 + 8 * 17 + 11 * 19, 2 * 14 + 5 * 16 + 8 * 18 + 11 * 20, |
769 | 1 | 3 * 13 + 6 * 15 + 9 * 17 + 12 * 19, 3 * 14 + 6 * 16 + 9 * 18 + 12 * 20, |
770 | 1 | 10 * 131 + 40 * 151 + 70 * 171 + 100 * 191, 10 * 141 + 40 * 161 + 70 * 181 + 100 * 201, |
771 | 1 | 20 * 131 + 50 * 151 + 80 * 171 + 110 * 191, 20 * 141 + 50 * 161 + 80 * 181 + 110 * 201, |
772 | 1 | 30 * 131 + 60 * 151 + 90 * 171 + 120 * 191, 30 * 141 + 60 * 161 + 90 * 181 + 120 * 201, |
773 | 1 | }; |
774 | 1 | ccv_nnc_tensor_t dbt = ccv_nnc_tensor(dbtp, CPU_TENSOR_NHWC(32F, 2, 3, 2), 0); |
775 | 1 | REQUIRE_TENSOR_EQ(db, &dbt, "db should be equal"); |
776 | 1 | ccv_nnc_tensor_free(g); |
777 | 1 | ccv_nnc_tensor_free(a); |
778 | 1 | ccv_nnc_tensor_free(b); |
779 | 1 | ccv_nnc_tensor_free(h); |
780 | 1 | ccv_nnc_tensor_free(db); |
781 | 1 | ccv_nnc_tensor_free(dbias); |
782 | 1 | } |
783 | | |
784 | | TEST_CASE("backward gemm with transpose a and b batch 2, same b") |
785 | 1 | { |
786 | 1 | float gp[] = { |
787 | 1 | 1, 2, 3, |
788 | 1 | 4, 5, 6, |
789 | 1 | 7, 8, 9, |
790 | 1 | 10, 11, 12, |
791 | 1 | 10, 20, 30, |
792 | 1 | 40, 50, 60, |
793 | 1 | 70, 80, 90, |
794 | 1 | 100, 110, 120, |
795 | 1 | }; |
796 | 1 | ccv_nnc_tensor_t* const g = ccv_nnc_tensor_new(gp, CPU_TENSOR_NHWC(32F, 2, 4, 3), 0); |
797 | 1 | float ap[] = { |
798 | 1 | 13, 15, 17, 19, |
799 | 1 | 14, 16, 18, 20, |
800 | 1 | 131, 151, 171, 191, |
801 | 1 | 141, 161, 181, 201, |
802 | 1 | }; |
803 | 1 | ccv_nnc_tensor_t* const a = ccv_nnc_tensor_new(ap, CPU_TENSOR_NHWC(32F, 2, 2, 4), 0); |
804 | 1 | float bp[] = { |
805 | 1 | 21, 24, |
806 | 1 | 22, 25, |
807 | 1 | 23, 26, |
808 | 1 | }; |
809 | 1 | ccv_nnc_tensor_t* const b = ccv_nnc_tensor_new(bp, CPU_TENSOR_NHWC(32F, 3, 2), 0); |
810 | 1 | ccv_nnc_tensor_t* const h = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, 2, 2, 4), 0); |
811 | 1 | ccv_nnc_tensor_t* const db = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, 3, 2), 0); |
812 | 1 | ccv_nnc_tensor_t* const dbias = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, 3), 0); |
813 | 1 | ccv_nnc_cmd_exec(CMD_GEMM_BACKWARD(TRANSPOSE(1, 2), TRANSPOSE(0, 1)), ccv_nnc_no_hint, 0, TENSOR_LIST(g, a, b), TENSOR_LIST(h, db, dbias), 0); |
814 | 1 | float dbiastp[] = { |
815 | 1 | 22 + 220, 26 + 260, 30 + 300, |
816 | 1 | }; |
817 | 1 | ccv_nnc_tensor_t dbiast = ccv_nnc_tensor(dbiastp, CPU_TENSOR_NHWC(32F, 3), 0); |
818 | 1 | REQUIRE_TENSOR_EQ(dbias, &dbiast, "bias should be equal"); |
819 | 1 | float htp[] = { |
820 | 1 | 1 * 21 + 2 * 22 + 3 * 23, 4 * 21 + 5 * 22 + 6 * 23, 7 * 21 + 8 * 22 + 9 * 23, 10 * 21 + 11 * 22 + 12 * 23, |
821 | 1 | 1 * 24 + 2 * 25 + 3 * 26, 4 * 24 + 5 * 25 + 6 * 26, 7 * 24 + 8 * 25 + 9 * 26, 10 * 24 + 11 * 25 + 12 * 26, |
822 | 1 | 10 * 21 + 20 * 22 + 30 * 23, 40 * 21 + 50 * 22 + 60 * 23, 70 * 21 + 80 * 22 + 90 * 23, 100 * 21 + 110 * 22 + 120 * 23, |
823 | 1 | 10 * 24 + 20 * 25 + 30 * 26, 40 * 24 + 50 * 25 + 60 * 26, 70 * 24 + 80 * 25 + 90 * 26, 100 * 24 + 110 * 25 + 120 * 26, |
824 | 1 | }; |
825 | 1 | ccv_nnc_tensor_t ht = ccv_nnc_tensor(htp, CPU_TENSOR_NHWC(32F, 2, 2, 4), 0); |
826 | 1 | REQUIRE_TENSOR_EQ(h, &ht, "h should be equal"); |
827 | 1 | float dbtp[] = { |
828 | 1 | 1 * 13 + 4 * 15 + 7 * 17 + 10 * 19 + 10 * 131 + 40 * 151 + 70 * 171 + 100 * 191, 1 * 14 + 4 * 16 + 7 * 18 + 10 * 20 + 10 * 141 + 40 * 161 + 70 * 181 + 100 * 201, |
829 | 1 | 2 * 13 + 5 * 15 + 8 * 17 + 11 * 19 + 20 * 131 + 50 * 151 + 80 * 171 + 110 * 191, 2 * 14 + 5 * 16 + 8 * 18 + 11 * 20 + 20 * 141 + 50 * 161 + 80 * 181 + 110 * 201, |
830 | 1 | 3 * 13 + 6 * 15 + 9 * 17 + 12 * 19 + 30 * 131 + 60 * 151 + 90 * 171 + 120 * 191, 3 * 14 + 6 * 16 + 9 * 18 + 12 * 20 + 30 * 141 + 60 * 161 + 90 * 181 + 120 * 201, |
831 | 1 | }; |
832 | 1 | ccv_nnc_tensor_t dbt = ccv_nnc_tensor(dbtp, CPU_TENSOR_NHWC(32F, 3, 2), 0); |
833 | 1 | REQUIRE_TENSOR_EQ(db, &dbt, "db should be equal"); |
834 | 1 | ccv_nnc_tensor_free(g); |
835 | 1 | ccv_nnc_tensor_free(a); |
836 | 1 | ccv_nnc_tensor_free(b); |
837 | 1 | ccv_nnc_tensor_free(h); |
838 | 1 | ccv_nnc_tensor_free(db); |
839 | 1 | ccv_nnc_tensor_free(dbias); |
840 | 1 | } |
841 | | |
842 | | #include "case_main.h" |