Coverage Report

Created: 2024-08-19 11:27

/home/liu/actions-runner/_work/ccv/ccv/test/unit/nnc/gemm.tests.c
Line
Count
Source
1
#include "case.h"
2
#include "ccv_case.h"
3
#include "ccv_nnc_case.h"
4
#include <ccv.h>
5
#include <nnc/ccv_nnc.h>
6
#include <nnc/ccv_nnc_easy.h>
7
8
TEST_SETUP()
9
{
10
  ccv_nnc_init();
11
}
12
13
TEST_CASE("[[1, 2], [3, 4], [5, 6], [7, 8]] * [[7, 8, 9], [10, 11, 12]]")
14
1
{
15
1
  float ap[] = {
16
1
    1, 2,
17
1
    3, 4,
18
1
    5, 6,
19
1
    7, 8,
20
1
  };
21
1
  ccv_nnc_tensor_t* const a = ccv_nnc_tensor_new(ap, CPU_TENSOR_NHWC(32F, 4, 2), 0);
22
1
  float bp[] = {
23
1
    7, 8, 9,
24
1
    10, 11, 12,
25
1
  };
26
1
  ccv_nnc_tensor_t* const b = ccv_nnc_tensor_new(bp, CPU_TENSOR_NHWC(32F, 2, 3), 0);
27
1
  ccv_nnc_tensor_t* const c = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, 4, 3), 0);
28
1
  ccv_nnc_cmd_exec(CMD_GEMM_FORWARD(), ccv_nnc_no_hint, 0, TENSOR_LIST(a, b), TENSOR_LIST(c), 0);
29
1
  float ctp[] = {
30
1
    1 * 7 + 2 * 10, 1 * 8 + 2 * 11, 1 * 9 + 2 * 12,
31
1
    3 * 7 + 4 * 10, 3 * 8 + 4 * 11, 3 * 9 + 4 * 12,
32
1
    5 * 7 + 6 * 10, 5 * 8 + 6 * 11, 5 * 9 + 6 * 12,
33
1
    7 * 7 + 8 * 10, 7 * 8 + 8 * 11, 7 * 9 + 8 * 12,
34
1
  };
35
1
  ccv_nnc_tensor_t ct = ccv_nnc_tensor(ctp, CPU_TENSOR_NHWC(32F, 4, 3), 0);
36
1
  REQUIRE_TENSOR_EQ(c, &ct, "result should be equal");
37
1
  ccv_nnc_tensor_free(a);
38
1
  ccv_nnc_tensor_free(b);
39
1
  ccv_nnc_tensor_free(c);
40
1
}
41
42
TEST_CASE("[1, 2] * [[7, 8, 9], [10, 11, 12]]")
43
1
{
44
1
  float ap[] = {
45
1
    1, 2,
46
1
  };
47
1
  ccv_nnc_tensor_t* const a = ccv_nnc_tensor_new(ap, CPU_TENSOR_NHWC(32F, 2), 0);
48
1
  float bp[] = {
49
1
    7, 8, 9,
50
1
    10, 11, 12,
51
1
  };
52
1
  ccv_nnc_tensor_t* const b = ccv_nnc_tensor_new(bp, CPU_TENSOR_NHWC(32F, 2, 3), 0);
53
1
  ccv_nnc_tensor_t* const c = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, 3), 0);
54
1
  ccv_nnc_cmd_exec(CMD_GEMM_FORWARD(), ccv_nnc_no_hint, 0, TENSOR_LIST(a, b), TENSOR_LIST(c), 0);
55
1
  float ctp[] = {
56
1
    1 * 7 + 2 * 10, 1 * 8 + 2 * 11, 1 * 9 + 2 * 12,
57
1
  };
58
1
  ccv_nnc_tensor_t ct = ccv_nnc_tensor(ctp, CPU_TENSOR_NHWC(32F, 3), 0);
59
1
  REQUIRE_TENSOR_EQ(c, &ct, "result should be equal");
60
1
  ccv_nnc_tensor_free(a);
61
1
  ccv_nnc_tensor_free(b);
62
1
  ccv_nnc_tensor_free(c);
63
1
}
64
65
TEST_CASE("[[1, 2], [3, 4], [5, 6], [7, 8]] * [[7, 10], [8, 11], [9, 12]]^T")
66
1
{
67
1
  float ap[] = {
68
1
    1, 2,
69
1
    3, 4,
70
1
    5, 6,
71
1
    7, 8,
72
1
  };
73
1
  ccv_nnc_tensor_t* const a = ccv_nnc_tensor_new(ap, CPU_TENSOR_NHWC(32F, 4, 2), 0);
74
1
  float bp[] = {
75
1
    7, 10,
76
1
    8, 11,
77
1
    9, 12,
78
1
  };
79
1
  ccv_nnc_tensor_t* const b = ccv_nnc_tensor_new(bp, CPU_TENSOR_NHWC(32F, 3, 2), 0);
80
1
  ccv_nnc_tensor_t* const c = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, 4, 3), 0);
81
1
  ccv_nnc_cmd_exec(CMD_GEMM_FORWARD(NO_TRANSPOSE, TRANSPOSE(0, 1)), ccv_nnc_no_hint, 0, TENSOR_LIST(a, b), TENSOR_LIST(c), 0);
82
1
  float ctp[] = {
83
1
    1 * 7 + 2 * 10, 1 * 8 + 2 * 11, 1 * 9 + 2 * 12,
84
1
    3 * 7 + 4 * 10, 3 * 8 + 4 * 11, 3 * 9 + 4 * 12,
85
1
    5 * 7 + 6 * 10, 5 * 8 + 6 * 11, 5 * 9 + 6 * 12,
86
1
    7 * 7 + 8 * 10, 7 * 8 + 8 * 11, 7 * 9 + 8 * 12,
87
1
  };
88
1
  ccv_nnc_tensor_t ct = ccv_nnc_tensor(ctp, CPU_TENSOR_NHWC(32F, 4, 3), 0);
89
1
  REQUIRE_TENSOR_EQ(c, &ct, "result should be equal");
90
1
  ccv_nnc_tensor_free(a);
91
1
  ccv_nnc_tensor_free(b);
92
1
  ccv_nnc_tensor_free(c);
93
1
}
94
95
TEST_CASE("[[1, 3, 5, 7], [2, 4, 6, 8]]^T * [[7, 10], [8, 11], [9, 12]]^T")
96
1
{
97
1
  float ap[] = {
98
1
    1, 3, 5, 7,
99
1
    2, 4, 6, 8,
100
1
  };
101
1
  ccv_nnc_tensor_t* const a = ccv_nnc_tensor_new(ap, CPU_TENSOR_NHWC(32F, 1, 2, 4), 0);
102
1
  float bp[] = {
103
1
    7, 10,
104
1
    8, 11,
105
1
    9, 12,
106
1
  };
107
1
  ccv_nnc_tensor_t* const b = ccv_nnc_tensor_new(bp, CPU_TENSOR_NHWC(32F, 3, 2), 0);
108
1
  ccv_nnc_tensor_t* const c = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, 1, 4, 3), 0);
109
1
  ccv_nnc_cmd_exec(CMD_GEMM_FORWARD(TRANSPOSE(1, 2), TRANSPOSE(0, 1)), ccv_nnc_no_hint, 0, TENSOR_LIST(a, b), TENSOR_LIST(c), 0);
110
1
  float ctp[] = {
111
1
    1 * 7 + 2 * 10, 1 * 8 + 2 * 11, 1 * 9 + 2 * 12,
112
1
    3 * 7 + 4 * 10, 3 * 8 + 4 * 11, 3 * 9 + 4 * 12,
113
1
    5 * 7 + 6 * 10, 5 * 8 + 6 * 11, 5 * 9 + 6 * 12,
114
1
    7 * 7 + 8 * 10, 7 * 8 + 8 * 11, 7 * 9 + 8 * 12,
115
1
  };
116
1
  ccv_nnc_tensor_t ct = ccv_nnc_tensor(ctp, CPU_TENSOR_NHWC(32F, 1, 4, 3), 0);
