/home/liu/actions-runner/_work/ccv/ccv/lib/nnc/cmd/convolution/cpu_opt/_ccv_nnc_conv_cpu_4x4_3x3_winograd.c
Line | Count | Source |
1 | | #include "ccv.h" |
2 | | #include "ccv_internal.h" |
3 | | #include "nnc/ccv_nnc.h" |
4 | | #include "nnc/ccv_nnc_easy.h" |
5 | | #include "nnc/ccv_nnc_internal.h" |
6 | | #if defined(HAVE_SSE2) |
7 | | #include <xmmintrin.h> |
8 | | #elif defined(HAVE_NEON) |
9 | | #include <arm_neon.h> |
10 | | #endif |
11 | | #ifdef USE_OPENMP |
12 | | #include <omp.h> |
13 | | #endif |
14 | | #ifdef USE_DISPATCH |
15 | | #include <dispatch/dispatch.h> |
16 | | #endif |
17 | | #include "../_ccv_nnc_conv_cpu_opt.h" |
18 | | |
19 | | #define set_n_m_dim(i, x, wd, ad) \ |
20 | 66.0k | do { \ |
21 | 66.0k | n[x] = ccv_max((i) * hint.stride.dim[x] - hint.border.begin[x], 0) - ((i) * hint.stride.dim[x] - hint.border.begin[x]); \ |
22 | 66.0k | m[x] = wd[x + 1] - n[x] - ((i) * hint.stride.dim[x] - hint.border.begin[x] + wd[x + 1] - ccv_min(ad[x], (i) * hint.stride.dim[x] - hint.border.begin[x] + wd[x + 1])); \ |
23 | 66.0k | } while (0) |
24 | | |
25 | | inline static void _ccv_nnc_winograd_4x4_3x3_gwtg_ref(const float* const w, const int c, float* gwtg) |
26 | 0 | { |
27 | 0 | int i; |
28 | 0 | for (i = 0; i < c; i++) |
29 | 0 | { |
30 | 0 | float g[18]; |
31 | | /* |
32 | | * a0, b1, c2 |
33 | | * d3, e4, f5 |
34 | | * g6, h7, i8 |
35 | | * {{a/4, b/4, c/4}, |
36 | | * {1/6 (-a - d - g), 1/6 (-b - e - h), 1/6 (-c - f - i)}, |
37 | | * {1/6 (-a + d - g), 1/6 (-b + e - h), 1/6 (-c + f - i)}, |
38 | | * {1/24 (a + 2 d + 4 g), 1/24 (b + 2 e + 4 h), 1/24 (c + 2 f + 4 i)}, |
39 | | * {1/24 (a - 2 d + 4 g), 1/24 (b - 2 e + 4 h), 1/24 (c - 2 f + 4 i)}, |
40 | | * {g, h, i}} |
41 | | */ |
42 | | /* row 1 */ |
43 | 0 | g[0] = w[i] / 4; |
44 | 0 | g[1] = w[c + i] / 4; |
45 | 0 | g[2] = w[2 * c + i] / 4; |
46 | | /* row 2 */ |
47 | 0 | g[3] = -(w[i] + w[3 * c + i] + w[6 * c + i]) / 6; |
48 | 0 | g[4] = -(w[c + i] + w[4 * c + i] + w[7 * c + i]) / 6; |
49 | 0 | g[5] = -(w[2 * c + i] + w[5 * c + i] + w[8 * c + i]) / 6; |
50 | | /* row 3 */ |
51 | 0 | g[6] = (-w[i] + w[3 * c + i] - w[6 * c + i]) / 6; |
52 | 0 | g[7] = (-w[c + i] + w[4 * c + i] - w[7 * c + i]) / 6; |
53 | 0 | g[8] = (-w[2 * c + i] + w[5 * c + i] - w[8 * c + i]) / 6; |
54 | | /* row 4 */ |
55 | 0 | g[9] = (w[i] + 2 * w[3 * c + i] + 4 * w[6 * c + i]) / 24; |
56 | 0 | g[10] = (w[c + i] + 2 * w[4 * c + i] + 4 * w[7 * c + i]) / 24; |
57 | 0 | g[11] = (w[2 * c + i] + 2 * w[5 * c + i] + 4 * w[8 * c + i]) / 24; |
58 | | /* row 5 */ |
59 | 0 | g[12] = (w[i] - 2 * w[3 * c + i] + 4 * w[6 * c + i]) / 24; |
60 | 0 | g[13] = (w[c + i] - 2 * w[4 * c + i] + 4 * w[7 * c + i]) / 24; |
61 | 0 | g[14] = (w[2 * c + i] - 2 * w[5 * c + i] + 4 * w[8 * c + i]) / 24; |
62 | | /* row 6 */ |
63 | 0 | g[15] = w[6 * c + i]; |
64 | 0 | g[16] = w[7 * c + i]; |
65 | 0 | g[17] = w[8 * c + i]; |
66 | | /* |
67 | | * a0, b1, c2 |
68 | | * d3, e4, f5 |
69 | | * g6, h7, i8 |
70 | | * j9, k10,l11 |
71 | | * m12,n13,o14 |
72 | | * p15,q16,r17 |
73 | | * {{a/4, 1/6 (-a - b - c), 1/6 (-a + b - c), 1/24 (a + 2 b + 4 c), 1/24 (a - 2 b + 4 c), c}, |
74 | | * {d/4, 1/6 (-d - e - f), 1/6 (-d + e - f), 1/24 (d + 2 e + 4 f), 1/24 (d - 2 e + 4 f), f}, |
75 | | * {g/4, 1/6 (-g - h - i), 1/6 (-g + h - i), 1/24 (g + 2 h + 4 i), 1/24 (g - 2 h + 4 i), i}, |
76 | | * {j/4, 1/6 (-j - k - l), 1/6 (-j + k - l), 1/24 (j + 2 k + 4 l), 1/24 (j - 2 k + 4 l), l}, |
77 | | * {m/4, 1/6 (-m - n - o), 1/6 (-m + n - o), 1/24 (m + 2 n + 4 o), 1/24 (m - 2 n + 4 o), o}, |
78 | | * {p/4, 1/6 (-p - q - r), 1/6 (-p + q - r), 1/24 (p + 2 q + 4 r), 1/24 (p - 2 q + 4 r), r}} |
79 | | */ |
80 | | /* row 1 */ |
81 | 0 | gwtg[0] = g[0] / 4; |
82 | 0 | gwtg[c] = -(g[0] + g[1] + g[2]) / 6; |
83 | 0 | gwtg[2 * c] = (-g[0] + g[1] - g[2]) / 6; |
84 | 0 | gwtg[3 * c] = (g[0] + 2 * g[1] + 4 * g[2]) / 24; |
85 | 0 | gwtg[4 * c] = (g[0] - 2 * g[1] + 4 * g[2]) / 24; |
86 | 0 | gwtg[5 * c] = g[2]; |
87 | | /* row 2 */ |
88 | 0 | gwtg[6 * c] = g[3] / 4; |
89 | 0 | gwtg[7 * c] = -(g[3] + g[4] + g[5]) / 6; |
90 | 0 | gwtg[8 * c] = (-g[3] + g[4] - g[5]) / 6; |
91 | 0 | gwtg[9 * c] = (g[3] + 2 * g[4] + 4 * g[5]) / 24; |
92 | 0 | gwtg[10 * c] = (g[3] - 2 * g[4] + 4 * g[5]) / 24; |
93 | 0 | gwtg[11 * c] = g[5]; |
94 | | /* row 3 */ |
95 | 0 | gwtg[12 * c] = g[6] / 4; |
96 | 0 | gwtg[13 * c] = -(g[6] + g[7] + g[8]) / 6; |
97 | 0 | gwtg[14 * c] = (-g[6] + g[7] - g[8]) / 6; |
98 | 0 | gwtg[15 * c] = (g[6] + 2 * g[7] + 4 * g[8]) / 24; |
99 | 0 | gwtg[16 * c] = (g[6] - 2 * g[7] + 4 * g[8]) / 24; |
100 | 0 | gwtg[17 * c] = g[8]; |
101 | | /* row 4 */ |
102 | 0 | gwtg[18 * c] = g[9] / 4; |
103 | 0 | gwtg[19 * c] = -(g[9] + g[10] + g[11]) / 6; |
104 | 0 | gwtg[20 * c] = (-g[9] + g[10] - g[11]) / 6; |
105 | 0 | gwtg[21 * c] = (g[9] + 2 * g[10] + 4 * g[11]) / 24; |
106 | 0 | gwtg[22 * c] = (g[9] - 2 * g[10] + 4 * g[11]) / 24; |
107 | 0 | gwtg[23 * c] = g[11]; |
108 | | /* row 5 */ |
109 | 0 | gwtg[24 * c] = g[12] / 4; |
110 | 0 | gwtg[25 * c] = -(g[12] + g[13] + g[14]) / 6; |
111 | 0 | gwtg[26 * c] = (-g[12] + g[13] - g[14]) / 6; |
112 | 0 | gwtg[27 * c] = (g[12] + 2 * g[13] + 4 * g[14]) / 24; |
113 | 0 | gwtg[28 * c] = (g[12] - 2 * g[13] + 4 * g[14]) / 24; |
114 | 0 | gwtg[29 * c] = g[14]; |
115 | | /* row 6 */ |
116 | 0 | gwtg[30 * c] = g[15] / 4; |
117 | 0 | gwtg[31 * c] = -(g[15] + g[16] + g[17]) / 6; |
118 | 0 | gwtg[32 * c] = (-g[15] + g[16] - g[17]) / 6; |
119 | 0 | gwtg[33 * c] = (g[15] + 2 * g[16] + 4 * g[17]) / 24; |
120 | 0 | gwtg[34 * c] = (g[15] - 2 * g[16] + 4 * g[17]) / 24; |
121 | 0 | gwtg[35 * c] = g[17]; |
122 | 0 | ++gwtg; |
123 | 0 | } |
124 | 0 | } |
125 | | |
126 | | static int _ccv_nnc_conv_forw_4x4_3x3_winograd_ref(const ccv_nnc_tensor_view_t* const a, const ccv_nnc_tensor_t* const w, const ccv_nnc_tensor_t* const bias, const ccv_nnc_hint_t hint, ccv_nnc_tensor_view_t* const b, ccv_nnc_stream_context_t* const stream_context) |
127 | 0 | { |
128 | 0 | const int a_nd = ccv_nnc_tensor_nd(a->info.dim); |
129 | 0 | assert(a_nd == CCV_NNC_MAX_DIM + 1 || a_nd == CCV_NNC_MAX_DIM + 2); |
130 | 0 | const int* adim = (a_nd == CCV_NNC_MAX_DIM + 1) ? a->info.dim : a->info.dim + 1; |
131 | 0 | const int b_nd = ccv_nnc_tensor_nd(b->info.dim); |
132 | 0 | assert(b_nd == CCV_NNC_MAX_DIM + 1 || b_nd == CCV_NNC_MAX_DIM + 2); |
133 | 0 | const int* bdim = (b_nd == CCV_NNC_MAX_DIM + 1) ? b->info.dim : b->info.dim + 1; |
134 | 0 | int astride[CCV_NNC_MAX_DIM_ALLOC]; |
135 | 0 | ccv_nnc_tensor_view_get_stride(a, astride); |
136 | 0 | int bstride[CCV_NNC_MAX_DIM_ALLOC]; |
137 | 0 | ccv_nnc_tensor_view_get_stride(b, bstride); |
138 | 0 | assert(hint.border.begin[0] <= 1); |
139 | 0 | assert(hint.border.begin[1] <= 1); |
140 | 0 | assert(w->info.dim[1] == 3); |
141 | 0 | assert(w->info.dim[2] == 3); |
142 | 0 | const int jump_dim = (bdim[0] + 3) / 4; |
143 | 0 | float* workmem; |
144 | | // allocating workspace memory for kernel reshaping and input reshaping. |
145 | | #if FOR_IS_PARALLEL |
146 | | // If we do parallel for, we need to allocate input reshaping for each block. |
147 | | workmem = ccv_nnc_stream_context_get_workspace(stream_context, sizeof(float) * (36 * adim[2] * jump_dim + 36 * w->info.dim[0] * w->info.dim[3]), CCV_TENSOR_CPU_MEMORY); |
148 | | #else |
149 | | // Otherwise, just one block. |
150 | 0 | workmem = ccv_nnc_stream_context_get_workspace(stream_context, sizeof(float) * (36 * adim[2] + 36 * w->info.dim[0] * w->info.dim[3]), CCV_TENSOR_CPU_MEMORY); |
151 | 0 | #endif |
152 | 0 | if (!workmem) |
153 | 0 | return CCV_NNC_EXEC_OOM; |
154 | | // Convert w to a 6x6 matrix, by computing G.w.T(G) // T for transpose. |
155 | 0 | float* const gwtg = workmem; |
156 | 0 | float* const btdb = workmem + 36 * w->info.dim[0] * w->info.dim[3]; |
157 | 0 | parallel_for(k, w->info.dim[0]) { |
158 | 0 | _ccv_nnc_winograd_4x4_3x3_gwtg_ref(w->data.f32 + k * w->info.dim[3] * w->info.dim[2] * w->info.dim[1], w->info.dim[3], gwtg + k * 36 * w->info.dim[3]); |
159 | 0 | } parallel_endfor |
160 | | // kernel weight for one dim. |
161 | | // Workaround issues of dispatch_apply (cannot reference to on-stack array) |
162 | 0 | const int tile_dim_s[CCV_NNC_MAX_DIM_ALLOC] = { |
163 | 0 | w->info.dim[0], 6, 6, w->info.dim[3] |
164 | 0 | }; |
165 | 0 | const int* const tile_dim = tile_dim_s; |
166 | | // This block will be cause in each for-loop, therefore, you can use it to generate some temporary variables. |
167 | 0 | if (bias) |
168 | 0 | { |
169 | 0 | const float* const biasval = bias->data.f32; |
170 | 0 | parallel_for(i, jump_dim) { |
171 | 0 | const int y = i * 4; // i is unsigned. |
172 | 0 | int j, x, k, c; |
173 | 0 | int n[CCV_NNC_MAX_DIM]; |
174 | 0 | int m[CCV_NNC_MAX_DIM]; |
175 | 0 | int z[CCV_NNC_MAX_DIM]; |
176 | 0 | set_n_m_dim(y, 0, tile_dim, adim); |
177 | 0 | z[0] = ccv_min(y + 4, bdim[0]) - y; |
178 | 0 | const float* ap = a->data.f32 + ccv_max(y - hint.border.begin[0], 0) * astride[1]; |
179 | 0 | float* bp = b->data.f32 + y * bstride[1]; |
180 | 0 | for (x = 0; x < bdim[1]; x += 4) |
181 | 0 | { |
182 | 0 | set_n_m_dim(x, 1, tile_dim, adim); |
183 | 0 | z[1] = ccv_min(x + 4, bdim[1]) - x; |
184 | | #if FOR_IS_PARALLEL |
185 | | float* g = btdb + i * 36 * adim[2]; |
186 | | #else |
187 | 0 | float* g = btdb; |
188 | 0 | #endif |
189 | | // zero g such that we can have zero-padding. |
190 | 0 | memset(g, 0, sizeof(float) * 36 * adim[2]); |
191 | 0 | int dx, dy; |
192 | 0 | const float* apz = ap + ccv_max(x - hint.border.begin[1], 0) * astride[2]; |
193 | 0 | float* gz = g + (n[0] * 6 + n[1]) * adim[2]; |
194 | 0 | unroll_for(dy, m[0], 6) { |
195 | 0 | unroll_for(dx, m[1], 6) { |
196 | 0 | float* const gzu = gz + (dy * 6 + dx) * adim[2]; |
197 | 0 | for (c = 0; c < adim[2]; c++) |
198 | 0 | gzu[c] = apz[dx * astride[2] + c]; |
199 | 0 | } unroll_endfor |
200 | 0 | apz += astride[1]; |
201 | 0 | } unroll_endfor |
202 | 0 | for (c = 0; c < adim[2]; c++) |
203 | 0 | { |
204 | | /* |
205 | | * a0, a1, a2, a3, a4, a5, |
206 | | * b6, b7, b8, b9, b10,l11, |
207 | | * c12,c13,c14,c15,c16,c17, |
208 | | * d18,d19,d20,d21,d22,d23, |
209 | | * e24,e25,e26,e27,e28,e29, |
210 | | * f30,f31,f32,f33,f34,f35 |
211 | | * {{4 a0 - 5 c12 + e24, 4 a1 - 5 c13 + e25, 4 a2 - 5 c14 + e26, 4 a3 - 5 c15 + e27, 4 a4 - 5 c16 + e28, 4 a5 - 5 c17 + e29}, |
212 | | * {-4 b6 - 4 c12 + d18 + e24, -4 b7 - 4 c13 + d19 + e25, -4 b8 - 4 c14 + d20 + e26, -4 b9 - 4 c15 + d21 + e27, -4 b10 - 4 c16 + d22 + e28, -4 b11 - 4 c17 + d23 + e29}, |
213 | | * {4 b6 - 4 c12 - d18 + e24, 4 b7 - 4 c13 - d19 + e25, 4 b8 - 4 c14 - d20 + e26, 4 b9 - 4 c15 - d21 + e27, 4 b10 - 4 c16 - d22 + e28, 4 b11 - 4 c17 - d23 + e29}, |
214 | | * {-2 b6 - c12 + 2 d18 + e24, -2 b7 - c13 + 2 d19 + e25, -2 b8 - c14 + 2 d20 + e26, -2 b9 - c15 + 2 d21 + e27, -2 b10 - c16 + 2 d22 + e28, -2 b11 - c17 + 2 d23 + e29}, |
215 | | * {2 b6 - c12 - 2 d18 + e24, 2 b7 - c13 - 2 d19 + e25, 2 b8 - c14 - 2 d20 + e26, 2 b9 - c15 - 2 d21 + e27, 2 b10 - c16 - 2 d22 + e28, 2 b11 - c17 - 2 d23 + e29}, |
216 | | * {4 b6 - 5 d18 + f30, 4 b7 - 5 d19 + f31, 4 b8 - 5 d20 + f32, 4 b9 - 5 d21 + f33, 4 b10 - 5 d22 + f34, 4 b11 - 5 d23 + f35}} |
217 | | */ |
218 | 0 | float d[36]; |
219 | | /* BT.d */ |
220 | 0 | unroll_for(j, 6) { |
221 | 0 | float g0 = g[j * adim[2]]; |
222 | 0 | float g12 = g[(12 + j) * adim[2]]; |
223 | 0 | float g24 = g[(24 + j) * adim[2]]; |
224 | | /* row 1 */ |
225 | 0 | d[j] = 4 * g0 - 5 * g12 + g24; |
226 | 0 | float g6 = g[(6 + j) * adim[2]]; |
227 | 0 | float g18 = g[(18 + j) * adim[2]]; |
228 | | /* row 2 */ |
229 | 0 | d[6 + j] = -4 * (g6 + g12) + g18 + g24; |
230 | | /* row 3 */ |
231 | 0 | d[12 + j] = 4 * (g6 - g12) - g18 + g24; |
232 | | /* row 4 */ |
233 | 0 | d[18 + j] = 2 * (g18 - g6) - g12 + g24; |
234 | | /* row 5 */ |
235 | 0 | d[24 + j] = 2 * (g6 - g18) - g12 + g24; |
236 | 0 | float g30 = g[(30 + j) * adim[2]]; |
237 | | /* row 6 */ |
238 | 0 | d[30 + j] = 4 * g6 - 5 * g18 + g30; |
239 | 0 | } unroll_endfor |
240 | | /* |
241 | | * a0, a1, a2, a3, a4, a5, |
242 | | * b6, b7, b8, b9, b10,l11, |
243 | | * c12,c13,c14,c15,c16,c17, |
244 | | * d18,d19,d20,d21,d22,d23, |
245 | | * e24,e25,e26,e27,e28,e29, |
246 | | * f30,f31,f32,f33,f34,f35 |
247 | | * {{4 a0 - 5 a2 + a4, -4 a1 - 4 a2 + a3 + a4, 4 a1 - 4 a2 - a3 + a4, -2 a1 - a2 + 2 a3 + a4, 2 a1 - a2 - 2 a3 + a4, 4 a1 - 5 a3 + a5}, |
248 | | * {b10 + 4 b6 - 5 b8, b10 - 4 b7 - 4 b8 + b9, b10 + 4 b7 - 4 b8 - b9, b10 - 2 b7 - b8 + 2 b9, b10 + 2 b7 - b8 - 2 b9, b11 + 4 b7 - 5 b9}, |
249 | | * {4 c12 - 5 c14 + c16, -4 c13 - 4 c14 + c15 + c16, 4 c13 - 4 c14 - c15 + c16, -2 c13 - c14 + 2 c15 + c16, 2 c13 - c14 - 2 c15 + c16, 4 c13 - 5 c15 + c17}, |
250 | | * {4 d18 - 5 d20 + d22, -4 d19 - 4 d20 + d21 + d22, 4 d19 - 4 d20 - d21 + d22, -2 d19 - d20 + 2 d21 + d22, 2 d19 - d20 - 2 d21 + d22, 4 d19 - 5 d21 + d23}, |
251 | | * {4 e24 - 5 e26 + e28, -4 e25 - 4 e26 + e27 + e28, 4 e25 - 4 e26 - e27 + e28, -2 e25 - e26 + 2 e27 + e28, 2 e25 - e26 - 2 e27 + e28, 4 e25 - 5 e27 + e29}, |
252 | | * {4 f30 - 5 f32 + f34, -4 f31 - 4 f32 + f33 + f34, 4 f31 - 4 f32 - f33 + f34, -2 f31 - f32 + 2 f33 + f34, 2 f31 - f32 - 2 f33 + f34, 4 f31 - 5 f33 + f35}} |
253 | | */ |
254 | | /* BT.d.B */ |
255 | 0 | unroll_for(j, 6) { |
256 | | /* row 1 - 6 */ |
257 | 0 | float* const gz = g + j * 6 * adim[2]; |
258 | 0 | float* const dz = d + j * 6; |
259 | 0 | gz[0] = 4 * dz[0] - 5 * dz[2] + dz[4]; |
260 | 0 | gz[adim[2]] = -4 * (dz[1] + dz[2]) + dz[3] + dz[4]; |
261 | 0 | gz[2 * adim[2]] = 4 * (dz[1] - dz[2]) - dz[3] + dz[4]; |
262 | 0 | gz[3 * adim[2]] = 2 * (dz[3] - dz[1]) - dz[2] + dz[4]; |
263 | 0 | gz[4 * adim[2]] = 2 * (dz[1] - dz[3]) - dz[2] + dz[4]; |
264 | 0 | gz[5 * adim[2]] = 4 * dz[1] - 5 * dz[3] + dz[5]; |
265 | 0 | } unroll_endfor |
266 | | // move to the next channel |
267 | 0 | ++g; |
268 | 0 | } |
269 | 0 | const float* wpz = gwtg; |
270 | 0 | for (k = 0; k < w->info.dim[0]; k++) |
271 | 0 | { |
272 | 0 | float q[36]; |
273 | | #if FOR_IS_PARALLEL |
274 | | g = btdb + i * 36 * adim[2]; |
275 | | #else |
276 | 0 | g = btdb; |
277 | 0 | #endif |
278 | 0 | for (j = 0; j < 36; j++) |
279 | 0 | { |
280 | 0 | float b = 0; |
281 | 0 | for (c = 0; c < adim[2]; c++) |
282 | 0 | b += g[c] * wpz[c]; |
283 | 0 | q[j] = b; |
284 | 0 | g += adim[2]; |
285 | 0 | wpz += adim[2]; |
286 | 0 | } |
287 | | /* |
288 | | * a0, a1, a2, a3, a4, a5, |
289 | | * b6, b7, b8, b9, b10,l11, |
290 | | * c12,c13,c14,c15,c16,c17, |
291 | | * d18,d19,d20,d21,d22,d23, |
292 | | * e24,e25,e26,e27,e28,e29, |
293 | | * f30,f31,f32,f33,f34,f35 |
294 | | * {{a0 + b6 + c12 + d18 + e24, a1 + b7 + c13 + d19 + e25, a2 + b8 + c14 + d20 + e26, a3 + b9 + c15 + d21 + e27, a4 + b10 + c16 + d22 + e28, a5 + b11 + c17 + d23 + e29}, |
295 | | * {b6 - c12 + 2 d18 - 2 e24, b7 - c13 + 2 d19 - 2 e25, b8 - c14 + 2 d20 - 2 e26, b9 - c15 + 2 d21 - 2 e27, b10 - c16 + 2 d22 - 2 e28, b11 - c17 + 2 d23 - 2 e29}, |
296 | | * {b6 + c12 + 4 (d18 + e24), b7 + c13 + 4 (d19 + e25), b8 + c14 + 4 (d20 + e26), b9 + c15 + 4 (d21 + e27), b10 + c16 + 4 (d22 + e28), b11 + c17 + 4 (d23 + e29)}, |
297 | | * {b6 - c12 + 8 d18 - 8 e24 + f30, b7 - c13 + 8 d19 - 8 e25 + f31, b8 - c14 + 8 d20 - 8 e26 + f32, b9 - c15 + 8 d21 - 8 e27 + f33, b10 - c16 + 8 d22 - 8 e28 + f34, b11 - c17 + 8 d23 - 8 e29 + f35}} |
298 | | */ |
299 | 0 | float d[24]; |
300 | | /* row 1 */ |
301 | 0 | d[0] = q[0] + q[6] + q[12] + q[18] + q[24]; |
302 | 0 | d[1] = q[1] + q[7] + q[13] + q[19] + q[25]; |
303 | 0 | d[2] = q[2] + q[8] + q[14] + q[20] + q[26]; |
304 | 0 | d[3] = q[3] + q[9] + q[15] + q[21] + q[27]; |
305 | 0 | d[4] = q[4] + q[10] + q[16] + q[22] + q[28]; |
306 | 0 | d[5] = q[5] + q[11] + q[17] + q[23] + q[29]; |
307 | | /* row 2 */ |
308 | 0 | d[6] = q[6] - q[12] + 2 * (q[18] - q[24]); |
309 | 0 | d[7] = q[7] - q[13] + 2 * (q[19] - q[25]); |
310 | 0 | d[8] = q[8] - q[14] + 2 * (q[20] - q[26]); |
311 | 0 | d[9] = q[9] - q[15] + 2 * (q[21] - q[27]); |
312 | 0 | d[10] = q[10] - q[16] + 2 * (q[22] - q[28]); |
313 | 0 | d[11] = q[11] - q[17] + 2 * (q[23] - q[29]); |
314 | | /* row 3 */ |
315 | 0 | d[12] = q[6] + q[12] + 4 * (q[18] + q[24]); |
316 | 0 | d[13] = q[7] + q[13] + 4 * (q[19] + q[25]); |
317 | 0 | d[14] = q[8] + q[14] + 4 * (q[20] + q[26]); |
318 | 0 | d[15] = q[9] + q[15] + 4 * (q[21] + q[27]); |
319 | 0 | d[16] = q[10] + q[16] + 4 * (q[22] + q[28]); |
320 | 0 | d[17] = q[11] + q[17] + 4 * (q[23] + q[29]); |
321 | | /* row 4 */ |
322 | 0 | d[18] = q[6] - q[12] + 8 * (q[18] - q[24]) + q[30]; |
323 | 0 | d[19] = q[7] - q[13] + 8 * (q[19] - q[25]) + q[31]; |
324 | 0 | d[20] = q[8] - q[14] + 8 * (q[20] - q[26]) + q[32]; |
325 | 0 | d[21] = q[9] - q[15] + 8 * (q[21] - q[27]) + q[33]; |
326 | 0 | d[22] = q[10] - q[16] + 8 * (q[22] - q[28]) + q[34]; |
327 | 0 | d[23] = q[11] - q[17] + 8 * (q[23] - q[29]) + q[35]; |
328 | | /* |
329 | | * {{a0 + a1 + a2 + a3 + a4, a1 - a2 + 2 a3 - 2 a4, a1 + a2 + 4 (a3 + a4), a1 - a2 + 8 a3 - 8 a4 + a5}, |
330 | | * {b10 + b6 + b7 + b8 + b9, -2 b10 + b7 - b8 + 2 b9, 4 b10 + b7 + b8 + 4 b9, -8 b10 + b11 + b7 - b8 + 8 b9}, |
331 | | * {c12 + c13 + c14 + c15 + c16, c13 - c14 + 2 c15 - 2 c16, c13 + c14 + 4 (c15 + c16), c13 - c14 + 8 c15 - 8 c16 + c17}, |
332 | | * {d18 + d19 + d20 + d21 + d22, d19 - d20 + 2 d21 - 2 d22, d19 + d20 + 4 (d21 + d22), d19 - d20 + 8 d21 - 8 d22 + d23}} |
333 | | */ |
334 | 0 | float* bpz = bp + x * bstride[2] + k; |
335 | 0 | unroll_for(dy, z[0], 4) { |
336 | 0 | float r[] = { |
337 | 0 | d[dy * 6 + 0] + d[dy * 6 + 1] + d[dy * 6 + 2] + d[dy * 6 + 3] + d[dy * 6 + 4] + biasval[k], |
338 | 0 | d[dy * 6 + 1] - d[dy * 6 + 2] + 2 * (d[dy * 6 + 3] - d[dy * 6 + 4]) + biasval[k], |
339 | 0 | d[dy * 6 + 1] + d[dy * 6 + 2] + 4 * (d[dy * 6 + 3] + d[dy * 6 + 4]) + biasval[k], |
340 | 0 | d[dy * 6 + 1] - d[dy * 6 + 2] + 8 * (d[dy * 6 + 3] - d[dy * 6 + 4]) + d[dy * 6 + 5] + biasval[k], |
341 | 0 | }; |
342 | 0 | unroll_for(dx, z[1], 4) { |
343 | 0 | bpz[dx * bstride[2]] = r[dx]; |
344 | 0 | } unroll_endfor |
345 | 0 | bpz += bstride[1]; |
346 | 0 | } unroll_endfor |
347 | 0 | } |
348 | 0 | } |
349 | 0 | } parallel_endfor |
350 | 0 | } else { |
351 | 0 | parallel_for(i, jump_dim) { |
352 | 0 | const int y = i * 4; // i is unsigned. |
353 | 0 | int j, x, k, c; |
354 | 0 | int n[CCV_NNC_MAX_DIM]; |
355 | 0 | int m[CCV_NNC_MAX_DIM]; |
356 | 0 | int z[CCV_NNC_MAX_DIM]; |
357 | 0 | set_n_m_dim(y, 0, tile_dim, adim); |
358 | 0 | z[0] = ccv_min(y + 4, bdim[0]) - y; |
359 | 0 | const float* ap = a->data.f32 + ccv_max(y - hint.border.begin[0], 0) * astride[1]; |
360 | 0 | float* bp = b->data.f32 + y * bstride[1]; |
361 | 0 | for (x = 0; x < bdim[1]; x += 4) |
362 | 0 | { |
363 | 0 | set_n_m_dim(x, 1, tile_dim, adim); |
364 | 0 | z[1] = ccv_min(x + 4, bdim[1]) - x; |
365 | | #if FOR_IS_PARALLEL |
366 | | float* g = btdb + i * 36 * adim[2]; |
367 | | #else |
368 | 0 | float* g = btdb; |
369 | 0 | #endif |
370 | | // zero g such that we can have zero-padding. |
371 | 0 | memset(g, 0, sizeof(float) * 36 * adim[2]); |
372 | 0 | int dx, dy; |
373 | 0 | const float* apz = ap + ccv_max(x - hint.border.begin[1], 0) * astride[2]; |
374 | 0 | float* gz = g + (n[0] * 6 + n[1]) * adim[2]; |
375 | 0 | unroll_for(dy, m[0], 6) { |
376 | 0 | unroll_for(dx, m[1], 6) { |
377 | 0 | float* const gzu = gz + (dy * 6 + dx) * adim[2]; |
378 | 0 | for (c = 0; c < adim[2]; c++) |
379 | 0 | gzu[c] = apz[dx * astride[2] + c]; |
380 | 0 | } unroll_endfor |
381 | 0 | apz += astride[1]; |
382 | 0 | } unroll_endfor |
383 | 0 | for (c = 0; c < adim[2]; c++) |
384 | 0 | { |
385 | | /* |
386 | | * a0, a1, a2, a3, a4, a5, |
387 | | * b6, b7, b8, b9, b10,l11, |
388 | | * c12,c13,c14,c15,c16,c17, |
389 | | * d18,d19,d20,d21,d22,d23, |
390 | | * e24,e25,e26,e27,e28,e29, |
391 | | * f30,f31,f32,f33,f34,f35 |
392 | | * {{4 a0 - 5 c12 + e24, 4 a1 - 5 c13 + e25, 4 a2 - 5 c14 + e26, 4 a3 - 5 c15 + e27, 4 a4 - 5 c16 + e28, 4 a5 - 5 c17 + e29}, |
393 | | * {-4 b6 - 4 c12 + d18 + e24, -4 b7 - 4 c13 + d19 + e25, -4 b8 - 4 c14 + d20 + e26, -4 b9 - 4 c15 + d21 + e27, -4 b10 - 4 c16 + d22 + e28, -4 b11 - 4 c17 + d23 + e29}, |
394 | | * {4 b6 - 4 c12 - d18 + e24, 4 b7 - 4 c13 - d19 + e25, 4 b8 - 4 c14 - d20 + e26, 4 b9 - 4 c15 - d21 + e27, 4 b10 - 4 c16 - d22 + e28, 4 b11 - 4 c17 - d23 + e29}, |
395 | | * {-2 b6 - c12 + 2 d18 + e24, -2 b7 - c13 + 2 d19 + e25, -2 b8 - c14 + 2 d20 + e26, -2 b9 - c15 + 2 d21 + e27, -2 b10 - c16 + 2 d22 + e28, -2 b11 - c17 + 2 d23 + e29}, |
396 | | * {2 b6 - c12 - 2 d18 + e24, 2 b7 - c13 - 2 d19 + e25, 2 b8 - c14 - 2 d20 + e26, 2 b9 - c15 - 2 d21 + e27, 2 b10 - c16 - 2 d22 + e28, 2 b11 - c17 - 2 d23 + e29}, |
397 | | * {4 b6 - 5 d18 + f30, 4 b7 - 5 d19 + f31, 4 b8 - 5 d20 + f32, 4 b9 - 5 d21 + f33, 4 b10 - 5 d22 + f34, 4 b11 - 5 d23 + f35}} |
398 | | */ |
399 | 0 | float d[36]; |
400 | | /* BT.d */ |
401 | 0 | unroll_for(j, 6) { |
402 | 0 | float g0 = g[j * adim[2]]; |
403 | 0 | float g12 = g[(12 + j) * adim[2]]; |
404 | 0 | float g24 = g[(24 + j) * adim[2]]; |
405 | | /* row 1 */ |
406 | 0 | d[j] = 4 * g0 - 5 * g12 + g24; |
407 | 0 | float g6 = g[(6 + j) * adim[2]]; |
408 | 0 | float g18 = g[(18 + j) * adim[2]]; |
409 | | /* row 2 */ |
410 | 0 | d[6 + j] = -4 * (g6 + g12) + g18 + g24; |
411 | | /* row 3 */ |
412 | 0 | d[12 + j] = 4 * (g6 - g12) - g18 + g24; |
413 | | /* row 4 */ |
414 | 0 | d[18 + j] = 2 * (g18 - g6) - g12 + g24; |
415 | | /* row 5 */ |
416 | 0 | d[24 + j] = 2 * (g6 - g18) - g12 + g24; |
417 | 0 | float g30 = g[(30 + j) * adim[2]]; |
418 | | /* row 6 */ |
419 | 0 | d[30 + j] = 4 * g6 - 5 * g18 + g30; |
420 | 0 | } unroll_endfor |
421 | | /* |
422 | | * a0, a1, a2, a3, a4, a5, |
423 | | * b6, b7, b8, b9, b10,l11, |
424 | | * c12,c13,c14,c15,c16,c17, |
425 | | * d18,d19,d20,d21,d22,d23, |
426 | | * e24,e25,e26,e27,e28,e29, |
427 | | * f30,f31,f32,f33,f34,f35 |
428 | | * {{4 a0 - 5 a2 + a4, -4 a1 - 4 a2 + a3 + a4, 4 a1 - 4 a2 - a3 + a4, -2 a1 - a2 + 2 a3 + a4, 2 a1 - a2 - 2 a3 + a4, 4 a1 - 5 a3 + a5}, |
429 | | * {b10 + 4 b6 - 5 b8, b10 - 4 b7 - 4 b8 + b9, b10 + 4 b7 - 4 b8 - b9, b10 - 2 b7 - b8 + 2 b9, b10 + 2 b7 - b8 - 2 b9, b11 + 4 b7 - 5 b9}, |
430 | | * {4 c12 - 5 c14 + c16, -4 c13 - 4 c14 + c15 + c16, 4 c13 - 4 c14 - c15 + c16, -2 c13 - c14 + 2 c15 + c16, 2 c13 - c14 - 2 c15 + c16, 4 c13 - 5 c15 + c17}, |
431 | | * {4 d18 - 5 d20 + d22, -4 d19 - 4 d20 + d21 + d22, 4 d19 - 4 d20 - d21 + d22, -2 d19 - d20 + 2 d21 + d22, 2 d19 - d20 - 2 d21 + d22, 4 d19 - 5 d21 + d23}, |
432 | | * {4 e24 - 5 e26 + e28, -4 e25 - 4 e26 + e27 + e28, 4 e25 - 4 e26 - e27 + e28, -2 e25 - e26 + 2 e27 + e28, 2 e25 - e26 - 2 e27 + e28, 4 e25 - 5 e27 + e29}, |
433 | | * {4 f30 - 5 f32 + f34, -4 f31 - 4 f32 + f33 + f34, 4 f31 - 4 f32 - f33 + f34, -2 f31 - f32 + 2 f33 + f34, 2 f31 - f32 - 2 f33 + f34, 4 f31 - 5 f33 + f35}} |
434 | | */ |
435 | | /* BT.d.B */ |
436 | 0 | unroll_for(j, 6) { |
437 | | /* row 1 - 6 */ |
438 | 0 | float* const gz = g + j * 6 * adim[2]; |
439 | 0 | float* const dz = d + j * 6; |
440 | 0 | gz[0] = 4 * dz[0] - 5 * dz[2] + dz[4]; |
441 | 0 | gz[adim[2]] = -4 * (dz[1] + dz[2]) + dz[3] + dz[4]; |
442 | 0 | gz[2 * adim[2]] = 4 * (dz[1] - dz[2]) - dz[3] + dz[4]; |
443 | 0 | gz[3 * adim[2]] = 2 * (dz[3] - dz[1]) - dz[2] + dz[4]; |
444 | 0 | gz[4 * adim[2]] = 2 * (dz[1] - dz[3]) - dz[2] + dz[4]; |
445 | 0 | gz[5 * adim[2]] = 4 * dz[1] - 5 * dz[3] + dz[5]; |
446 | 0 | } unroll_endfor |
447 | | // move to the next channel |
448 | 0 | ++g; |
449 | 0 | } |
450 | 0 | const float* wpz = gwtg; |
451 | 0 | for (k = 0; k < w->info.dim[0]; k++) |
452 | 0 | { |
453 | 0 | float q[36]; |
454 | | #if FOR_IS_PARALLEL |
455 | | g = btdb + i * 36 * adim[2]; |
456 | | #else |
457 | 0 | g = btdb; |
458 | 0 | #endif |
459 | 0 | for (j = 0; j < 36; j++) |
460 | 0 | { |
461 | 0 | float b = 0; |
462 | 0 | for (c = 0; c < adim[2]; c++) |
463 | 0 | b += g[c] * wpz[c]; |
464 | 0 | q[j] = b; |
465 | 0 | g += adim[2]; |
466 | 0 | wpz += adim[2]; |
467 | 0 | } |
468 | | /* |
469 | | * a0, a1, a2, a3, a4, a5, |
470 | | * b6, b7, b8, b9, b10,l11, |
471 | | * c12,c13,c14,c15,c16,c17, |
472 | | * d18,d19,d20,d21,d22,d23, |
473 | | * e24,e25,e26,e27,e28,e29, |
474 | | * f30,f31,f32,f33,f34,f35 |
475 | | * {{a0 + b6 + c12 + d18 + e24, a1 + b7 + c13 + d19 + e25, a2 + b8 + c14 + d20 + e26, a3 + b9 + c15 + d21 + e27, a4 + b10 + c16 + d22 + e28, a5 + b11 + c17 + d23 + e29}, |
476 | | * {b6 - c12 + 2 d18 - 2 e24, b7 - c13 + 2 d19 - 2 e25, b8 - c14 + 2 d20 - 2 e26, b9 - c15 + 2 d21 - 2 e27, b10 - c16 + 2 d22 - 2 e28, b11 - c17 + 2 d23 - 2 e29}, |
477 | | * {b6 + c12 + 4 (d18 + e24), b7 + c13 + 4 (d19 + e25), b8 + c14 + 4 (d20 + e26), b9 + c15 + 4 (d21 + e27), b10 + c16 + 4 (d22 + e28), b11 + c17 + 4 (d23 + e29)}, |
478 | | * {b6 - c12 + 8 d18 - 8 e24 + f30, b7 - c13 + 8 d19 - 8 e25 + f31, b8 - c14 + 8 d20 - 8 e26 + f32, b9 - c15 + 8 d21 - 8 e27 + f33, b10 - c16 + 8 d22 - 8 e28 + f34, b11 - c17 + 8 d23 - 8 e29 + f35}} |
479 | | */ |
480 | 0 | float d[24]; |
481 | | /* row 1 */ |
482 | 0 | d[0] = q[0] + q[6] + q[12] + q[18] + q[24]; |
483 | 0 | d[1] = q[1] + q[7] + q[13] + q[19] + q[25]; |
484 | 0 | d[2] = q[2] + q[8] + q[14] + q[20] + q[26]; |
485 | 0 | d[3] = q[3] + q[9] + q[15] + q[21] + q[27]; |
486 | 0 | d[4] = q[4] + q[10] + q[16] + q[22] + q[28]; |
487 | 0 | d[5] = q[5] + q[11] + q[17] + q[23] + q[29]; |
488 | | /* row 2 */ |
489 | 0 | d[6] = q[6] - q[12] + 2 * (q[18] - q[24]); |
490 | 0 | d[7] = q[7] - q[13] + 2 * (q[19] - q[25]); |
491 | 0 | d[8] = q[8] - q[14] + 2 * (q[20] - q[26]); |
492 | 0 | d[9] = q[9] - q[15] + 2 * (q[21] - q[27]); |
493 | 0 | d[10] = q[10] - q[16] + 2 * (q[22] - q[28]); |
494 | 0 | d[11] = q[11] - q[17] + 2 * (q[23] - q[29]); |
495 | | /* row 3 */ |
496 | 0 | d[12] = q[6] + q[12] + 4 * (q[18] + q[24]); |
497 | 0 | d[13] = q[7] + q[13] + 4 * (q[19] + q[25]); |
498 | 0 | d[14] = q[8] + q[14] + 4 * (q[20] + q[26]); |
499 | 0 | d[15] = q[9] + q[15] + 4 * (q[21] + q[27]); |
500 | 0 | d[16] = q[10] + q[16] + 4 * (q[22] + q[28]); |
501 | 0 | d[17] = q[11] + q[17] + 4 * (q[23] + q[29]); |
502 | | /* row 4 */ |
503 | 0 | d[18] = q[6] - q[12] + 8 * (q[18] - q[24]) + q[30]; |
504 | 0 | d[19] = q[7] - q[13] + 8 * (q[19] - q[25]) + q[31]; |
505 | 0 | d[20] = q[8] - q[14] + 8 * (q[20] - q[26]) + q[32]; |
506 | 0 | d[21] = q[9] - q[15] + 8 * (q[21] - q[27]) + q[33]; |
507 | 0 | d[22] = q[10] - q[16] + 8 * (q[22] - q[28]) + q[34]; |
508 | 0 | d[23] = q[11] - q[17] + 8 * (q[23] - q[29]) + q[35]; |
509 | | /* |
510 | | * {{a0 + a1 + a2 + a3 + a4, a1 - a2 + 2 a3 - 2 a4, a1 + a2 + 4 (a3 + a4), a1 - a2 + 8 a3 - 8 a4 + a5}, |
511 | | * {b10 + b6 + b7 + b8 + b9, -2 b10 + b7 - b8 + 2 b9, 4 b10 + b7 + b8 + 4 b9, -8 b10 + b11 + b7 - b8 + 8 b9}, |
512 | | * {c12 + c13 + c14 + c15 + c16, c13 - c14 + 2 c15 - 2 c16, c13 + c14 + 4 (c15 + c16), c13 - c14 + 8 c15 - 8 c16 + c17}, |
513 | | * {d18 + d19 + d20 + d21 + d22, d19 - d20 + 2 d21 - 2 d22, d19 + d20 + 4 (d21 + d22), d19 - d20 + 8 d21 - 8 d22 + d23}} |
514 | | */ |
515 | 0 | float* bpz = bp + x * bstride[2] + k; |
516 | 0 | unroll_for(dy, z[0], 4) { |
517 | 0 | float r[] = { |
518 | 0 | d[dy * 6 + 0] + d[dy * 6 + 1] + d[dy * 6 + 2] + d[dy * 6 + 3] + d[dy * 6 + 4], |
519 | 0 | d[dy * 6 + 1] - d[dy * 6 + 2] + 2 * (d[dy * 6 + 3] - d[dy * 6 + 4]), |
520 | 0 | d[dy * 6 + 1] + d[dy * 6 + 2] + 4 * (d[dy * 6 + 3] + d[dy * 6 + 4]), |
521 | 0 | d[dy * 6 + 1] - d[dy * 6 + 2] + 8 * (d[dy * 6 + 3] - d[dy * 6 + 4]) + d[dy * 6 + 5], |
522 | 0 | }; |
523 | 0 | unroll_for(dx, z[1], 4) { |
524 | 0 | bpz[dx * bstride[2]] = r[dx]; |
525 | 0 | } unroll_endfor |
526 | 0 | bpz += bstride[1]; |
527 | 0 | } unroll_endfor |
528 | 0 | } |
529 | 0 | } |
530 | 0 | } parallel_endfor |
531 | 0 | } |
532 | 0 | return CCV_NNC_EXEC_SUCCESS; |
533 | 0 | } |
534 | | |
535 | | #ifdef HAVE_SSE2 |
536 | | inline static void _ccv_nnc_winograd_4x4_3x3_gwtg_sse2(const float* const w, const int* const dim, float* const gwtg) |
537 | 119 | { |
538 | 119 | const int jump_dim = dim[0] / 4; |
539 | 119 | const int dimCx4 = (dim[3] + 3) & -4; |
540 | 7.56k | parallel_for119 (k, jump_dim) { |
541 | 7.56k | int i, j; |
542 | 7.56k | float* gwtgz = gwtg + k * 4 * 36 * dimCx4; |
543 | 7.56k | const float* wz[] = { |
544 | 7.56k | w + (k * 4) * 9 * dim[3], |
545 | 7.56k | w + (k * 4 + 1) * 9 * dim[3], |
546 | 7.56k | w + (k * 4 + 2) * 9 * dim[3], |
547 | 7.56k | w + (k * 4 + 3) * 9 * dim[3], |
548 | 7.56k | }; |
549 | 2.88M | for (i = 0; i < dim[3]; i++2.87M ) |
550 | 2.87M | { |
551 | 2.87M | float x9w[9 * 4] __attribute__ ((__aligned__(16))); |
552 | 25.8M | unroll_for(j, 9) { |
553 | 25.8M | x9w[j * 4] = wz[0][j * dim[3] + i]; |
554 | 25.8M | x9w[j * 4 + 1] = wz[1][j * dim[3] + i]; |
555 | 25.8M | x9w[j * 4 + 2] = wz[2][j * dim[3] + i]; |
556 | 25.8M | x9w[j * 4 + 3] = wz[3][j * dim[3] + i]; |
557 | 25.8M | } unroll_endfor |
558 | 2.87M | float g[18 * 4] __attribute__ ((__aligned__(16))); |
559 | 2.87M | __m128 x9w0 = _mm_load_ps(x9w); |
560 | 2.87M | __m128 x9w1 = _mm_load_ps(x9w + 4); |
561 | 2.87M | __m128 x9w2 = _mm_load_ps(x9w + 8); |
562 | 2.87M | __m128 x9w3 = _mm_load_ps(x9w + 12); |
563 | 2.87M | __m128 x9w4 = _mm_load_ps(x9w + 16); |
564 | 2.87M | __m128 x9w5 = _mm_load_ps(x9w + 20); |
565 | 2.87M | __m128 x9w6 = _mm_load_ps(x9w + 24); |
566 | 2.87M | __m128 x9w7 = _mm_load_ps(x9w + 28); |
567 | 2.87M | __m128 x9w8 = _mm_load_ps(x9w + 32); |
568 | | /* row 1 */ |
569 | 2.87M | __m128 c1_4 = _mm_set1_ps(1.0 / 4); |
570 | 2.87M | _mm_store_ps(g, _mm_mul_ps(x9w0, c1_4)); |
571 | 2.87M | _mm_store_ps(g + 4, _mm_mul_ps(x9w1, c1_4)); |
572 | 2.87M | _mm_store_ps(g + 8, _mm_mul_ps(x9w2, c1_4)); |
573 | | /* row 2 */ |
574 | 2.87M | __m128 cn1_6 = _mm_set1_ps(-1.0 / 6); |
575 | 2.87M | _mm_store_ps(g + 12, _mm_mul_ps(_mm_add_ps(_mm_add_ps(x9w0, x9w6), x9w3), cn1_6)); |
576 | 2.87M | _mm_store_ps(g + 16, _mm_mul_ps(_mm_add_ps(_mm_add_ps(x9w1, x9w7), x9w4), cn1_6)); |
577 | 2.87M | _mm_store_ps(g + 20, _mm_mul_ps(_mm_add_ps(_mm_add_ps(x9w2, x9w8), x9w5), cn1_6)); |
578 | | /* row 3 */ |
579 | 2.87M | _mm_store_ps(g + 24, _mm_mul_ps(_mm_sub_ps(_mm_add_ps(x9w0, x9w6), x9w3), cn1_6)); |
580 | 2.87M | _mm_store_ps(g + 28, _mm_mul_ps(_mm_sub_ps(_mm_add_ps(x9w1, x9w7), x9w4), cn1_6)); |
581 | 2.87M | _mm_store_ps(g + 32, _mm_mul_ps(_mm_sub_ps(_mm_add_ps(x9w2, x9w8), x9w5), cn1_6)); |
582 | | /* row 6 */ |
583 | 2.87M | _mm_store_ps(g + 60, x9w6); |
584 | 2.87M | _mm_store_ps(g + 64, x9w7); |
585 | 2.87M | _mm_store_ps(g + 68, x9w8); |
586 | | /* w[x] * 2 */ |
587 | 2.87M | x9w3 = _mm_add_ps(x9w3, x9w3); |
588 | 2.87M | x9w4 = _mm_add_ps(x9w4, x9w4); |
589 | 2.87M | x9w5 = _mm_add_ps(x9w5, x9w5); |
590 | | /* w[x] * 4 */ |
591 | 2.87M | x9w6 = _mm_add_ps(x9w6, x9w6); |
592 | 2.87M | x9w6 = _mm_add_ps(x9w6, x9w6); |
593 | 2.87M | x9w7 = _mm_add_ps(x9w7, x9w7); |
594 | 2.87M | x9w7 = _mm_add_ps(x9w7, x9w7); |
595 | 2.87M | x9w8 = _mm_add_ps(x9w8, x9w8); |
596 | 2.87M | x9w8 = _mm_add_ps(x9w8, x9w8); |
597 | | /* row 4 */ |
598 | 2.87M | __m128 c1_24 = _mm_set1_ps(1.0 / 24); |
599 | 2.87M | _mm_store_ps(g + 36, _mm_mul_ps(_mm_add_ps(_mm_add_ps(x9w0, x9w6), x9w3), c1_24)); |
600 | 2.87M | _mm_store_ps(g + 40, _mm_mul_ps(_mm_add_ps(_mm_add_ps(x9w1, x9w7), x9w4), c1_24)); |
601 | 2.87M | _mm_store_ps(g + 44, _mm_mul_ps(_mm_add_ps(_mm_add_ps(x9w2, x9w8), x9w5), c1_24)); |
602 | | /* row 5 */ |
603 | 2.87M | _mm_store_ps(g + 48, _mm_mul_ps(_mm_sub_ps(_mm_add_ps(x9w0, x9w6), x9w3), c1_24)); |
604 | 2.87M | _mm_store_ps(g + 52, _mm_mul_ps(_mm_sub_ps(_mm_add_ps(x9w1, x9w7), x9w4), c1_24)); |
605 | 2.87M | _mm_store_ps(g + 56, _mm_mul_ps(_mm_sub_ps(_mm_add_ps(x9w2, x9w8), x9w5), c1_24)); |
606 | 17.2M | unroll_for(j, 6) { |
607 | 17.2M | const float* const gz = g + j * 12; |
608 | 17.2M | float* const gwtgzu = gwtgz + j * 24 * dimCx4; |
609 | 17.2M | __m128 g0 = _mm_load_ps(gz); |
610 | 17.2M | __m128 g1 = _mm_load_ps(gz + 4); |
611 | 17.2M | __m128 g2 = _mm_load_ps(gz + 8); |
612 | 17.2M | _mm_store_ps(gwtgzu, _mm_mul_ps(g0, c1_4)); |
613 | 17.2M | _mm_store_ps(gwtgzu + 4 * dimCx4, _mm_mul_ps(_mm_add_ps(_mm_add_ps(g0, g2), g1), cn1_6)); |
614 | 17.2M | _mm_store_ps(gwtgzu + 8 * dimCx4, _mm_mul_ps(_mm_sub_ps(_mm_add_ps(g0, g2), g1), cn1_6)); |
615 | 17.2M | _mm_store_ps(gwtgzu + 20 * dimCx4, g2); |
616 | | /* g[1] * 2 */ |
617 | 17.2M | g1 = _mm_add_ps(g1, g1); |
618 | | /* g[2] * 4 */ |
619 | 17.2M | g2 = _mm_add_ps(g2, g2); |
620 | 17.2M | g2 = _mm_add_ps(g2, g2); |
621 | 17.2M | _mm_store_ps(gwtgzu + 12 * dimCx4, _mm_mul_ps(_mm_add_ps(_mm_add_ps(g0, g2), g1), c1_24)); |
622 | 17.2M | _mm_store_ps(gwtgzu + 16 * dimCx4, _mm_mul_ps(_mm_sub_ps(_mm_add_ps(g0, g2), g1), c1_24)); |
623 | 17.2M | } unroll_endfor |
624 | 2.87M | gwtgz += 4; |
625 | 2.87M | } |
626 | 7.56k | } parallel_endfor |
627 | 119 | } |
628 | | |
629 | | static int _ccv_nnc_conv_forw_4x4_3x3_winograd_sse2(const ccv_nnc_tensor_view_t* const a, const ccv_nnc_tensor_t* const w, const ccv_nnc_tensor_t* const bias, const ccv_nnc_hint_t hint, ccv_nnc_tensor_view_t* const b, ccv_nnc_stream_context_t* const stream_context) |
630 | 119 | { |
631 | 119 | const int a_nd = ccv_nnc_tensor_nd(a->info.dim); |
632 | 119 | assert(a_nd == CCV_NNC_MAX_DIM + 1 || a_nd == CCV_NNC_MAX_DIM + 2); |
633 | 119 | const int* adim = (a_nd == CCV_NNC_MAX_DIM + 1) ? a->info.dim : a->info.dim + 10 ; |
634 | 119 | const int b_nd = ccv_nnc_tensor_nd(b->info.dim); |
635 | 119 | assert(b_nd == CCV_NNC_MAX_DIM + 1 || b_nd == CCV_NNC_MAX_DIM + 2); |
636 | 119 | const int* bdim = (b_nd == CCV_NNC_MAX_DIM + 1) ? b->info.dim : b->info.dim + 10 ; |
637 | 119 | int astride[CCV_NNC_MAX_DIM_ALLOC]; |
638 | 119 | ccv_nnc_tensor_view_get_stride(a, astride); |
639 | 119 | int bstride[CCV_NNC_MAX_DIM_ALLOC]; |
640 | 119 | ccv_nnc_tensor_view_get_stride(b, bstride); |
641 | 119 | assert(hint.border.begin[0] <= 1); |
642 | 119 | assert(hint.border.begin[1] <= 1); |
643 | 119 | assert(w->info.dim[0] % 4 == 0); |
644 | 119 | assert(w->info.dim[1] == 3); |
645 | 119 | assert(w->info.dim[2] == 3); |
646 | 119 | const int jump_dim = (bdim[0] + 3) / 4; |
647 | 119 | const int dimCx4 = (adim[2] + 3) & -4; |
648 | | // allocating workspace memory for kernel reshaping and input reshaping. |
649 | 119 | float* workmem = 0; |
650 | | #if FOR_IS_PARALLEL |
651 | | // If we do parallel for, we need to allocate input reshaping for each block. |
652 | | workmem = ccv_nnc_stream_context_get_workspace(stream_context, sizeof(float) * (36 * dimCx4 * jump_dim + 36 * dimCx4 * w->info.dim[0]), CCV_TENSOR_CPU_MEMORY); |
653 | | #else |
654 | | // Otherwise, just one block. |
655 | 119 | workmem = ccv_nnc_stream_context_get_workspace(stream_context, sizeof(float) * (36 * dimCx4 + 36 * dimCx4 * w->info.dim[0]), CCV_TENSOR_CPU_MEMORY); |
656 | 119 | #endif |
657 | 119 | if (!workmem) |
658 | 0 | return CCV_NNC_EXEC_OOM; |
659 | | // Convert w to a 6x6 matrix, by computing G.w.T(G) // T for transpose. |
660 | 119 | float* const gwtg = workmem; |
661 | 119 | float* const btdb = workmem + 36 * dimCx4 * w->info.dim[0]; |
662 | 119 | memset(gwtg, 0, sizeof(float) * 36 * dimCx4 * w->info.dim[0]); |
663 | 119 | _ccv_nnc_winograd_4x4_3x3_gwtg_sse2(w->data.f32, w->info.dim, gwtg); |
664 | | // kernel weight for one dim. |
665 | | // Workaround issues of dispatch_apply (cannot reference to on-stack array) |
666 | 119 | const int tile_dim_s[CCV_NNC_MAX_DIM_ALLOC] = { |
667 | 119 | w->info.dim[0], 6, 6, w->info.dim[3] |
668 | 119 | }; |
669 | 119 | const int* const tile_dim = tile_dim_s; |
670 | 119 | if (bias) |
671 | 118 | { |
672 | 118 | const float* const biasval = bias->data.f32; |
673 | | // This block will be cause in each for-loop, therefore, you can use it to generate some temporary variables. |
674 | 1.83k | parallel_for118 (i, jump_dim) { |
675 | 1.83k | const int y = i * 4; // i is unsigned. |
676 | 1.83k | int j, x, k, c; |
677 | 1.83k | int n[CCV_NNC_MAX_DIM]; |
678 | 1.83k | int m[CCV_NNC_MAX_DIM]; |
679 | 1.83k | int z[CCV_NNC_MAX_DIM]; |
680 | 1.83k | set_n_m_dim(y, 0, tile_dim, adim); |
681 | 1.83k | z[0] = ccv_min(y + 4, bdim[0]) - y; |
682 | 1.83k | const float* ap = a->data.f32 + ccv_max(y - hint.border.begin[0], 0) * astride[1]; |
683 | 1.83k | float* bp = b->data.f32 + y * bstride[1]; |
684 | 65.8k | for (x = 0; x < bdim[1]; x += 463.9k ) |
685 | 63.9k | { |
686 | 63.9k | set_n_m_dim(x, 1, tile_dim, adim); |
687 | 63.9k | z[1] = ccv_min(x + 4, bdim[1]) - x; |
688 | | #if FOR_IS_PARALLEL |
689 | | float* g = btdb + i * 36 * dimCx4; |
690 | | #else |
691 | 63.9k | float* g = btdb; |
692 | 63.9k | #endif |
693 | | // zero g such that we can have zero-padding. |
694 | 63.9k | memset(g, 0, sizeof(float) * 36 * dimCx4); |
695 | 63.9k | int dx, dy; |
696 | 63.9k | const float* apz = ap + ccv_max(x - hint.border.begin[1], 0) * astride[2]; |
697 | 63.9k | float* gz = g + (n[0] * 6 + n[1]) * dimCx4; |
698 | 379k | unroll_for(dy, m[0], 6) { |
699 | 2.24M | unroll_for(dx, m[1], 6) { |
700 | 2.24M | float* const gzu = gz + (dy * 6 + dx) * dimCx4; |
701 | 139M | for (c = 0; c < adim[2]; c++137M ) |
702 | 137M | gzu[c] = apz[dx * astride[2] + c]; |
703 | 2.24M | } unroll_endfor |
704 | 379k | apz += astride[1]; |
705 | 379k | } unroll_endfor |
706 | 1.08M | for (c = 0; c < adim[2]; c += 41.02M ) |
707 | 1.02M | { |
708 | 1.02M | float d[36 * 4] __attribute__ ((__aligned__(16))); |
709 | | /* BT.d */ |
710 | 6.14M | unroll_for(j, 6) { |
711 | | /* row 1 */ |
712 | 6.14M | const float* const gz = g + j * dimCx4; |
713 | 6.14M | float* dz = d + j * 4; |
714 | 6.14M | __m128 g0 = _mm_load_ps(gz); |
715 | 6.14M | __m128 g12 = _mm_load_ps(gz + 12 * dimCx4); |
716 | 6.14M | __m128 g18 = _mm_load_ps(gz + 18 * dimCx4); |
717 | 6.14M | __m128 g24 = _mm_load_ps(gz + 24 * dimCx4); |
718 | 6.14M | g0 = _mm_add_ps(g0, g0); |
719 | 6.14M | g0 = _mm_add_ps(g0, g0); |
720 | 6.14M | __m128 g12x2 = _mm_add_ps(g12, g12); |
721 | 6.14M | g12x2 = _mm_add_ps(g12x2, g12x2); |
722 | 6.14M | g12x2 = _mm_add_ps(g12x2, g12); |
723 | 6.14M | _mm_store_ps(dz, _mm_sub_ps(_mm_add_ps(g0, g24), g12x2)); |
724 | | /* row 2 */ |
725 | 6.14M | __m128 g6 = _mm_load_ps(gz + 6 * dimCx4); |
726 | 6.14M | __m128 g6x12 = _mm_add_ps(g6, g12); |
727 | 6.14M | g6x12 = _mm_add_ps(g6x12, g6x12); |
728 | 6.14M | g6x12 = _mm_add_ps(g6x12, g6x12); |
729 | 6.14M | _mm_store_ps(dz + 24, _mm_sub_ps(_mm_add_ps(g18, g24), g6x12)); |
730 | | /* row 3 */ |
731 | 6.14M | g6x12 = _mm_sub_ps(g6, g12); |
732 | 6.14M | g6x12 = _mm_add_ps(g6x12, g6x12); |
733 | 6.14M | g6x12 = _mm_add_ps(g6x12, g6x12); |
734 | 6.14M | _mm_store_ps(dz + 48, _mm_add_ps(_mm_sub_ps(g24, g18), g6x12)); |
735 | | /* row 4 */ |
736 | 6.14M | __m128 g18x6 = _mm_sub_ps(g18, g6); |
737 | 6.14M | g18x6 = _mm_add_ps(g18x6, g18x6); |
738 | 6.14M | _mm_store_ps(dz + 72, _mm_add_ps(_mm_sub_ps(g24, g12), g18x6)); |
739 | | /* row 5 */ |
740 | 6.14M | _mm_store_ps(dz + 96, _mm_sub_ps(_mm_sub_ps(g24, g12), g18x6)); |
741 | | /* row 6 */ |
742 | 6.14M | __m128 g30 = _mm_load_ps(gz + 30 * dimCx4); |
743 | 6.14M | __m128 g18x2 = _mm_add_ps(g18, g18); |
744 | 6.14M | g18x2 = _mm_add_ps(g18x2, g18x2); |
745 | 6.14M | g18x2 = _mm_add_ps(g18, g18x2); |
746 | 6.14M | g6 = _mm_add_ps(g6, g6); |
747 | 6.14M | g6 = _mm_add_ps(g6, g6); |
748 | 6.14M | _mm_store_ps(dz + 120, _mm_sub_ps(_mm_add_ps(g6, g30), g18x2)); |
749 | 6.14M | } unroll_endfor |
750 | | /* BT.d.B */ |
751 | 6.14M | unroll_for(j, 6) { |
752 | 6.14M | float* gz = g + j * 6 * dimCx4; |
753 | 6.14M | const float* const dz = d + j * 24; |
754 | 6.14M | __m128 d0 = _mm_load_ps(dz); |
755 | 6.14M | __m128 d1 = _mm_load_ps(dz + 4); |
756 | 6.14M | __m128 d2 = _mm_load_ps(dz + 8); |
757 | 6.14M | __m128 d3 = _mm_load_ps(dz + 12); |
758 | 6.14M | __m128 d4 = _mm_load_ps(dz + 16); |
759 | 6.14M | __m128 d5 = _mm_load_ps(dz + 20); |
760 | 6.14M | d0 = _mm_add_ps(d0, d0); |
761 | 6.14M | d0 = _mm_add_ps(d0, d0); |
762 | 6.14M | __m128 d2x5 = _mm_add_ps(d2, d2); |
763 | 6.14M | d2x5 = _mm_add_ps(d2x5, d2x5); |
764 | 6.14M | d2x5 = _mm_add_ps(d2, d2x5); |
765 | 6.14M | _mm_store_ps(gz, _mm_sub_ps(_mm_add_ps(d0, d4), d2x5)); |
766 | 6.14M | __m128 d1x2 = _mm_add_ps(d1, d2); |
767 | 6.14M | d1x2 = _mm_add_ps(d1x2, d1x2); |
768 | 6.14M | d1x2 = _mm_add_ps(d1x2, d1x2); |
769 | 6.14M | _mm_store_ps(gz + dimCx4, _mm_sub_ps(_mm_add_ps(d3, d4), d1x2)); |
770 | 6.14M | d1x2 = _mm_sub_ps(d1, d2); |
771 | 6.14M | d1x2 = _mm_add_ps(d1x2, d1x2); |
772 | 6.14M | d1x2 = _mm_add_ps(d1x2, d1x2); |
773 | 6.14M | _mm_store_ps(gz + 2 * dimCx4, _mm_add_ps(_mm_sub_ps(d4, d3), d1x2)); |
774 | 6.14M | __m128 d3x1 = _mm_sub_ps(d3, d1); |
775 | 6.14M | d3x1 = _mm_add_ps(d3x1, d3x1); |
776 | 6.14M | _mm_store_ps(gz + 3 * dimCx4, _mm_add_ps(_mm_sub_ps(d4, d2), d3x1)); |
777 | 6.14M | _mm_store_ps(gz + 4 * dimCx4, _mm_sub_ps(_mm_sub_ps(d4, d2), d3x1)); |
778 | 6.14M | d1 = _mm_add_ps(d1, d1); |
779 | 6.14M | d1 = _mm_add_ps(d1, d1); |
780 | 6.14M | __m128 d3x5 = _mm_add_ps(d3, d3); |
781 | 6.14M | d3x5 = _mm_add_ps(d3x5, d3x5); |
782 | 6.14M | d3x5 = _mm_add_ps(d3, d3x5); |
783 | 6.14M | _mm_store_ps(gz + 5 * dimCx4, _mm_sub_ps(_mm_add_ps(d1, d5), d3x5)); |
784 | 6.14M | } unroll_endfor |
785 | | // move to the next channel |
786 | 1.02M | g += 4; |
787 | 1.02M | } |
788 | 63.9k | const float* wpz = gwtg; |
789 | 1.66M | for (k = 0; k < w->info.dim[0]; k += 41.60M ) |
790 | 1.60M | { |
791 | 1.60M | float q[36 * 4] __attribute__ ((__aligned__(16))); |
792 | | #if FOR_IS_PARALLEL |
793 | | g = btdb + i * 36 * dimCx4; |
794 | | #else |
795 | 1.60M | g = btdb; |
796 | 1.60M | #endif |
797 | 59.3M | for (j = 0; j < 36; j++57.7M ) |
798 | 57.7M | { |
799 | 57.7M | __m128 v40 = _mm_setzero_ps(); |
800 | 57.7M | __m128 v41 = _mm_setzero_ps(); |
801 | 57.7M | __m128 v42 = _mm_setzero_ps(); |
802 | 57.7M | __m128 v43 = _mm_setzero_ps(); |
803 | 1.80G | for (c = 0; c < adim[2]; c += 41.74G ) |
804 | 1.74G | { |
805 | 1.74G | __m128 g4 = _mm_load_ps(g); |
806 | 1.74G | __m128 w40 = _mm_load_ps(wpz); |
807 | 1.74G | __m128 w41 = _mm_load_ps(wpz + 4); |
808 | 1.74G | __m128 w42 = _mm_load_ps(wpz + 8); |
809 | 1.74G | __m128 w43 = _mm_load_ps(wpz + 12); |
810 | 1.74G | __m128 g40 = _mm_shuffle_ps(g4, g4, 0x00); |
811 | 1.74G | __m128 g41 = _mm_shuffle_ps(g4, g4, 0x55); |
812 | 1.74G | __m128 g42 = _mm_shuffle_ps(g4, g4, 0xAA); |
813 | 1.74G | __m128 g43 = _mm_shuffle_ps(g4, g4, 0xFF); |
814 | 1.74G | v40 = _mm_add_ps(_mm_mul_ps(w40, g40), v40); |
815 | 1.74G | v41 = _mm_add_ps(_mm_mul_ps(w41, g41), v41); |
816 | 1.74G | v42 = _mm_add_ps(_mm_mul_ps(w42, g42), v42); |
817 | 1.74G | v43 = _mm_add_ps(_mm_mul_ps(w43, g43), v43); |
818 | 1.74G | g += 4; |
819 | 1.74G | wpz += 16; |
820 | 1.74G | } |
821 | 57.7M | v40 = _mm_add_ps(v40, v41); |
822 | 57.7M | v42 = _mm_add_ps(v42, v43); |
823 | 57.7M | _mm_store_ps(q + j * 4, _mm_add_ps(v40, v42)); |
824 | 57.7M | } |
825 | 1.60M | float d[24 * 4] __attribute__ ((__aligned__(16))); |
826 | 9.62M | unroll_for(j, 6) { |
827 | 9.62M | const float* const qz = q + j * 4; |
828 | 9.62M | float* const dz = d + j * 4; |
829 | 9.62M | __m128 q0 = _mm_load_ps(qz); |
830 | 9.62M | __m128 q6 = _mm_load_ps(qz + 24); |
831 | 9.62M | __m128 q12 = _mm_load_ps(qz + 48); |
832 | 9.62M | __m128 q18 = _mm_load_ps(qz + 72); |
833 | 9.62M | __m128 q24 = _mm_load_ps(qz + 96); |
834 | 9.62M | __m128 qs6x12 = _mm_add_ps(q6, q12); |
835 | 9.62M | __m128 qs18x24 = _mm_add_ps(q18, q24); |
836 | 9.62M | __m128 qss = _mm_add_ps(qs6x12, q0); |
837 | | /* row 1 */ |
838 | 9.62M | _mm_store_ps(dz, _mm_add_ps(qss, qs18x24)); |
839 | 9.62M | __m128 qn6x12 = _mm_sub_ps(q6, q12); |
840 | 9.62M | __m128 qn18x24 = _mm_sub_ps(q18, q24); |
841 | 9.62M | qn18x24 = _mm_add_ps(qn18x24, qn18x24); |
842 | | /* row 2 */ |
843 | 9.62M | _mm_store_ps(dz + 24, _mm_add_ps(qn6x12, qn18x24)); |
844 | 9.62M | qs18x24 = _mm_add_ps(qs18x24, qs18x24); |
845 | 9.62M | qs18x24 = _mm_add_ps(qs18x24, qs18x24); |
846 | | /* row 3 */ |
847 | 9.62M | _mm_store_ps(dz + 48, _mm_add_ps(qs6x12, qs18x24)); |
848 | 9.62M | qn18x24 = _mm_add_ps(qn18x24, qn18x24); |
849 | 9.62M | qn18x24 = _mm_add_ps(qn18x24, qn18x24); |
850 | 9.62M | __m128 q30 = _mm_load_ps(qz + 120); |
851 | | /* row 4 */ |
852 | 9.62M | _mm_store_ps(dz + 72, _mm_add_ps(_mm_add_ps(qn6x12, q30), qn18x24)); |
853 | 9.62M | } unroll_endfor |
854 | 1.60M | float* bpz = bp + x * bstride[2] + k; |
855 | 1.60M | __m128 bias4 = _mm_loadu_ps(biasval + k); |
856 | 1.60M | switch (z[1]) { |
857 | 10.8k | case 1: |
858 | 35.1k | unroll_for(dy, z[0], 4) { |
859 | 35.1k | const float* const dz = d + dy * 24; |
860 | 35.1k | __m128 d0 = _mm_load_ps(dz); |
861 | 35.1k | __m128 d1 = _mm_load_ps(dz + 4); |
862 | 35.1k | __m128 d2 = _mm_load_ps(dz + 8); |
863 | 35.1k | __m128 d3 = _mm_load_ps(dz + 12); |
864 | 35.1k | __m128 d4 = _mm_load_ps(dz + 16); |
865 | 35.1k | __m128 ds1x2 = _mm_add_ps(d1, d2); |
866 | 35.1k | __m128 ds3x4 = _mm_add_ps(d3, d4); |
867 | 35.1k | ds1x2 = _mm_add_ps(ds1x2, bias4); |
868 | 35.1k | _mm_stream_ps(bpz, _mm_add_ps(ds1x2, _mm_add_ps(d0, ds3x4))); |
869 | 35.1k | bpz += bstride[1]; |
870 | 35.1k | } unroll_endfor |
871 | 10.8k | break; |
872 | 0 | case 2: |
873 | 0 | unroll_for(dy, z[0], 4) { |
874 | 0 | const float* const dz = d + dy * 24; |
875 | 0 | __m128 d0 = _mm_load_ps(dz); |
876 | 0 | __m128 d1 = _mm_load_ps(dz + 4); |
877 | 0 | __m128 d2 = _mm_load_ps(dz + 8); |
878 | 0 | __m128 d3 = _mm_load_ps(dz + 12); |
879 | 0 | __m128 d4 = _mm_load_ps(dz + 16); |
880 | 0 | __m128 ds1x2 = _mm_add_ps(d1, d2); |
881 | 0 | __m128 ds3x4 = _mm_add_ps(d3, d4); |
882 | 0 | ds1x2 = _mm_add_ps(ds1x2, bias4); |
883 | 0 | _mm_stream_ps(bpz, _mm_add_ps(ds1x2, _mm_add_ps(d0, ds3x4))); |
884 | 0 | __m128 dn1x2 = _mm_sub_ps(d1, d2); |
885 | 0 | __m128 dn3x4 = _mm_sub_ps(d3, d4); |
886 | 0 | dn3x4 = _mm_add_ps(dn3x4, dn3x4); |
887 | 0 | dn1x2 = _mm_add_ps(dn1x2, bias4); |
888 | 0 | _mm_stream_ps(bpz + bstride[2], _mm_add_ps(dn1x2, dn3x4)); |
889 | 0 | bpz += bstride[1]; |
890 | 0 | } unroll_endfor |
891 | 0 | break; |
892 | 63.1k | case 3: |
893 | 247k | unroll_for(dy, z[0], 4) { |
894 | 247k | const float* const dz = d + dy * 24; |
895 | 247k | __m128 d0 = _mm_load_ps(dz); |
896 | 247k | __m128 d1 = _mm_load_ps(dz + 4); |
897 | 247k | __m128 d2 = _mm_load_ps(dz + 8); |
898 | 247k | __m128 d3 = _mm_load_ps(dz + 12); |
899 | 247k | __m128 d4 = _mm_load_ps(dz + 16); |
900 | 247k | __m128 ds1x2 = _mm_add_ps(d1, d2); |
901 | 247k | __m128 ds3x4 = _mm_add_ps(d3, d4); |
902 | 247k | ds1x2 = _mm_add_ps(ds1x2, bias4); |
903 | 247k | _mm_stream_ps(bpz, _mm_add_ps(ds1x2, _mm_add_ps(d0, ds3x4))); |
904 | 247k | __m128 dn1x2 = _mm_sub_ps(d1, d2); |
905 | 247k | __m128 dn3x4 = _mm_sub_ps(d3, d4); |
906 | 247k | dn3x4 = _mm_add_ps(dn3x4, dn3x4); |
907 | 247k | dn1x2 = _mm_add_ps(dn1x2, bias4); |
908 | 247k | _mm_stream_ps(bpz + bstride[2], _mm_add_ps(dn1x2, dn3x4)); |
909 | 247k | ds3x4 = _mm_add_ps(ds3x4, ds3x4); |
910 | 247k | ds3x4 = _mm_add_ps(ds3x4, ds3x4); |
911 | 247k | _mm_stream_ps(bpz + 2 * bstride[2], _mm_add_ps(ds1x2, ds3x4)); |
912 | 247k | bpz += bstride[1]; |
913 | 247k | } unroll_endfor |
914 | 63.1k | break; |
915 | 1.53M | case 4: |
916 | 6.04M | unroll_for(dy, z[0], 4) { |
917 | 6.04M | const float* const dz = d + dy * 24; |
918 | 6.04M | __m128 d0 = _mm_load_ps(dz); |
919 | 6.04M | __m128 d1 = _mm_load_ps(dz + 4); |
920 | 6.04M | __m128 d2 = _mm_load_ps(dz + 8); |
921 | 6.04M | __m128 d3 = _mm_load_ps(dz + 12); |
922 | 6.04M | __m128 d4 = _mm_load_ps(dz + 16); |
923 | 6.04M | __m128 ds1x2 = _mm_add_ps(d1, d2); |
924 | 6.04M | __m128 ds3x4 = _mm_add_ps(d3, d4); |
925 | 6.04M | ds1x2 = _mm_add_ps(ds1x2, bias4); |
926 | 6.04M | _mm_stream_ps(bpz, _mm_add_ps(ds1x2, _mm_add_ps(d0, ds3x4))); |
927 | 6.04M | __m128 dn1x2 = _mm_sub_ps(d1, d2); |
928 | 6.04M | __m128 dn3x4 = _mm_sub_ps(d3, d4); |
929 | 6.04M | dn3x4 = _mm_add_ps(dn3x4, dn3x4); |
930 | 6.04M | dn1x2 = _mm_add_ps(dn1x2, bias4); |
931 | 6.04M | _mm_stream_ps(bpz + bstride[2], _mm_add_ps(dn1x2, dn3x4)); |
932 | 6.04M | ds3x4 = _mm_add_ps(ds3x4, ds3x4); |
933 | 6.04M | ds3x4 = _mm_add_ps(ds3x4, ds3x4); |
934 | 6.04M | _mm_stream_ps(bpz + 2 * bstride[2], _mm_add_ps(ds1x2, ds3x4)); |
935 | 6.04M | __m128 d5 = _mm_load_ps(dz + 20); |
936 | 6.04M | dn3x4 = _mm_add_ps(dn3x4, dn3x4); |
937 | 6.04M | dn3x4 = _mm_add_ps(dn3x4, dn3x4); |
938 | 6.04M | _mm_stream_ps(bpz + 3 * bstride[2], _mm_add_ps(_mm_add_ps(dn1x2, d5), dn3x4)); |
939 | 6.04M | bpz += bstride[1]; |
940 | 6.04M | } unroll_endfor |
941 | 1.53M | break; |
942 | 1.60M | }; |
943 | 1.60M | } |
944 | 63.9k | } |
945 | 1.83k | } parallel_endfor |
946 | 118 | } else { |
947 | | // This block will be cause in each for-loop, therefore, you can use it to generate some temporary variables. |
948 | 14 | parallel_for1 (i, jump_dim) { |
949 | 14 | const int y = i * 4; // i is unsigned. |
950 | 14 | int j, x, k, c; |
951 | 14 | int n[CCV_NNC_MAX_DIM]; |
952 | 14 | int m[CCV_NNC_MAX_DIM]; |
953 | 14 | int z[CCV_NNC_MAX_DIM]; |
954 | 14 | set_n_m_dim(y, 0, tile_dim, adim); |
955 | 14 | z[0] = ccv_min(y + 4, bdim[0]) - y; |
956 | 14 | const float* ap = a->data.f32 + ccv_max(y - hint.border.begin[0], 0) * astride[1]; |
957 | 14 | float* bp = b->data.f32 + y * bstride[1]; |
958 | 210 | for (x = 0; x < bdim[1]; x += 4196 ) |
959 | 196 | { |
960 | 196 | set_n_m_dim(x, 1, tile_dim, adim); |
961 | 196 | z[1] = ccv_min(x + 4, bdim[1]) - x; |
962 | | #if FOR_IS_PARALLEL |
963 | | float* g = btdb + i * 36 * dimCx4; |
964 | | #else |
965 | 196 | float* g = btdb; |
966 | 196 | #endif |
967 | | // zero g such that we can have zero-padding. |
968 | 196 | memset(g, 0, sizeof(float) * 36 * dimCx4); |
969 | 196 | int dx, dy; |
970 | 196 | const float* apz = ap + ccv_max(x - hint.border.begin[1], 0) * astride[2]; |
971 | 196 | float* gz = g + (n[0] * 6 + n[1]) * dimCx4; |
972 | 1.14k | unroll_for(dy, m[0], 6) { |
973 | 6.72k | unroll_for(dx, m[1], 6) { |
974 | 6.72k | float* const gzu = gz + (dy * 6 + dx) * dimCx4; |
975 | 867k | for (c = 0; c < adim[2]; c++860k ) |
976 | 860k | gzu[c] = apz[dx * astride[2] + c]; |
977 | 6.72k | } unroll_endfor |
978 | 1.14k | apz += astride[1]; |
979 | 1.14k | } unroll_endfor |
980 | 6.46k | for (c = 0; c < adim[2]; c += 46.27k ) |
981 | 6.27k | { |
982 | 6.27k | float d[36 * 4] __attribute__ ((__aligned__(16))); |
983 | | /* BT.d */ |
984 | 37.6k | unroll_for(j, 6) { |
985 | | /* row 1 */ |
986 | 37.6k | const float* const gz = g + j * dimCx4; |
987 | 37.6k | float* dz = d + j * 4; |
988 | 37.6k | __m128 g0 = _mm_load_ps(gz); |
989 | 37.6k | __m128 g12 = _mm_load_ps(gz + 12 * dimCx4); |
990 | 37.6k | __m128 g18 = _mm_load_ps(gz + 18 * dimCx4); |
991 | 37.6k | __m128 g24 = _mm_load_ps(gz + 24 * dimCx4); |
992 | 37.6k | g0 = _mm_add_ps(g0, g0); |
993 | 37.6k | g0 = _mm_add_ps(g0, g0); |
994 | 37.6k | __m128 g12x2 = _mm_add_ps(g12, g12); |
995 | 37.6k | g12x2 = _mm_add_ps(g12x2, g12x2); |
996 | 37.6k | g12x2 = _mm_add_ps(g12x2, g12); |
997 | 37.6k | _mm_store_ps(dz, _mm_sub_ps(_mm_add_ps(g0, g24), g12x2)); |
998 | | /* row 2 */ |
999 | 37.6k | __m128 g6 = _mm_load_ps(gz + 6 * dimCx4); |
1000 | 37.6k | __m128 g6x12 = _mm_add_ps(g6, g12); |
1001 | 37.6k | g6x12 = _mm_add_ps(g6x12, g6x12); |
1002 | 37.6k | g6x12 = _mm_add_ps(g6x12, g6x12); |
1003 | 37.6k | _mm_store_ps(dz + 24, _mm_sub_ps(_mm_add_ps(g18, g24), g6x12)); |
1004 | | /* row 3 */ |
1005 | 37.6k | g6x12 = _mm_sub_ps(g6, g12); |
1006 | 37.6k | g6x12 = _mm_add_ps(g6x12, g6x12); |
1007 | 37.6k | g6x12 = _mm_add_ps(g6x12, g6x12); |
1008 | 37.6k | _mm_store_ps(dz + 48, _mm_add_ps(_mm_sub_ps(g24, g18), g6x12)); |
1009 | | /* row 4 */ |
1010 | 37.6k | __m128 g18x6 = _mm_sub_ps(g18, g6); |
1011 | 37.6k | g18x6 = _mm_add_ps(g18x6, g18x6); |
1012 | 37.6k | _mm_store_ps(dz + 72, _mm_add_ps(_mm_sub_ps(g24, g12), g18x6)); |
1013 | | /* row 5 */ |
1014 | 37.6k | _mm_store_ps(dz + 96, _mm_sub_ps(_mm_sub_ps(g24, g12), g18x6)); |
1015 | | /* row 6 */ |
1016 | 37.6k | __m128 g30 = _mm_load_ps(gz + 30 * dimCx4); |
1017 | 37.6k | __m128 g18x2 = _mm_add_ps(g18, g18); |
1018 | 37.6k | g18x2 = _mm_add_ps(g18x2, g18x2); |
1019 | 37.6k | g18x2 = _mm_add_ps(g18, g18x2); |
1020 | 37.6k | g6 = _mm_add_ps(g6, g6); |
1021 | 37.6k | g6 = _mm_add_ps(g6, g6); |
1022 | 37.6k | _mm_store_ps(dz + 120, _mm_sub_ps(_mm_add_ps(g6, g30), g18x2)); |
1023 | 37.6k | } unroll_endfor |
1024 | | /* BT.d.B */ |
1025 | 37.6k | unroll_for(j, 6) { |
1026 | 37.6k | float* gz = g + j * 6 * dimCx4; |
1027 | 37.6k | const float* const dz = d + j * 24; |
1028 | 37.6k | __m128 d0 = _mm_load_ps(dz); |
1029 | 37.6k | __m128 d1 = _mm_load_ps(dz + 4); |
1030 | 37.6k | __m128 d2 = _mm_load_ps(dz + 8); |
1031 | 37.6k | __m128 d3 = _mm_load_ps(dz + 12); |
1032 | 37.6k | __m128 d4 = _mm_load_ps(dz + 16); |
1033 | 37.6k | __m128 d5 = _mm_load_ps(dz + 20); |
1034 | 37.6k | d0 = _mm_add_ps(d0, d0); |
1035 | 37.6k | d0 = _mm_add_ps(d0, d0); |
1036 | 37.6k | __m128 d2x5 = _mm_add_ps(d2, d2); |
1037 | 37.6k | d2x5 = _mm_add_ps(d2x5, d2x5); |
1038 | 37.6k | d2x5 = _mm_add_ps(d2, d2x5); |
1039 | 37.6k | _mm_store_ps(gz, _mm_sub_ps(_mm_add_ps(d0, d4), d2x5)); |
1040 | 37.6k | __m128 d1x2 = _mm_add_ps(d1, d2); |
1041 | 37.6k | d1x2 = _mm_add_ps(d1x2, d1x2); |
1042 | 37.6k | d1x2 = _mm_add_ps(d1x2, d1x2); |
1043 | 37.6k | _mm_store_ps(gz + dimCx4, _mm_sub_ps(_mm_add_ps(d3, d4), d1x2)); |
1044 | 37.6k | d1x2 = _mm_sub_ps(d1, d2); |
1045 | 37.6k | d1x2 = _mm_add_ps(d1x2, d1x2); |
1046 | 37.6k | d1x2 = _mm_add_ps(d1x2, d1x2); |
1047 | 37.6k | _mm_store_ps(gz + 2 * dimCx4, _mm_add_ps(_mm_sub_ps(d4, d3), d1x2)); |
1048 | 37.6k | __m128 d3x1 = _mm_sub_ps(d3, d1); |
1049 | 37.6k | d3x1 = _mm_add_ps(d3x1, d3x1); |
1050 | 37.6k | _mm_store_ps(gz + 3 * dimCx4, _mm_add_ps(_mm_sub_ps(d4, d2), d3x1)); |
1051 | 37.6k | _mm_store_ps(gz + 4 * dimCx4, _mm_sub_ps(_mm_sub_ps(d4, d2), d3x1)); |
1052 | 37.6k | d1 = _mm_add_ps(d1, d1); |
1053 | 37.6k | d1 = _mm_add_ps(d1, d1); |
1054 | 37.6k | __m128 d3x5 = _mm_add_ps(d3, d3); |
1055 | 37.6k | d3x5 = _mm_add_ps(d3x5, d3x5); |
1056 | 37.6k | d3x5 = _mm_add_ps(d3, d3x5); |
1057 | 37.6k | _mm_store_ps(gz + 5 * dimCx4, _mm_sub_ps(_mm_add_ps(d1, d5), d3x5)); |
1058 | 37.6k | } unroll_endfor |
1059 | | // move to the next channel |
1060 | 6.27k | g += 4; |
1061 | 6.27k | } |
1062 | 196 | const float* wpz = gwtg; |
1063 | 6.46k | for (k = 0; k < w->info.dim[0]; k += 46.27k ) |
1064 | 6.27k | { |
1065 | 6.27k | float q[36 * 4] __attribute__ ((__aligned__(16))); |
1066 | | #if FOR_IS_PARALLEL |
1067 | | g = btdb + i * 36 * dimCx4; |
1068 | | #else |
1069 | 6.27k | g = btdb; |
1070 | 6.27k | #endif |
1071 | 232k | for (j = 0; j < 36; j++225k ) |
1072 | 225k | { |
1073 | 225k | __m128 v40 = _mm_setzero_ps(); |
1074 | 225k | __m128 v41 = _mm_setzero_ps(); |
1075 | 225k | __m128 v42 = _mm_setzero_ps(); |
1076 | 225k | __m128 v43 = _mm_setzero_ps(); |
1077 | 7.45M | for (c = 0; c < adim[2]; c += 47.22M ) |
1078 | 7.22M | { |
1079 | 7.22M | __m128 g4 = _mm_load_ps(g); |
1080 | 7.22M | __m128 w40 = _mm_load_ps(wpz); |
1081 | 7.22M | __m128 w41 = _mm_load_ps(wpz + 4); |
1082 | 7.22M | __m128 w42 = _mm_load_ps(wpz + 8); |
1083 | 7.22M | __m128 w43 = _mm_load_ps(wpz + 12); |
1084 | 7.22M | __m128 g40 = _mm_shuffle_ps(g4, g4, 0x00); |
1085 | 7.22M | __m128 g41 = _mm_shuffle_ps(g4, g4, 0x55); |
1086 | 7.22M | __m128 g42 = _mm_shuffle_ps(g4, g4, 0xAA); |
1087 | 7.22M | __m128 g43 = _mm_shuffle_ps(g4, g4, 0xFF); |
1088 | 7.22M | v40 = _mm_add_ps(_mm_mul_ps(w40, g40), v40); |
1089 | 7.22M | v41 = _mm_add_ps(_mm_mul_ps(w41, g41), v41); |
1090 | 7.22M | v42 = _mm_add_ps(_mm_mul_ps(w42, g42), v42); |
1091 | 7.22M | v43 = _mm_add_ps(_mm_mul_ps(w43, g43), v43); |
1092 | 7.22M | g += 4; |
1093 | 7.22M | wpz += 16; |
1094 | 7.22M | } |
1095 | 225k | v40 = _mm_add_ps(v40, v41); |
1096 | 225k | v42 = _mm_add_ps(v42, v43); |
1097 | 225k | _mm_store_ps(q + j * 4, _mm_add_ps(v40, v42)); |
1098 | 225k | } |
1099 | 6.27k | float d[24 * 4] __attribute__ ((__aligned__(16))); |
1100 | 37.6k | unroll_for(j, 6) { |
1101 | 37.6k | const float* const qz = q + j * 4; |
1102 | 37.6k | float* const dz = d + j * 4; |
1103 | 37.6k | __m128 q0 = _mm_load_ps(qz); |
1104 | 37.6k | __m128 q6 = _mm_load_ps(qz + 24); |
1105 | 37.6k | __m128 q12 = _mm_load_ps(qz + 48); |
1106 | 37.6k | __m128 q18 = _mm_load_ps(qz + 72); |
1107 | 37.6k | __m128 q24 = _mm_load_ps(qz + 96); |
1108 | 37.6k | __m128 qs6x12 = _mm_add_ps(q6, q12); |
1109 | 37.6k | __m128 qs18x24 = _mm_add_ps(q18, q24); |
1110 | 37.6k | __m128 qss = _mm_add_ps(qs6x12, q0); |
1111 | | /* row 1 */ |
1112 | 37.6k | _mm_store_ps(dz, _mm_add_ps(qss, qs18x24)); |
1113 | 37.6k | __m128 qn6x12 = _mm_sub_ps(q6, q12); |
1114 | 37.6k | __m128 qn18x24 = _mm_sub_ps(q18, q24); |
1115 | 37.6k | qn18x24 = _mm_add_ps(qn18x24, qn18x24); |
1116 | | /* row 2 */ |
1117 | 37.6k | _mm_store_ps(dz + 24, _mm_add_ps(qn6x12, qn18x24)); |
1118 | 37.6k | qs18x24 = _mm_add_ps(qs18x24, qs18x24); |
1119 | 37.6k | qs18x24 = _mm_add_ps(qs18x24, qs18x24); |
1120 | | /* row 3 */ |
1121 | 37.6k | _mm_store_ps(dz + 48, _mm_add_ps(qs6x12, qs18x24)); |
1122 | 37.6k | qn18x24 = _mm_add_ps(qn18x24, qn18x24); |
1123 | 37.6k | qn18x24 = _mm_add_ps(qn18x24, qn18x24); |
1124 | 37.6k | __m128 q30 = _mm_load_ps(qz + 120); |
1125 | | /* row 4 */ |
1126 | 37.6k | _mm_store_ps(dz + 72, _mm_add_ps(_mm_add_ps(qn6x12, q30), qn18x24)); |
1127 | 37.6k | } unroll_endfor |
1128 | 6.27k | float* bpz = bp + x * bstride[2] + k; |
1129 | 6.27k | switch (z[1]) { |
1130 | 0 | case 1: |
1131 | 0 | unroll_for(dy, z[0], 4) { |
1132 | 0 | const float* const dz = d + dy * 24; |
1133 | 0 | __m128 d0 = _mm_load_ps(dz); |
1134 | 0 | __m128 d1 = _mm_load_ps(dz + 4); |
1135 | 0 | __m128 d2 = _mm_load_ps(dz + 8); |
1136 | 0 | __m128 d3 = _mm_load_ps(dz + 12); |
1137 | 0 | __m128 d4 = _mm_load_ps(dz + 16); |
1138 | 0 | __m128 ds1x2 = _mm_add_ps(d1, d2); |
1139 | 0 | __m128 ds3x4 = _mm_add_ps(d3, d4); |
1140 | 0 | _mm_stream_ps(bpz, _mm_add_ps(ds1x2, _mm_add_ps(d0, ds3x4))); |
1141 | 0 | bpz += bstride[1]; |
1142 | 0 | } unroll_endfor |
1143 | 0 | break; |
1144 | 0 | case 2: |
1145 | 0 | unroll_for(dy, z[0], 4) { |
1146 | 0 | const float* const dz = d + dy * 24; |
1147 | 0 | __m128 d0 = _mm_load_ps(dz); |
1148 | 0 | __m128 d1 = _mm_load_ps(dz + 4); |
1149 | 0 | __m128 d2 = _mm_load_ps(dz + 8); |
1150 | 0 | __m128 d3 = _mm_load_ps(dz + 12); |
1151 | 0 | __m128 d4 = _mm_load_ps(dz + 16); |
1152 | 0 | __m128 ds1x2 = _mm_add_ps(d1, d2); |
1153 | 0 | __m128 ds3x4 = _mm_add_ps(d3, d4); |
1154 | 0 | _mm_stream_ps(bpz, _mm_add_ps(ds1x2, _mm_add_ps(d0, ds3x4))); |
1155 | 0 | __m128 dn1x2 = _mm_sub_ps(d1, d2); |
1156 | 0 | __m128 dn3x4 = _mm_sub_ps(d3, d4); |
1157 | 0 | dn3x4 = _mm_add_ps(dn3x4, dn3x4); |
1158 | 0 | _mm_stream_ps(bpz + bstride[2], _mm_add_ps(dn1x2, dn3x4)); |
1159 | 0 | bpz += bstride[1]; |
1160 | 0 | } unroll_endfor |
1161 | 0 | break; |
1162 | 0 | case 3: |
1163 | 0 | unroll_for(dy, z[0], 4) { |
1164 | 0 | const float* const dz = d + dy * 24; |
1165 | 0 | __m128 d0 = _mm_load_ps(dz); |
1166 | 0 | __m128 d1 = _mm_load_ps(dz + 4); |
1167 | 0 | __m128 d2 = _mm_load_ps(dz + 8); |
1168 | 0 | __m128 d3 = _mm_load_ps(dz + 12); |
1169 | 0 | __m128 d4 = _mm_load_ps(dz + 16); |
1170 | 0 | __m128 ds1x2 = _mm_add_ps(d1, d2); |
1171 | 0 | __m128 ds3x4 = _mm_add_ps(d3, d4); |
1172 | 0 | _mm_stream_ps(bpz, _mm_add_ps(ds1x2, _mm_add_ps(d0, ds3x4))); |
1173 | 0 | __m128 dn1x2 = _mm_sub_ps(d1, d2); |
1174 | 0 | __m128 dn3x4 = _mm_sub_ps(d3, d4); |
1175 | 0 | dn3x4 = _mm_add_ps(dn3x4, dn3x4); |
1176 | 0 | _mm_stream_ps(bpz + bstride[2], _mm_add_ps(dn1x2, dn3x4)); |
1177 | 0 | ds3x4 = _mm_add_ps(ds3x4, ds3x4); |
1178 | 0 | ds3x4 = _mm_add_ps(ds3x4, ds3x4); |
1179 | 0 | _mm_stream_ps(bpz + 2 * bstride[2], _mm_add_ps(ds1x2, ds3x4)); |
1180 | 0 | bpz += bstride[1]; |
1181 | 0 | } unroll_endfor |
1182 | 0 | break; |
1183 | 6.27k | case 4: |
1184 | 25.0k | unroll_for(dy, z[0], 4) { |
1185 | 25.0k | const float* const dz = d + dy * 24; |
1186 | 25.0k | __m128 d0 = _mm_load_ps(dz); |
1187 | 25.0k | __m128 d1 = _mm_load_ps(dz + 4); |
1188 | 25.0k | __m128 d2 = _mm_load_ps(dz + 8); |
1189 | 25.0k | __m128 d3 = _mm_load_ps(dz + 12); |
1190 | 25.0k | __m128 d4 = _mm_load_ps(dz + 16); |
1191 | 25.0k | __m128 ds1x2 = _mm_add_ps(d1, d2); |
1192 | 25.0k | __m128 ds3x4 = _mm_add_ps(d3, d4); |
1193 | 25.0k | _mm_stream_ps(bpz, _mm_add_ps(ds1x2, _mm_add_ps(d0, ds3x4))); |
1194 | 25.0k | __m128 dn1x2 = _mm_sub_ps(d1, d2); |
1195 | 25.0k | __m128 dn3x4 = _mm_sub_ps(d3, d4); |
1196 | 25.0k | dn3x4 = _mm_add_ps(dn3x4, dn3x4); |
1197 | 25.0k | _mm_stream_ps(bpz + bstride[2], _mm_add_ps(dn1x2, dn3x4)); |
1198 | 25.0k | ds3x4 = _mm_add_ps(ds3x4, ds3x4); |
1199 | 25.0k | ds3x4 = _mm_add_ps(ds3x4, ds3x4); |
1200 | 25.0k | _mm_stream_ps(bpz + 2 * bstride[2], _mm_add_ps(ds1x2, ds3x4)); |
1201 | 25.0k | __m128 d5 = _mm_load_ps(dz + 20); |
1202 | 25.0k | dn3x4 = _mm_add_ps(dn3x4, dn3x4); |
1203 | 25.0k | dn3x4 = _mm_add_ps(dn3x4, dn3x4); |
1204 | 25.0k | _mm_stream_ps(bpz + 3 * bstride[2], _mm_add_ps(_mm_add_ps(dn1x2, d5), dn3x4)); |
1205 | 25.0k | bpz += bstride[1]; |
1206 | 25.0k | } unroll_endfor |
1207 | 6.27k | break; |
1208 | 6.27k | }; |
1209 | 6.27k | } |
1210 | 196 | } |
1211 | 14 | } parallel_endfor |
1212 | 1 | } |
1213 | 119 | return CCV_NNC_EXEC_SUCCESS; |
1214 | 119 | } |
1215 | | #endif |
1216 | | |
1217 | | #ifdef HAVE_NEON |
1218 | | inline static void _ccv_nnc_winograd_4x4_3x3_gwtg_neon(const float* const w, const int* const dim, float* const gwtg) |
1219 | | { |
1220 | | const int jump_dim = dim[0] / 4; |
1221 | | const int dimCx4 = (dim[3] + 3) & -4; |
1222 | | parallel_for(k, jump_dim) { |
1223 | | int i, j; |
1224 | | float* gwtgz = gwtg + k * 4 * 36 * dimCx4; |
1225 | | const float* wz[] = { |
1226 | | w + (k * 4) * 9 * dim[3], |
1227 | | w + (k * 4 + 1) * 9 * dim[3], |
1228 | | w + (k * 4 + 2) * 9 * dim[3], |
1229 | | w + (k * 4 + 3) * 9 * dim[3], |
1230 | | }; |
1231 | | for (i = 0; i < dim[3]; i++) |
1232 | | { |
1233 | | float x9w[9 * 4] __attribute__ ((__aligned__(16))); |
1234 | | unroll_for(j, 9) { |
1235 | | x9w[j * 4] = wz[0][j * dim[3] + i]; |
1236 | | x9w[j * 4 + 1] = wz[1][j * dim[3] + i]; |
1237 | | x9w[j * 4 + 2] = wz[2][j * dim[3] + i]; |
1238 | | x9w[j * 4 + 3] = wz[3][j * dim[3] + i]; |
1239 | | } unroll_endfor |
1240 | | float g[18 * 4] __attribute__ ((__aligned__(16))); |
1241 | | float32x4_t x9w0 = vld1q_f32(x9w); |
1242 | | float32x4_t x9w1 = vld1q_f32(x9w + 4); |
1243 | | float32x4_t x9w2 = vld1q_f32(x9w + 8); |
1244 | | float32x4_t x9w3 = vld1q_f32(x9w + 12); |
1245 | | float32x4_t x9w4 = vld1q_f32(x9w + 16); |
1246 | | float32x4_t x9w5 = vld1q_f32(x9w + 20); |
1247 | | float32x4_t x9w6 = vld1q_f32(x9w + 24); |
1248 | | float32x4_t x9w7 = vld1q_f32(x9w + 28); |
1249 | | float32x4_t x9w8 = vld1q_f32(x9w + 32); |
1250 | | /* row 1 */ |
1251 | | float32x4_t c1_4 = vdupq_n_f32(1.0 / 4); |
1252 | | vst1q_f32(g, vmulq_f32(x9w0, c1_4)); |
1253 | | vst1q_f32(g + 4, vmulq_f32(x9w1, c1_4)); |
1254 | | vst1q_f32(g + 8, vmulq_f32(x9w2, c1_4)); |
1255 | | /* row 2 */ |
1256 | | float32x4_t cn1_6 = vdupq_n_f32(-1.0 / 6); |
1257 | | vst1q_f32(g + 12, vmulq_f32(vaddq_f32(vaddq_f32(x9w0, x9w6), x9w3), cn1_6)); |
1258 | | vst1q_f32(g + 16, vmulq_f32(vaddq_f32(vaddq_f32(x9w1, x9w7), x9w4), cn1_6)); |
1259 | | vst1q_f32(g + 20, vmulq_f32(vaddq_f32(vaddq_f32(x9w2, x9w8), x9w5), cn1_6)); |
1260 | | /* row 3 */ |
1261 | | vst1q_f32(g + 24, vmulq_f32(vsubq_f32(vaddq_f32(x9w0, x9w6), x9w3), cn1_6)); |
1262 | | vst1q_f32(g + 28, vmulq_f32(vsubq_f32(vaddq_f32(x9w1, x9w7), x9w4), cn1_6)); |
1263 | | vst1q_f32(g + 32, vmulq_f32(vsubq_f32(vaddq_f32(x9w2, x9w8), x9w5), cn1_6)); |
1264 | | /* row 6 */ |
1265 | | vst1q_f32(g + 60, x9w6); |
1266 | | vst1q_f32(g + 64, x9w7); |
1267 | | vst1q_f32(g + 68, x9w8); |
1268 | | /* w[x] * 2 */ |
1269 | | x9w3 = vaddq_f32(x9w3, x9w3); |
1270 | | x9w4 = vaddq_f32(x9w4, x9w4); |
1271 | | x9w5 = vaddq_f32(x9w5, x9w5); |
1272 | | /* w[x] * 4 */ |
1273 | | x9w6 = vaddq_f32(x9w6, x9w6); |
1274 | | x9w6 = vaddq_f32(x9w6, x9w6); |
1275 | | x9w7 = vaddq_f32(x9w7, x9w7); |
1276 | | x9w7 = vaddq_f32(x9w7, x9w7); |
1277 | | x9w8 = vaddq_f32(x9w8, x9w8); |
1278 | | x9w8 = vaddq_f32(x9w8, x9w8); |
1279 | | /* row 4 */ |
1280 | | float32x4_t c1_24 = vdupq_n_f32(1.0 / 24); |
1281 | | vst1q_f32(g + 36, vmulq_f32(vaddq_f32(vaddq_f32(x9w0, x9w6), x9w3), c1_24)); |
1282 | | vst1q_f32(g + 40, vmulq_f32(vaddq_f32(vaddq_f32(x9w1, x9w7), x9w4), c1_24)); |
1283 | | vst1q_f32(g + 44, vmulq_f32(vaddq_f32(vaddq_f32(x9w2, x9w8), x9w5), c1_24)); |
1284 | | /* row 5 */ |
1285 | | vst1q_f32(g + 48, vmulq_f32(vsubq_f32(vaddq_f32(x9w0, x9w6), x9w3), c1_24)); |
1286 | | vst1q_f32(g + 52, vmulq_f32(vsubq_f32(vaddq_f32(x9w1, x9w7), x9w4), c1_24)); |
1287 | | vst1q_f32(g + 56, vmulq_f32(vsubq_f32(vaddq_f32(x9w2, x9w8), x9w5), c1_24)); |
1288 | | unroll_for(j, 6) { |
1289 | | const float* const gz = g + j * 12; |
1290 | | float* const gwtgzu = gwtgz + j * 24 * dimCx4; |
1291 | | float32x4_t g0 = vld1q_f32(gz); |
1292 | | float32x4_t g1 = vld1q_f32(gz + 4); |
1293 | | float32x4_t g2 = vld1q_f32(gz + 8); |
1294 | | vst1q_f32(gwtgzu, vmulq_f32(g0, c1_4)); |
1295 | | vst1q_f32(gwtgzu + 4 * dimCx4, vmulq_f32(vaddq_f32(vaddq_f32(g0, g2), g1), cn1_6)); |
1296 | | vst1q_f32(gwtgzu + 8 * dimCx4, vmulq_f32(vsubq_f32(vaddq_f32(g0, g2), g1), cn1_6)); |
1297 | | vst1q_f32(gwtgzu + 20 * dimCx4, g2); |
1298 | | /* g[1] * 2 */ |
1299 | | g1 = vaddq_f32(g1, g1); |
1300 | | /* g[2] * 4 */ |
1301 | | g2 = vaddq_f32(g2, g2); |
1302 | | g2 = vaddq_f32(g2, g2); |
1303 | | vst1q_f32(gwtgzu + 12 * dimCx4, vmulq_f32(vaddq_f32(vaddq_f32(g0, g2), g1), c1_24)); |
1304 | | vst1q_f32(gwtgzu + 16 * dimCx4, vmulq_f32(vsubq_f32(vaddq_f32(g0, g2), g1), c1_24)); |
1305 | | } unroll_endfor |
1306 | | gwtgz += 4; |
1307 | | } |
1308 | | } parallel_endfor |
1309 | | } |
1310 | | |
1311 | | static int _ccv_nnc_conv_forw_4x4_3x3_winograd_neon(const ccv_nnc_tensor_view_t* const a, const ccv_nnc_tensor_t* const w, const ccv_nnc_tensor_t* const bias, const ccv_nnc_hint_t hint, ccv_nnc_tensor_view_t* const b, ccv_nnc_stream_context_t* const stream_context) |
1312 | | { |
1313 | | const int a_nd = ccv_nnc_tensor_nd(a->info.dim); |
1314 | | assert(a_nd == CCV_NNC_MAX_DIM + 1 || a_nd == CCV_NNC_MAX_DIM + 2); |
1315 | | const int* adim = (a_nd == CCV_NNC_MAX_DIM + 1) ? a->info.dim : a->info.dim + 1; |
1316 | | const int b_nd = ccv_nnc_tensor_nd(b->info.dim); |
1317 | | assert(b_nd == CCV_NNC_MAX_DIM + 1 || b_nd == CCV_NNC_MAX_DIM + 2); |
1318 | | const int* bdim = (b_nd == CCV_NNC_MAX_DIM + 1) ? b->info.dim : b->info.dim + 1; |
1319 | | int astride[CCV_NNC_MAX_DIM_ALLOC]; |
1320 | | ccv_nnc_tensor_view_get_stride(a, astride); |
1321 | | int bstride[CCV_NNC_MAX_DIM_ALLOC]; |
1322 | | ccv_nnc_tensor_view_get_stride(b, bstride); |
1323 | | assert(hint.border.begin[0] <= 1); |
1324 | | assert(hint.border.begin[1] <= 1); |
1325 | | assert(w->info.dim[0] % 4 == 0); |
1326 | | assert(w->info.dim[1] == 3); |
1327 | | assert(w->info.dim[2] == 3); |
1328 | | const int jump_dim = (bdim[0] + 3) / 4; |
1329 | | const int dimCx4 = (adim[2] + 3) & -4; |
1330 | | // allocating workspace memory for kernel reshaping and input reshaping. |
1331 | | float* workmem = 0; |
1332 | | #if FOR_IS_PARALLEL |
1333 | | // If we do parallel for, we need to allocate input reshaping for each block. |
1334 | | workmem = (float*)ccv_nnc_stream_context_get_workspace(stream_context, sizeof(float) * (36 * dimCx4 * jump_dim + 36 * dimCx4 * w->info.dim[0]), CCV_TENSOR_CPU_MEMORY); |
1335 | | #else |
1336 | | // Otherwise, just one block. |
1337 | | workmem = (float*)ccv_nnc_stream_context_get_workspace(stream_context, sizeof(float) * (36 * dimCx4 + 36 * dimCx4 * w->info.dim[0]), CCV_TENSOR_CPU_MEMORY); |
1338 | | #endif |
1339 | | if (!workmem) |
1340 | | return CCV_NNC_EXEC_OOM; |
1341 | | // Convert w to a 6x6 matrix, by computing G.w.T(G) // T for transpose. |
1342 | | float* const gwtg = workmem; |
1343 | | float* const btdb = workmem + 36 * dimCx4 * w->info.dim[0]; |
1344 | | memset(gwtg, 0, sizeof(float) * 36 * dimCx4 * w->info.dim[0]); |
1345 | | _ccv_nnc_winograd_4x4_3x3_gwtg_neon(w->data.f32, w->info.dim, gwtg); |
1346 | | // kernel weight for one dim. |
1347 | | // Workaround issues of dispatch_apply (cannot reference to on-stack array) |
1348 | | const int tile_dim_s[CCV_NNC_MAX_DIM_ALLOC] = { |
1349 | | w->info.dim[0], 6, 6, w->info.dim[3] |
1350 | | }; |
1351 | | const int* const tile_dim = tile_dim_s; |
1352 | | if (bias) |
1353 | | { |
1354 | | const float* const biasval = bias->data.f32; |
1355 | | // This block will be cause in each for-loop, therefore, you can use it to generate some temporary variables. |
1356 | | parallel_for(i, jump_dim) { |
1357 | | const int y = i * 4; // i is unsigned. |
1358 | | int j, x, k, c; |
1359 | | int n[CCV_NNC_MAX_DIM]; |
1360 | | int m[CCV_NNC_MAX_DIM]; |
1361 | | int z[CCV_NNC_MAX_DIM]; |
1362 | | set_n_m_dim(y, 0, tile_dim, adim); |
1363 | | z[0] = ccv_min(y + 4, bdim[0]) - y; |
1364 | | const float* ap = a->data.f32 + ccv_max(y - hint.border.begin[0], 0) * astride[1]; |
1365 | | float* bp = b->data.f32 + y * bstride[1]; |
1366 | | for (x = 0; x < bdim[1]; x += 4) |
1367 | | { |
1368 | | set_n_m_dim(x, 1, tile_dim, adim); |
1369 | | z[1] = ccv_min(x + 4, bdim[1]) - x; |
1370 | | #if FOR_IS_PARALLEL |
1371 | | float* g = btdb + i * 36 * dimCx4; |
1372 | | #else |
1373 | | float* g = btdb; |
1374 | | #endif |
1375 | | // zero g such that we can have zero-padding. |
1376 | | memset(g, 0, sizeof(float) * 36 * dimCx4); |
1377 | | int dx, dy; |
1378 | | const float* apz = ap + ccv_max(x - hint.border.begin[1], 0) * astride[2]; |
1379 | | float* gz = g + (n[0] * 6 + n[1]) * dimCx4; |
1380 | | unroll_for(dy, m[0], 6) { |
1381 | | unroll_for(dx, m[1], 6) { |
1382 | | float* const gzu = gz + (dy * 6 + dx) * dimCx4; |
1383 | | for (c = 0; c < adim[2]; c++) |
1384 | | gzu[c] = apz[dx * astride[2] + c]; |
1385 | | } unroll_endfor |
1386 | | apz += astride[1]; |
1387 | | } unroll_endfor |
1388 | | for (c = 0; c < adim[2]; c += 4) |
1389 | | { |
1390 | | float d[36 * 4] __attribute__ ((__aligned__(16))); |
1391 | | /* BT.d */ |
1392 | | unroll_for(j, 6) { |
1393 | | /* row 1 */ |
1394 | | const float* const gz = g + j * dimCx4; |
1395 | | float* dz = d + j * 4; |
1396 | | float32x4_t g0 = vld1q_f32(gz); |
1397 | | float32x4_t g12 = vld1q_f32(gz + 12 * dimCx4); |
1398 | | float32x4_t g18 = vld1q_f32(gz + 18 * dimCx4); |
1399 | | float32x4_t g24 = vld1q_f32(gz + 24 * dimCx4); |
1400 | | g0 = vaddq_f32(g0, g0); |
1401 | | g0 = vaddq_f32(g0, g0); |
1402 | | float32x4_t g12x2 = vaddq_f32(g12, g12); |
1403 | | g12x2 = vaddq_f32(g12x2, g12x2); |
1404 | | g12x2 = vaddq_f32(g12x2, g12); |
1405 | | vst1q_f32(dz, vsubq_f32(vaddq_f32(g0, g24), g12x2)); |
1406 | | /* row 2 */ |
1407 | | float32x4_t g6 = vld1q_f32(gz + 6 * dimCx4); |
1408 | | float32x4_t g6x12 = vaddq_f32(g6, g12); |
1409 | | g6x12 = vaddq_f32(g6x12, g6x12); |
1410 | | g6x12 = vaddq_f32(g6x12, g6x12); |
1411 | | vst1q_f32(dz + 24, vsubq_f32(vaddq_f32(g18, g24), g6x12)); |
1412 | | /* row 3 */ |
1413 | | g6x12 = vsubq_f32(g6, g12); |
1414 | | g6x12 = vaddq_f32(g6x12, g6x12); |
1415 | | g6x12 = vaddq_f32(g6x12, g6x12); |
1416 | | vst1q_f32(dz + 48, vaddq_f32(vsubq_f32(g24, g18), g6x12)); |
1417 | | /* row 4 */ |
1418 | | float32x4_t g18x6 = vsubq_f32(g18, g6); |
1419 | | g18x6 = vaddq_f32(g18x6, g18x6); |
1420 | | vst1q_f32(dz + 72, vaddq_f32(vsubq_f32(g24, g12), g18x6)); |
1421 | | /* row 5 */ |
1422 | | vst1q_f32(dz + 96, vsubq_f32(vsubq_f32(g24, g12), g18x6)); |
1423 | | /* row 6 */ |
1424 | | float32x4_t g30 = vld1q_f32(gz + 30 * dimCx4); |
1425 | | float32x4_t g18x2 = vaddq_f32(g18, g18); |
1426 | | g18x2 = vaddq_f32(g18x2, g18x2); |
1427 | | g18x2 = vaddq_f32(g18, g18x2); |
1428 | | g6 = vaddq_f32(g6, g6); |
1429 | | g6 = vaddq_f32(g6, g6); |
1430 | | vst1q_f32(dz + 120, vsubq_f32(vaddq_f32(g6, g30), g18x2)); |
1431 | | } unroll_endfor |
1432 | | /* BT.d.B */ |
1433 | | unroll_for(j, 6) { |
1434 | | float* gz = g + j * 6 * dimCx4; |
1435 | | const float* const dz = d + j * 24; |
1436 | | float32x4_t d0 = vld1q_f32(dz); |
1437 | | float32x4_t d1 = vld1q_f32(dz + 4); |
1438 | | float32x4_t d2 = vld1q_f32(dz + 8); |
1439 | | float32x4_t d3 = vld1q_f32(dz + 12); |
1440 | | float32x4_t d4 = vld1q_f32(dz + 16); |
1441 | | float32x4_t d5 = vld1q_f32(dz + 20); |
1442 | | d0 = vaddq_f32(d0, d0); |
1443 | | d0 = vaddq_f32(d0, d0); |
1444 | | float32x4_t d2x5 = vaddq_f32(d2, d2); |
1445 | | d2x5 = vaddq_f32(d2x5, d2x5); |
1446 | | d2x5 = vaddq_f32(d2, d2x5); |
1447 | | vst1q_f32(gz, vsubq_f32(vaddq_f32(d0, d4), d2x5)); |
1448 | | float32x4_t d1x2 = vaddq_f32(d1, d2); |
1449 | | d1x2 = vaddq_f32(d1x2, d1x2); |
1450 | | d1x2 = vaddq_f32(d1x2, d1x2); |
1451 | | vst1q_f32(gz + dimCx4, vsubq_f32(vaddq_f32(d3, d4), d1x2)); |
1452 | | d1x2 = vsubq_f32(d1, d2); |
1453 | | d1x2 = vaddq_f32(d1x2, d1x2); |
1454 | | d1x2 = vaddq_f32(d1x2, d1x2); |
1455 | | vst1q_f32(gz + 2 * dimCx4, vaddq_f32(vsubq_f32(d4, d3), d1x2)); |
1456 | | float32x4_t d3x1 = vsubq_f32(d3, d1); |
1457 | | d3x1 = vaddq_f32(d3x1, d3x1); |
1458 | | vst1q_f32(gz + 3 * dimCx4, vaddq_f32(vsubq_f32(d4, d2), d3x1)); |
1459 | | vst1q_f32(gz + 4 * dimCx4, vsubq_f32(vsubq_f32(d4, d2), d3x1)); |
1460 | | d1 = vaddq_f32(d1, d1); |
1461 | | d1 = vaddq_f32(d1, d1); |
1462 | | float32x4_t d3x5 = vaddq_f32(d3, d3); |
1463 | | d3x5 = vaddq_f32(d3x5, d3x5); |
1464 | | d3x5 = vaddq_f32(d3, d3x5); |
1465 | | vst1q_f32(gz + 5 * dimCx4, vsubq_f32(vaddq_f32(d1, d5), d3x5)); |
1466 | | } unroll_endfor |
1467 | | // move to the next channel |
1468 | | g += 4; |
1469 | | } |
1470 | | const float* wpz = gwtg; |
1471 | | for (k = 0; k < w->info.dim[0]; k += 4) |
1472 | | { |
1473 | | float q[36 * 4] __attribute__ ((__aligned__(16))); |
1474 | | #if FOR_IS_PARALLEL |
1475 | | g = btdb + i * 36 * dimCx4; |
1476 | | #else |
1477 | | g = btdb; |
1478 | | #endif |
1479 | | for (j = 0; j < 36; j++) |
1480 | | { |
1481 | | float32x4_t v40 = vmovq_n_f32(0); |
1482 | | float32x4_t v41 = vmovq_n_f32(0); |
1483 | | float32x4_t v42 = vmovq_n_f32(0); |
1484 | | float32x4_t v43 = vmovq_n_f32(0); |
1485 | | for (c = 0; c < adim[2]; c += 4) |
1486 | | { |
1487 | | float32x2x2_t g4 = vld2_f32(g); |
1488 | | float32x4_t w40 = vld1q_f32(wpz); |
1489 | | float32x4_t w41 = vld1q_f32(wpz + 4); |
1490 | | float32x4_t w42 = vld1q_f32(wpz + 8); |
1491 | | float32x4_t w43 = vld1q_f32(wpz + 12); |
1492 | | float32x4_t g40 = vdupq_lane_f32(g4.val[0], 0); |
1493 | | float32x4_t g41 = vdupq_lane_f32(g4.val[1], 0); |
1494 | | float32x4_t g42 = vdupq_lane_f32(g4.val[0], 1); |
1495 | | float32x4_t g43 = vdupq_lane_f32(g4.val[1], 1); |
1496 | | v40 = vmlaq_f32(v40, w40, g40); |
1497 | | v41 = vmlaq_f32(v41, w41, g41); |
1498 | | v42 = vmlaq_f32(v42, w42, g42); |
1499 | | v43 = vmlaq_f32(v43, w43, g43); |
1500 | | g += 4; |
1501 | | wpz += 16; |
1502 | | } |
1503 | | v40 = vaddq_f32(v40, v41); |
1504 | | v42 = vaddq_f32(v42, v43); |
1505 | | vst1q_f32(q + j * 4, vaddq_f32(v40, v42)); |
1506 | | } |
1507 | | float d[24 * 4] __attribute__ ((__aligned__(16))); |
1508 | | unroll_for(j, 6) { |
1509 | | const float* const qz = q + j * 4; |
1510 | | float* const dz = d + j * 4; |
1511 | | float32x4_t q0 = vld1q_f32(qz); |
1512 | | float32x4_t q6 = vld1q_f32(qz + 24); |
1513 | | float32x4_t q12 = vld1q_f32(qz + 48); |
1514 | | float32x4_t q18 = vld1q_f32(qz + 72); |
1515 | | float32x4_t q24 = vld1q_f32(qz + 96); |
1516 | | float32x4_t qs6x12 = vaddq_f32(q6, q12); |
1517 | | float32x4_t qs18x24 = vaddq_f32(q18, q24); |
1518 | | float32x4_t qss = vaddq_f32(qs6x12, q0); |
1519 | | /* row 1 */ |
1520 | | vst1q_f32(dz, vaddq_f32(qss, qs18x24)); |
1521 | | float32x4_t qn6x12 = vsubq_f32(q6, q12); |
1522 | | float32x4_t qn18x24 = vsubq_f32(q18, q24); |
1523 | | qn18x24 = vaddq_f32(qn18x24, qn18x24); |
1524 | | /* row 2 */ |
1525 | | vst1q_f32(dz + 24, vaddq_f32(qn6x12, qn18x24)); |
1526 | | qs18x24 = vaddq_f32(qs18x24, qs18x24); |
1527 | | qs18x24 = vaddq_f32(qs18x24, qs18x24); |
1528 | | /* row 3 */ |
1529 | | vst1q_f32(dz + 48, vaddq_f32(qs6x12, qs18x24)); |
1530 | | qn18x24 = vaddq_f32(qn18x24, qn18x24); |
1531 | | qn18x24 = vaddq_f32(qn18x24, qn18x24); |
1532 | | float32x4_t q30 = vld1q_f32(qz + 120); |
1533 | | /* row 4 */ |
1534 | | vst1q_f32(dz + 72, vaddq_f32(vaddq_f32(qn6x12, q30), qn18x24)); |
1535 | | } unroll_endfor |
1536 | | float* bpz = bp + x * bstride[2] + k; |
1537 | | float32x4_t bias4 = vld1q_f32(biasval + k); |
1538 | | switch (z[1]) { |
1539 | | case 1: |
1540 | | unroll_for(dy, z[0], 4) { |
1541 | | const float* const dz = d + dy * 24; |
1542 | | float32x4_t d0 = vld1q_f32(dz); |
1543 | | float32x4_t d1 = vld1q_f32(dz + 4); |
1544 | | float32x4_t d2 = vld1q_f32(dz + 8); |
1545 | | float32x4_t d3 = vld1q_f32(dz + 12); |
1546 | | float32x4_t d4 = vld1q_f32(dz + 16); |
1547 | | float32x4_t ds1x2 = vaddq_f32(d1, d2); |
1548 | | float32x4_t ds3x4 = vaddq_f32(d3, d4); |
1549 | | ds1x2 = vaddq_f32(ds1x2, bias4); |
1550 | | vst1q_f32(bpz, vaddq_f32(ds1x2, vaddq_f32(d0, ds3x4))); |
1551 | | bpz += bstride[1]; |
1552 | | } unroll_endfor |
1553 | | break; |
1554 | | case 2: |
1555 | | unroll_for(dy, z[0], 4) { |
1556 | | const float* const dz = d + dy * 24; |
1557 | | float32x4_t d0 = vld1q_f32(dz); |
1558 | | float32x4_t d1 = vld1q_f32(dz + 4); |
1559 | | float32x4_t d2 = vld1q_f32(dz + 8); |
1560 | | float32x4_t d3 = vld1q_f32(dz + 12); |
1561 | | float32x4_t d4 = vld1q_f32(dz + 16); |
1562 | | float32x4_t ds1x2 = vaddq_f32(d1, d2); |
1563 | | float32x4_t ds3x4 = vaddq_f32(d3, d4); |
1564 | | ds1x2 = vaddq_f32(ds1x2, bias4); |
1565 | | vst1q_f32(bpz, vaddq_f32(ds1x2, vaddq_f32(d0, ds3x4))); |
1566 | | float32x4_t dn1x2 = vsubq_f32(d1, d2); |
1567 | | float32x4_t dn3x4 = vsubq_f32(d3, d4); |
1568 | | dn3x4 = vaddq_f32(dn3x4, dn3x4); |
1569 | | dn1x2 = vaddq_f32(dn1x2, bias4); |
1570 | | vst1q_f32(bpz + bstride[2], vaddq_f32(dn1x2, dn3x4)); |
1571 | | bpz += bstride[1]; |
1572 | | } unroll_endfor |
1573 | | break; |
1574 | | case 3: |
1575 | | unroll_for(dy, z[0], 4) { |
1576 | | const float* const dz = d + dy * 24; |
1577 | | float32x4_t d0 = vld1q_f32(dz); |
1578 | | float32x4_t d1 = vld1q_f32(dz + 4); |
1579 | | float32x4_t d2 = vld1q_f32(dz + 8); |
1580 | | float32x4_t d3 = vld1q_f32(dz + 12); |
1581 | | float32x4_t d4 = vld1q_f32(dz + 16); |
1582 | | float32x4_t ds1x2 = vaddq_f32(d1, d2); |
1583 | | float32x4_t ds3x4 = vaddq_f32(d3, d4); |
1584 | | ds1x2 = vaddq_f32(ds1x2, bias4); |
1585 | | vst1q_f32(bpz, vaddq_f32(ds1x2, vaddq_f32(d0, ds3x4))); |
1586 | | float32x4_t dn1x2 = vsubq_f32(d1, d2); |
1587 | | float32x4_t dn3x4 = vsubq_f32(d3, d4); |
1588 | | dn3x4 = vaddq_f32(dn3x4, dn3x4); |
1589 | | dn1x2 = vaddq_f32(dn1x2, bias4); |
1590 | | vst1q_f32(bpz + bstride[2], vaddq_f32(dn1x2, dn3x4)); |
1591 | | ds3x4 = vaddq_f32(ds3x4, ds3x4); |
1592 | | ds3x4 = vaddq_f32(ds3x4, ds3x4); |
1593 | | vst1q_f32(bpz + 2 * bstride[2], vaddq_f32(ds1x2, ds3x4)); |
1594 | | bpz += bstride[1]; |
1595 | | } unroll_endfor |
1596 | | break; |
1597 | | case 4: |
1598 | | unroll_for(dy, z[0], 4) { |
1599 | | const float* const dz = d + dy * 24; |
1600 | | float32x4_t d0 = vld1q_f32(dz); |
1601 | | float32x4_t d1 = vld1q_f32(dz + 4); |
1602 | | float32x4_t d2 = vld1q_f32(dz + 8); |
1603 | | float32x4_t d3 = vld1q_f32(dz + 12); |
1604 | | float32x4_t d4 = vld1q_f32(dz + 16); |
1605 | | float32x4_t ds1x2 = vaddq_f32(d1, d2); |
1606 | | float32x4_t ds3x4 = vaddq_f32(d3, d4); |
1607 | | ds1x2 = vaddq_f32(ds1x2, bias4); |
1608 | | vst1q_f32(bpz, vaddq_f32(ds1x2, vaddq_f32(d0, ds3x4))); |
1609 | | float32x4_t dn1x2 = vsubq_f32(d1, d2); |
1610 | | float32x4_t dn3x4 = vsubq_f32(d3, d4); |
1611 | | dn3x4 = vaddq_f32(dn3x4, dn3x4); |
1612 | | dn1x2 = vaddq_f32(dn1x2, bias4); |
1613 | | vst1q_f32(bpz + bstride[2], vaddq_f32(dn1x2, dn3x4)); |
1614 | | ds3x4 = vaddq_f32(ds3x4, ds3x4); |
1615 | | ds3x4 = vaddq_f32(ds3x4, ds3x4); |
1616 | | vst1q_f32(bpz + 2 * bstride[2], vaddq_f32(ds1x2, ds3x4)); |
1617 | | float32x4_t d5 = vld1q_f32(dz + 20); |
1618 | | dn3x4 = vaddq_f32(dn3x4, dn3x4); |
1619 | | dn3x4 = vaddq_f32(dn3x4, dn3x4); |
1620 | | vst1q_f32(bpz + 3 * bstride[2], vaddq_f32(vaddq_f32(dn1x2, d5), dn3x4)); |
1621 | | bpz += bstride[1]; |
1622 | | } unroll_endfor |
1623 | | break; |
1624 | | }; |
1625 | | } |
1626 | | } |
1627 | | } parallel_endfor |
1628 | | } else { |
1629 | | // This block will be cause in each for-loop, therefore, you can use it to generate some temporary variables. |
1630 | | parallel_for(i, jump_dim) { |
1631 | | const int y = i * 4; // i is unsigned. |
1632 | | int j, x, k, c; |
1633 | | int n[CCV_NNC_MAX_DIM]; |
1634 | | int m[CCV_NNC_MAX_DIM]; |
1635 | | int z[CCV_NNC_MAX_DIM]; |
1636 | | set_n_m_dim(y, 0, tile_dim, adim); |
1637 | | z[0] = ccv_min(y + 4, bdim[0]) - y; |
1638 | | const float* ap = a->data.f32 + ccv_max(y - hint.border.begin[0], 0) * astride[1]; |
1639 | | float* bp = b->data.f32 + y * bstride[1]; |
1640 | | for (x = 0; x < bdim[1]; x += 4) |
1641 | | { |
1642 | | set_n_m_dim(x, 1, tile_dim, adim); |
1643 | | z[1] = ccv_min(x + 4, bdim[1]) - x; |
1644 | | #if FOR_IS_PARALLEL |
1645 | | float* g = btdb + i * 36 * dimCx4; |
1646 | | #else |
1647 | | float* g = btdb; |
1648 | | #endif |
1649 | | // zero g such that we can have zero-padding. |
1650 | | memset(g, 0, sizeof(float) * 36 * dimCx4); |
1651 | | int dx, dy; |
1652 | | const float* apz = ap + ccv_max(x - hint.border.begin[1], 0) * astride[2]; |
1653 | | float* gz = g + (n[0] * 6 + n[1]) * dimCx4; |
1654 | | unroll_for(dy, m[0], 6) { |
1655 | | unroll_for(dx, m[1], 6) { |
1656 | | float* const gzu = gz + (dy * 6 + dx) * dimCx4; |
1657 | | for (c = 0; c < adim[2]; c++) |
1658 | | gzu[c] = apz[dx * astride[2] + c]; |
1659 | | } unroll_endfor |
1660 | | apz += astride[1]; |
1661 | | } unroll_endfor |
1662 | | for (c = 0; c < adim[2]; c += 4) |
1663 | | { |
1664 | | float d[36 * 4] __attribute__ ((__aligned__(16))); |
1665 | | /* BT.d */ |
1666 | | unroll_for(j, 6) { |
1667 | | /* row 1 */ |
1668 | | const float* const gz = g + j * dimCx4; |
1669 | | float* dz = d + j * 4; |
1670 | | float32x4_t g0 = vld1q_f32(gz); |
1671 | | float32x4_t g12 = vld1q_f32(gz + 12 * dimCx4); |
1672 | | float32x4_t g18 = vld1q_f32(gz + 18 * dimCx4); |
1673 | | float32x4_t g24 = vld1q_f32(gz + 24 * dimCx4); |
1674 | | g0 = vaddq_f32(g0, g0); |
1675 | | g0 = vaddq_f32(g0, g0); |
1676 | | float32x4_t g12x2 = vaddq_f32(g12, g12); |
1677 | | g12x2 = vaddq_f32(g12x2, g12x2); |
1678 | | g12x2 = vaddq_f32(g12x2, g12); |
1679 | | vst1q_f32(dz, vsubq_f32(vaddq_f32(g0, g24), g12x2)); |
1680 | | /* row 2 */ |
1681 | | float32x4_t g6 = vld1q_f32(gz + 6 * dimCx4); |
1682 | | float32x4_t g6x12 = vaddq_f32(g6, g12); |
1683 | | g6x12 = vaddq_f32(g6x12, g6x12); |
1684 | | g6x12 = vaddq_f32(g6x12, g6x12); |
1685 | | vst1q_f32(dz + 24, vsubq_f32(vaddq_f32(g18, g24), g6x12)); |
1686 | | /* row 3 */ |
1687 | | g6x12 = vsubq_f32(g6, g12); |
1688 | | g6x12 = vaddq_f32(g6x12, g6x12); |
1689 | | g6x12 = vaddq_f32(g6x12, g6x12); |
1690 | | vst1q_f32(dz + 48, vaddq_f32(vsubq_f32(g24, g18), g6x12)); |
1691 | | /* row 4 */ |
1692 | | float32x4_t g18x6 = vsubq_f32(g18, g6); |
1693 | | g18x6 = vaddq_f32(g18x6, g18x6); |
1694 | | vst1q_f32(dz + 72, vaddq_f32(vsubq_f32(g24, g12), g18x6)); |
1695 | | /* row 5 */ |
1696 | | vst1q_f32(dz + 96, vsubq_f32(vsubq_f32(g24, g12), g18x6)); |
1697 | | /* row 6 */ |
1698 | | float32x4_t g30 = vld1q_f32(gz + 30 * dimCx4); |
1699 | | float32x4_t g18x2 = vaddq_f32(g18, g18); |
1700 | | g18x2 = vaddq_f32(g18x2, g18x2); |
1701 | | g18x2 = vaddq_f32(g18, g18x2); |
1702 | | g6 = vaddq_f32(g6, g6); |
1703 | | g6 = vaddq_f32(g6, g6); |
1704 | | vst1q_f32(dz + 120, vsubq_f32(vaddq_f32(g6, g30), g18x2)); |
1705 | | } unroll_endfor |
1706 | | /* BT.d.B */ |
1707 | | unroll_for(j, 6) { |
1708 | | float* gz = g + j * 6 * dimCx4; |
1709 | | const float* const dz = d + j * 24; |
1710 | | float32x4_t d0 = vld1q_f32(dz); |
1711 | | float32x4_t d1 = vld1q_f32(dz + 4); |
1712 | | float32x4_t d2 = vld1q_f32(dz + 8); |
1713 | | float32x4_t d3 = vld1q_f32(dz + 12); |
1714 | | float32x4_t d4 = vld1q_f32(dz + 16); |
1715 | | float32x4_t d5 = vld1q_f32(dz + 20); |
1716 | | d0 = vaddq_f32(d0, d0); |
1717 | | d0 = vaddq_f32(d0, d0); |
1718 | | float32x4_t d2x5 = vaddq_f32(d2, d2); |
1719 | | d2x5 = vaddq_f32(d2x5, d2x5); |
1720 | | d2x5 = vaddq_f32(d2, d2x5); |
1721 | | vst1q_f32(gz, vsubq_f32(vaddq_f32(d0, d4), d2x5)); |
1722 | | float32x4_t d1x2 = vaddq_f32(d1, d2); |
1723 | | d1x2 = vaddq_f32(d1x2, d1x2); |
1724 | | d1x2 = vaddq_f32(d1x2, d1x2); |
1725 | | vst1q_f32(gz + dimCx4, vsubq_f32(vaddq_f32(d3, d4), d1x2)); |
1726 | | d1x2 = vsubq_f32(d1, d2); |
1727 | | d1x2 = vaddq_f32(d1x2, d1x2); |
1728 | | d1x2 = vaddq_f32(d1x2, d1x2); |
1729 | | vst1q_f32(gz + 2 * dimCx4, vaddq_f32(vsubq_f32(d4, d3), d1x2)); |
1730 | | float32x4_t d3x1 = vsubq_f32(d3, d1); |
1731 | | d3x1 = vaddq_f32(d3x1, d3x1); |
1732 | | vst1q_f32(gz + 3 * dimCx4, vaddq_f32(vsubq_f32(d4, d2), d3x1)); |
1733 | | vst1q_f32(gz + 4 * dimCx4, vsubq_f32(vsubq_f32(d4, d2), d3x1)); |
1734 | | d1 = vaddq_f32(d1, d1); |
1735 | | d1 = vaddq_f32(d1, d1); |
1736 | | float32x4_t d3x5 = vaddq_f32(d3, d3); |
1737 | | d3x5 = vaddq_f32(d3x5, d3x5); |
1738 | | d3x5 = vaddq_f32(d3, d3x5); |
1739 | | vst1q_f32(gz + 5 * dimCx4, vsubq_f32(vaddq_f32(d1, d5), d3x5)); |
1740 | | } unroll_endfor |
1741 | | // move to the next channel |
1742 | | g += 4; |
1743 | | } |
1744 | | const float* wpz = gwtg; |
1745 | | for (k = 0; k < w->info.dim[0]; k += 4) |
1746 | | { |
1747 | | float q[36 * 4] __attribute__ ((__aligned__(16))); |
1748 | | #if FOR_IS_PARALLEL |
1749 | | g = btdb + i * 36 * dimCx4; |
1750 | | #else |
1751 | | g = btdb; |
1752 | | #endif |
1753 | | for (j = 0; j < 36; j++) |
1754 | | { |
1755 | | float32x4_t v40 = vmovq_n_f32(0); |
1756 | | float32x4_t v41 = vmovq_n_f32(0); |
1757 | | float32x4_t v42 = vmovq_n_f32(0); |
1758 | | float32x4_t v43 = vmovq_n_f32(0); |
1759 | | for (c = 0; c < adim[2]; c += 4) |
1760 | | { |
1761 | | float32x2x2_t g4 = vld2_f32(g); |
1762 | | float32x4_t w40 = vld1q_f32(wpz); |
1763 | | float32x4_t w41 = vld1q_f32(wpz + 4); |
1764 | | float32x4_t w42 = vld1q_f32(wpz + 8); |
1765 | | float32x4_t w43 = vld1q_f32(wpz + 12); |
1766 | | float32x4_t g40 = vdupq_lane_f32(g4.val[0], 0); |
1767 | | float32x4_t g41 = vdupq_lane_f32(g4.val[1], 0); |
1768 | | float32x4_t g42 = vdupq_lane_f32(g4.val[0], 1); |
1769 | | float32x4_t g43 = vdupq_lane_f32(g4.val[1], 1); |
1770 | | v40 = vmlaq_f32(v40, w40, g40); |
1771 | | v41 = vmlaq_f32(v41, w41, g41); |
1772 | | v42 = vmlaq_f32(v42, w42, g42); |
1773 | | v43 = vmlaq_f32(v43, w43, g43); |
1774 | | g += 4; |
1775 | | wpz += 16; |
1776 | | } |
1777 | | v40 = vaddq_f32(v40, v41); |
1778 | | v42 = vaddq_f32(v42, v43); |
1779 | | vst1q_f32(q + j * 4, vaddq_f32(v40, v42)); |
1780 | | } |
1781 | | float d[24 * 4] __attribute__ ((__aligned__(16))); |
1782 | | unroll_for(j, 6) { |
1783 | | const float* const qz = q + j * 4; |
1784 | | float* const dz = d + j * 4; |
1785 | | float32x4_t q0 = vld1q_f32(qz); |
1786 | | float32x4_t q6 = vld1q_f32(qz + 24); |
1787 | | float32x4_t q12 = vld1q_f32(qz + 48); |
1788 | | float32x4_t q18 = vld1q_f32(qz + 72); |
1789 | | float32x4_t q24 = vld1q_f32(qz + 96); |
1790 | | float32x4_t qs6x12 = vaddq_f32(q6, q12); |
1791 | | float32x4_t qs18x24 = vaddq_f32(q18, q24); |
1792 | | float32x4_t qss = vaddq_f32(qs6x12, q0); |
1793 | | /* row 1 */ |
1794 | | vst1q_f32(dz, vaddq_f32(qss, qs18x24)); |
1795 | | float32x4_t qn6x12 = vsubq_f32(q6, q12); |
1796 | | float32x4_t qn18x24 = vsubq_f32(q18, q24); |
1797 | | qn18x24 = vaddq_f32(qn18x24, qn18x24); |
1798 | | /* row 2 */ |
1799 | | vst1q_f32(dz + 24, vaddq_f32(qn6x12, qn18x24)); |
1800 | | qs18x24 = vaddq_f32(qs18x24, qs18x24); |
1801 | | qs18x24 = vaddq_f32(qs18x24, qs18x24); |
1802 | | /* row 3 */ |
1803 | | vst1q_f32(dz + 48, vaddq_f32(qs6x12, qs18x24)); |
1804 | | qn18x24 = vaddq_f32(qn18x24, qn18x24); |
1805 | | qn18x24 = vaddq_f32(qn18x24, qn18x24); |
1806 | | float32x4_t q30 = vld1q_f32(qz + 120); |
1807 | | /* row 4 */ |
1808 | | vst1q_f32(dz + 72, vaddq_f32(vaddq_f32(qn6x12, q30), qn18x24)); |
1809 | | } unroll_endfor |
1810 | | float* bpz = bp + x * bstride[2] + k; |
1811 | | switch (z[1]) { |
1812 | | case 1: |
1813 | | unroll_for(dy, z[0], 4) { |
1814 | | const float* const dz = d + dy * 24; |
1815 | | float32x4_t d0 = vld1q_f32(dz); |
1816 | | float32x4_t d1 = vld1q_f32(dz + 4); |
1817 | | float32x4_t d2 = vld1q_f32(dz + 8); |
1818 | | float32x4_t d3 = vld1q_f32(dz + 12); |
1819 | | float32x4_t d4 = vld1q_f32(dz + 16); |
1820 | | float32x4_t ds1x2 = vaddq_f32(d1, d2); |
1821 | | float32x4_t ds3x4 = vaddq_f32(d3, d4); |
1822 | | vst1q_f32(bpz, vaddq_f32(ds1x2, vaddq_f32(d0, ds3x4))); |
1823 | | bpz += bstride[1]; |
1824 | | } unroll_endfor |
1825 | | break; |
1826 | | case 2: |
1827 | | unroll_for(dy, z[0], 4) { |
1828 | | const float* const dz = d + dy * 24; |
1829 | | float32x4_t d0 = vld1q_f32(dz); |
1830 | | float32x4_t d1 = vld1q_f32(dz + 4); |
1831 | | float32x4_t d2 = vld1q_f32(dz + 8); |
1832 | | float32x4_t d3 = vld1q_f32(dz + 12); |
1833 | | float32x4_t d4 = vld1q_f32(dz + 16); |
1834 | | float32x4_t ds1x2 = vaddq_f32(d1, d2); |
1835 | | float32x4_t ds3x4 = vaddq_f32(d3, d4); |
1836 | | vst1q_f32(bpz, vaddq_f32(ds1x2, vaddq_f32(d0, ds3x4))); |
1837 | | float32x4_t dn1x2 = vsubq_f32(d1, d2); |
1838 | | float32x4_t dn3x4 = vsubq_f32(d3, d4); |
1839 | | dn3x4 = vaddq_f32(dn3x4, dn3x4); |
1840 | | vst1q_f32(bpz + bstride[2], vaddq_f32(dn1x2, dn3x4)); |
1841 | | bpz += bstride[1]; |
1842 | | } unroll_endfor |
1843 | | break; |
1844 | | case 3: |
1845 | | unroll_for(dy, z[0], 4) { |
1846 | | const float* const dz = d + dy * 24; |
1847 | | float32x4_t d0 = vld1q_f32(dz); |
1848 | | float32x4_t d1 = vld1q_f32(dz + 4); |
1849 | | float32x4_t d2 = vld1q_f32(dz + 8); |
1850 | | float32x4_t d3 = vld1q_f32(dz + 12); |
1851 | | float32x4_t d4 = vld1q_f32(dz + 16); |
1852 | | float32x4_t ds1x2 = vaddq_f32(d1, d2); |
1853 | | float32x4_t ds3x4 = vaddq_f32(d3, d4); |
1854 | | vst1q_f32(bpz, vaddq_f32(ds1x2, vaddq_f32(d0, ds3x4))); |
1855 | | float32x4_t dn1x2 = vsubq_f32(d1, d2); |
1856 | | float32x4_t dn3x4 = vsubq_f32(d3, d4); |
1857 | | dn3x4 = vaddq_f32(dn3x4, dn3x4); |
1858 | | vst1q_f32(bpz + bstride[2], vaddq_f32(dn1x2, dn3x4)); |
1859 | | ds3x4 = vaddq_f32(ds3x4, ds3x4); |
1860 | | ds3x4 = vaddq_f32(ds3x4, ds3x4); |
1861 | | vst1q_f32(bpz + 2 * bstride[2], vaddq_f32(ds1x2, ds3x4)); |
1862 | | bpz += bstride[1]; |
1863 | | } unroll_endfor |
1864 | | break; |
1865 | | case 4: |
1866 | | unroll_for(dy, z[0], 4) { |
1867 | | const float* const dz = d + dy * 24; |
1868 | | float32x4_t d0 = vld1q_f32(dz); |
1869 | | float32x4_t d1 = vld1q_f32(dz + 4); |
1870 | | float32x4_t d2 = vld1q_f32(dz + 8); |
1871 | | float32x4_t d3 = vld1q_f32(dz + 12); |
1872 | | float32x4_t d4 = vld1q_f32(dz + 16); |
1873 | | float32x4_t ds1x2 = vaddq_f32(d1, d2); |
1874 | | float32x4_t ds3x4 = vaddq_f32(d3, d4); |
1875 | | vst1q_f32(bpz, vaddq_f32(ds1x2, vaddq_f32(d0, ds3x4))); |
1876 | | float32x4_t dn1x2 = vsubq_f32(d1, d2); |
1877 | | float32x4_t dn3x4 = vsubq_f32(d3, d4); |
1878 | | dn3x4 = vaddq_f32(dn3x4, dn3x4); |
1879 | | vst1q_f32(bpz + bstride[2], vaddq_f32(dn1x2, dn3x4)); |
1880 | | ds3x4 = vaddq_f32(ds3x4, ds3x4); |
1881 | | ds3x4 = vaddq_f32(ds3x4, ds3x4); |
1882 | | vst1q_f32(bpz + 2 * bstride[2], vaddq_f32(ds1x2, ds3x4)); |
1883 | | float32x4_t d5 = vld1q_f32(dz + 20); |
1884 | | dn3x4 = vaddq_f32(dn3x4, dn3x4); |
1885 | | dn3x4 = vaddq_f32(dn3x4, dn3x4); |
1886 | | vst1q_f32(bpz + 3 * bstride[2], vaddq_f32(vaddq_f32(dn1x2, d5), dn3x4)); |
1887 | | bpz += bstride[1]; |
1888 | | } unroll_endfor |
1889 | | break; |
1890 | | }; |
1891 | | } |
1892 | | } |
1893 | | } parallel_endfor |
1894 | | } |
1895 | | return CCV_NNC_EXEC_SUCCESS; |
1896 | | } |
1897 | | #endif |
1898 | | |
1899 | | int _ccv_nnc_conv_forw_4x4_3x3_winograd_cpu_opt(const ccv_nnc_tensor_view_t* const a, const ccv_nnc_tensor_t* const w, const ccv_nnc_tensor_t* const bias, const ccv_nnc_hint_t hint, ccv_nnc_tensor_view_t* const b, ccv_nnc_stream_context_t* const stream_context) |
1900 | 119 | { |
1901 | 119 | #if defined(HAVE_SSE2) |
1902 | 119 | if (w->info.dim[0] % 4 == 0) |
1903 | 119 | return _ccv_nnc_conv_forw_4x4_3x3_winograd_sse2(a, w, bias, hint, b, stream_context); |
1904 | | #elif defined(HAVE_NEON) |
1905 | | if (w->info.dim[0] % 4 == 0) |
1906 | | return _ccv_nnc_conv_forw_4x4_3x3_winograd_neon(a, w, bias, hint, b, stream_context); |
1907 | | #endif |
1908 | 0 | return _ccv_nnc_conv_forw_4x4_3x3_winograd_ref(a, w, bias, hint, b, stream_context); |
1909 | 119 | } |