/home/liu/actions-runner/_work/ccv/ccv/lib/nnc/cmd/blas/cpu_opt/_ccv_nnc_gemm_cpu_opt.c
Line | Count | Source (jump to first uncovered line) |
1 | | #include "ccv.h" |
2 | | #include "ccv_internal.h" |
3 | | #include "nnc/ccv_nnc.h" |
4 | | #include "nnc/ccv_nnc_easy.h" |
5 | | #include "nnc/ccv_nnc_internal.h" |
6 | | #if defined(HAVE_SSE2) |
7 | | #include <xmmintrin.h> |
8 | | #elif defined(HAVE_NEON) |
9 | | #include <arm_neon.h> |
10 | | #endif |
11 | | #ifdef USE_OPENMP |
12 | | #include <omp.h> |
13 | | #endif |
14 | | #ifdef USE_DISPATCH |
15 | | #include <dispatch/dispatch.h> |
16 | | #endif |
17 | | #include "../_ccv_nnc_gemm_cpu_opt.h" |
18 | | |
19 | | #ifdef HAVE_SSE2 |
20 | | static int _ccv_nnc_gemm_forw_sse2(const ccv_nnc_tensor_view_t* const a, const ccv_nnc_tensor_view_t* const w, const ccv_nnc_tensor_view_t* const bias, ccv_nnc_tensor_view_t* const b) |
21 | 70 | { |
22 | 70 | const int a_nd = ccv_nnc_tensor_nd(a->info.dim); |
23 | 70 | const int* adim = (a_nd == 1) ? a->info.dim28 : a->info.dim + 142 ; |
24 | 70 | const int b_nd = ccv_nnc_tensor_nd(b->info.dim); |
25 | 70 | const int* bdim = (b_nd == 1) ? b->info.dim28 : b->info.dim + 142 ; |
26 | 70 | assert(!bias || bdim[0] == bias->info.dim[0]); |
27 | 70 | assert(bdim[0] == w->info.dim[0]); |
28 | 70 | assert(adim[0] == w->info.dim[1]); |
29 | 70 | const int batch_size = a_nd == 1 ? 128 : ccv_max42 (1, a->info.dim[0]); |
30 | 70 | assert(batch_size == (b_nd == 1) ? 1 : ccv_max(1, b->info.dim[0])); |
31 | 70 | const int a_batch_inc = CCV_IS_TENSOR_VIEW(a) ? (0 a_nd == 10 ? adim[0]0 : a->stride[0]0 ) : adim[0]; |
32 | 70 | const int b_batch_inc = CCV_IS_TENSOR_VIEW(b) ? (0 b_nd == 10 ? bdim[0]0 : b->stride[0]0 ) : bdim[0]; |
33 | 70 | const int wstride = CCV_IS_TENSOR_VIEW(w) ? w->stride[0]0 : w->info.dim[1]; |
34 | 70 | int i; |
35 | 70 | if (bias) |
36 | 70 | { |
37 | 140 | for (i = 0; i < batch_size; i++70 ) |
38 | 70 | { |
39 | 70 | const float* const ap = a->data.f32 + i * a_batch_inc; |
40 | 70 | float* const bp = b->data.f32 + i * b_batch_inc; |
41 | 68.9k | parallel_for70 (j, bdim[0]) { |
42 | 68.9k | const float* const wp = w->data.f32 + j * wstride; |
43 | 68.9k | int k; |
44 | 68.9k | __m128 v40 = _mm_set_ss(bias->data.f32[j]); |
45 | 68.9k | __m128 v41 = _mm_setzero_ps(); |
46 | 84.2M | for (k = 0; k < adim[0] - 7; k += 884.2M ) |
47 | 84.2M | { |
48 | 84.2M | __m128 ap40 = _mm_load_ps(ap + k); |
49 | 84.2M | __m128 ap41 = _mm_load_ps(ap + k + 4); |
50 | 84.2M | __m128 w40 = _mm_load_ps(wp + k); |
51 | 84.2M | __m128 w41 = _mm_load_ps(wp + k + 4); |
52 | 84.2M | v40 =_mm_add_ps(_mm_mul_ps(w40, ap40), v40); |
53 | 84.2M | v41 =_mm_add_ps(_mm_mul_ps(w41, ap41), v41); |
54 | 84.2M | } |
55 | 68.9k | v40 = _mm_add_ps(v40, v41); |
56 | 68.9k | v41 = _mm_add_ps(v40, _mm_movehl_ps(v40, v40)); |
57 | 68.9k | v40 = _mm_add_ss(v41, _mm_shuffle_ps(v41, v41, 1)); |
58 | 68.9k | _mm_store_ss(bp + j, v40); |
59 | 68.9k | } parallel_endfor |
60 | 70 | } |
61 | 70 | } else { |
62 | 0 | for (i = 0; i < batch_size; i++) |
63 | 0 | { |
64 | 0 | const float* const ap = a->data.f32 + i * a_batch_inc; |
65 | 0 | float* const bp = b->data.f32 + i * b_batch_inc; |
66 | 0 | parallel_for(j, bdim[0]) { |
67 | 0 | const float* const wp = w->data.f32 + j * wstride; |
68 | 0 | int k; |
69 | 0 | __m128 v40 = _mm_setzero_ps(); |
70 | 0 | __m128 v41 = _mm_setzero_ps(); |
71 | 0 | for (k = 0; k < adim[0] - 7; k += 8) |
72 | 0 | { |
73 | 0 | __m128 ap40 = _mm_load_ps(ap + k); |
74 | 0 | __m128 ap41 = _mm_load_ps(ap + k + 4); |
75 | 0 | __m128 w40 = _mm_load_ps(wp + k); |
76 | 0 | __m128 w41 = _mm_load_ps(wp + k + 4); |
77 | 0 | v40 =_mm_add_ps(_mm_mul_ps(w40, ap40), v40); |
78 | 0 | v41 =_mm_add_ps(_mm_mul_ps(w41, ap41), v41); |
79 | 0 | } |
80 | 0 | v40 = _mm_add_ps(v40, v41); |
81 | 0 | v41 = _mm_add_ps(v40, _mm_movehl_ps(v40, v40)); |
82 | 0 | v40 = _mm_add_ss(v41, _mm_shuffle_ps(v41, v41, 1)); |
83 | 0 | _mm_store_ss(bp + j, v40); |
84 | 0 | } parallel_endfor |
85 | 0 | } |
86 | 0 | } |
87 | 70 | return CCV_NNC_EXEC_SUCCESS; |
88 | 70 | } |
89 | | |
90 | | static int _ccv_nnc_gemm_back_sse2(const ccv_nnc_tensor_view_t* const g, const ccv_nnc_tensor_view_t* const a, const ccv_nnc_tensor_view_t* const w, ccv_nnc_tensor_view_t* const dw, ccv_nnc_tensor_view_t* const bias, ccv_nnc_tensor_view_t* const h, const int flags) |
91 | 9 | { |
92 | 9 | const int dwstride = CCV_IS_TENSOR_VIEW(dw) ? dw->stride[0]0 : dw->info.dim[1]; |
93 | 9 | if (!(flags & CCV_NNC_ACCUMULATE_OUTPUT)) // reset the gradients to 0 |
94 | 9 | { |
95 | 9 | memset(dw->data.u8, 0, sizeof(float) * dwstride * dw->info.dim[0]); |
96 | 9 | if (bias) |
97 | 9 | memset(bias->data.u8, 0, sizeof(float) * bias->info.dim[0]); |
98 | 9 | } |
99 | 9 | const int a_nd = ccv_nnc_tensor_nd(a->info.dim); |
100 | 9 | const int* adim = (a_nd == 1) ? a->info.dim0 : a->info.dim + 1; |
101 | 9 | const int g_nd = ccv_nnc_tensor_nd(g->info.dim); |
102 | 9 | const int* gdim = (g_nd == 1) ? g->info.dim0 : g->info.dim + 1; |
103 | 9 | const int batch_size = a_nd == 1 ? 10 : ccv_max(1, a->info.dim[0]); |
104 | 9 | int i, j; |
105 | 9 | float* gp = g->data.f32; |
106 | 9 | const int g_batch_inc = CCV_IS_TENSOR_VIEW(g) ? (0 (g_nd == 1)0 ? gdim[0]0 : g->stride[0]0 ) : gdim[0]; |
107 | 9 | if (bias) |
108 | 9 | { |
109 | 9 | float* bp = bias->data.f32; |
110 | 9 | assert(bias->info.dim[0] == gdim[0]); |
111 | 18 | for (i = 0; 9 i < batch_size; i++9 ) |
112 | 9 | { |
113 | 585 | for (j = 0; j < gdim[0] - 3; j += 4576 ) |
114 | 576 | { |
115 | 576 | __m128 g4 = _mm_load_ps(gp + j); |
116 | 576 | __m128 b4 = _mm_load_ps(bp + j); |
117 | 576 | _mm_stream_ps(bp + j, _mm_add_ps(b4, g4)); |
118 | 576 | } |
119 | 9 | gp += g_batch_inc; |
120 | 9 | } |
121 | 9 | } |
122 | 9 | assert(gdim[0] == dw->info.dim[0]); |
123 | 9 | assert(adim[0] == dw->info.dim[1]); |
124 | 9 | const int a_batch_inc = CCV_IS_TENSOR_VIEW(a) ? (0 (a_nd == 1)0 ? adim[0]0 : a->stride[0]0 ) : adim[0]; |
125 | 18 | for (i = 0; i < batch_size; i++9 ) |
126 | 9 | { |
127 | 9 | const float* const gp = g->data.f32 + i * g_batch_inc; |
128 | 9 | const float* const ap = a->data.f32 + i * a_batch_inc; |
129 | 2.30k | parallel_for9 (j, gdim[0]) { |
130 | 2.30k | float* const dwp = dw->data.f32 + j * dwstride; |
131 | 2.30k | __m128 g4 = _mm_set1_ps(gp[j]); |
132 | 2.30k | int k; |
133 | 334k | for (k = 0; k < adim[0] - 3; k+= 4331k ) |
134 | 331k | { |
135 | 331k | __m128 a4 = _mm_load_ps(ap + k); |
136 | 331k | __m128 dw4 = _mm_load_ps(dwp + k); |
137 | 331k | _mm_stream_ps(dwp + k, _mm_add_ps(dw4, _mm_mul_ps(a4, g4))); |
138 | 331k | } |
139 | 2.30k | } parallel_endfor |
140 | 9 | } |
141 | 9 | if (h && w) |
142 | 9 | { |
143 | 9 | const int h_nd = ccv_nnc_tensor_nd(h->info.dim); |
144 | 9 | const int* hdim = (h_nd == 1) ? h->info.dim0 : h->info.dim + 1; |
145 | 9 | assert(hdim[0] == adim[0]); |
146 | 9 | const int h_batch_inc = CCV_IS_TENSOR_VIEW(h) ? (0 (h_nd == 1)0 ? hdim[0]0 : h->stride[0]0 ) : hdim[0]; |
147 | 9 | const int wstride = CCV_IS_TENSOR_VIEW(w) ? w->stride[0]0 : w->info.dim[1]; |
148 | 18 | for (i = 0; i < batch_size; i++9 ) |
149 | 9 | { |
150 | 9 | const float* const gp = g->data.f32 + i * g_batch_inc; |
151 | 9 | float* const hp = h->data.f32 + i * h_batch_inc; |
152 | 1.29k | parallel_for9 (y, hdim[0] / 4) { |
153 | 1.29k | const int j = y * 4; |
154 | 1.29k | const float* const wp = w->data.f32 + j; |
155 | 1.29k | __m128 v40 = _mm_setzero_ps(); |
156 | 1.29k | __m128 v41 = _mm_setzero_ps(); |
157 | 1.29k | __m128 v42 = _mm_setzero_ps(); |
158 | 1.29k | __m128 v43 = _mm_setzero_ps(); |
159 | 1.29k | int k; |
160 | 84.2k | for (k = 0; k < gdim[0]; k += 482.9k ) |
161 | 82.9k | { |
162 | 82.9k | __m128 g4 = _mm_load_ps(gp + k); |
163 | 82.9k | __m128 w40 = _mm_load_ps(wp + k * wstride); |
164 | 82.9k | __m128 w41 = _mm_load_ps(wp + (k + 1) * wstride); |
165 | 82.9k | __m128 w42 = _mm_load_ps(wp + (k + 2) * wstride); |
166 | 82.9k | __m128 w43 = _mm_load_ps(wp + (k + 3) * wstride); |
167 | 82.9k | __m128 g40 = _mm_shuffle_ps(g4, g4, 0x00); |
168 | 82.9k | __m128 g41 = _mm_shuffle_ps(g4, g4, 0x55); |
169 | 82.9k | __m128 g42 = _mm_shuffle_ps(g4, g4, 0xAA); |
170 | 82.9k | __m128 g43 = _mm_shuffle_ps(g4, g4, 0xFF); |
171 | 82.9k | v40 = _mm_add_ps(_mm_mul_ps(g40, w40), v40); |
172 | 82.9k | v41 = _mm_add_ps(_mm_mul_ps(g41, w41), v41); |
173 | 82.9k | v42 = _mm_add_ps(_mm_mul_ps(g42, w42), v42); |
174 | 82.9k | v43 = _mm_add_ps(_mm_mul_ps(g43, w43), v43); |
175 | 82.9k | } |
176 | 1.29k | v40 = _mm_add_ps(v40, v41); |
177 | 1.29k | v42 = _mm_add_ps(v42, v43); |
178 | 1.29k | _mm_stream_ps(hp + j, _mm_add_ps(v40, v42)); |
179 | 1.29k | } parallel_endfor |
180 | 9 | } |
181 | 9 | } |
182 | 9 | return CCV_NNC_EXEC_SUCCESS; |
183 | 9 | } |
184 | | #endif |
185 | | |
186 | | #ifdef HAVE_NEON |
187 | | static int _ccv_nnc_gemm_forw_neon(const ccv_nnc_tensor_view_t* const a, const ccv_nnc_tensor_view_t* const w, const ccv_nnc_tensor_view_t* const bias, ccv_nnc_tensor_view_t* const b) |
188 | | { |
189 | | const int a_nd = ccv_nnc_tensor_nd(a->info.dim); |
190 | | const int* adim = (a_nd == 1) ? a->info.dim : a->info.dim + 1; |
191 | | const int b_nd = ccv_nnc_tensor_nd(b->info.dim); |
192 | | const int* bdim = (b_nd == 1) ? b->info.dim : b->info.dim + 1; |
193 | | const int batch_size = a_nd == 1 ? 1 : ccv_max(1, a->info.dim[0]); |
194 | | assert(batch_size == (b_nd == 1) ? 1 : ccv_max(1, b->info.dim[0])); |
195 | | const int a_batch_inc = CCV_IS_TENSOR_VIEW(a) ? (a_nd == 1 ? adim[0] : a->stride[0]) : adim[0]; |
196 | | const int b_batch_inc = CCV_IS_TENSOR_VIEW(b) ? (b_nd == 1 ? bdim[0] : b->stride[0]) : bdim[0]; |
197 | | const int wstride = CCV_IS_TENSOR_VIEW(w) ? w->stride[0] : w->info.dim[1]; |
198 | | int i; |
199 | | if (bias) |
200 | | { |
201 | | for (i = 0; i < batch_size; i++) |
202 | | { |
203 | | const float* const ap = a->data.f32 + i * a_batch_inc; |
204 | | float* const bp = b->data.f32 + i * b_batch_inc; |
205 | | parallel_for(j, bdim[0]) { |
206 | | const float* const wp = w->data.f32 + j * wstride; |
207 | | int k; |
208 | | float32x4_t v41 = vmovq_n_f32(0); |
209 | | float32x4_t v40 = vld1q_lane_f32(bias->data.f32 + j, v41, 0); |
210 | | for (k = 0; k < adim[0] - 7; k += 8) |
211 | | { |
212 | | float32x4_t ap40 = vld1q_f32(ap + k); |
213 | | float32x4_t ap41 = vld1q_f32(ap + k + 4); |
214 | | float32x4_t w40 = vld1q_f32(wp + k); |
215 | | float32x4_t w41 = vld1q_f32(wp + k + 4); |
216 | | v40 = vmlaq_f32(v40, w40, ap40); |
217 | | v41 = vmlaq_f32(v41, w41, ap41); |
218 | | } |
219 | | v40 = vaddq_f32(v40, v41); |
220 | | float32x2_t v2 = vpadd_f32(vget_high_f32(v40), vget_low_f32(v40)); |
221 | | bp[j] = vget_lane_f32(vpadd_f32(v2, v2), 0); |
222 | | } parallel_endfor |
223 | | } |
224 | | } else { |
225 | | for (i = 0; i < batch_size; i++) |
226 | | { |
227 | | const float* const ap = a->data.f32 + i * a_batch_inc; |
228 | | float* const bp = b->data.f32 + i * b_batch_inc; |
229 | | parallel_for(j, bdim[0]) { |
230 | | const float* const wp = w->data.f32 + j * wstride; |
231 | | int k; |
232 | | float32x4_t v41 = vmovq_n_f32(0); |
233 | | float32x4_t v40 = vmovq_n_f32(0); |
234 | | for (k = 0; k < adim[0] - 7; k += 8) |
235 | | { |
236 | | float32x4_t ap40 = vld1q_f32(ap + k); |
237 | | float32x4_t ap41 = vld1q_f32(ap + k + 4); |
238 | | float32x4_t w40 = vld1q_f32(wp + k); |
239 | | float32x4_t w41 = vld1q_f32(wp + k + 4); |
240 | | v40 = vmlaq_f32(v40, w40, ap40); |
241 | | v41 = vmlaq_f32(v41, w41, ap41); |
242 | | } |
243 | | v40 = vaddq_f32(v40, v41); |
244 | | float32x2_t v2 = vpadd_f32(vget_high_f32(v40), vget_low_f32(v40)); |
245 | | bp[j] = vget_lane_f32(vpadd_f32(v2, v2), 0); |
246 | | } parallel_endfor |
247 | | } |
248 | | } |
249 | | return CCV_NNC_EXEC_SUCCESS; |
250 | | } |
251 | | |
252 | | static int _ccv_nnc_gemm_back_neon(const ccv_nnc_tensor_view_t* const g, const ccv_nnc_tensor_view_t* const a, const ccv_nnc_tensor_view_t* const w, ccv_nnc_tensor_view_t* const dw, ccv_nnc_tensor_view_t* const bias, ccv_nnc_tensor_view_t* const h, const int flags) |
253 | | { |
254 | | const int dwstride = CCV_IS_TENSOR_VIEW(dw) ? dw->stride[0] : dw->info.dim[1]; |
255 | | if (!(flags & CCV_NNC_ACCUMULATE_OUTPUT)) // reset the gradients to 0 |
256 | | { |
257 | | memset(dw->data.u8, 0, sizeof(float) * dwstride * dw->info.dim[0]); |
258 | | if (bias) |
259 | | memset(bias->data.u8, 0, sizeof(float) * bias->info.dim[0]); |
260 | | } |
261 | | const int a_nd = ccv_nnc_tensor_nd(a->info.dim); |
262 | | const int* adim = (a_nd == 1) ? a->info.dim : a->info.dim + 1; |
263 | | const int g_nd = ccv_nnc_tensor_nd(g->info.dim); |
264 | | const int* gdim = (g_nd == 1) ? g->info.dim : g->info.dim + 1; |
265 | | const int batch_size = a_nd == 1 ? 1 : ccv_max(1, a->info.dim[0]); |
266 | | int i, j; |
267 | | float* gp = g->data.f32; |
268 | | const int g_batch_inc = CCV_IS_TENSOR_VIEW(g) ? ((g_nd == 1) ? gdim[0] : g->stride[0]) : gdim[0]; |
269 | | if (bias) |
270 | | { |
271 | | float* bp = bias->data.f32; |
272 | | for (i = 0; i < batch_size; i++) |
273 | | { |
274 | | for (j = 0; j < gdim[0] - 3; j += 4) |
275 | | { |
276 | | float32x4_t g4 = vld1q_f32(gp + j); |
277 | | float32x4_t b4 = vld1q_f32(bp + j); |
278 | | vst1q_f32(bp + j, vaddq_f32(b4, g4)); |
279 | | } |
280 | | gp += g_batch_inc; |
281 | | } |
282 | | } |
283 | | const int a_batch_inc = CCV_IS_TENSOR_VIEW(a) ? ((a_nd == 1) ? adim[0] : a->stride[0]) : adim[0]; |
284 | | for (i = 0; i < batch_size; i++) |
285 | | { |
286 | | const float* const gp = g->data.f32 + i * g_batch_inc; |
287 | | const float* const ap = a->data.f32 + i * a_batch_inc; |
288 | | parallel_for(j, gdim[0]) { |
289 | | float* const dwp = dw->data.f32 + j * dwstride; |
290 | | float32x4_t g4 = vld1q_dup_f32(gp + j); |
291 | | int k; |
292 | | for (k = 0; k < adim[0] - 3; k+= 4) |
293 | | { |
294 | | float32x4_t a4 = vld1q_f32(ap + k); |
295 | | float32x4_t dw4 = vld1q_f32(dwp + k); |
296 | | vst1q_f32(dwp + k, vmlaq_f32(dw4, a4, g4)); |
297 | | } |
298 | | } parallel_endfor |
299 | | } |
300 | | if (h && w) |
301 | | { |
302 | | const int h_nd = ccv_nnc_tensor_nd(h->info.dim); |
303 | | const int* hdim = (h_nd == 1) ? h->info.dim : h->info.dim + 1; |
304 | | const int h_batch_inc = CCV_IS_TENSOR_VIEW(h) ? ((h_nd == 1) ? hdim[0] : h->stride[0]) : hdim[0]; |
305 | | const int wstride = CCV_IS_TENSOR_VIEW(w) ? w->stride[0] : w->info.dim[1]; |
306 | | for (i = 0; i < batch_size; i++) |
307 | | { |
308 | | const float* const gp = g->data.f32 + i * g_batch_inc; |
309 | | float* const hp = h->data.f32 + i * h_batch_inc; |
310 | | parallel_for(y, hdim[0] / 4) { |
311 | | const int j = y * 4; |
312 | | const float* const wp = w->data.f32 + j; |
313 | | float32x4_t v40 = vmovq_n_f32(0); |
314 | | float32x4_t v41 = vmovq_n_f32(0); |
315 | | float32x4_t v42 = vmovq_n_f32(0); |
316 | | float32x4_t v43 = vmovq_n_f32(0); |
317 | | int k; |
318 | | for (k = 0; k < gdim[0]; k += 4) |
319 | | { |
320 | | float32x2x2_t g4 = vld2_f32(gp + k); |
321 | | float32x4_t w40 = vld1q_f32(wp + k * wstride); |
322 | | float32x4_t w41 = vld1q_f32(wp + (k + 1) * wstride); |
323 | | float32x4_t w42 = vld1q_f32(wp + (k + 2) * wstride); |
324 | | float32x4_t w43 = vld1q_f32(wp + (k + 3) * wstride); |
325 | | float32x4_t g40 = vdupq_lane_f32(g4.val[0], 0); |
326 | | float32x4_t g41 = vdupq_lane_f32(g4.val[1], 0); |
327 | | float32x4_t g42 = vdupq_lane_f32(g4.val[0], 1); |
328 | | float32x4_t g43 = vdupq_lane_f32(g4.val[1], 1); |
329 | | v40 = vmlaq_f32(v40, g40, w40); |
330 | | v41 = vmlaq_f32(v41, g41, w41); |
331 | | v42 = vmlaq_f32(v42, g42, w42); |
332 | | v43 = vmlaq_f32(v43, g43, w43); |
333 | | } |
334 | | v40 = vaddq_f32(v40, v41); |
335 | | v42 = vaddq_f32(v42, v43); |
336 | | vst1q_f32(hp + j, vaddq_f32(v40, v42)); |
337 | | } parallel_endfor |
338 | | } |
339 | | } |
340 | | return CCV_NNC_EXEC_SUCCESS; |
341 | | } |
342 | | #endif |
343 | | |
344 | | int _ccv_nnc_gemm_forw_cpu_opt(const ccv_nnc_tensor_view_t* const a, const ccv_nnc_tensor_view_t* const w, const ccv_nnc_tensor_view_t* const bias, ccv_nnc_tensor_view_t* const b) |
345 | 286 | { |
346 | 286 | #if defined(HAVE_SSE2) || defined(HAVE_NEON) |
347 | 286 | const int a_nd = ccv_nnc_tensor_nd(a->info.dim); |
348 | 286 | const int adim = (a_nd == 1) ? a->info.dim[0]214 : a->info.dim[1]72 ; |
349 | 286 | #endif |
350 | 286 | #if defined(HAVE_SSE2) |
351 | 286 | if (adim % 8 == 0) |
352 | 70 | return _ccv_nnc_gemm_forw_sse2(a, w, bias, b); |
353 | | #elif defined(HAVE_NEON) |
354 | | if (adim % 8 == 0) |
355 | | return _ccv_nnc_gemm_forw_neon(a, w, bias, b); |
356 | | #endif |
357 | 216 | return CCV_NNC_EXEC_INVALID; |
358 | 286 | } |
359 | | |
360 | | int _ccv_nnc_gemm_back_cpu_opt(const ccv_nnc_tensor_view_t* const g, const ccv_nnc_tensor_view_t* const a, const ccv_nnc_tensor_view_t* const w, ccv_nnc_tensor_view_t* const dw, ccv_nnc_tensor_view_t* const bias, ccv_nnc_tensor_view_t* const h, const int flags) |
361 | 93 | { |
362 | 93 | #if defined(HAVE_SSE2) || defined(HAVE_NEON) |
363 | 93 | const int a_nd = ccv_nnc_tensor_nd(a->info.dim); |
364 | 93 | const int adim = (a_nd == 1) ? a->info.dim[0]48 : a->info.dim[1]45 ; |
365 | 93 | const int g_nd = ccv_nnc_tensor_nd(g->info.dim); |
366 | 93 | const int gdim = (g_nd == 1) ? g->info.dim[0]48 : g->info.dim[1]45 ; |
367 | 93 | const int h_nd = h ? ccv_nnc_tensor_nd(h->info.dim)66 : 027 ; |
368 | 93 | const int hdim = h ? (66 (h_nd == 1)66 ? h->info.dim[0]33 : h->info.dim[1]33 ) : 027 ; |
369 | 93 | #endif |
370 | 93 | #if defined(HAVE_SSE2) |
371 | 93 | if (gdim % 4 == 0 && adim % 4 == 09 && (9 !h9 || hdim % 4 == 09 )) |
372 | 9 | return _ccv_nnc_gemm_back_sse2(g, a, w, dw, bias, h, flags); |
373 | | #elif defined(HAVE_NEON) |
374 | | if (gdim % 4 == 0 && adim % 4 == 0 && (!h || hdim % 4 == 0)) |
375 | | return _ccv_nnc_gemm_back_neon(g, a, w, dw, bias, h, flags); |
376 | | #endif |
377 | 84 | return CCV_NNC_EXEC_INVALID; |
378 | 93 | } |