117
1
  REQUIRE_TENSOR_EQ(c, &ct, "result should be equal");
118
1
  ccv_nnc_tensor_free(a);
119
1
  ccv_nnc_tensor_free(b);
120
1
  ccv_nnc_tensor_free(c);
121
1
}
122
123
TEST_CASE("[[1, 2], [3, 4], [5, 6], [7, 8]] * [[7, 8, 9], [10, 11, 12]] + [-1, 0, 1]")
124
1
{
125
1
  float ap[] = {
126
1
    1, 2,
127
1
    3, 4,
128
1
    5, 6,
129
1
    7, 8,
130
1
  };
131
1
  ccv_nnc_tensor_t* const a = ccv_nnc_tensor_new(ap, CPU_TENSOR_NHWC(32F, 4, 2), 0);
132
1
  float bp[] = {
133
1
    7, 8, 9,
134
1
    10, 11, 12,
135
1
  };
136
1
  ccv_nnc_tensor_t* const b = ccv_nnc_tensor_new(bp, CPU_TENSOR_NHWC(32F, 2, 3), 0);
137
1
  float biasp[] = {
138
1
    -1, 0, 1,
139
1
  };
140
1
  ccv_nnc_tensor_t* const bias = ccv_nnc_tensor_new(biasp, CPU_TENSOR_NHWC(32F, 3), 0);
141
1
  ccv_nnc_tensor_t* const c = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, 4, 3), 0);
142
1
  ccv_nnc_cmd_exec(CMD_GEMM_FORWARD(), ccv_nnc_no_hint, 0, TENSOR_LIST(a, b, bias), TENSOR_LIST(c), 0);
143
1
  float ctp[] = {
144
1
    1 * 7 + 2 * 10 - 1, 1 * 8 + 2 * 11, 1 * 9 + 2 * 12 + 1,
145
1
    3 * 7 + 4 * 10 - 1, 3 * 8 + 4 * 11, 3 * 9 + 4 * 12 + 1,
146
1
    5 * 7 + 6 * 10 - 1, 5 * 8 + 6 * 11, 5 * 9 + 6 * 12 + 1,
147
1
    7 * 7 + 8 * 10 - 1, 7 * 8 + 8 * 11, 7 * 9 + 8 * 12 + 1,
148
1
  };
149
1
  ccv_nnc_tensor_t ct = ccv_nnc_tensor(ctp, CPU_TENSOR_NHWC(32F, 4, 3), 0);
150
1
  REQUIRE_TENSOR_EQ(c, &ct, "result should be equal");
151
1
  ccv_nnc_tensor_free(a);
152
1
  ccv_nnc_tensor_free(b);
153
1
  ccv_nnc_tensor_free(bias);
154
1
  ccv_nnc_tensor_free(c);
155
1
}
156
157
TEST_CASE("backward gemm with no transpose")
158
1
{
159
1
  float gp[] = {
160
1
    1, 2, 3,
161
1
    4, 5, 6,
162
1
    7, 8, 9,
163
1
    10, 11, 12,
164
1
  };
165
1
  ccv_nnc_tensor_t* const g = ccv_nnc_tensor_new(gp, CPU_TENSOR_NHWC(32F, 4, 3), 0);
166
1
  float ap[] = {
167
1
    13, 14,
168
1
    15, 16,
169
1
    17, 18,
170
1
    19, 20,
171
1
  };
172
1
  ccv_nnc_tensor_t* const a = ccv_nnc_tensor_new(ap, CPU_TENSOR_NHWC(32F, 4, 2), 0);
173
1
  float bp[] = {
174
1
    21, 22, 23,
175
1
    24, 25, 26,
176
1
  };
177
1
  ccv_nnc_tensor_t* const b = ccv_nnc_tensor_new(bp, CPU_TENSOR_NHWC(32F, 2, 3), 0);
178
1
  ccv_nnc_tensor_t* const h = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, 4, 2), 0);
179
1
  ccv_nnc_tensor_t* const db = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, 2, 3), 0);
180
1
  ccv_nnc_tensor_t* const dbias = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, 3), 0);
181
1
  ccv_nnc_cmd_exec(CMD_GEMM_BACKWARD(), ccv_nnc_no_hint, 0, TENSOR_LIST(g, a, b), TENSOR_LIST(h, db, dbias), 0);
182
1
  float dbiastp[] = {
183
1
    22, 26, 30,
184
1
  };
185
1
  ccv_nnc_tensor_t dbiast = ccv_nnc_tensor(dbiastp, CPU_TENSOR_NHWC(32F, 3), 0);
186
1
  REQUIRE_TENSOR_EQ(dbias, &dbiast, "bias should be equal");
187
1
  float htp[] = {
188
1
    1 * 21 + 2 * 22 + 3 * 23, 1 * 24 + 2 * 25 + 3 * 26,
189
1
    4 * 21 + 5 * 22 + 6 * 23, 4 * 24 + 5 * 25 + 6 * 26,
190
1
    7 * 21 + 8 * 22 + 9 * 23, 7 * 24 + 8 * 25 + 9 * 26,
191
1
    10 * 21 + 11 * 22 + 12 * 23, 10 * 24 + 11 * 25 + 12 * 26,
192
1
  };
193
1
  ccv_nnc_tensor_t ht = ccv_nnc_tensor(htp, CPU_TENSOR_NHWC(32F, 4, 2), 0);
194
1
  REQUIRE_TENSOR_EQ(h, &ht, "h should be equal");
195
1
  float dbtp[] = {
196
1
    1 * 13 + 4 * 15 + 7 * 17 + 10 * 19, 2 * 13 + 5 * 15 + 8 * 17 + 11 * 19, 3 * 13 + 6 * 15 + 9 * 17 + 12 * 19,
197
1
    1 * 14 + 4 * 16 + 7 * 18 + 10 * 20, 2 * 14 + 5 * 16 + 8 * 18 + 11 * 20, 3 * 14 + 6 * 16 + 9 * 18 + 12 * 20,
198
1
  };
199
1
  ccv_nnc_tensor_t dbt = ccv_nnc_tensor(dbtp, CPU_TENSOR_NHWC(32F, 2, 3), 0);
200
1
  REQUIRE_TENSOR_EQ(db, &dbt, "db should be equal");
201
1
  ccv_nnc_tensor_free(g);
202
1
  ccv_nnc_tensor_free(a);
203
1
  ccv_nnc_tensor_free(b);
204
1
  ccv_nnc_tensor_free(h);
205
1
  ccv_nnc_tensor_free(db);
206
1
  ccv_nnc_tensor_free(dbias);
207
1
}
208
209
TEST_CASE("backward gemm with transpose a")
210
1
{
211
1
  float gp[] = {
212
1
    1, 2, 3,
213
1
    4, 5, 6,
214
1
    7, 8, 9,
215
1
    10, 11, 12,
216
1
  };
217
1
  ccv_nnc_tensor_t* const g = ccv_nnc_tensor_new(gp, CPU_TENSOR_NHWC(32F, 4, 3), 0);
218
1
  float ap[] = {
219
1
    13, 15, 17, 19,
220
1
    14, 16, 18, 20,
221
1
  };
222
1
  ccv_nnc_tensor_t* const a = ccv_nnc_tensor_new(ap, CPU_TENSOR_NHWC(32F, 2, 4), 0);
223
1
  float bp[] = {
224
1
    21, 22, 23,
225
1
    24, 25, 26,
226
1
  };
227
1
  ccv_nnc_tensor_t* const b = ccv_nnc_tensor_new(bp, CPU_TENSOR_NHWC(32F, 2, 3), 0);
228
1
  ccv_nnc_tensor_t* const h = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, 2, 4), 0);
229
1
  ccv_nnc_tensor_t* const db = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, 2, 3), 0);
230
1
  ccv_nnc_tensor_t* const dbias = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, 3), 0);
231
1
  ccv_nnc_cmd_exec(CMD_GEMM_BACKWARD(TRANSPOSE(0, 1)), ccv_nnc_no_hint, 0, TENSOR_LIST(g, a, b), TENSOR_LIST(h, db, dbias), 0);
