/home/liu/actions-runner/_work/ccv/ccv/lib/nnc/cmd/blas/ccv_nnc_gemm_cpu_ref.c
Line | Count | Source (jump to first uncovered line) |
1 | | #include "ccv.h" |
2 | | #include "ccv_internal.h" |
3 | | #include "nnc/ccv_nnc.h" |
4 | | #include "nnc/ccv_nnc_easy.h" |
5 | | #include "nnc/ccv_nnc_internal.h" |
6 | | #ifdef USE_OPENMP |
7 | | #include <omp.h> |
8 | | #endif |
9 | | #ifdef USE_DISPATCH |
10 | | #include <dispatch/dispatch.h> |
11 | | #endif |
12 | | |
13 | | static inline void _ccv_nnc_bmm_and_bias(const float* const a, const float* const w, const float* const bias, float* const b, const int b_batch_size, const int a_batch_inc, const int w_batch_inc, const int bias_batch_inc, const int b_batch_inc, const int b_rows, const int b_cols, const int a_cols, const int a_cols_inc, const int w_cols_inc, const int bias_cols_inc, const int b_cols_inc, const int a_rows_inc, const int w_rows_inc, const int bias_rows_inc, const int b_rows_inc) |
14 | 11.8k | { |
15 | 11.8k | int n, i; |
16 | 23.7k | for (n = 0; n < b_batch_size; n++11.9k ) |
17 | 11.9k | { |
18 | 11.9k | const float* const ap = a + n * a_batch_inc; |
19 | 11.9k | const float* const wp = w + n * w_batch_inc; |
20 | 11.9k | const float* const biasp = bias + n * bias_batch_inc; |
21 | 11.9k | float* const bp = b + n * b_batch_inc; |
22 | 34.6k | for (i = 0; i < b_rows; i++22.7k ) |
23 | 22.7k | { |
24 | 22.7k | const float* const api = ap + i * a_rows_inc; |
25 | 22.7k | const float* const biaspi = biasp + i * bias_rows_inc; |
26 | 22.7k | float* const bpi = bp + i * b_rows_inc; |
27 | 3.28M | parallel_for22.7k (j, b_cols) { |
28 | 3.28M | float v = biaspi[j * bias_cols_inc]; |
29 | 3.28M | const float* const wpj = wp + j * w_cols_inc; |
30 | 3.28M | int k; |
31 | 3.10G | for (k = 0; k < a_cols; k++3.09G ) |
32 | 3.09G | v += wpj[k * w_rows_inc] * api[k * a_cols_inc]; |
33 | 3.28M | bpi[j * b_cols_inc] = v; |
34 | 3.28M | } parallel_endfor |
35 | 22.7k | } |
36 | 11.9k | } |
37 | 11.8k | } |
38 | | |
39 | | static inline void _ccv_nnc_gbmm_and_bias(const float* const a, const int a_nd, const int* const adim, const int* const astride, const float* const w, const int w_nd, const int* const wdim, const int* const wstride, const float* const bias, const int bias_nd, const int* const biasdim, const int* const biasstride, float* const b, const int b_nd, const int* const bdim, const int* const bstride, const int b_batch_size, const int a_batch_inc, const int w_batch_inc, const int bias_batch_inc, const int b_batch_inc, const int b_rows, const int b_cols, const int a_cols, const int a_cols_inc, const int w_cols_inc, const int bias_cols_inc, const int b_cols_inc, const int a_rows_inc, const int w_rows_inc, const int bias_rows_inc, const int b_rows_inc) |
40 | 11.8k | { |
41 | 11.8k | if (b_nd <= 3) |
42 | 11.8k | { |
43 | 11.8k | _ccv_nnc_bmm_and_bias(a, w, bias, b, b_batch_size, a_batch_inc, w_batch_inc, bias_batch_inc, b_batch_inc, b_rows, b_cols, a_cols, a_cols_inc, w_cols_inc, bias_cols_inc, b_cols_inc, a_rows_inc, w_rows_inc, bias_rows_inc, b_rows_inc); |
44 | 11.8k | return; |
45 | 11.8k | } |
46 | 6 | const int dim = bdim[0]; |
47 | 6 | if (a_nd > 3) |
48 | 6 | { assert(adim[0] == 1 || dim == adim[0]); } |
49 | 6 | if (w_nd > 3) |
50 | 3 | { assert(wdim[0] == 1 || dim == wdim[0]); } |
51 | 6 | if (bias_nd > 3) |
52 | 0 | { assert(biasdim[0] == 1 || dim == biasdim[0]); } |
53 | 6 | int i; |
54 | 18 | for (i = 0; i < dim; i++12 ) |
55 | 12 | { |
56 | 12 | _ccv_nnc_gbmm_and_bias( |
57 | 12 | a_nd > 3 ? a + i * astride[0] : a0 , a_nd > 3 ? a_nd - 1 : a_nd0 , a_nd > 3 ? adim + 1 : adim0 , a_nd > 3 ? astride + 1 : astride0 , |
58 | 12 | w_nd > 3 ? w + i * wstride[0]6 : w6 , w_nd > 3 ? w_nd - 16 : w_nd6 , w_nd > 3 ? wdim + 16 : wdim6 , w_nd > 3 ? wstride + 16 : wstride6 , |
59 | 12 | bias_nd > 3 ? bias + i * biasstride[0]0 : bias, bias_nd > 3 ? bias_nd - 10 : bias_nd, bias_nd > 3 ? biasdim + 10 : biasdim, bias_nd > 3 ? biasstride + 10 : biasstride, |
60 | 12 | b + i * bstride[0], b_nd - 1, bdim + 1, bstride + 1, b_batch_size, a_batch_inc, w_batch_inc, bias_batch_inc, b_batch_inc, b_rows, b_cols, a_cols, a_cols_inc, w_cols_inc, bias_cols_inc, b_cols_inc, a_rows_inc, w_rows_inc, bias_rows_inc, b_rows_inc); |
61 | 12 | } |
62 | 6 | } |
63 | | |
64 | | static inline void _ccv_nnc_bmm(const float* const a, const float* const w, float* const b, const int b_batch_size, const int a_batch_inc, const int w_batch_inc, const int b_batch_inc, const int b_rows, const int b_cols, const int a_cols, const int a_cols_inc, const int w_cols_inc, const int b_cols_inc, const int a_rows_inc, const int w_rows_inc, const int b_rows_inc) |
65 | 3.96k | { |
66 | 3.96k | int n, i; |
67 | 9.32k | for (n = 0; n < b_batch_size; n++5.36k ) |
68 | 5.36k | { |
69 | 5.36k | const float* const ap = a + n * a_batch_inc; |
70 | 5.36k | const float* const wp = w + n * w_batch_inc; |
71 | 5.36k | float* const bp = b + n * b_batch_inc; |
72 | 206k | for (i = 0; i < b_rows; i++201k ) |
73 | 201k | { |
74 | 201k | const float* const api = ap + i * a_rows_inc; |
75 | 201k | float* const bpi = bp + i * b_rows_inc; |
76 | 22.0M | parallel_for201k (j, b_cols) { |
77 | 22.0M | float v = 0; |
78 | 22.0M | const float* const wpj = wp + j * w_cols_inc; |
79 | 22.0M | int k; |
80 | 2.04G | for (k = 0; k < a_cols; k++2.01G ) |
81 | 2.01G | v += wpj[k * w_rows_inc] * api[k * a_cols_inc]; |
82 | 22.0M | bpi[j * b_cols_inc] = v; |
83 | 22.0M | } parallel_endfor |
84 | 201k | } |
85 | 5.36k | } |
86 | 3.96k | } |
87 | | |
88 | | static inline void _ccv_nnc_gbmm(const float* const a, const int a_nd, const int* const adim, const int* const astride, const float* const w, const int w_nd, const int* const wdim, const int* const wstride, float* const b, const int b_nd, const int* const bdim, const int* const bstride, const int b_batch_size, const int a_batch_inc, const int w_batch_inc, const int b_batch_inc, const int b_rows, const int b_cols, const int a_cols, const int a_cols_inc, const int w_cols_inc, const int b_cols_inc, const int a_rows_inc, const int w_rows_inc, const int b_rows_inc) |
89 | 3.97k | { |
90 | 3.97k | if (b_nd <= 3) |
91 | 3.96k | { |
92 | 3.96k | _ccv_nnc_bmm(a, w, b, b_batch_size, a_batch_inc, w_batch_inc, b_batch_inc, b_rows, b_cols, a_cols, a_cols_inc, w_cols_inc, b_cols_inc, a_rows_inc, w_rows_inc, b_rows_inc); |
93 | 3.96k | return; |
94 | 3.96k | } |
95 | 12 | const int dim = bdim[0]; |
96 | 12 | if (a_nd > 3) |
97 | 12 | { assert(adim[0] == 1 || dim == adim[0]); } |
98 | 12 | if (w_nd > 3) |
99 | 9 | { assert(wdim[0] == 1 || dim == wdim[0]); } |
100 | 12 | int i; |
101 | 216 | for (i = 0; i < dim; i++204 ) |
102 | 204 | { |
103 | 204 | _ccv_nnc_gbmm( |
104 | 204 | a_nd > 3 ? a + i * astride[0] : a0 , a_nd > 3 ? a_nd - 1 : a_nd0 , a_nd > 3 ? adim + 1 : adim0 , a_nd > 3 ? astride + 1 : astride0 , |
105 | 204 | w_nd > 3 ? w + i * wstride[0]198 : w6 , w_nd > 3 ? w_nd - 1198 : w_nd6 , w_nd > 3 ? wdim + 1198 : wdim6 , w_nd > 3 ? wstride + 1198 : wstride6 , |
106 | 204 | b + i * bstride[0], b_nd - 1, bdim + 1, bstride + 1, b_batch_size, a_batch_inc, w_batch_inc, b_batch_inc, b_rows, b_cols, a_cols, a_cols_inc, w_cols_inc, b_cols_inc, a_rows_inc, w_rows_inc, b_rows_inc); |
107 | 204 | } |
108 | 12 | } |
109 | | |
110 | | static int _ccv_nnc_gemm_forw(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint, const int flags, ccv_nnc_tensor_t* const* const inputs, const int input_size, ccv_nnc_tensor_t* const* const outputs, const int output_size, ccv_nnc_stream_context_t* const stream_context) |
111 | 15.5k | { |
112 | 15.5k | assert(input_size >= 2); |
113 | 15.5k | const ccv_nnc_tensor_view_t* a = (const ccv_nnc_tensor_view_t*)inputs[0]; |
114 | 15.5k | const ccv_nnc_tensor_view_t* w = (const ccv_nnc_tensor_view_t*)inputs[1]; |
115 | 15.5k | const ccv_nnc_tensor_view_t* bias = input_size > 2 ? (const ccv_nnc_tensor_view_t*)inputs[2]11.8k : 03.76k ; |
116 | | // Copy the most of parameters, but reshape the dimension of a to a vector. |
117 | 15.5k | assert(output_size == 1); |
118 | 15.5k | ccv_nnc_tensor_view_t* b = (ccv_nnc_tensor_view_t*)outputs[0]; |
119 | 15.5k | assert(!bias || (bias->info.dim[1] == 0 || bias->info.dim[2] == 0 || bias->info.dim[3] == 0)); // It is a 1-d array |
120 | 15.5k | int a_batch_size, a_rows, a_cols, a_batch_inc, a_rows_inc, a_cols_inc; |
121 | 15.5k | int w_batch_size, w_rows, w_cols, w_batch_inc, w_rows_inc, w_cols_inc; |
122 | 15.5k | int b_batch_size, b_rows, b_cols, b_batch_inc, b_rows_inc, b_cols_inc; |
123 | 15.5k | const static int no_transpose[2] = {}; |
124 | 15.5k | ccv_nnc_tensor_get_matrix_params(a->info, CCV_IS_TENSOR_VIEW(a) ? a->stride14 : 015.5k , a->info.dim, cmd.info.blas.transpose_a, &a_batch_size, &a_rows, &a_cols, &a_batch_inc, &a_rows_inc, &a_cols_inc); |
125 | 15.5k | ccv_nnc_tensor_get_matrix_params(w->info, CCV_IS_TENSOR_VIEW(w) ? w->stride4 : 015.5k , w->info.dim, cmd.info.blas.transpose_b, &w_batch_size, &w_rows, &w_cols, &w_batch_inc, &w_rows_inc, &w_cols_inc); |
126 | 15.5k | ccv_nnc_tensor_get_matrix_params(b->info, CCV_IS_TENSOR_VIEW(b) ? b->stride26 : 015.5k , b->info.dim, no_transpose, &b_batch_size, &b_rows, &b_cols, &b_batch_inc, &b_rows_inc, &b_cols_inc); |
127 | 15.5k | assert(ccv_max(a_batch_size, w_batch_size) == b_batch_size); |
128 | 15.5k | assert(a_batch_size == b_batch_size || a_batch_size == 1); |
129 | 15.5k | if (a_batch_size == 1 && b_batch_size > 115.5k ) |
130 | 0 | a_batch_inc = 0; |
131 | 15.5k | assert(w_batch_size == b_batch_size || w_batch_size == 1); |
132 | 15.5k | if (w_batch_size == 1 && b_batch_size > 115.5k ) |
133 | 11 | w_batch_inc = 0; |
134 | 15.5k | assert(a_rows == b_rows); |
135 | 15.5k | assert(a_cols == w_rows); |
136 | 15.5k | assert(w_cols == b_cols); |
137 | 15.5k | int astride_from_dim[CCV_NNC_MAX_DIM_ALLOC]; |
138 | 15.5k | int wstride_from_dim[CCV_NNC_MAX_DIM_ALLOC]; |
139 | 15.5k | int bstride_from_dim[CCV_NNC_MAX_DIM_ALLOC]; |
140 | 15.5k | const int* astride; |
141 | 15.5k | if (CCV_IS_TENSOR_VIEW(a)) |
142 | 14 | astride = a->stride; |
143 | 15.5k | else { |
144 | 15.5k | ccv_nnc_tensor_get_stride(a->info.dim, astride_from_dim); |
145 | 15.5k | astride = astride_from_dim; |
146 | 15.5k | } |
147 | 15.5k | const int* wstride; |
148 | 15.5k | if (CCV_IS_TENSOR_VIEW(w)) |
149 | 4 | wstride = w->stride; |
150 | 15.5k | else { |
151 | 15.5k | ccv_nnc_tensor_get_stride(w->info.dim, wstride_from_dim); |
152 | 15.5k | wstride = wstride_from_dim; |
153 | 15.5k | } |
154 | 15.5k | const int* bstride; |
155 | 15.5k | if (CCV_IS_TENSOR_VIEW(b)) |
156 | 26 | bstride = b->stride; |
157 | 15.5k | else { |
158 | 15.5k | ccv_nnc_tensor_get_stride(b->info.dim, bstride_from_dim); |
159 | 15.5k | bstride = bstride_from_dim; |
160 | 15.5k | } |
161 | 15.5k | if (bias) |
162 | 11.8k | { |
163 | 11.8k | int bias_batch_size, bias_rows, bias_cols, bias_batch_inc, bias_rows_inc, bias_cols_inc; |
164 | 11.8k | ccv_nnc_tensor_get_matrix_params(bias->info, CCV_IS_TENSOR_VIEW(bias) ? bias->stride0 : 0, bias->info.dim, no_transpose, &bias_batch_size, &bias_rows, &bias_cols, &bias_batch_inc, &bias_rows_inc, &bias_cols_inc); |
165 | 11.8k | assert(bias_batch_size == b_batch_size || bias_batch_size == 1); |
166 | 11.8k | if (bias_batch_size == 1 && b_batch_size > 111.8k ) |
167 | 10 | bias_batch_inc = 0; |
168 | 11.8k | if (bias_rows == 1 && b_rows > 1) |
169 | 6.12k | bias_rows_inc = 0; |
170 | 11.8k | assert(bias_cols == b_cols); |
171 | 11.8k | const int* biasstride; |
172 | 11.8k | int biasstride_from_dim[CCV_NNC_MAX_DIM_ALLOC]; |
173 | 11.8k | if (CCV_IS_TENSOR_VIEW(bias)) |
174 | 0 | biasstride = bias->stride; |
175 | 11.8k | else { |
176 | 11.8k | ccv_nnc_tensor_get_stride(bias->info.dim, biasstride_from_dim); |
177 | 11.8k | biasstride = biasstride_from_dim; |
178 | 11.8k | } |
179 | 11.8k | _ccv_nnc_gbmm_and_bias(a->data.f32, ccv_nnc_tensor_nd(a->info.dim), a->info.dim, astride, w->data.f32, ccv_nnc_tensor_nd(w->info.dim), w->info.dim, wstride, bias->data.f32, ccv_nnc_tensor_nd(bias->info.dim), bias->info.dim, biasstride, b->data.f32, ccv_nnc_tensor_nd(b->info.dim), b->info.dim, bstride, b_batch_size, a_batch_inc, w_batch_inc, bias_batch_inc, b_batch_inc, b_rows, b_cols, a_cols, a_cols_inc, w_cols_inc, bias_cols_inc, b_cols_inc, a_rows_inc, w_rows_inc, bias_rows_inc, b_rows_inc); |
180 | 11.8k | } else { |
181 | 3.76k | _ccv_nnc_gbmm(a->data.f32, ccv_nnc_tensor_nd(a->info.dim), a->info.dim, astride, w->data.f32, ccv_nnc_tensor_nd(w->info.dim), w->info.dim, wstride, b->data.f32, ccv_nnc_tensor_nd(b->info.dim), b->info.dim, bstride, b_batch_size, a_batch_inc, w_batch_inc, b_batch_inc, b_rows, b_cols, a_cols, a_cols_inc, w_cols_inc, b_cols_inc, a_rows_inc, w_rows_inc, b_rows_inc); |
182 | 3.76k | } |
183 | 15.5k | return CCV_NNC_EXEC_SUCCESS; |
184 | 15.5k | } |
185 | | |
186 | | static inline void _ccv_nnc_bmm_dbias(const float* const g, float* const dbias, const int g_batch_size, const int g_batch_inc, const int dbias_batch_inc, const int g_rows, const int g_cols, const int g_cols_inc, const int dbias_cols_inc, const int g_rows_inc, const int dbias_rows_inc) |
187 | 3.49k | { |
188 | 3.49k | int n, i, j; |
189 | 7.00k | for (n = 0; n < g_batch_size; n++3.51k ) |
190 | 3.51k | { |
191 | 3.51k | const float* const gp = g + n * g_batch_inc; |
192 | 3.51k | float* const bp = dbias + n * dbias_batch_inc; |
193 | 7.34k | for (i = 0; i < g_rows; i++3.83k ) |
194 | 3.83k | { |
195 | 3.83k | const float* const gpi = gp + i * g_rows_inc; |
196 | 3.83k | float* const bpi = bp + i * dbias_rows_inc; |
197 | 22.1k | for (j = 0; j < g_cols; j++18.3k ) |
198 | 18.3k | bpi[j * dbias_cols_inc] += gpi[j * g_cols_inc]; |
199 | 3.83k | } |
200 | 3.51k | } |
201 | 3.49k | } |
202 | | |
203 | | static inline void _ccv_nnc_gbmm_dbias(const float* const g, const int g_nd, const int* const gdim, const int* const gstride, float* const dbias, const int dbias_nd, const int* const dbiasdim, const int* const dbiasstride, const int g_batch_size, const int g_batch_inc, const int dbias_batch_inc, const int g_rows, const int g_cols, const int g_cols_inc, const int dbias_cols_inc, const int g_rows_inc, const int dbias_rows_inc) |
204 | 3.49k | { |
205 | 3.49k | if (g_nd <= 3) |
206 | 3.49k | { |
207 | 3.49k | _ccv_nnc_bmm_dbias(g, dbias, g_batch_size, g_batch_inc, dbias_batch_inc, g_rows, g_cols, g_cols_inc, dbias_cols_inc, g_rows_inc, dbias_rows_inc); |
208 | 3.49k | return; |
209 | 3.49k | } |
210 | 2 | const int dim = gdim[0]; |
211 | 2 | if (dbias_nd > 3) |
212 | 0 | { assert(dbiasdim[0] == 1 || dim == dbiasdim[0]); } |
213 | 2 | int i; |
214 | 6 | for (i = 0; i < dim; i++4 ) |
215 | 4 | { |
216 | 4 | _ccv_nnc_gbmm_dbias( |
217 | 4 | g + i * gstride[0], g_nd - 1, gdim + 1, gstride + 1, |
218 | 4 | dbias_nd > 3 ? dbias + i * dbiasstride[0]0 : dbias, dbias_nd > 3 ? dbias_nd - 10 : dbias_nd, dbias_nd > 3 ? dbiasdim + 10 : dbiasdim, dbias_nd > 3 ? dbiasstride + 10 : dbiasstride, |
219 | 4 | g_batch_size, g_batch_inc, dbias_batch_inc, g_rows, g_cols, g_cols_inc, dbias_cols_inc, g_rows_inc, dbias_rows_inc); |
220 | 4 | } |
221 | 2 | } |
222 | | |
223 | | static inline void _ccv_nnc_bmm_dw(const float* const g, const float* const a, float* const dw, const int g_batch_size, const int g_batch_inc, const int a_batch_inc, const int dw_batch_inc, const int a_rows, const int a_cols, const int g_cols, const int g_cols_inc, const int a_cols_inc, const int dw_cols_inc, const int g_rows_inc, const int a_rows_inc, const int dw_rows_inc) |
224 | 10.7k | { |
225 | 10.7k | int n, i; |
226 | 21.9k | for (n = 0; n < g_batch_size; n++11.1k ) |
227 | 11.1k | { |
228 | 11.1k | const float* const gp = g + n * g_batch_inc; |
229 | 11.1k | const float* const ap = a + n * a_batch_inc; |
230 | 11.1k | float* const dwp = dw + n * dw_batch_inc; |
231 | 93.9k | for (i = 0; i < a_rows; i++82.7k ) |
232 | 82.7k | { |
233 | 82.7k | const float* const gpi = gp + i * g_rows_inc; |
234 | 82.7k | const float* const api = ap + i * a_rows_inc; |
235 | 7.39M | parallel_for82.7k (j, g_cols) { |
236 | 7.39M | const float v = gpi[j * g_cols_inc]; |
237 | 7.39M | float* dwpj = dwp + j * dw_cols_inc; |
238 | 7.39M | int k; |
239 | 683M | for (k = 0; k < a_cols; k++675M ) |
240 | 675M | dwpj[k * dw_rows_inc] += api[k * a_cols_inc] * v; |
241 | 7.39M | } parallel_endfor |
242 | 82.7k | } |
243 | 11.1k | } |
244 | 10.7k | } |
245 | | |
246 | | static inline void _ccv_nnc_gbmm_dw(const float* const g, const int g_nd, const int* const gdim, const int* const gstride, const float* const a, const int a_nd, const int* const adim, const int* const astride, float* const dw, const int dw_nd, const int* const dwdim, const int* const dwstride, const int g_batch_size, const int g_batch_inc, const int a_batch_inc, const int dw_batch_inc, const int a_rows, const int a_cols, const int g_cols, const int g_cols_inc, const int a_cols_inc, const int dw_cols_inc, const int g_rows_inc, const int a_rows_inc, const int dw_rows_inc) |
247 | 10.7k | { |
248 | 10.7k | if (g_nd <= 3) |
249 | 10.7k | { |
250 | 10.7k | _ccv_nnc_bmm_dw(g, a, dw, g_batch_size, g_batch_inc, a_batch_inc, dw_batch_inc, a_rows, a_cols, g_cols, g_cols_inc, a_cols_inc, dw_cols_inc, g_rows_inc, a_rows_inc, dw_rows_inc); |
251 | 10.7k | return; |
252 | 10.7k | } |
253 | 6 | const int dim = gdim[0]; |
254 | 6 | if (a_nd > 3) |
255 | 6 | { assert(adim[0] == 1 || dim == adim[0]); } |
256 | 6 | if (dw_nd > 3) |
257 | 4 | { assert(dwdim[0] == 1 || dim == dwdim[0]); } |
258 | 6 | int i; |
259 | 78 | for (i = 0; i < dim; i++72 ) |
260 | 72 | { |
261 | 72 | _ccv_nnc_gbmm_dw( |
262 | 72 | g + i * gstride[0], g_nd - 1, gdim + 1, gstride + 1, |
263 | 72 | a_nd > 3 ? a + i * astride[0] : a0 , a_nd > 3 ? a_nd - 1 : a_nd0 , a_nd > 3 ? adim + 1 : adim0 , a_nd > 3 ? astride + 1 : astride0 , |
264 | 72 | dw_nd > 3 ? dw + i * dwstride[0]68 : dw4 , dw_nd > 3 ? dw_nd - 168 : dw_nd4 , dw_nd > 3 ? dwdim + 168 : dwdim4 , dw_nd > 3 ? dwstride + 168 : dwstride4 , |
265 | 72 | g_batch_size, g_batch_inc, a_batch_inc, dw_batch_inc, a_rows, a_cols, g_cols, g_cols_inc, a_cols_inc, dw_cols_inc, g_rows_inc, a_rows_inc, dw_rows_inc); |
266 | 72 | } |
267 | 6 | } |
268 | | |
269 | | static inline void _ccv_nnc_bmm_h(const float* const g, const float* const w, float* const h, const int zero_h, const int g_batch_size, const int g_batch_inc, const int w_batch_inc, const int h_batch_inc, const int h_rows, const int h_cols, const int g_cols, const int g_cols_inc, const int w_cols_inc, const int h_cols_inc, const int g_rows_inc, const int w_rows_inc, const int h_rows_inc) |
270 | 2.46k | { |
271 | 2.46k | int n, i; |
272 | 5.41k | for (n = 0; n < g_batch_size; n++2.95k ) |
273 | 2.95k | { |
274 | 2.95k | const float* const gp = g + n * g_batch_inc; |
275 | 2.95k | const float* const wp = w + n * w_batch_inc; |
276 | 2.95k | float* const hp = h + n * h_batch_inc; |
277 | 71.4k | for (i = 0; i < h_rows; i++68.4k ) |
278 | 68.4k | { |
279 | 68.4k | const float* const gpi = gp + i * g_rows_inc; |
280 | 68.4k | float* const hpi = hp + i * h_rows_inc; |
281 | 6.35M | parallel_for68.4k (j, h_cols) { |
282 | 6.35M | const float* const wpj = wp + j * w_rows_inc; |
283 | 6.35M | float v = zero_h ? 0 : hpi[j * h_cols_inc]0 ; |
284 | 6.35M | int k; |
285 | 681M | for (k = 0; k < g_cols; k++675M ) |
286 | 675M | v += wpj[k * w_cols_inc] * gpi[k * g_cols_inc]; |
287 | 6.35M | hpi[j * h_cols_inc] = v; |
288 | 6.35M | } parallel_endfor |
289 | 68.4k | } |
290 | 2.95k | } |
291 | 2.46k | } |
292 | | |
293 | | static inline void _ccv_nnc_gbmm_h(const float* const g, const int g_nd, const int* const gdim, const int* const gstride, const float* const w, const int w_nd, const int* const wdim, const int* const wstride, float* const h, const int zero_h, const int h_nd, const int* const hdim, const int* const hstride, const int g_batch_size, const int g_batch_inc, const int w_batch_inc, const int h_batch_inc, const int h_rows, const int h_cols, const int g_cols, const int g_cols_inc, const int w_cols_inc, const int h_cols_inc, const int g_rows_inc, const int w_rows_inc, const int h_rows_inc) |
294 | 2.47k | { |
295 | 2.47k | if (g_nd <= 3) |
296 | 2.46k | { |
297 | 2.46k | _ccv_nnc_bmm_h(g, w, h, zero_h, g_batch_size, g_batch_inc, w_batch_inc, h_batch_inc, h_rows, h_cols, g_cols, g_cols_inc, w_cols_inc, h_cols_inc, g_rows_inc, w_rows_inc, h_rows_inc); |
298 | 2.46k | return; |
299 | 2.46k | } |
300 | 6 | const int dim = gdim[0]; |
301 | 6 | if (w_nd > 3) |
302 | 4 | { assert(wdim[0] == 1 || dim == wdim[0]); } |
303 | 6 | if (h_nd > 3) |
304 | 6 | { assert(hdim[0] == 1 || dim == hdim[0]); } |
305 | 6 | int i; |
306 | 78 | for (i = 0; i < dim; i++72 ) |
307 | 72 | { |
308 | | // Only zero h if we are not doing h again. |
309 | 72 | const int zero_h_override = (i == 0 || (66 i * hstride[0] > 066 && h_nd > 366 )) ? zero_h : 00 ; |
310 | 72 | _ccv_nnc_gbmm_h( |
311 | 72 | g + i * gstride[0], g_nd - 1, gdim + 1, gstride + 1, |
312 | 72 | w_nd > 3 ? w + i * wstride[0]68 : w4 , w_nd > 3 ? w_nd - 168 : w_nd4 , w_nd > 3 ? wdim + 168 : wdim4 , w_nd > 3 ? wstride + 168 : wstride4 , |
313 | 72 | h_nd > 3 ? h + i * hstride[0] : h0 , zero_h_override, h_nd > 3 ? h_nd - 1 : h_nd0 , h_nd > 3 ? hdim + 1 : hdim0 , h_nd > 3 ? hstride + 1 : hstride0 , |
314 | 72 | g_batch_size, g_batch_inc, w_batch_inc, h_batch_inc, h_rows, h_cols, g_cols, g_cols_inc, w_cols_inc, h_cols_inc, g_rows_inc, w_rows_inc, h_rows_inc); |
315 | 72 | } |
316 | 6 | } |
317 | | |
318 | | static int _ccv_nnc_gemm_back(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint, const int flags, ccv_nnc_tensor_t* const* const inputs, const int input_size, ccv_nnc_tensor_t* const* const outputs, const int output_size, ccv_nnc_stream_context_t* const stream_context) |
319 | 10.6k | { |
320 | | // inputs: gradient, forw prop input, [w] |
321 | | // outputs: [output gradient], weight updates, bias updates |
322 | 10.6k | assert(input_size >= 2 && output_size >= 1); |
323 | 10.6k | const ccv_nnc_tensor_view_t* g = (const ccv_nnc_tensor_view_t*)inputs[0]; |
324 | 10.6k | ccv_nnc_tensor_view_t* dw = output_size > 1 ? (ccv_nnc_tensor_view_t*)outputs[1] : 00 ; |
325 | 10.6k | ccv_nnc_tensor_view_t* bias = output_size > 2 ? (ccv_nnc_tensor_view_t*)outputs[2]9.50k : 01.16k ; |
326 | 10.6k | assert(!bias || (bias->info.dim[1] == 0 || bias->info.dim[2] == 0 || bias->info.dim[3] == 0)); // It is a 2-d or 3-d array. |
327 | 10.6k | int g_batch_size, g_rows, g_cols, g_batch_inc, g_rows_inc, g_cols_inc; |
328 | 10.6k | const static int no_transpose[2] = {}; |
329 | 10.6k | ccv_nnc_tensor_get_matrix_params(g->info, CCV_IS_TENSOR_VIEW(g) ? g->stride6 : 010.6k , g->info.dim, no_transpose, &g_batch_size, &g_rows, &g_cols, &g_batch_inc, &g_rows_inc, &g_cols_inc); |
330 | 10.6k | if (bias) |
331 | 3.49k | { |
332 | 3.49k | if (!(flags & CCV_NNC_ACCUMULATE_OUTPUT)) // reset the gradients to 0 |
333 | 3.49k | ccv_nnc_tensor_zero(bias); |
334 | 3.49k | int bias_batch_size, bias_rows, bias_cols, bias_batch_inc, bias_rows_inc, bias_cols_inc; |
335 | 3.49k | ccv_nnc_tensor_get_matrix_params(bias->info, CCV_IS_TENSOR_VIEW(bias) ? bias->stride0 : 0, bias->info.dim, no_transpose, &bias_batch_size, &bias_rows, &bias_cols, &bias_batch_inc, &bias_rows_inc, &bias_cols_inc); |
336 | 3.49k | assert(bias_cols == g_cols); |
337 | 3.49k | assert(bias_batch_size == 1 || bias_batch_size == g_batch_size); |
338 | 3.49k | if (bias_batch_size == 1 && g_batch_size > 13.49k ) |
339 | 5 | bias_batch_inc = 0; |
340 | 3.49k | if (bias_rows == 1 && g_rows > 1) |
341 | 124 | bias_rows_inc = 0; |
342 | 3.49k | int gstride_from_dim[CCV_NNC_MAX_DIM_ALLOC]; |
343 | 3.49k | const int* gstride; |
344 | 3.49k | if (CCV_IS_TENSOR_VIEW(g)) |
345 | 0 | gstride = g->stride; |
346 | 3.49k | else { |
347 | 3.49k | ccv_nnc_tensor_get_stride(g->info.dim, gstride_from_dim); |
348 | 3.49k | gstride = gstride_from_dim; |
349 | 3.49k | } |
350 | 3.49k | int biasstride_from_dim[CCV_NNC_MAX_DIM_ALLOC]; |
351 | 3.49k | const int* biasstride; |
352 | 3.49k | if (CCV_IS_TENSOR_VIEW(bias)) |
353 | 0 | biasstride = bias->stride; |
354 | 3.49k | else { |
355 | 3.49k | ccv_nnc_tensor_get_stride(bias->info.dim, biasstride_from_dim); |
356 | 3.49k | biasstride = biasstride_from_dim; |
357 | 3.49k | } |
358 | 3.49k | _ccv_nnc_gbmm_dbias(g->data.f32, ccv_nnc_tensor_nd(g->info.dim), g->info.dim, gstride, bias->data.f32, ccv_nnc_tensor_nd(bias->info.dim), bias->info.dim, biasstride, g_batch_size, g_batch_inc, bias_batch_inc, g_rows, g_cols, g_cols_inc, bias_cols_inc, g_rows_inc, bias_rows_inc); |
359 | 3.49k | } |
360 | 10.6k | if (dw) |
361 | 10.6k | { |
362 | 10.6k | if (!(flags & CCV_NNC_ACCUMULATE_OUTPUT)) // reset the gradients to 0 |
363 | 10.6k | ccv_nnc_tensor_zero(dw); |
364 | 10.6k | const ccv_nnc_tensor_view_t* a = (const ccv_nnc_tensor_view_t*)inputs[1]; |
365 | 10.6k | int a_batch_size, a_rows, a_cols, a_batch_inc, a_rows_inc, a_cols_inc; |
366 | 10.6k | int dw_batch_size, dw_rows, dw_cols, dw_batch_inc, dw_rows_inc, dw_cols_inc; |
367 | 10.6k | ccv_nnc_tensor_get_matrix_params(a->info, CCV_IS_TENSOR_VIEW(a) ? a->stride2 : 010.6k , a->info.dim, cmd.info.blas.transpose_a, &a_batch_size, &a_rows, &a_cols, &a_batch_inc, &a_rows_inc, &a_cols_inc); |
368 | 10.6k | ccv_nnc_tensor_get_matrix_params(dw->info, CCV_IS_TENSOR_VIEW(dw) ? dw->stride0 : 0, dw->info.dim, cmd.info.blas.transpose_b, &dw_batch_size, &dw_rows, &dw_cols, &dw_batch_inc, &dw_rows_inc, &dw_cols_inc); |
369 | 10.6k | assert(a_rows == g_rows); |
370 | 10.6k | assert(a_cols == dw_rows); |
371 | 10.6k | assert(dw_cols == g_cols); |
372 | 10.6k | assert(a_batch_size == g_batch_size || a_batch_size == 1); |
373 | 10.6k | if (a_batch_size == 1 && g_batch_size > 110.6k ) |
374 | 0 | a_batch_inc = 0; |
375 | 10.6k | assert(dw_batch_size == g_batch_size || dw_batch_size == 1); |
376 | 10.6k | if (dw_batch_size == 1 && g_batch_size > 110.6k ) |
377 | 5 | dw_batch_inc = 0; |
378 | 10.6k | int gstride_from_dim[CCV_NNC_MAX_DIM_ALLOC]; |
379 | 10.6k | int astride_from_dim[CCV_NNC_MAX_DIM_ALLOC]; |
380 | 10.6k | int dwstride_from_dim[CCV_NNC_MAX_DIM_ALLOC]; |
381 | 10.6k | const int* gstride; |
382 | 10.6k | if (CCV_IS_TENSOR_VIEW(g)) |
383 | 2 | gstride = g->stride; |
384 | 10.6k | else { |
385 | 10.6k | ccv_nnc_tensor_get_stride(g->info.dim, gstride_from_dim); |
386 | 10.6k | gstride = gstride_from_dim; |
387 | 10.6k | } |
388 | 10.6k | const int* astride; |
389 | 10.6k | if (CCV_IS_TENSOR_VIEW(a)) |
390 | 2 | astride = a->stride; |
391 | 10.6k | else { |
392 | 10.6k | ccv_nnc_tensor_get_stride(a->info.dim, astride_from_dim); |
393 | 10.6k | astride = astride_from_dim; |
394 | 10.6k | } |
395 | 10.6k | const int* dwstride; |
396 | 10.6k | if (CCV_IS_TENSOR_VIEW(dw)) |
397 | 0 | dwstride = dw->stride; |
398 | 10.6k | else { |
399 | 10.6k | ccv_nnc_tensor_get_stride(dw->info.dim, dwstride_from_dim); |
400 | 10.6k | dwstride = dwstride_from_dim; |
401 | 10.6k | } |
402 | 10.6k | _ccv_nnc_gbmm_dw(g->data.f32, ccv_nnc_tensor_nd(g->info.dim), g->info.dim, gstride, a->data.f32, ccv_nnc_tensor_nd(a->info.dim), a->info.dim, astride, dw->data.f32, ccv_nnc_tensor_nd(dw->info.dim), dw->info.dim, dwstride, g_batch_size, g_batch_inc, a_batch_inc, dw_batch_inc, a_rows, a_cols, g_cols, g_cols_inc, a_cols_inc, dw_cols_inc, g_rows_inc, a_rows_inc, dw_rows_inc); |
403 | 10.6k | } |
404 | 10.6k | ccv_nnc_tensor_view_t* h = (ccv_nnc_tensor_view_t*)outputs[0]; |
405 | 10.6k | if (h) |
406 | 2.40k | { |
407 | 2.40k | const int zero_h = !(flags & CCV_NNC_ACCUMULATE_OUTPUT); // reset the gradients to 0 |
408 | 2.40k | const ccv_nnc_tensor_view_t* w = (const ccv_nnc_tensor_view_t*)inputs[2]; |
409 | 2.40k | int h_batch_size, h_rows, h_cols, h_batch_inc, h_rows_inc, h_cols_inc; |
410 | 2.40k | int w_batch_size, w_rows, w_cols, w_batch_inc, w_rows_inc, w_cols_inc; |
411 | 2.40k | ccv_nnc_tensor_get_matrix_params(h->info, CCV_IS_TENSOR_VIEW(h) ? h->stride4 : 02.39k , h->info.dim, cmd.info.blas.transpose_a, &h_batch_size, &h_rows, &h_cols, &h_batch_inc, &h_rows_inc, &h_cols_inc); |
412 | 2.40k | ccv_nnc_tensor_get_matrix_params(w->info, CCV_IS_TENSOR_VIEW(w) ? w->stride0 : 0, w->info.dim, cmd.info.blas.transpose_b, &w_batch_size, &w_rows, &w_cols, &w_batch_inc, &w_rows_inc, &w_cols_inc); |
413 | 2.40k | assert(h_cols == w_rows); |
414 | 2.40k | assert(w_cols == g_cols); |
415 | 2.40k | assert(h_batch_size == g_batch_size || h_batch_size == 1); |
416 | 2.40k | if (h_batch_size == 1 && g_batch_size > 12.39k ) |
417 | 0 | h_batch_inc = 0; |
418 | 2.40k | assert(w_batch_size == g_batch_size || w_batch_size == 1); |
419 | 2.40k | if (w_batch_size == 1 && g_batch_size > 12.39k ) |
420 | 5 | w_batch_inc = 0; |
421 | 2.40k | int gstride_from_dim[CCV_NNC_MAX_DIM_ALLOC]; |
422 | 2.40k | int wstride_from_dim[CCV_NNC_MAX_DIM_ALLOC]; |
423 | 2.40k | int hstride_from_dim[CCV_NNC_MAX_DIM_ALLOC]; |
424 | 2.40k | const int* gstride; |
425 | 2.40k | if (CCV_IS_TENSOR_VIEW(g)) |
426 | 4 | gstride = g->stride; |
427 | 2.39k | else { |
428 | 2.39k | ccv_nnc_tensor_get_stride(g->info.dim, gstride_from_dim); |
429 | 2.39k | gstride = gstride_from_dim; |
430 | 2.39k | } |
431 | 2.40k | const int* wstride; |
432 | 2.40k | if (CCV_IS_TENSOR_VIEW(w)) |
433 | 0 | wstride = w->stride; |
434 | 2.40k | else { |
435 | 2.40k | ccv_nnc_tensor_get_stride(w->info.dim, wstride_from_dim); |
436 | 2.40k | wstride = wstride_from_dim; |
437 | 2.40k | } |
438 | 2.40k | const int* hstride; |
439 | 2.40k | if (CCV_IS_TENSOR_VIEW(h)) |
440 | 4 | hstride = h->stride; |
441 | 2.39k | else { |
442 | 2.39k | ccv_nnc_tensor_get_stride(h->info.dim, hstride_from_dim); |
443 | 2.39k | hstride = hstride_from_dim; |
444 | 2.39k | } |
445 | 2.40k | _ccv_nnc_gbmm_h(g->data.f32, ccv_nnc_tensor_nd(g->info.dim), g->info.dim, gstride, w->data.f32, ccv_nnc_tensor_nd(w->info.dim), w->info.dim, wstride, h->data.f32, zero_h, ccv_nnc_tensor_nd(h->info.dim), h->info.dim, hstride, g_batch_size, g_batch_inc, w_batch_inc, h_batch_inc, h_rows, h_cols, g_cols, g_cols_inc, w_cols_inc, h_cols_inc, g_rows_inc, w_rows_inc, h_rows_inc); |
446 | 2.40k | } |
447 | 10.6k | return CCV_NNC_EXEC_SUCCESS; |
448 | 10.6k | } |
449 | | |
450 | | REGISTER_COMMAND_BACKEND(CCV_NNC_GEMM_FORWARD, CCV_NNC_BACKEND_CPU_REF)(ccv_nnc_cmd_backend_registry_t* const registry) |
451 | 1 | { |
452 | 1 | registry->tensor_formats = CCV_TENSOR_FORMAT_NHWC | CCV_TENSOR_FORMAT_NCHW; |
453 | 1 | registry->tensor_datatypes = CCV_32F; |
454 | 1 | registry->tensor_memory = CCV_TENSOR_CPU_MEMORY; |
455 | 1 | registry->algorithms = 1; |
456 | 1 | registry->exec = _ccv_nnc_gemm_forw; |
457 | 1 | } |
458 | | |
459 | | REGISTER_COMMAND_BACKEND(CCV_NNC_GEMM_BACKWARD, CCV_NNC_BACKEND_CPU_REF)(ccv_nnc_cmd_backend_registry_t* const registry) |
460 | 1 | { |
461 | 1 | registry->tensor_formats = CCV_TENSOR_FORMAT_NHWC | CCV_TENSOR_FORMAT_NCHW; |
462 | 1 | registry->tensor_datatypes = CCV_32F; |
463 | 1 | registry->tensor_memory = CCV_TENSOR_CPU_MEMORY; |
464 | 1 | registry->algorithms = 1; |
465 | 1 | registry->exec = _ccv_nnc_gemm_back; |
466 | 1 | } |