232
1
  float dbiastp[] = {
233
1
    22, 26, 30,
234
1
  };
235
1
  ccv_nnc_tensor_t dbiast = ccv_nnc_tensor(dbiastp, CPU_TENSOR_NHWC(32F, 3), 0);
236
1
  REQUIRE_TENSOR_EQ(dbias, &dbiast, "bias should be equal");
237
1
  float htp[] = {
238
1
    1 * 21 + 2 * 22 + 3 * 23, 4 * 21 + 5 * 22 + 6 * 23, 7 * 21 + 8 * 22 + 9 * 23, 10 * 21 + 11 * 22 + 12 * 23,
239
1
    1 * 24 + 2 * 25 + 3 * 26, 4 * 24 + 5 * 25 + 6 * 26, 7 * 24 + 8 * 25 + 9 * 26, 10 * 24 + 11 * 25 + 12 * 26,
240
1
  };
241
1
  ccv_nnc_tensor_t ht = ccv_nnc_tensor(htp, CPU_TENSOR_NHWC(32F, 2, 4), 0);
242
1
  REQUIRE_TENSOR_EQ(h, &ht, "h should be equal");
243
1
  float dbtp[] = {
244
1
    1 * 13 + 4 * 15 + 7 * 17 + 10 * 19, 2 * 13 + 5 * 15 + 8 * 17 + 11 * 19, 3 * 13 + 6 * 15 + 9 * 17 + 12 * 19,
245
1
    1 * 14 + 4 * 16 + 7 * 18 + 10 * 20, 2 * 14 + 5 * 16 + 8 * 18 + 11 * 20, 3 * 14 + 6 * 16 + 9 * 18 + 12 * 20,
246
1
  };
247
1
  ccv_nnc_tensor_t dbt = ccv_nnc_tensor(dbtp, CPU_TENSOR_NHWC(32F, 2, 3), 0);
248
1
  REQUIRE_TENSOR_EQ(db, &dbt, "db should be equal");
249
1
  ccv_nnc_tensor_free(g);
250
1
  ccv_nnc_tensor_free(a);
251
1
  ccv_nnc_tensor_free(b);
252
1
  ccv_nnc_tensor_free(h);
253
1
  ccv_nnc_tensor_free(db);
254
1
  ccv_nnc_tensor_free(dbias);
255
1
}
256
257
TEST_CASE("backward gemm with transpose b")
258
1
{
259
1
  float gp[] = {
260
1
    1, 2, 3,
261
1
    4, 5, 6,
262
1
    7, 8, 9,
263
1
    10, 11, 12,
264
1
  };
265
1
  ccv_nnc_tensor_t* const g = ccv_nnc_tensor_new(gp, CPU_TENSOR_NHWC(32F, 4, 3), 0);
266
1
  float ap[] = {
267
1
    13, 14,
268
1
    15, 16,
269
1
    17, 18,
270
1
    19, 20,
271
1
  };
272
1
  ccv_nnc_tensor_t* const a = ccv_nnc_tensor_new(ap, CPU_TENSOR_NHWC(32F, 4, 2), 0);
273
1
  float bp[] = {
274
1
    21, 24,
275
1
    22, 25,
276
1
    23, 26,
277
1
  };
278
1
  ccv_nnc_tensor_t* const b = ccv_nnc_tensor_new(bp, CPU_TENSOR_NHWC(32F, 3, 2), 0);
279
1
  ccv_nnc_tensor_t* const h = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, 4, 2), 0);
280
1
  ccv_nnc_tensor_t* const db = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, 3, 2), 0);
281
1
  ccv_nnc_tensor_t* const dbias = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, 3), 0);
282
1
  ccv_nnc_cmd_exec(CMD_GEMM_BACKWARD(NO_TRANSPOSE, TRANSPOSE(0, 1)), ccv_nnc_no_hint, 0, TENSOR_LIST(g, a, b), TENSOR_LIST(h, db, dbias), 0);
283
1
  float dbiastp[] = {
284
1
    22, 26, 30,
285
1
  };
286
1
  ccv_nnc_tensor_t dbiast = ccv_nnc_tensor(dbiastp, CPU_TENSOR_NHWC(32F, 3), 0);
287
1
  REQUIRE_TENSOR_EQ(dbias, &dbiast, "bias should be equal");
288
1
  float htp[] = {
289
1
    1 * 21 + 2 * 22 + 3 * 23, 1 * 24 + 2 * 25 + 3 * 26,
290
1
    4 * 21 + 5 * 22 + 6 * 23, 4 * 24 + 5 * 25 + 6 * 26,
291
1
    7 * 21 + 8 * 22 + 9 * 23, 7 * 24 + 8 * 25 + 9 * 26,
292
1
    10 * 21 + 11 * 22 + 12 * 23, 10 * 24 + 11 * 25 + 12 * 26,
293
1
  };
294
1
  ccv_nnc_tensor_t ht = ccv_nnc_tensor(htp, CPU_TENSOR_NHWC(32F, 4, 2), 0);
295
1
  REQUIRE_TENSOR_EQ(h, &ht, "h should be equal");
296
1
  float dbtp[] = {
297
1
    1 * 13 + 4 * 15 + 7 * 17 + 10 * 19, 1 * 14 + 4 * 16 + 7 * 18 + 10 * 20,
298
1
    2 * 13 + 5 * 15 + 8 * 17 + 11 * 19, 2 * 14 + 5 * 16 + 8 * 18 + 11 * 20,
299
1
    3 * 13 + 6 * 15 + 9 * 17 + 12 * 19, 3 * 14 + 6 * 16 + 9 * 18 + 12 * 20,
300
1
  };
301
1
  ccv_nnc_tensor_t dbt = ccv_nnc_tensor(dbtp, CPU_TENSOR_NHWC(32F, 3, 2), 0);
302
1
  REQUIRE_TENSOR_EQ(db, &dbt, "db should be equal");
303
1
  ccv_nnc_tensor_free(g);
304
1
  ccv_nnc_tensor_free(a);
305
1
  ccv_nnc_tensor_free(b);
306
1
  ccv_nnc_tensor_free(h);
307
1
  ccv_nnc_tensor_free(db);
308
1
  ccv_nnc_tensor_free(dbias);
309
1
}
310
311
TEST_CASE("backward gemm with transpose a and b")
312
1
{
313
1
  float gp[] = {
314
1
    1, 2, 3,
315
1
    4, 5, 6,
316
1
    7, 8, 9,
317
1
    10, 11, 12,
318
1
  };
319
1
  ccv_nnc_tensor_t* const g = ccv_nnc_tensor_new(gp, CPU_TENSOR_NHWC(32F, 4, 3), 0);
320
1
  float ap[] = {
321
1
    13, 15, 17, 19,
322
1
    14, 16, 18, 20,
323
1
  };
324
1
  ccv_nnc_tensor_t* const a = ccv_nnc_tensor_new(ap, CPU_TENSOR_NHWC(32F, 2, 4), 0);
325
1
  float bp[] = {
326
1
    21, 24,
327
1
    22, 25,
328
1
    23, 26,
329
1
  };
330
1
  ccv_nnc_tensor_t* const b = ccv_nnc_tensor_new(bp, CPU_TENSOR_NHWC(32F, 3, 2), 0);
331
1
  ccv_nnc_tensor_t* const h = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, 2, 4), 0);
332
1
  ccv_nnc_tensor_t* const db = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, 3, 2), 0);
333
1
  ccv_nnc_tensor_t* const dbias = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, 3), 0);
334
1
  ccv_nnc_cmd_exec(CMD_GEMM_BACKWARD(TRANSPOSE(0, 1), TRANSPOSE(0, 1)), ccv_nnc_no_hint, 0, TENSOR_LIST(g, a, b), TENSOR_LIST(h, db, dbias), 0);
335
1
  float dbiastp[] = {
336
1
    22, 26, 30,
337
1
  };
338
1
  ccv_nnc_tensor_t dbiast = ccv_nnc_tensor(dbiastp, CPU_TENSOR_NHWC(32F, 3), 0);
339
1
  REQUIRE_TENSOR_EQ(dbias, &dbiast, "bias should be equal");
340
1
  float htp[] = {
341
1
    1 * 21 + 2 * 22 + 3 * 23, 4 * 21 + 5 * 22 + 6 * 23, 7 * 21 + 8 * 22 + 9 * 23, 10 * 21 + 11 * 22 + 12 * 23,
342
1
    1 * 24 + 2 * 25 + 3 * 26, 4 * 24 + 5 * 25 + 6 * 26, 7 * 24 + 8 * 25 + 9 * 26, 10 * 24 + 11 * 25 + 12 * 26,
343
1
  };
344
1
  ccv_nnc_tensor_t ht = ccv_nnc_tensor(htp, CPU_TENSOR_NHWC(32F, 2, 4), 0);
345
1
  REQUIRE_TENSOR_EQ(h, &ht, "h should be equal");
346
1
  float dbtp[] = {
347
1
    1 * 13 + 4 * 15 + 7 * 17 + 10 * 19, 1 * 14 + 4 * 16 + 7 * 18 + 10 * 20,
348
1
    2 * 13 + 5 * 15 + 8 * 17 + 11 * 19, 2 * 14 + 5 * 16 + 8 * 18 + 11 * 20,
349
1
    3 * 13 + 6 * 15 + 9 * 17 + 12 * 19, 3 * 14 + 6 * 16 + 9 * 18 + 12 * 20,
350
1
  };
351
1
  ccv_nnc_tensor_t dbt = ccv_nnc_tensor(dbtp, CPU_TENSOR_NHWC(32F, 3, 2), 0);
352
1
  REQUIRE_TENSOR_EQ(db, &dbt, "db should be equal");
353
1
  ccv_nnc_tensor_free(g);
354
1
  ccv_nnc_tensor_free(a);
355
1
  ccv_nnc_tensor_free(b);
356
1
  ccv_nnc_tensor_free(h);
357
1
  ccv_nnc_tensor_free(db);
358
1
  ccv_nnc_tensor_free(dbias);
359
1
}
360
361
TEST_CASE("[[1, 2], [3, 4], [5, 6], [7, 8]], [[2, 3], [4, 5], [6, 7], [8, 9]] * [[7, 8, 9], [10, 11, 12]]")
362
1
{
363
1
  float ap[] = {
364
1
    1, 2,
365
1
    3, 4,
366
1
    5, 6,
367
1
    7, 8,
368
1
    2, 3,
369
1
    4, 5,
370
1
    6, 7,
371
1
    8, 9
372
1
  };
373
1
  ccv_nnc_tensor_t* const a = ccv_nnc_tensor_new(ap, CPU_TENSOR_NHWC(32F, 2, 4, 2), 0);
374
1
  float bp[] = {
375
1
    7, 8, 9,
376
1
    10, 11, 12,
377
1
  };
378
1
  ccv_nnc_tensor_t* const b = ccv_nnc_tensor_new(bp, CPU_TENSOR_NHWC(32F, 2, 3), 0);
379
1
  ccv_nnc_tensor_t* const c = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, 2, 4, 3), 0);
380
1
  ccv_nnc_cmd_exec(CMD_GEMM_FORWARD(), ccv_nnc_no_hint, 0, TENSOR_LIST(a, b), TENSOR_LIST(c), 0);
381
1
  float ctp[] = {
382
1
    1 * 7 + 2 * 10, 1 * 8 + 2 * 11, 1 * 9 + 2 * 12,
383
1
    3 * 7 + 4 * 10, 3 * 8 + 4 * 11, 3 * 9 + 4 * 12,
384
1
    5 * 7 + 6 * 10, 5 * 8 + 6 * 11, 5 * 9 + 6 * 12,
385
1
    7 * 7 + 8 * 10, 7 * 8 + 8 * 11, 7 * 9 + 8 * 12,
386
1
    2 * 7 + 3 * 10, 2 * 8 + 3 * 11, 2 * 9 + 3 * 12,
387
1
    4 * 7 + 5 * 10, 4 * 8 + 5 * 11, 4 * 9 + 5 * 12,
388
1
    6 * 7 + 7 * 10, 6 * 8 + 7 * 11, 6 * 9 + 7 * 12,
389
1
    8 * 7 + 9 * 10, 8 * 8 + 9 * 11, 8 * 9 + 9 * 12,
390
1
  };
391
1
  ccv_nnc_tensor_t ct = ccv_nnc_tensor(ctp, CPU_TENSOR_NHWC(32F, 2, 4, 3), 0);
392
1
  REQUIRE_TENSOR_EQ(c, &ct, "result should be equal");
393
1
  ccv_nnc_tensor_free(a);
394
1
  ccv_nnc_tensor_free(b);
395
1
  ccv_nnc_tensor_free(c);
396
1
}
397
398
TEST_CASE("[[1, 2], [3, 4], [5, 6], [7, 8]], [[2, 3], [4, 5], [6, 7], [8, 9]] * [[7, 8, 9], [10, 11, 12]], [[80, 90, 10], [110, 120, 13]]")
399
1
{
400
1
  float ap[] = {
401
1
    1, 2,
402
1
    3, 4,
403
1
    5, 6,
404
1
    7, 8,
405
1
    2, 3,
406
1
    4, 5,
407
1
    6, 7,
408
1
    8, 9
409
1
  };
410
1
  ccv_nnc_tensor_t* const a = ccv_nnc_tensor_new(ap, CPU_TENSOR_NHWC(32F, 2, 4, 2), 0);
411
1
  float bp[] = {
412
1
    7, 8, 9,
413
1
    10, 11, 12,
414
1
    80, 90, 10,
415
1
    110, 120, 13,
416
1
  };
417
1
  ccv_nnc_tensor_t* const b = ccv_nnc_tensor_new(bp, CPU_TENSOR_NHWC(32F, 2, 2, 3), 0);
418
1
  ccv_nnc_tensor_t* const c = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, 2, 4, 3), 0);
419
1
  ccv_nnc_cmd_exec(CMD_GEMM_FORWARD(), ccv_nnc_no_hint, 0, TENSOR_LIST(a, b), TENSOR_LIST(c), 0);
420
1
  float ctp[] = {
421
1
    1 * 7 + 2 * 10, 1 * 8 + 2 * 11, 1 * 9 + 2 * 12,
422
1
    3 * 7 + 4 * 10, 3 * 8 + 4 * 11, 3 * 9 + 4 * 12,
423
1
    5 * 7 + 6 * 10, 5 * 8 + 6 * 11, 5 * 9 + 6 * 12,
424
1
    7 * 7 + 8 * 10, 7 * 8 + 8 * 11, 7 * 9 + 8 * 12,
425
1
    2 * 80 + 3 * 110, 2 * 90 + 3 * 120, 2 * 10 + 3 * 13,
426
1
    4 * 80 + 5 * 110, 4 * 90 + 5 * 120, 4 * 10 + 5 * 13,
427
1
    6 * 80 + 7 * 110, 6 * 90 + 7 * 120, 6 * 10 + 7 * 13,
428
1
    8 * 80 + 9 * 110, 8 * 90 + 9 * 120, 8 * 10 + 9 * 13,
429
1
  };
430
1
  ccv_nnc_tensor_t ct = ccv_nnc_tensor(ctp, CPU_TENSOR_NHWC(32F, 2, 4, 3), 0);
431
1
  REQUIRE_TENSOR_EQ(c, &ct, "result should be equal");
432
1
  ccv_nnc_tensor_free(a);
433
1
  ccv_nnc_tensor_free(b);
434
1
  ccv_nnc_tensor_free(c);
435
1
}
436
437
TEST_CASE("[[1, 3, 5, 7], [2, 4, 6, 8]], [[2, 4, 6, 8], [3, 5, 7, 9]]^T * [[7, 8, 9], [10, 11, 12]] + [-1, 0, 1]")
438
1
{
439
1
  float ap[] = {
440
1
    1, 3, 5, 7,
441
1
    2, 4, 6, 8,
442
1
    2, 4, 6, 8,
443
1
    3, 5, 7, 9,
444
1
  };
445
1
  ccv_nnc_tensor_t* const a = ccv_nnc_tensor_new(ap, CPU_TENSOR_NHWC(32F, 2, 2, 4), 0);
446
1
  float bp[] = {
447
1
    7, 8, 9,
448
1
    10, 11, 12,
449
1
  };
450
1
  ccv_nnc_tensor_t* const b = ccv_nnc_tensor_new(bp, CPU_TENSOR_NHWC(32F, 2, 3), 0);
451
1
  float dp[] = {
452
1
    -1, 0, 1,
453
1
  };
454
1
  ccv_nnc_tensor_t* const d = ccv_nnc_tensor_new(dp, CPU_TENSOR_NHWC(32F, 3), 0);
455
1
  ccv_nnc_tensor_t* const c = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, 2, 4, 3), 0);
456
1
  ccv_nnc_cmd_exec(CMD_GEMM_FORWARD(TRANSPOSE(1, 2)), ccv_nnc_no_hint, 0, TENSOR_LIST(a, b, d), TENSOR_LIST(c), 0);
457
1
  float ctp[] = {
458
1
    1 * 7 + 2 * 10 - 1, 1 * 8 + 2 * 11, 1 * 9 + 2 * 12 + 1,
459
1
    3 * 7 + 4 * 10 - 1, 3 * 8 + 4 * 11, 3 * 9 + 4 * 12 + 1,
460
1
    5 * 7 + 6 * 10 - 1, 5 * 8 + 6 * 11, 5 * 9 + 6 * 12 + 1,
461
1
    7 * 7 + 8 * 10 - 1, 7 * 8 + 8 * 11, 7 * 9 + 8 * 12 + 1,
462
1
    2 * 7 + 3 * 10 - 1, 2 * 8 + 3 * 11, 2 * 9 + 3 * 12 + 1,
463
1
    4 * 7 + 5 * 10 - 1, 4 * 8 + 5 * 11, 4 * 9 + 5 * 12 + 1,
464
1
    6 * 7 + 7 * 10 - 1, 6 * 8 + 7 * 11, 6 * 9 + 7 * 12 + 1,
465
1
    8 * 7 + 9 * 10 - 1, 8 * 8 + 9 * 11, 8 * 9 + 9 * 12 + 1,
466
1
  };
467
1
  ccv_nnc_tensor_t ct = ccv_nnc_tensor(ctp, CPU_TENSOR_NHWC(32F, 2, 4, 3), 0);
468
1
  REQUIRE_TENSOR_EQ(c, &ct, "result should be equal");
469
1
  ccv_nnc_tensor_free(a);
470
1
  ccv_nnc_tensor_free(b);
471
1
  ccv_nnc_tensor_free(c);
472
1
  ccv_nnc_tensor_free(d);
473
1
}
474
475
TEST_CASE("[[1, 2], [3, 4], [5, 6], [7, 8]], [[2, 3], [4, 5], [6, 7], [8, 9]] * [[7, 10], [8, 11], [9, 12]], [[80, 110], [90, 120], [10, 13]]^T + [-1, 0, 1], [2, 3, -4]")
476
1
{
477
1
  float ap[] = {
478
1
    1, 2,
479
1
    3, 4,
480
1
    5, 6,
481
1
    7, 8,
482
1
    2, 3,
483
1
    4, 5,
484
1
    6, 7,
485
1
    8, 9
486
1
  };
487
1
  ccv_nnc_tensor_t* const a = ccv_nnc_tensor_new(ap, CPU_TENSOR_NHWC(32F, 2, 4, 2), 0);
488
1
  float bp[] = {
489
1
    7, 10,
490
1
    8, 11,
491
1
    9, 12,
492
1
    80, 110,
493
1
    90, 120,
494
1
    10, 13,
495
1
  };
496
1
  ccv_nnc_tensor_t* const b = ccv_nnc_tensor_new(bp, CPU_TENSOR_NHWC(32F, 2, 3, 2), 0);
497
1
  float dp[] = {
498
1
    -1, 0, 1,
499
1
    2, 3, -4,
500
1
  };
501
1
  ccv_nnc_tensor_t* const d = ccv_nnc_tensor_new(dp, CPU_TENSOR_NHWC(32F, 2, 1, 3), 0);
502
1
  ccv_nnc_tensor_t* const c = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, 2, 4, 3), 0);
503
1
  ccv_nnc_cmd_exec(CMD_GEMM_FORWARD(NO_TRANSPOSE, TRANSPOSE(1, 2)), ccv_nnc_no_hint, 0, TENSOR_LIST(a, b, d), TENSOR_LIST(c), 0);
504
1
  float ctp[] = {
505
1
    1 * 7 + 2 * 10 - 1, 1 * 8 + 2 * 11, 1 * 9 + 2 * 12 + 1,
506
1
    3 * 7 + 4 * 10 - 1, 3 * 8 + 4 * 11, 3 * 9 + 4 * 12 + 1,
507
1
    5 * 7 + 6 * 10 - 1, 5 * 8 + 6 * 11, 5 * 9 + 6 * 12 + 1,
508
1
    7 * 7 + 8 * 10 - 1, 7 * 8 + 8 * 11, 7 * 9 + 8 * 12 + 1,
509
1
    2 * 80 + 3 * 110 + 2, 2 * 90 + 3 * 120 + 3, 2 * 10 + 3 * 13 - 4,
510
1
    4 * 80 + 5 * 110 + 2, 4 * 90 + 5 * 120 + 3, 4 * 10 + 5 * 13 - 4,
511
1
    6 * 80 + 7 * 110 + 2, 6 * 90 + 7 * 120 + 3, 6 * 10 + 7 * 13 - 4,
512
1
    8 * 80 + 9 * 110 + 2, 8 * 90 + 9 * 120 + 3, 8 * 10 + 9 * 13 - 4,
513
1
  };
514
1
  ccv_nnc_tensor_t ct = ccv_nnc_tensor(ctp, CPU_TENSOR_NHWC(32F, 2, 4, 3), 0);
515
1
  REQUIRE_TENSOR_EQ(c, &ct, "result should be equal");
516
1
  ccv_nnc_tensor_free(a);
517
1
  ccv_nnc_tensor_free(b);
518
1
  ccv_nnc_tensor_free(c);
519
1
  ccv_nnc_tensor_free(d);
520
1
}
521
522
TEST_CASE("backward gemm with no transpose batch 2, same b")
523
1
{
524
1
  float gp[] = {
525
1
    1, 2, 3,
526
1
    4, 5, 6,
527
1
    7, 8, 9,
528
1
    10, 11, 12,
529
1
    10, 20, 30,
530
1
    40, 50, 60,
531
1
    70, 80, 90,
532
1
    100, 110, 120,
533
1
  };
534
1
  ccv_nnc_tensor_t* const g = ccv_nnc_tensor_new(gp, CPU_TENSOR_NHWC(32F, 2, 4, 3), 0);
535
1
  float ap[] = {
536
1
    13, 14,
537
1
    15, 16,
538
1
    17, 18,
539
1
    19, 20,
540
1
    131, 141,
541
1
    151, 161,
542
1
    171, 181,
543
1
    191, 201,
544
1
  };
545
1
  ccv_nnc_tensor_t* const a = ccv_nnc_tensor_new(ap, CPU_TENSOR_NHWC(32F, 2, 4, 2), 0);
546
1
  float bp[] = {
547
1
    21, 22, 23,
548
1
    24, 25, 26,
549
1
  };
550
1
  ccv_nnc_tensor_t* const b = ccv_nnc_tensor_new(bp, CPU_TENSOR_NHWC(32F, 2, 3), 0);
551
1
  ccv_nnc_tensor_t* const h = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, 2, 4, 2), 0);
552
1
  ccv_nnc_tensor_t* const db = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, 2, 3), 0);
553
1
  ccv_nnc_tensor_t* const dbias = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, 3), 0);
554
1
  ccv_nnc_cmd_exec(CMD_GEMM_BACKWARD(), ccv_nnc_no_hint, 0, TENSOR_LIST(g, a, b), TENSOR_LIST(h, db, dbias), 0);
555
1
  float dbiastp[] = {
556
1
    22 + 220, 26 + 260, 30 + 300,
557
1
  };
558
1
  ccv_nnc_tensor_t dbiast = ccv_nnc_tensor(dbiastp, CPU_TENSOR_NHWC(32F, 3), 0);
559
1
  REQUIRE_TENSOR_EQ(dbias, &dbiast, "bias should be equal");
560
1
  float htp[] = {
561
1
    1 * 21 + 2 * 22 + 3 * 23, 1 * 24 + 2 * 25 + 3 * 26,
562
1
    4 * 21 + 5 * 22 + 6 * 23, 4 * 24 + 5 * 25 + 6 * 26,
563
1
    7 * 21 + 8 * 22 + 9 * 23, 7 * 24 + 8 * 25 + 9 * 26,
564
1
    10 * 21 + 11 * 22 + 12 * 23, 10 * 24 + 11 * 25 + 12 * 26,
565
1
    10 * 21 + 20 * 22 + 30 * 23, 10 * 24 + 20 * 25 + 30 * 26,
566
1
    40 * 21 + 50 * 22 + 60 * 23, 40 * 24 + 50 * 25 + 60 * 26,
567
1
    70 * 21 + 80 * 22 + 90 * 23, 70 * 24 + 80 * 25 + 90 * 26,
568
1
    100 * 21 + 110 * 22 + 120 * 23, 100 * 24 + 110 * 25 + 120 * 26,
569
1
  };
570
1
  ccv_nnc_tensor_t ht = ccv_nnc_tensor(htp, CPU_TENSOR_NHWC(32F, 2, 4, 2), 0);
571
1
  REQUIRE_TENSOR_EQ(h, &ht, "h should be equal");
572
1
  float dbtp[] = {
573
1
    1 * 13 + 4 * 15 + 7 * 17 + 10 * 19 + 10 * 131 + 40 * 151 + 70 * 171 + 100 * 191, 2 * 13 + 5 * 15 + 8 * 17 + 11 * 19 + 20 * 131 + 50 * 151 + 80 * 171 + 110 * 191, 3 * 13 + 6 * 15 + 9 * 17 + 12 * 19 + 30 * 131 + 60 * 151 + 90 * 171 + 120 * 191,
574
1
    1 * 14 + 4 * 16 + 7 * 18 + 10 * 20 + 10 * 141 + 40 * 161 + 70 * 181 + 100 * 201, 2 * 14 + 5 * 16 + 8 * 18 + 11 * 20 + 20 * 141 + 50 * 161 + 80 * 181 + 110 * 201, 3 * 14 + 6 * 16 + 9 * 18 + 12 * 20 + 30 * 141 + 60 * 161 + 90 * 181 + 120 * 201,
575
1
  };
576
1
  ccv_nnc_tensor_t dbt = ccv_nnc_tensor(dbtp, CPU_TENSOR_NHWC(32F, 2, 3), 0);
577
1
  REQUIRE_TENSOR_EQ(db, &dbt, "db should be equal");
578
1
  ccv_nnc_tensor_free(g);
579
1
  ccv_nnc_tensor_free(a);
580
1
  ccv_nnc_tensor_free(b);
581
1
  ccv_nnc_tensor_free(h);
582
1
  ccv_nnc_tensor_free(db);
583
1
  ccv_nnc_tensor_free(dbias);
584
1
}
585
586
TEST_CASE("backward gemm with no transpose batch 2, batched b")
587
1
{
588
1
  float gp[] = {
589
1
    1, 2, 3,
590
1
    4, 5, 6,
591
1
    7, 8, 9,
592
1
    10, 11, 12,
593
1
    10, 20, 30,
594
1
    40, 50, 60,
595
1
    70, 80, 90,
596
1
    100, 110, 120,
597
1
  };
598
1
  ccv_nnc_tensor_t* const g = ccv_nnc_tensor_new(gp, CPU_TENSOR_NHWC(32F, 2, 4, 3), 0);
599
1
  float ap[] = {
600
1
    13, 14,
601
1
    15, 16,
602
1
    17, 18,
603
1
    19, 20,
604
1
    131, 141,
605
1
    151, 161,
606
1
    171, 181,
607
1
    191, 201,
608
1
  };
609
1
  ccv_nnc_tensor_t* const a = ccv_nnc_tensor_new(ap, CPU_TENSOR_NHWC(32F, 2, 4, 2), 0);
610
1
  float bp[] = {
611
1
    21, 22, 23,
612
1
    24, 25, 26,
613
1
    212, 222, 232,
614
1
    242, 252, 262,
615
1
  };
616
1
  ccv_nnc_tensor_t* const b = ccv_nnc_tensor_new(bp, CPU_TENSOR_NHWC(32F, 2, 2, 3), 0);
617
1
  ccv_nnc_tensor_t* const h = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, 2, 4, 2), 0);
618
1
  ccv_nnc_tensor_t* const db = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, 2, 2, 3), 0);
619
1
  ccv_nnc_tensor_t* const dbias = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, 2, 1, 3), 0);
620
1
  ccv_nnc_cmd_exec(CMD_GEMM_BACKWARD(), ccv_nnc_no_hint, 0, TENSOR_LIST(g, a, b), TENSOR_LIST(h, db, dbias), 0);
621
1
  float dbiastp[] = {
622
1
    22, 26, 30,
623
1
    220, 260, 300,
624
1
  };
625
1
  ccv_nnc_tensor_t dbiast = ccv_nnc_tensor(dbiastp, CPU_TENSOR_NHWC(32F, 2, 1, 3), 0);
626
1
  REQUIRE_TENSOR_EQ(dbias, &dbiast, "bias should be equal");
627
1
  float htp[] = {
628
1
    1 * 21 + 2 * 22 + 3 * 23, 1 * 24 + 2 * 25 + 3 * 26,
629
1
    4 * 21 + 5 * 22 + 6 * 23, 4 * 24 + 5 * 25 + 6 * 26,
630
1
    7 * 21 + 8 * 22 + 9 * 23, 7 * 24 + 8 * 25 + 9 * 26,
631
1
    10 * 21 + 11 * 22 + 12 * 23, 10 * 24 + 11 * 25 + 12 * 26,
632
1
    10 * 212 + 20 * 222 + 30 * 232, 10 * 242 + 20 * 252 + 30 * 262,
633
1
    40 * 212 + 50 * 222 + 60 * 232, 40 * 242 + 50 * 252 + 60 * 262,
634
1
    70 * 212 + 80 * 222 + 90 * 232, 70 * 242 + 80 * 252 + 90 * 262,
635
1
    100 * 212 + 110 * 222 + 120 * 232, 100 * 242 + 110 * 252 + 120 * 262,
636
1
  };
637
1
  ccv_nnc_tensor_t ht = ccv_nnc_tensor(htp, CPU_TENSOR_NHWC(32F, 2, 4, 2), 0);
638
1
  REQUIRE_TENSOR_EQ(h, &ht, "h should be equal");
639
1
  float dbtp[] = {
640
1
    1 * 13 + 4 * 15 + 7 * 17 + 10 * 19, 2 * 13 + 5 * 15 + 8 * 17 + 11 * 19, 3 * 13 + 6 * 15 + 9 * 17 + 12 * 19,
641
1
    1 * 14 + 4 * 16 + 7 * 18 + 10 * 20, 2 * 14 + 5 * 16 + 8 * 18 + 11 * 20, 3 * 14 + 6 * 16 + 9 * 18 + 12 * 20,
642
1
    10 * 131 + 40 * 151 + 70 * 171 + 100 * 191, 20 * 131 + 50 * 151 + 80 * 171 + 110 * 191, 30 * 131 + 60 * 151 + 90 * 171 + 120 * 191,
643
1
    10 * 141 + 40 * 161 + 70 * 181 + 100 * 201, 20 * 141 + 50 * 161 + 80 * 181 + 110 * 201, 30 * 141 + 60 * 161 + 90 * 181 + 120 * 201,
644
1
  };
645
1
  ccv_nnc_tensor_t dbt = ccv_nnc_tensor(dbtp, CPU_TENSOR_NHWC(32F, 2, 2, 3), 0);
646
1
  REQUIRE_TENSOR_EQ(db, &dbt, "db should be equal");
647
1
  ccv_nnc_tensor_free(g);
648
1
  ccv_nnc_tensor_free(a);
649
1
  ccv_nnc_tensor_free(b);
650
1
  ccv_nnc_tensor_free(h);
651
1
  ccv_nnc_tensor_free(db);
652
1
  ccv_nnc_tensor_free(dbias);
653
1
}
654
655
TEST_CASE("backward gemm with transpose a batch 2, same b")
656
1
{
657
1
  float gp[] = {
658
1
    1, 2, 3,
659
1
    4, 5, 6,
660
1
    7, 8, 9,
661
1
    10, 11, 12,
662
1
    10, 20, 30,
663
1
    40, 50, 60,
664
1
    70, 80, 90,
665
1
    100, 110, 120,
666
1
  };
667
1
  ccv_nnc_tensor_t* const g = ccv_nnc_tensor_new(gp, CPU_TENSOR_NHWC(32F, 2, 4, 3), 0);
668
1
  float ap[] = {
669
1
    13, 15, 17, 19,
670
1
    14, 16, 18, 20,
671
1
    131, 151, 171, 191,
672
1
    141, 161, 181, 201,
673
1
  };
674
1
  ccv_nnc_tensor_t* const a = ccv_nnc_tensor_new(ap, CPU_TENSOR_NHWC(32F, 2, 2, 4), 0);
675
1
  float bp[] = {
676
1
    21, 22, 23,
677
1
    24, 25, 26,
678
1
  };
679
1
  ccv_nnc_tensor_t* const b = ccv_nnc_tensor_new(bp, CPU_TENSOR_NHWC(32F, 2, 3), 0);
680
1
  ccv_nnc_tensor_t* const h = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, 2, 2, 4), 0);
681
1
  ccv_nnc_tensor_t* const db = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, 2, 3), 0);
682
1
  ccv_nnc_tensor_t* const dbias = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, 3), 0);
683
1
  ccv_nnc_cmd_exec(CMD_GEMM_BACKWARD(TRANSPOSE(1, 2)), ccv_nnc_no_hint, 0, TENSOR_LIST(g, a, b), TENSOR_LIST(h, db, dbias), 0);
684
1
  float dbiastp[] = {
685
1
    22 + 220, 26 + 260, 30 + 300,
686
1
  };
687
1
  ccv_nnc_tensor_t dbiast = ccv_nnc_tensor(dbiastp, CPU_TENSOR_NHWC(32F, 3), 0);
688
1
  REQUIRE_TENSOR_EQ(dbias, &dbiast, "bias should be equal");
689
1
  float htp[] = {
690
1
    1 * 21 + 2 * 22 + 3 * 23, 4 * 21 + 5 * 22 + 6 * 23, 7 * 21 + 8 * 22 + 9 * 23, 10 * 21 + 11 * 22 + 12 * 23,
691
1
    1 * 24 + 2 * 25 + 3 * 26, 4 * 24 + 5 * 25 + 6 * 26, 7 * 24 + 8 * 25 + 9 * 26, 10 * 24 + 11 * 25 + 12 * 26,
692
1
    10 * 21 + 20 * 22 + 30 * 23, 40 * 21 + 50 * 22 + 60 * 23, 70 * 21 + 80 * 22 + 90 * 23, 100 * 21 + 110 * 22 + 120 * 23,
693
1
    10 * 24 + 20 * 25 + 30 * 26, 40 * 24 + 50 * 25 + 60 * 26, 70 * 24 + 80 * 25 + 90 * 26, 100 * 24 + 110 * 25 + 120 * 26,
694
1
  };
695
1
  ccv_nnc_tensor_t ht = ccv_nnc_tensor(htp, CPU_TENSOR_NHWC(32F, 2, 2, 4), 0);
696
1
  REQUIRE_TENSOR_EQ(h, &ht, "h should be equal");
697
1
  float dbtp[] = {
698
1
    1 * 13 + 4 * 15 + 7 * 17 + 10 * 19 + 10 * 131 + 40 * 151 + 70 * 171 + 100 * 191, 2 * 13 + 5 * 15 + 8 * 17 + 11 * 19 + 20 * 131 + 50 * 151 + 80 * 171 + 110 * 191, 3 * 13 + 6 * 15 + 9 * 17 + 12 * 19 + 30 * 131 + 60 * 151 + 90 * 171 + 120 * 191,
699
1
    1 * 14 + 4 * 16 + 7 * 18 + 10 * 20 + 10 * 141 + 40 * 161 + 70 * 181 + 100 * 201, 2 * 14 + 5 * 16 + 8 * 18 + 11 * 20 + 20 * 141 + 50 * 161 + 80 * 181 + 110 * 201, 3 * 14 + 6 * 16 + 9 * 18 + 12 * 20 + 30 * 141 + 60 * 161 + 90 * 181 + 120 * 201,
700
1
  };
701
1
  ccv_nnc_tensor_t dbt = ccv_nnc_tensor(dbtp, CPU_TENSOR_NHWC(32F, 2, 3), 0);
702
1
  REQUIRE_TENSOR_EQ(db, &dbt, "db should be equal");
703
1
  ccv_nnc_tensor_free(g);
704
1
  ccv_nnc_tensor_free(a);
705
1
  ccv_nnc_tensor_free(b);
706
1
  ccv_nnc_tensor_free(h);
707
1
  ccv_nnc_tensor_free(db);
708
1
  ccv_nnc_tensor_free(dbias);
709
1
}
710
711
TEST_CASE("backward gemm with transpose b batch 2, batched b")
712
1
{
713
1
  float gp[] = {
714
1
    1, 2, 3,
715
1
    4, 5, 6,
716
1
    7, 8, 9,
717
1
    10, 11, 12,
718
1
    10, 20, 30,
719
1
    40, 50, 60,
720
1
    70, 80, 90,
721
1
    100, 110, 120,
722
1
  };
723
1
  ccv_nnc_tensor_t* const g = ccv_nnc_tensor_new(gp, CPU_TENSOR_NHWC(32F, 2, 4, 3), 0);
724
1
  float ap[] = {
725
1
    13, 14,
726
1
    15, 16,
727
1
    17, 18,
728
1
    19, 20,
729
1
    131, 141,
730
1
    151, 161,
731
1
    171, 181,
732
1
    191, 201,
733
1
  };
734
1
  ccv_nnc_tensor_t* const a = ccv_nnc_tensor_new(ap, CPU_TENSOR_NHWC(32F, 2, 4, 2), 0);
735
1
  float bp[] = {
736
1
    21, 24,
737
1
    22, 25,
738
1
    23, 26,
739
1
    212, 242,
740
1
    222, 252,
741
1
    232, 262,
742
1
  };
743
1
  ccv_nnc_tensor_t* const b = ccv_nnc_tensor_new(bp, CPU_TENSOR_NHWC(32F, 2, 3, 2), 0);
744
1
  ccv_nnc_tensor_t* const h = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, 2, 4, 2), 0);
745
1
  ccv_nnc_tensor_t* const db = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, 2, 3, 2), 0);
746
1
  ccv_nnc_tensor_t* const dbias = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, 2, 1, 3), 0);
747
1
  ccv_nnc_cmd_exec(CMD_GEMM_BACKWARD(NO_TRANSPOSE, TRANSPOSE(1, 2)), ccv_nnc_no_hint, 0, TENSOR_LIST(g, a, b), TENSOR_LIST(h, db, dbias), 0);
748
1
  float dbiastp[] = {
749
1
    22, 26, 30,
750
1
    220, 260, 300,
751
1
  };
752
1
  ccv_nnc_tensor_t dbiast = ccv_nnc_tensor(dbiastp, CPU_TENSOR_NHWC(32F, 2, 1, 3), 0);
753
1
  REQUIRE_TENSOR_EQ(dbias, &dbiast, "bias should be equal");
754
1
  float htp[] = {
755
1
    1 * 21 + 2 * 22 + 3 * 23, 1 * 24 + 2 * 25 + 3 * 26,
756
1
    4 * 21 + 5 * 22 + 6 * 23, 4 * 24 + 5 * 25 + 6 * 26,
757
1
    7 * 21 + 8 * 22 + 9 * 23, 7 * 24 + 8 * 25 + 9 * 26,
758
1
    10 * 21 + 11 * 22 + 12 * 23, 10 * 24 + 11 * 25 + 12 * 26,
759
1
    10 * 212 + 20 * 222 + 30 * 232, 10 * 242 + 20 * 252 + 30 * 262,
760
1
    40 * 212 + 50 * 222 + 60 * 232, 40 * 242 + 50 * 252 + 60 * 262,
761
1
    70 * 212 + 80 * 222 + 90 * 232, 70 * 242 + 80 * 252 + 90 * 262,
762
1
    100 * 212 + 110 * 222 + 120 * 232, 100 * 242 + 110 * 252 + 120 * 262,
763
1
  };
764
1
  ccv_nnc_tensor_t ht = ccv_nnc_tensor(htp, CPU_TENSOR_NHWC(32F, 2, 4, 2), 0);
765
1
  REQUIRE_TENSOR_EQ(h, &ht, "h should be equal");
766
1
  float dbtp[] = {
767
1
    1 * 13 + 4 * 15 + 7 * 17 + 10 * 19, 1 * 14 + 4 * 16 + 7 * 18 + 10 * 20,
768
1
    2 * 13 + 5 * 15 + 8 * 17 + 11 * 19, 2 * 14 + 5 * 16 + 8 * 18 + 11 * 20,
769
1
    3 * 13 + 6 * 15 + 9 * 17 + 12 * 19, 3 * 14 + 6 * 16 + 9 * 18 + 12 * 20,
770
1
    10 * 131 + 40 * 151 + 70 * 171 + 100 * 191, 10 * 141 + 40 * 161 + 70 * 181 + 100 * 201,
771
1
    20 * 131 + 50 * 151 + 80 * 171 + 110 * 191, 20 * 141 + 50 * 161 + 80 * 181 + 110 * 201,
772
1
    30 * 131 + 60 * 151 + 90 * 171 + 120 * 191, 30 * 141 + 60 * 161 + 90 * 181 + 120 * 201,
773
1
  };
774
1
  ccv_nnc_tensor_t dbt = ccv_nnc_tensor(dbtp, CPU_TENSOR_NHWC(32F, 2, 3, 2), 0);
775
1
  REQUIRE_TENSOR_EQ(db, &dbt, "db should be equal");
776
1
  ccv_nnc_tensor_free(g);
777
1
  ccv_nnc_tensor_free(a);
778
1
  ccv_nnc_tensor_free(b);
779
1
  ccv_nnc_tensor_free(h);
780
1
  ccv_nnc_tensor_free(db);
781
1
  ccv_nnc_tensor_free(dbias);
782
1
}
783
784
TEST_CASE("backward gemm with transpose a and b batch 2, same b")
785
1
{
786
1
  float gp[] = {
787
1
    1, 2, 3,
788
1
    4, 5, 6,
789
1
    7, 8, 9,
790
1
    10, 11, 12,
791
1
    10, 20, 30,
792
1
    40, 50, 60,
793
1
    70, 80, 90,
794
1
    100, 110, 120,
795
1
  };
796
1
  ccv_nnc_tensor_t* const g = ccv_nnc_tensor_new(gp, CPU_TENSOR_NHWC(32F, 2, 4, 3), 0);
797
1
  float ap[] = {
798
1
    13, 15, 17, 19,
799
1
    14, 16, 18, 20,
800
1
    131, 151, 171, 191,
801
1
    141, 161, 181, 201,
802
1
  };
803
1
  ccv_nnc_tensor_t* const a = ccv_nnc_tensor_new(ap, CPU_TENSOR_NHWC(32F, 2, 2, 4), 0);
804
1
  float bp[] = {
805
1
    21, 24,
806
1
    22, 25,
807
1
    23, 26,
808
1
  };
809
1
  ccv_nnc_tensor_t* const b = ccv_nnc_tensor_new(bp, CPU_TENSOR_NHWC(32F, 3, 2), 0);
810
1
  ccv_nnc_tensor_t* const h = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, 2, 2, 4), 0);
811
1
  ccv_nnc_tensor_t* const db = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, 3, 2), 0);
812
1
  ccv_nnc_tensor_t* const dbias = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, 3), 0);
813
1
  ccv_nnc_cmd_exec(CMD_GEMM_BACKWARD(TRANSPOSE(1, 2), TRANSPOSE(0, 1)), ccv_nnc_no_hint, 0, TENSOR_LIST(g, a, b), TENSOR_LIST(h, db, dbias), 0);
814
1
  float dbiastp[] = {
815
1
    22 + 220, 26 + 260, 30 + 300,
816
1
  };
817
1
  ccv_nnc_tensor_t dbiast = ccv_nnc_tensor(dbiastp, CPU_TENSOR_NHWC(32F, 3), 0);
818
1
  REQUIRE_TENSOR_EQ(dbias, &dbiast, "bias should be equal");
819
1
  float htp[] = {
820
1
    1 * 21 + 2 * 22 + 3 * 23, 4 * 21 + 5 * 22 + 6 * 23, 7 * 21 + 8 * 22 + 9 * 23, 10 * 21 + 11 * 22 + 12 * 23,
821
1
    1 * 24 + 2 * 25 + 3 * 26, 4 * 24 + 5 * 25 + 6 * 26, 7 * 24 + 8 * 25 + 9 * 26, 10 * 24 + 11 * 25 + 12 * 26,
822
1
    10 * 21 + 20 * 22 + 30 * 23, 40 * 21 + 50 * 22 + 60 * 23, 70 * 21 + 80 * 22 + 90 * 23, 100 * 21 + 110 * 22 + 120 * 23,
823
1
    10 * 24 + 20 * 25 + 30 * 26, 40 * 24 + 50 * 25 + 60 * 26, 70 * 24 + 80 * 25 + 90 * 26, 100 * 24 + 110 * 25 + 120 * 26,
824
1
  };
825
1
  ccv_nnc_tensor_t ht = ccv_nnc_tensor(htp, CPU_TENSOR_NHWC(32F, 2, 2, 4), 0);
826
1
  REQUIRE_TENSOR_EQ(h, &ht, "h should be equal");
827
1
  float dbtp[] = {
828
1
    1 * 13 + 4 * 15 + 7 * 17 + 10 * 19 + 10 * 131 + 40 * 151 + 70 * 171 + 100 * 191, 1 * 14 + 4 * 16 + 7 * 18 + 10 * 20 + 10 * 141 + 40 * 161 + 70 * 181 + 100 * 201,
829
1
    2 * 13 + 5 * 15 + 8 * 17 + 11 * 19 + 20 * 131 + 50 * 151 + 80 * 171 + 110 * 191, 2 * 14 + 5 * 16 + 8 * 18 + 11 * 20 + 20 * 141 + 50 * 161 + 80 * 181 + 110 * 201,
830
1
    3 * 13 + 6 * 15 + 9 * 17 + 12 * 19 + 30 * 131 + 60 * 151 + 90 * 171 + 120 * 191, 3 * 14 + 6 * 16 + 9 * 18 + 12 * 20 + 30 * 141 + 60 * 161 + 90 * 181 + 120 * 201,
831
1
  };
832
1
  ccv_nnc_tensor_t dbt = ccv_nnc_tensor(dbtp, CPU_TENSOR_NHWC(32F, 3, 2), 0);
833
1
  REQUIRE_TENSOR_EQ(db, &dbt, "db should be equal");
834
1
  ccv_nnc_tensor_free(g);
835
1
  ccv_nnc_tensor_free(a);
836
1
  ccv_nnc_tensor_free(b);
837
1
  ccv_nnc_tensor_free(h);
838
1
  ccv_nnc_tensor_free(db);
839
1
  ccv_nnc_tensor_free(dbias);
840
1
}
841
842
#include "case_main.h"