Coverage Report

Created: 2025-02-24 17:43

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/home/liu/actions-runner/_work/ccv/ccv/lib/nnc/cmd/convolution/cpu_opt/_ccv_nnc_conv_cpu_4x4_3x3_winograd.c
Line
Count
Source
1
#include "ccv.h"
2
#include "ccv_internal.h"
3
#include "nnc/ccv_nnc.h"
4
#include "nnc/ccv_nnc_easy.h"
5
#include "nnc/ccv_nnc_internal.h"
6
#if defined(HAVE_SSE2)
7
#include <xmmintrin.h>
8
#elif defined(HAVE_NEON)
9
#include <arm_neon.h>
10
#endif
11
#ifdef USE_OPENMP
12
#include <omp.h>
13
#endif
14
#ifdef USE_DISPATCH
15
#include <dispatch/dispatch.h>
16
#endif
17
#include "../_ccv_nnc_conv_cpu_opt.h"
18
19
#define set_n_m_dim(i, x, wd, ad) \
20
66.0k
  do { \
21
66.0k
    n[x] = ccv_max((i) * hint.stride.dim[x] - hint.border.begin[x], 0) - ((i) * hint.stride.dim[x] - hint.border.begin[x]); \
22
66.0k
    m[x] = wd[x + 1] - n[x] - ((i) * hint.stride.dim[x] - hint.border.begin[x] + wd[x + 1] - ccv_min(ad[x], (i) * hint.stride.dim[x] - hint.border.begin[x] + wd[x + 1])); \
23
66.0k
  } while (0)
24
25
inline static void _ccv_nnc_winograd_4x4_3x3_gwtg_ref(const float* const w, const int c, float* gwtg)
26
0
{
27
0
  int i;
28
0
  for (i = 0; i < c; i++)
29
0
  {
30
0
    float g[18];
31
    /*
32
     * a0, b1, c2
33
     * d3, e4, f5
34
     * g6, h7, i8
35
     * {{a/4, b/4, c/4},
36
     * {1/6 (-a - d - g), 1/6 (-b - e - h), 1/6 (-c - f - i)},
37
     * {1/6 (-a + d - g), 1/6 (-b + e - h), 1/6 (-c + f - i)},
38
     * {1/24 (a + 2 d + 4 g), 1/24 (b + 2 e + 4 h), 1/24 (c + 2 f + 4 i)},
39
     * {1/24 (a - 2 d + 4 g), 1/24 (b - 2 e + 4 h), 1/24 (c - 2 f + 4 i)},
40
     * {g, h, i}}
41
     */
42
    /* row 1 */
43
0
    g[0] = w[i] / 4;
44
0
    g[1] = w[c + i] / 4;
45
0
    g[2] = w[2 * c + i] / 4;
46
    /* row 2 */
47
0
    g[3] = -(w[i] + w[3 * c + i] + w[6 * c + i]) / 6;
48
0
    g[4] = -(w[c + i] + w[4 * c + i] + w[7 * c + i]) / 6;
49
0
    g[5] = -(w[2 * c + i] + w[5 * c + i] + w[8 * c + i]) / 6;
50
    /* row 3 */
51
0
    g[6] = (-w[i] + w[3 * c + i] - w[6 * c + i]) / 6;
52
0
    g[7] = (-w[c + i] + w[4 * c + i] - w[7 * c + i]) / 6;
53
0
    g[8] = (-w[2 * c + i] + w[5 * c + i] - w[8 * c + i]) / 6;
54
    /* row 4 */
55
0
    g[9] = (w[i] + 2 * w[3 * c + i] + 4 * w[6 * c + i]) / 24;
56
0
    g[10] = (w[c + i] + 2 * w[4 * c + i] + 4 * w[7 * c + i]) / 24;
57
0
    g[11] = (w[2 * c + i] + 2 * w[5 * c + i] + 4 * w[8 * c + i]) / 24;
58
    /* row 5 */
59
0
    g[12] = (w[i] - 2 * w[3 * c + i] + 4 * w[6 * c + i]) / 24;
60
0
    g[13] = (w[c + i] - 2 * w[4 * c + i] + 4 * w[7 * c + i]) / 24;
61
0
    g[14] = (w[2 * c + i] - 2 * w[5 * c + i] + 4 * w[8 * c + i]) / 24;
62
    /* row 6 */
63
0
    g[15] = w[6 * c + i];
64
0
    g[16] = w[7 * c + i];
65
0
    g[17] = w[8 * c + i];
66
    /*
67
     * a0, b1, c2
68
     * d3, e4, f5
69
     * g6, h7, i8
70
     * j9, k10,l11
71
     * m12,n13,o14
72
     * p15,q16,r17
73
     * {{a/4, 1/6 (-a - b - c), 1/6 (-a + b - c), 1/24 (a + 2 b + 4 c), 1/24 (a - 2 b + 4 c), c},
74
     * {d/4, 1/6 (-d - e - f), 1/6 (-d + e - f), 1/24 (d + 2 e + 4 f), 1/24 (d - 2 e + 4 f), f},
75
     * {g/4, 1/6 (-g - h - i), 1/6 (-g + h - i), 1/24 (g + 2 h + 4 i), 1/24 (g - 2 h + 4 i), i},
76
     * {j/4, 1/6 (-j - k - l), 1/6 (-j + k - l), 1/24 (j + 2 k + 4 l), 1/24 (j - 2 k + 4 l), l},
77
     * {m/4, 1/6 (-m - n - o), 1/6 (-m + n - o), 1/24 (m + 2 n + 4 o), 1/24 (m - 2 n + 4 o), o},
78
     * {p/4, 1/6 (-p - q - r), 1/6 (-p + q - r), 1/24 (p + 2 q + 4 r), 1/24 (p - 2 q + 4 r), r}}
79
     */
80
    /* row 1 */
81
0
    gwtg[0] = g[0] / 4;
82
0
    gwtg[c] = -(g[0] + g[1] + g[2]) / 6;
83
0
    gwtg[2 * c] = (-g[0] + g[1] - g[2]) / 6;
84
0
    gwtg[3 * c] = (g[0] + 2 * g[1] + 4 * g[2]) / 24;
85
0
    gwtg[4 * c] = (g[0] - 2 * g[1] + 4 * g[2]) / 24;
86
0
    gwtg[5 * c] = g[2];
87
    /* row 2 */
88
0
    gwtg[6 * c] = g[3] / 4;
89
0
    gwtg[7 * c] = -(g[3] + g[4] + g[5]) / 6;
90
0
    gwtg[8 * c] = (-g[3] + g[4] - g[5]) / 6;
91
0
    gwtg[9 * c] = (g[3] + 2 * g[4] + 4 * g[5]) / 24;
92
0
    gwtg[10 * c] = (g[3] - 2 * g[4] + 4 * g[5]) / 24;
93
0
    gwtg[11 * c] = g[5];
94
    /* row 3 */
95
0
    gwtg[12 * c] = g[6] / 4;
96
0
    gwtg[13 * c] = -(g[6] + g[7] + g[8]) / 6;
97
0
    gwtg[14 * c] = (-g[6] + g[7] - g[8]) / 6;
98
0
    gwtg[15 * c] = (g[6] + 2 * g[7] + 4 * g[8]) / 24;
99
0
    gwtg[16 * c] = (g[6] - 2 * g[7] + 4 * g[8]) / 24;
100
0
    gwtg[17 * c] = g[8];
101
    /* row 4 */
102
0
    gwtg[18 * c] = g[9] / 4;
103
0
    gwtg[19 * c] = -(g[9] + g[10] + g[11]) / 6;
104
0
    gwtg[20 * c] = (-g[9] + g[10] - g[11]) / 6;
105
0
    gwtg[21 * c] = (g[9] + 2 * g[10] + 4 * g[11]) / 24;
106
0
    gwtg[22 * c] = (g[9] - 2 * g[10] + 4 * g[11]) / 24;
107
0
    gwtg[23 * c] = g[11];
108
    /* row 5 */
109
0
    gwtg[24 * c] = g[12] / 4;
110
0
    gwtg[25 * c] = -(g[12] + g[13] + g[14]) / 6;
111
0
    gwtg[26 * c] = (-g[12] + g[13] - g[14]) / 6;
112
0
    gwtg[27 * c] = (g[12] + 2 * g[13] + 4 * g[14]) / 24;
113
0
    gwtg[28 * c] = (g[12] - 2 * g[13] + 4 * g[14]) / 24;
114
0
    gwtg[29 * c] = g[14];
115
    /* row 6 */
116
0
    gwtg[30 * c] = g[15] / 4;
117
0
    gwtg[31 * c] = -(g[15] + g[16] + g[17]) / 6;
118
0
    gwtg[32 * c] = (-g[15] + g[16] - g[17]) / 6;
119
0
    gwtg[33 * c] = (g[15] + 2 * g[16] + 4 * g[17]) / 24;
120
0
    gwtg[34 * c] = (g[15] - 2 * g[16] + 4 * g[17]) / 24;
121
0
    gwtg[35 * c] = g[17];
122
0
    ++gwtg;
123
0
  }
124
0
}
125
126
static int _ccv_nnc_conv_forw_4x4_3x3_winograd_ref(const ccv_nnc_tensor_view_t* const a, const ccv_nnc_tensor_t* const w, const ccv_nnc_tensor_t* const bias, const ccv_nnc_hint_t hint, ccv_nnc_tensor_view_t* const b, ccv_nnc_stream_context_t* const stream_context)
127
0
{
128
0
  const int a_nd = ccv_nnc_tensor_nd(a->info.dim);
129
0
  assert(a_nd == CCV_NNC_MAX_DIM + 1 || a_nd == CCV_NNC_MAX_DIM + 2);
130
0
  const int* adim = (a_nd == CCV_NNC_MAX_DIM + 1) ? a->info.dim : a->info.dim + 1;
131
0
  const int b_nd = ccv_nnc_tensor_nd(b->info.dim);
132
0
  assert(b_nd == CCV_NNC_MAX_DIM + 1 || b_nd == CCV_NNC_MAX_DIM + 2);
133
0
  const int* bdim = (b_nd == CCV_NNC_MAX_DIM + 1) ? b->info.dim : b->info.dim + 1;
134
0
  int astride[CCV_NNC_MAX_DIM_ALLOC];
135
0
  ccv_nnc_tensor_view_get_stride(a, astride);
136
0
  int bstride[CCV_NNC_MAX_DIM_ALLOC];
137
0
  ccv_nnc_tensor_view_get_stride(b, bstride);
138
0
  assert(hint.border.begin[0] <= 1);
139
0
  assert(hint.border.begin[1] <= 1);
140
0
  assert(w->info.dim[1] == 3);
141
0
  assert(w->info.dim[2] == 3);
142
0
  const int jump_dim = (bdim[0] + 3) / 4;
143
0
  float* workmem;
144
  // allocating workspace memory for kernel reshaping and input reshaping.
145
#if FOR_IS_PARALLEL
146
  // If we do parallel for, we need to allocate input reshaping for each block.
147
  workmem = ccv_nnc_stream_context_get_workspace(stream_context, sizeof(float) * (36 * adim[2] * jump_dim + 36 * w->info.dim[0] * w->info.dim[3]), CCV_TENSOR_CPU_MEMORY);
148
#else
149
  // Otherwise, just one block.
150
0
  workmem = ccv_nnc_stream_context_get_workspace(stream_context, sizeof(float) * (36 * adim[2] + 36 * w->info.dim[0] * w->info.dim[3]), CCV_TENSOR_CPU_MEMORY);
151
0
#endif
152
0
  if (!workmem)
153
0
    return CCV_NNC_EXEC_OOM;
154
  // Convert w to a 6x6 matrix, by computing G.w.T(G) // T for transpose.
155
0
  float* const gwtg = workmem;
156
0
  float* const btdb = workmem + 36 * w->info.dim[0] * w->info.dim[3];
157
0
  parallel_for(k, w->info.dim[0]) {
158
0
    _ccv_nnc_winograd_4x4_3x3_gwtg_ref(w->data.f32 + k * w->info.dim[3] * w->info.dim[2] * w->info.dim[1], w->info.dim[3], gwtg + k * 36 * w->info.dim[3]);
159
0
  } parallel_endfor
160
  // kernel weight for one dim.
161
  // Workaround issues of dispatch_apply (cannot reference to on-stack array)
162
0
  const int tile_dim_s[CCV_NNC_MAX_DIM_ALLOC] = {
163
0
    w->info.dim[0], 6, 6, w->info.dim[3]
164
0
  };
165
0
  const int* const tile_dim = tile_dim_s;
166
  // This block will be cause in each for-loop, therefore, you can use it to generate some temporary variables.
167
0
  if (bias)
168
0
  {
169
0
    const float* const biasval = bias->data.f32;
170
0
    parallel_for(i, jump_dim) {
171
0
      const int y = i * 4; // i is unsigned.
172
0
      int j, x, k, c;
173
0
      int n[CCV_NNC_MAX_DIM];
174
0
      int m[CCV_NNC_MAX_DIM];
175
0
      int z[CCV_NNC_MAX_DIM];
176
0
      set_n_m_dim(y, 0, tile_dim, adim);
177
0
      z[0] = ccv_min(y + 4, bdim[0]) - y;
178
0
      const float* ap = a->data.f32 + ccv_max(y - hint.border.begin[0], 0) * astride[1];
179
0
      float* bp = b->data.f32 + y * bstride[1];
180
0
      for (x = 0; x < bdim[1]; x += 4)
181
0
      {
182
0
        set_n_m_dim(x, 1, tile_dim, adim);
183
0
        z[1] = ccv_min(x + 4, bdim[1]) - x;
184
#if FOR_IS_PARALLEL
185
        float* g = btdb + i * 36 * adim[2];
186
#else
187
0
        float* g = btdb;
188
0
#endif
189
        // zero g such that we can have zero-padding.
190
0
        memset(g, 0, sizeof(float) * 36 * adim[2]);
191
0
        int dx, dy;
192
0
        const float* apz = ap + ccv_max(x - hint.border.begin[1], 0) * astride[2];
193
0
        float* gz = g + (n[0] * 6 + n[1]) * adim[2];
194
0
        unroll_for(dy, m[0], 6) {
195
0
          unroll_for(dx, m[1], 6) {
196
0
            float* const gzu = gz + (dy * 6 + dx) * adim[2];
197
0
            for (c = 0; c < adim[2]; c++)
198
0
              gzu[c] = apz[dx * astride[2] + c];
199
0
          } unroll_endfor
200
0
          apz += astride[1];
201
0
        } unroll_endfor
202
0
        for (c = 0; c < adim[2]; c++)
203
0
        {
204
          /*
205
           * a0, a1, a2, a3, a4, a5,
206
           * b6, b7, b8, b9, b10,l11,
207
           * c12,c13,c14,c15,c16,c17,
208
           * d18,d19,d20,d21,d22,d23,
209
           * e24,e25,e26,e27,e28,e29,
210
           * f30,f31,f32,f33,f34,f35
211
           * {{4 a0 - 5 c12 + e24, 4 a1 - 5 c13 + e25, 4 a2 - 5 c14 + e26, 4 a3 - 5 c15 + e27, 4 a4 - 5 c16 + e28, 4 a5 - 5 c17 + e29},
212
           * {-4 b6 - 4 c12 + d18 + e24, -4 b7 - 4 c13 + d19 + e25, -4 b8 - 4 c14 + d20 + e26, -4 b9 - 4 c15 + d21 + e27, -4 b10 - 4 c16 + d22 + e28, -4 b11 - 4 c17 + d23 + e29},
213
           * {4 b6 - 4 c12 - d18 + e24, 4 b7 - 4 c13 - d19 + e25, 4 b8 - 4 c14 - d20 + e26, 4 b9 - 4 c15 - d21 + e27, 4 b10 - 4 c16 - d22 + e28, 4 b11 - 4 c17 - d23 + e29},
214
           * {-2 b6 - c12 + 2 d18 + e24, -2 b7 - c13 + 2 d19 + e25, -2 b8 - c14 + 2 d20 + e26, -2 b9 - c15 + 2 d21 + e27, -2 b10 - c16 + 2 d22 + e28, -2 b11 - c17 + 2 d23 + e29},
215
           * {2 b6 - c12 - 2 d18 + e24, 2 b7 - c13 - 2 d19 + e25, 2 b8 - c14 - 2 d20 + e26, 2 b9 - c15 - 2 d21 + e27, 2 b10 - c16 - 2 d22 + e28, 2 b11 - c17 - 2 d23 + e29},
216
           * {4 b6 - 5 d18 + f30, 4 b7 - 5 d19 + f31, 4 b8 - 5 d20 + f32, 4 b9 - 5 d21 + f33, 4 b10 - 5 d22 + f34, 4 b11 - 5 d23 + f35}}
217
           */
218
0
          float d[36];
219
          /* BT.d */
220
0
          unroll_for(j, 6) {
221
0
            float g0 = g[j * adim[2]];
222
0
            float g12 = g[(12 + j) * adim[2]];
223
0
            float g24 = g[(24 + j) * adim[2]];
224
            /* row 1 */
225
0
            d[j] = 4 * g0 - 5 * g12 + g24;
226
0
            float g6 = g[(6 + j) * adim[2]];
227
0
            float g18 = g[(18 + j) * adim[2]];
228
            /* row 2 */
229
0
            d[6 + j] = -4 * (g6 + g12) + g18 + g24;
230
            /* row 3 */
231
0
            d[12 + j] = 4 * (g6 - g12) - g18 + g24;
232
            /* row 4 */
233
0
            d[18 + j] = 2 * (g18 - g6) - g12 + g24;
234
            /* row 5 */
235
0
            d[24 + j] = 2 * (g6 - g18) - g12 + g24;
236
0
            float g30 = g[(30 + j) * adim[2]];
237
            /* row 6 */
238
0
            d[30 + j] = 4 * g6 - 5 * g18 + g30;
239
0
          } unroll_endfor
240
          /*
241
           * a0, a1, a2, a3, a4, a5,
242
           * b6, b7, b8, b9, b10,l11,
243
           * c12,c13,c14,c15,c16,c17,
244
           * d18,d19,d20,d21,d22,d23,
245
           * e24,e25,e26,e27,e28,e29,
246
           * f30,f31,f32,f33,f34,f35
247
           * {{4 a0 - 5 a2 + a4, -4 a1 - 4 a2 + a3 + a4, 4 a1 - 4 a2 - a3 + a4, -2 a1 - a2 + 2 a3 + a4, 2 a1 - a2 - 2 a3 + a4, 4 a1 - 5 a3 + a5},
248
           * {b10 + 4 b6 - 5 b8, b10 - 4 b7 - 4 b8 + b9, b10 + 4 b7 - 4 b8 - b9, b10 - 2 b7 - b8 + 2 b9, b10 + 2 b7 - b8 - 2 b9, b11 + 4 b7 - 5 b9},
249
           * {4 c12 - 5 c14 + c16, -4 c13 - 4 c14 + c15 + c16, 4 c13 - 4 c14 - c15 + c16, -2 c13 - c14 + 2 c15 + c16, 2 c13 - c14 - 2 c15 + c16, 4 c13 - 5 c15 + c17},
250
           * {4 d18 - 5 d20 + d22, -4 d19 - 4 d20 + d21 + d22, 4 d19 - 4 d20 - d21 + d22, -2 d19 - d20 + 2 d21 + d22, 2 d19 - d20 - 2 d21 + d22, 4 d19 - 5 d21 + d23},
251
           * {4 e24 - 5 e26 + e28, -4 e25 - 4 e26 + e27 + e28, 4 e25 - 4 e26 - e27 + e28, -2 e25 - e26 + 2 e27 + e28, 2 e25 - e26 - 2 e27 + e28, 4 e25 - 5 e27 + e29},
252
           * {4 f30 - 5 f32 + f34, -4 f31 - 4 f32 + f33 + f34, 4 f31 - 4 f32 - f33 + f34, -2 f31 - f32 + 2 f33 + f34, 2 f31 - f32 - 2 f33 + f34, 4 f31 - 5 f33 + f35}}
253
           */
254
          /* BT.d.B */
255
0
          unroll_for(j, 6) {
256
            /* row 1 - 6 */
257
0
            float* const gz = g + j * 6 * adim[2];
258
0
            float* const dz = d + j * 6;
259
0
            gz[0] = 4 * dz[0] - 5 * dz[2] + dz[4];
260
0
            gz[adim[2]] = -4 * (dz[1] + dz[2]) + dz[3] + dz[4];
261
0
            gz[2 * adim[2]] = 4 * (dz[1] - dz[2]) - dz[3] + dz[4];
262
0
            gz[3 * adim[2]] = 2 * (dz[3] - dz[1]) - dz[2] + dz[4];
263
0
            gz[4 * adim[2]] = 2 * (dz[1] - dz[3]) - dz[2] + dz[4];
264
0
            gz[5 * adim[2]] = 4 * dz[1] - 5 * dz[3] + dz[5];
265
0
          } unroll_endfor
266
          // move to the next channel
267
0
          ++g;
268
0
        }
269
0
        const float* wpz = gwtg;
270
0
        for (k = 0; k < w->info.dim[0]; k++)
271
0
        {
272
0
          float q[36];
273
#if FOR_IS_PARALLEL
274
          g = btdb + i * 36 * adim[2];
275
#else
276
0
          g = btdb;
277
0
#endif
278
0
          for (j = 0; j < 36; j++)
279
0
          {
280
0
            float b = 0;
281
0
            for (c = 0; c < adim[2]; c++)
282
0
              b += g[c] * wpz[c];
283
0
            q[j] = b;
284
0
            g += adim[2];
285
0
            wpz += adim[2];
286
0
          }
287
          /*
288
           * a0, a1, a2, a3, a4, a5,
289
           * b6, b7, b8, b9, b10,l11,
290
           * c12,c13,c14,c15,c16,c17,
291
           * d18,d19,d20,d21,d22,d23,
292
           * e24,e25,e26,e27,e28,e29,
293
           * f30,f31,f32,f33,f34,f35
294
           * {{a0 + b6 + c12 + d18 + e24, a1 + b7 + c13 + d19 + e25, a2 + b8 + c14 + d20 + e26, a3 + b9 + c15 + d21 + e27, a4 + b10 + c16 + d22 + e28, a5 + b11 + c17 + d23 + e29},
295
           * {b6 - c12 + 2 d18 - 2 e24, b7 - c13 + 2 d19 - 2 e25, b8 - c14 + 2 d20 - 2 e26, b9 - c15 + 2 d21 - 2 e27, b10 - c16 + 2 d22 - 2 e28, b11 - c17 + 2 d23 - 2 e29},
296
           * {b6 + c12 + 4 (d18 + e24), b7 + c13 + 4 (d19 + e25), b8 + c14 + 4 (d20 + e26), b9 + c15 + 4 (d21 + e27), b10 + c16 + 4 (d22 + e28), b11 + c17 + 4 (d23 + e29)},
297
           * {b6 - c12 + 8 d18 - 8 e24 + f30, b7 - c13 + 8 d19 - 8 e25 + f31, b8 - c14 + 8 d20 - 8 e26 + f32, b9 - c15 + 8 d21 - 8 e27 + f33, b10 - c16 + 8 d22 - 8 e28 + f34, b11 - c17 + 8 d23 - 8 e29 + f35}}
298
           */
299
0
          float d[24];
300
          /* row 1 */
301
0
          d[0] = q[0] + q[6] + q[12] + q[18] + q[24];
302
0
          d[1] = q[1] + q[7] + q[13] + q[19] + q[25];
303
0
          d[2] = q[2] + q[8] + q[14] + q[20] + q[26];
304
0
          d[3] = q[3] + q[9] + q[15] + q[21] + q[27];
305
0
          d[4] = q[4] + q[10] + q[16] + q[22] + q[28];
306
0
          d[5] = q[5] + q[11] + q[17] + q[23] + q[29];
307
          /* row 2 */
308
0
          d[6] = q[6] - q[12] + 2 * (q[18] - q[24]);
309
0
          d[7] = q[7] - q[13] + 2 * (q[19] - q[25]);
310
0
          d[8] = q[8] - q[14] + 2 * (q[20] - q[26]);
311
0
          d[9] = q[9] - q[15] + 2 * (q[21] - q[27]);
312
0
          d[10] = q[10] - q[16] + 2 * (q[22] - q[28]);
313
0
          d[11] = q[11] - q[17] + 2 * (q[23] - q[29]);
314
          /* row 3 */
315
0
          d[12] = q[6] + q[12] + 4 * (q[18] + q[24]);
316
0
          d[13] = q[7] + q[13] + 4 * (q[19] + q[25]);
317
0
          d[14] = q[8] + q[14] + 4 * (q[20] + q[26]);
318
0
          d[15] = q[9] + q[15] + 4 * (q[21] + q[27]);
319
0
          d[16] = q[10] + q[16] + 4 * (q[22] + q[28]);
320
0
          d[17] = q[11] + q[17] + 4 * (q[23] + q[29]);
321
          /* row 4 */
322
0
          d[18] = q[6] - q[12] + 8 * (q[18] - q[24]) + q[30];
323
0
          d[19] = q[7] - q[13] + 8 * (q[19] - q[25]) + q[31];
324
0
          d[20] = q[8] - q[14] + 8 * (q[20] - q[26]) + q[32];
325
0
          d[21] = q[9] - q[15] + 8 * (q[21] - q[27]) + q[33];
326
0
          d[22] = q[10] - q[16] + 8 * (q[22] - q[28]) + q[34];
327
0
          d[23] = q[11] - q[17] + 8 * (q[23] - q[29]) + q[35];
328
          /*
329
           * {{a0 + a1 + a2 + a3 + a4, a1 - a2 + 2 a3 - 2 a4, a1 + a2 + 4 (a3 + a4), a1 - a2 + 8 a3 - 8 a4 + a5},
330
           * {b10 + b6 + b7 + b8 + b9, -2 b10 + b7 - b8 + 2 b9, 4 b10 + b7 + b8 + 4 b9, -8 b10 + b11 + b7 - b8 + 8 b9},
331
           * {c12 + c13 + c14 + c15 + c16, c13 - c14 + 2 c15 - 2 c16, c13 + c14 + 4 (c15 + c16), c13 - c14 + 8 c15 - 8 c16 + c17},
332
           * {d18 + d19 + d20 + d21 + d22, d19 - d20 + 2 d21 - 2 d22, d19 + d20 + 4 (d21 + d22), d19 - d20 + 8 d21 - 8 d22 + d23}}
333
           */
334
0
          float* bpz = bp + x * bstride[2] + k;
335
0
          unroll_for(dy, z[0], 4) {
336
0
            float r[] = {
337
0
              d[dy * 6 + 0] + d[dy * 6 + 1] + d[dy * 6 + 2] + d[dy * 6 + 3] + d[dy * 6 + 4] + biasval[k],
338
0
              d[dy * 6 + 1] - d[dy * 6 + 2] + 2 * (d[dy * 6 + 3] - d[dy * 6 + 4]) + biasval[k],
339
0
              d[dy * 6 + 1] + d[dy * 6 + 2] + 4 * (d[dy * 6 + 3] + d[dy * 6 + 4]) + biasval[k],
340
0
              d[dy * 6 + 1] - d[dy * 6 + 2] + 8 * (d[dy * 6 + 3] - d[dy * 6 + 4]) + d[dy * 6 + 5] + biasval[k],
341
0
            };
342
0
            unroll_for(dx, z[1], 4) {
343
0
              bpz[dx * bstride[2]] = r[dx];
344
0
            } unroll_endfor
345
0
            bpz += bstride[1];
346
0
          } unroll_endfor
347
0
        }
348
0
      }
349
0
    } parallel_endfor
350
0
  } else {
351
0
    parallel_for(i, jump_dim) {
352
0
      const int y = i * 4; // i is unsigned.
353
0
      int j, x, k, c;
354
0
      int n[CCV_NNC_MAX_DIM];
355
0
      int m[CCV_NNC_MAX_DIM];
356
0
      int z[CCV_NNC_MAX_DIM];
357
0
      set_n_m_dim(y, 0, tile_dim, adim);
358
0
      z[0] = ccv_min(y + 4, bdim[0]) - y;
359
0
      const float* ap = a->data.f32 + ccv_max(y - hint.border.begin[0], 0) * astride[1];
360
0
      float* bp = b->data.f32 + y * bstride[1];
361
0
      for (x = 0; x < bdim[1]; x += 4)
362
0
      {
363
0
        set_n_m_dim(x, 1, tile_dim, adim);
364
0
        z[1] = ccv_min(x + 4, bdim[1]) - x;
365
#if FOR_IS_PARALLEL
366
        float* g = btdb + i * 36 * adim[2];
367
#else
368
0
        float* g = btdb;
369
0
#endif
370
        // zero g such that we can have zero-padding.
371
0
        memset(g, 0, sizeof(float) * 36 * adim[2]);
372
0
        int dx, dy;
373
0
        const float* apz = ap + ccv_max(x - hint.border.begin[1], 0) * astride[2];
374
0
        float* gz = g + (n[0] * 6 + n[1]) * adim[2];
375
0
        unroll_for(dy, m[0], 6) {
376
0
          unroll_for(dx, m[1], 6) {
377
0
            float* const gzu = gz + (dy * 6 + dx) * adim[2];
378
0
            for (c = 0; c < adim[2]; c++)
379
0
              gzu[c] = apz[dx * astride[2] + c];
380
0
          } unroll_endfor
381
0
          apz += astride[1];
382
0
        } unroll_endfor
383
0
        for (c = 0; c < adim[2]; c++)
384
0
        {
385
          /*
386
           * a0, a1, a2, a3, a4, a5,
387
           * b6, b7, b8, b9, b10,l11,
388
           * c12,c13,c14,c15,c16,c17,
389
           * d18,d19,d20,d21,d22,d23,
390
           * e24,e25,e26,e27,e28,e29,
391
           * f30,f31,f32,f33,f34,f35
392
           * {{4 a0 - 5 c12 + e24, 4 a1 - 5 c13 + e25, 4 a2 - 5 c14 + e26, 4 a3 - 5 c15 + e27, 4 a4 - 5 c16 + e28, 4 a5 - 5 c17 + e29},
393
           * {-4 b6 - 4 c12 + d18 + e24, -4 b7 - 4 c13 + d19 + e25, -4 b8 - 4 c14 + d20 + e26, -4 b9 - 4 c15 + d21 + e27, -4 b10 - 4 c16 + d22 + e28, -4 b11 - 4 c17 + d23 + e29},
394
           * {4 b6 - 4 c12 - d18 + e24, 4 b7 - 4 c13 - d19 + e25, 4 b8 - 4 c14 - d20 + e26, 4 b9 - 4 c15 - d21 + e27, 4 b10 - 4 c16 - d22 + e28, 4 b11 - 4 c17 - d23 + e29},
395
           * {-2 b6 - c12 + 2 d18 + e24, -2 b7 - c13 + 2 d19 + e25, -2 b8 - c14 + 2 d20 + e26, -2 b9 - c15 + 2 d21 + e27, -2 b10 - c16 + 2 d22 + e28, -2 b11 - c17 + 2 d23 + e29},
396
           * {2 b6 - c12 - 2 d18 + e24, 2 b7 - c13 - 2 d19 + e25, 2 b8 - c14 - 2 d20 + e26, 2 b9 - c15 - 2 d21 + e27, 2 b10 - c16 - 2 d22 + e28, 2 b11 - c17 - 2 d23 + e29},
397
           * {4 b6 - 5 d18 + f30, 4 b7 - 5 d19 + f31, 4 b8 - 5 d20 + f32, 4 b9 - 5 d21 + f33, 4 b10 - 5 d22 + f34, 4 b11 - 5 d23 + f35}}
398
           */
399
0
          float d[36];
400
          /* BT.d */
401
0
          unroll_for(j, 6) {
402
0
            float g0 = g[j * adim[2]];
403
0
            float g12 = g[(12 + j) * adim[2]];
404
0
            float g24 = g[(24 + j) * adim[2]];
405
            /* row 1 */
406
0
            d[j] = 4 * g0 - 5 * g12 + g24;
407
0
            float g6 = g[(6 + j) * adim[2]];
408
0
            float g18 = g[(18 + j) * adim[2]];
409
            /* row 2 */
410
0
            d[6 + j] = -4 * (g6 + g12) + g18 + g24;
411
            /* row 3 */
412
0
            d[12 + j] = 4 * (g6 - g12) - g18 + g24;
413
            /* row 4 */
414
0
            d[18 + j] = 2 * (g18 - g6) - g12 + g24;
415
            /* row 5 */
416
0
            d[24 + j] = 2 * (g6 - g18) - g12 + g24;
417
0
            float g30 = g[(30 + j) * adim[2]];
418
            /* row 6 */
419
0
            d[30 + j] = 4 * g6 - 5 * g18 + g30;
420
0
          } unroll_endfor
421
          /*
422
           * a0, a1, a2, a3, a4, a5,
423
           * b6, b7, b8, b9, b10,l11,
424
           * c12,c13,c14,c15,c16,c17,
425
           * d18,d19,d20,d21,d22,d23,
426
           * e24,e25,e26,e27,e28,e29,
427
           * f30,f31,f32,f33,f34,f35
428
           * {{4 a0 - 5 a2 + a4, -4 a1 - 4 a2 + a3 + a4, 4 a1 - 4 a2 - a3 + a4, -2 a1 - a2 + 2 a3 + a4, 2 a1 - a2 - 2 a3 + a4, 4 a1 - 5 a3 + a5},
429
           * {b10 + 4 b6 - 5 b8, b10 - 4 b7 - 4 b8 + b9, b10 + 4 b7 - 4 b8 - b9, b10 - 2 b7 - b8 + 2 b9, b10 + 2 b7 - b8 - 2 b9, b11 + 4 b7 - 5 b9},
430
           * {4 c12 - 5 c14 + c16, -4 c13 - 4 c14 + c15 + c16, 4 c13 - 4 c14 - c15 + c16, -2 c13 - c14 + 2 c15 + c16, 2 c13 - c14 - 2 c15 + c16, 4 c13 - 5 c15 + c17},
431
           * {4 d18 - 5 d20 + d22, -4 d19 - 4 d20 + d21 + d22, 4 d19 - 4 d20 - d21 + d22, -2 d19 - d20 + 2 d21 + d22, 2 d19 - d20 - 2 d21 + d22, 4 d19 - 5 d21 + d23},
432
           * {4 e24 - 5 e26 + e28, -4 e25 - 4 e26 + e27 + e28, 4 e25 - 4 e26 - e27 + e28, -2 e25 - e26 + 2 e27 + e28, 2 e25 - e26 - 2 e27 + e28, 4 e25 - 5 e27 + e29},
433
           * {4 f30 - 5 f32 + f34, -4 f31 - 4 f32 + f33 + f34, 4 f31 - 4 f32 - f33 + f34, -2 f31 - f32 + 2 f33 + f34, 2 f31 - f32 - 2 f33 + f34, 4 f31 - 5 f33 + f35}}
434
           */
435
          /* BT.d.B */
436
0
          unroll_for(j, 6) {
437
            /* row 1 - 6 */
438
0
            float* const gz = g + j * 6 * adim[2];
439
0
            float* const dz = d + j * 6;
440
0
            gz[0] = 4 * dz[0] - 5 * dz[2] + dz[4];
441
0
            gz[adim[2]] = -4 * (dz[1] + dz[2]) + dz[3] + dz[4];
442
0
            gz[2 * adim[2]] = 4 * (dz[1] - dz[2]) - dz[3] + dz[4];
443
0
            gz[3 * adim[2]] = 2 * (dz[3] - dz[1]) - dz[2] + dz[4];
444
0
            gz[4 * adim[2]] = 2 * (dz[1] - dz[3]) - dz[2] + dz[4];
445
0
            gz[5 * adim[2]] = 4 * dz[1] - 5 * dz[3] + dz[5];
446
0
          } unroll_endfor
447
          // move to the next channel
448
0
          ++g;
449
0
        }
450
0
        const float* wpz = gwtg;
451
0
        for (k = 0; k < w->info.dim[0]; k++)
452
0
        {
453
0
          float q[36];
454
#if FOR_IS_PARALLEL
455
          g = btdb + i * 36 * adim[2];
456
#else
457
0
          g = btdb;
458
0
#endif
459
0
          for (j = 0; j < 36; j++)
460
0
          {
461
0
            float b = 0;
462
0
            for (c = 0; c < adim[2]; c++)
463
0
              b += g[c] * wpz[c];
464
0
            q[j] = b;
465
0
            g += adim[2];
466
0
            wpz += adim[2];
467
0
          }
468
          /*
469
           * a0, a1, a2, a3, a4, a5,
470
           * b6, b7, b8, b9, b10,l11,
471
           * c12,c13,c14,c15,c16,c17,
472
           * d18,d19,d20,d21,d22,d23,
473
           * e24,e25,e26,e27,e28,e29,
474
           * f30,f31,f32,f33,f34,f35
475
           * {{a0 + b6 + c12 + d18 + e24, a1 + b7 + c13 + d19 + e25, a2 + b8 + c14 + d20 + e26, a3 + b9 + c15 + d21 + e27, a4 + b10 + c16 + d22 + e28, a5 + b11 + c17 + d23 + e29},
476
           * {b6 - c12 + 2 d18 - 2 e24, b7 - c13 + 2 d19 - 2 e25, b8 - c14 + 2 d20 - 2 e26, b9 - c15 + 2 d21 - 2 e27, b10 - c16 + 2 d22 - 2 e28, b11 - c17 + 2 d23 - 2 e29},
477
           * {b6 + c12 + 4 (d18 + e24), b7 + c13 + 4 (d19 + e25), b8 + c14 + 4 (d20 + e26), b9 + c15 + 4 (d21 + e27), b10 + c16 + 4 (d22 + e28), b11 + c17 + 4 (d23 + e29)},
478
           * {b6 - c12 + 8 d18 - 8 e24 + f30, b7 - c13 + 8 d19 - 8 e25 + f31, b8 - c14 + 8 d20 - 8 e26 + f32, b9 - c15 + 8 d21 - 8 e27 + f33, b10 - c16 + 8 d22 - 8 e28 + f34, b11 - c17 + 8 d23 - 8 e29 + f35}}
479
           */
480
0
          float d[24];
481
          /* row 1 */
482
0
          d[0] = q[0] + q[6] + q[12] + q[18] + q[24];
483
0
          d[1] = q[1] + q[7] + q[13] + q[19] + q[25];
484
0
          d[2] = q[2] + q[8] + q[14] + q[20] + q[26];
485
0
          d[3] = q[3] + q[9] + q[15] + q[21] + q[27];
486
0
          d[4] = q[4] + q[10] + q[16] + q[22] + q[28];
487
0
          d[5] = q[5] + q[11] + q[17] + q[23] + q[29];
488
          /* row 2 */
489
0
          d[6] = q[6] - q[12] + 2 * (q[18] - q[24]);
490
0
          d[7] = q[7] - q[13] + 2 * (q[19] - q[25]);
491
0
          d[8] = q[8] - q[14] + 2 * (q[20] - q[26]);
492
0
          d[9] = q[9] - q[15] + 2 * (q[21] - q[27]);
493
0
          d[10] = q[10] - q[16] + 2 * (q[22] - q[28]);
494
0
          d[11] = q[11] - q[17] + 2 * (q[23] - q[29]);
495
          /* row 3 */
496
0
          d[12] = q[6] + q[12] + 4 * (q[18] + q[24]);
497
0
          d[13] = q[7] + q[13] + 4 * (q[19] + q[25]);
498
0
          d[14] = q[8] + q[14] + 4 * (q[20] + q[26]);
499
0
          d[15] = q[9] + q[15] + 4 * (q[21] + q[27]);
500
0
          d[16] = q[10] + q[16] + 4 * (q[22] + q[28]);
501
0
          d[17] = q[11] + q[17] + 4 * (q[23] + q[29]);
502
          /* row 4 */
503
0
          d[18] = q[6] - q[12] + 8 * (q[18] - q[24]) + q[30];
504
0
          d[19] = q[7] - q[13] + 8 * (q[19] - q[25]) + q[31];
505
0
          d[20] = q[8] - q[14] + 8 * (q[20] - q[26]) + q[32];
506
0
          d[21] = q[9] - q[15] + 8 * (q[21] - q[27]) + q[33];
507
0
          d[22] = q[10] - q[16] + 8 * (q[22] - q[28]) + q[34];
508
0
          d[23] = q[11] - q[17] + 8 * (q[23] - q[29]) + q[35];
509
          /*
510
           * {{a0 + a1 + a2 + a3 + a4, a1 - a2 + 2 a3 - 2 a4, a1 + a2 + 4 (a3 + a4), a1 - a2 + 8 a3 - 8 a4 + a5},
511
           * {b10 + b6 + b7 + b8 + b9, -2 b10 + b7 - b8 + 2 b9, 4 b10 + b7 + b8 + 4 b9, -8 b10 + b11 + b7 - b8 + 8 b9},
512
           * {c12 + c13 + c14 + c15 + c16, c13 - c14 + 2 c15 - 2 c16, c13 + c14 + 4 (c15 + c16), c13 - c14 + 8 c15 - 8 c16 + c17},
513
           * {d18 + d19 + d20 + d21 + d22, d19 - d20 + 2 d21 - 2 d22, d19 + d20 + 4 (d21 + d22), d19 - d20 + 8 d21 - 8 d22 + d23}}
514
           */
515
0
          float* bpz = bp + x * bstride[2] + k;
516
0
          unroll_for(dy, z[0], 4) {
517
0
            float r[] = {
518
0
              d[dy * 6 + 0] + d[dy * 6 + 1] + d[dy * 6 + 2] + d[dy * 6 + 3] + d[dy * 6 + 4],
519
0
              d[dy * 6 + 1] - d[dy * 6 + 2] + 2 * (d[dy * 6 + 3] - d[dy * 6 + 4]),
520
0
              d[dy * 6 + 1] + d[dy * 6 + 2] + 4 * (d[dy * 6 + 3] + d[dy * 6 + 4]),
521
0
              d[dy * 6 + 1] - d[dy * 6 + 2] + 8 * (d[dy * 6 + 3] - d[dy * 6 + 4]) + d[dy * 6 + 5],
522
0
            };
523
0
            unroll_for(dx, z[1], 4) {
524
0
              bpz[dx * bstride[2]] = r[dx];
525
0
            } unroll_endfor
526
0
            bpz += bstride[1];
527
0
          } unroll_endfor
528
0
        }
529
0
      }
530
0
    } parallel_endfor
531
0
  }
532
0
  return CCV_NNC_EXEC_SUCCESS;
533
0
}
534
535
#ifdef HAVE_SSE2
536
inline static void _ccv_nnc_winograd_4x4_3x3_gwtg_sse2(const float* const w, const int* const dim, float* const gwtg)
537
119
{
538
119
  const int jump_dim = dim[0] / 4;
539
119
  const int dimCx4 = (dim[3] + 3) & -4;
540
7.56k
  
parallel_for119
(k, jump_dim) {
541
7.56k
    int i, j;
542
7.56k
    float* gwtgz = gwtg + k * 4 * 36 * dimCx4;
543
7.56k
    const float* wz[] = {
544
7.56k
      w + (k * 4) * 9 * dim[3],
545
7.56k
      w + (k * 4 + 1) * 9 * dim[3],
546
7.56k
      w + (k * 4 + 2) * 9 * dim[3],
547
7.56k
      w + (k * 4 + 3) * 9 * dim[3],
548
7.56k
    };
549
2.88M
    for (i = 0; i < dim[3]; 
i++2.87M
)
550
2.87M
    {
551
2.87M
      float x9w[9 * 4] __attribute__ ((__aligned__(16)));
552
25.8M
      unroll_for(j, 9) {
553
25.8M
        x9w[j * 4] = wz[0][j * dim[3] + i];
554
25.8M
        x9w[j * 4 + 1] = wz[1][j * dim[3] + i];
555
25.8M
        x9w[j * 4 + 2] = wz[2][j * dim[3] + i];
556
25.8M
        x9w[j * 4 + 3] = wz[3][j * dim[3] + i];
557
25.8M
      } unroll_endfor
558
2.87M
      float g[18 * 4] __attribute__ ((__aligned__(16)));
559
2.87M
      __m128 x9w0 = _mm_load_ps(x9w);
560
2.87M
      __m128 x9w1 = _mm_load_ps(x9w + 4);
561
2.87M
      __m128 x9w2 = _mm_load_ps(x9w + 8);
562
2.87M
      __m128 x9w3 = _mm_load_ps(x9w + 12);
563
2.87M
      __m128 x9w4 = _mm_load_ps(x9w + 16);
564
2.87M
      __m128 x9w5 = _mm_load_ps(x9w + 20);
565
2.87M
      __m128 x9w6 = _mm_load_ps(x9w + 24);
566
2.87M
      __m128 x9w7 = _mm_load_ps(x9w + 28);
567
2.87M
      __m128 x9w8 = _mm_load_ps(x9w + 32);
568
      /* row 1 */
569
2.87M
      __m128 c1_4 = _mm_set1_ps(1.0 / 4);
570
2.87M
      _mm_store_ps(g, _mm_mul_ps(x9w0, c1_4));
571
2.87M
      _mm_store_ps(g + 4, _mm_mul_ps(x9w1, c1_4));
572
2.87M
      _mm_store_ps(g + 8, _mm_mul_ps(x9w2, c1_4));
573
      /* row 2 */
574
2.87M
      __m128 cn1_6 = _mm_set1_ps(-1.0 / 6);
575
2.87M
      _mm_store_ps(g + 12, _mm_mul_ps(_mm_add_ps(_mm_add_ps(x9w0, x9w6), x9w3), cn1_6));
576
2.87M
      _mm_store_ps(g + 16, _mm_mul_ps(_mm_add_ps(_mm_add_ps(x9w1, x9w7), x9w4), cn1_6));
577
2.87M
      _mm_store_ps(g + 20, _mm_mul_ps(_mm_add_ps(_mm_add_ps(x9w2, x9w8), x9w5), cn1_6));
578
      /* row 3 */
579
2.87M
      _mm_store_ps(g + 24, _mm_mul_ps(_mm_sub_ps(_mm_add_ps(x9w0, x9w6), x9w3), cn1_6));
580
2.87M
      _mm_store_ps(g + 28, _mm_mul_ps(_mm_sub_ps(_mm_add_ps(x9w1, x9w7), x9w4), cn1_6));
581
2.87M
      _mm_store_ps(g + 32, _mm_mul_ps(_mm_sub_ps(_mm_add_ps(x9w2, x9w8), x9w5), cn1_6));
582
      /* row 6 */
583
2.87M
      _mm_store_ps(g + 60, x9w6);
584
2.87M
      _mm_store_ps(g + 64, x9w7);
585
2.87M
      _mm_store_ps(g + 68, x9w8);
586
      /* w[x] * 2 */
587
2.87M
      x9w3 = _mm_add_ps(x9w3, x9w3);
588
2.87M
      x9w4 = _mm_add_ps(x9w4, x9w4);
589
2.87M
      x9w5 = _mm_add_ps(x9w5, x9w5);
590
      /* w[x] * 4 */
591
2.87M
      x9w6 = _mm_add_ps(x9w6, x9w6);
592
2.87M
      x9w6 = _mm_add_ps(x9w6, x9w6);
593
2.87M
      x9w7 = _mm_add_ps(x9w7, x9w7);
594
2.87M
      x9w7 = _mm_add_ps(x9w7, x9w7);
595
2.87M
      x9w8 = _mm_add_ps(x9w8, x9w8);
596
2.87M
      x9w8 = _mm_add_ps(x9w8, x9w8);
597
      /* row 4 */
598
2.87M
      __m128 c1_24 = _mm_set1_ps(1.0 / 24);
599
2.87M
      _mm_store_ps(g + 36, _mm_mul_ps(_mm_add_ps(_mm_add_ps(x9w0, x9w6), x9w3), c1_24));
600
2.87M
      _mm_store_ps(g + 40, _mm_mul_ps(_mm_add_ps(_mm_add_ps(x9w1, x9w7), x9w4), c1_24));
601
2.87M
      _mm_store_ps(g + 44, _mm_mul_ps(_mm_add_ps(_mm_add_ps(x9w2, x9w8), x9w5), c1_24));
602
      /* row 5 */
603
2.87M
      _mm_store_ps(g + 48, _mm_mul_ps(_mm_sub_ps(_mm_add_ps(x9w0, x9w6), x9w3), c1_24));
604
2.87M
      _mm_store_ps(g + 52, _mm_mul_ps(_mm_sub_ps(_mm_add_ps(x9w1, x9w7), x9w4), c1_24));
605
2.87M
      _mm_store_ps(g + 56, _mm_mul_ps(_mm_sub_ps(_mm_add_ps(x9w2, x9w8), x9w5), c1_24));
606
17.2M
      unroll_for(j, 6) {
607
17.2M
        const float* const gz = g + j * 12;
608
17.2M
        float* const gwtgzu = gwtgz + j * 24 * dimCx4;
609
17.2M
        __m128 g0 = _mm_load_ps(gz);
610
17.2M
        __m128 g1 = _mm_load_ps(gz + 4);
611
17.2M
        __m128 g2 = _mm_load_ps(gz + 8);
612
17.2M
        _mm_store_ps(gwtgzu, _mm_mul_ps(g0, c1_4));
613
17.2M
        _mm_store_ps(gwtgzu + 4 * dimCx4, _mm_mul_ps(_mm_add_ps(_mm_add_ps(g0, g2), g1), cn1_6));
614
17.2M
        _mm_store_ps(gwtgzu + 8 * dimCx4, _mm_mul_ps(_mm_sub_ps(_mm_add_ps(g0, g2), g1), cn1_6));
615
17.2M
        _mm_store_ps(gwtgzu + 20 * dimCx4, g2);
616
        /* g[1] * 2 */
617
17.2M
        g1 = _mm_add_ps(g1, g1);
618
        /* g[2] * 4 */
619
17.2M
        g2 = _mm_add_ps(g2, g2);
620
17.2M
        g2 = _mm_add_ps(g2, g2);
621
17.2M
        _mm_store_ps(gwtgzu + 12 * dimCx4, _mm_mul_ps(_mm_add_ps(_mm_add_ps(g0, g2), g1), c1_24));
622
17.2M
        _mm_store_ps(gwtgzu + 16 * dimCx4, _mm_mul_ps(_mm_sub_ps(_mm_add_ps(g0, g2), g1), c1_24));
623
17.2M
      } unroll_endfor
624
2.87M
      gwtgz += 4;
625
2.87M
    }
626
7.56k
  } parallel_endfor
627
119
}
628
629
static int _ccv_nnc_conv_forw_4x4_3x3_winograd_sse2(const ccv_nnc_tensor_view_t* const a, const ccv_nnc_tensor_t* const w, const ccv_nnc_tensor_t* const bias, const ccv_nnc_hint_t hint, ccv_nnc_tensor_view_t* const b, ccv_nnc_stream_context_t* const stream_context)
630
119
{
631
119
  const int a_nd = ccv_nnc_tensor_nd(a->info.dim);
632
119
  assert(a_nd == CCV_NNC_MAX_DIM + 1 || a_nd == CCV_NNC_MAX_DIM + 2);
633
119
  const int* adim = (a_nd == CCV_NNC_MAX_DIM + 1) ? a->info.dim : 
a->info.dim + 10
;
634
119
  const int b_nd = ccv_nnc_tensor_nd(b->info.dim);
635
119
  assert(b_nd == CCV_NNC_MAX_DIM + 1 || b_nd == CCV_NNC_MAX_DIM + 2);
636
119
  const int* bdim = (b_nd == CCV_NNC_MAX_DIM + 1) ? b->info.dim : 
b->info.dim + 10
;
637
119
  int astride[CCV_NNC_MAX_DIM_ALLOC];
638
119
  ccv_nnc_tensor_view_get_stride(a, astride);
639
119
  int bstride[CCV_NNC_MAX_DIM_ALLOC];
640
119
  ccv_nnc_tensor_view_get_stride(b, bstride);
641
119
  assert(hint.border.begin[0] <= 1);
642
119
  assert(hint.border.begin[1] <= 1);
643
119
  assert(w->info.dim[0] % 4 == 0);
644
119
  assert(w->info.dim[1] == 3);
645
119
  assert(w->info.dim[2] == 3);
646
119
  const int jump_dim = (bdim[0] + 3) / 4;
647
119
  const int dimCx4 = (adim[2] + 3) & -4;
648
  // allocating workspace memory for kernel reshaping and input reshaping.
649
119
  float* workmem = 0;
650
#if FOR_IS_PARALLEL
651
  // If we do parallel for, we need to allocate input reshaping for each block.
652
  workmem = ccv_nnc_stream_context_get_workspace(stream_context, sizeof(float) * (36 * dimCx4 * jump_dim + 36 * dimCx4 * w->info.dim[0]), CCV_TENSOR_CPU_MEMORY);
653
#else
654
  // Otherwise, just one block.
655
119
  workmem = ccv_nnc_stream_context_get_workspace(stream_context, sizeof(float) * (36 * dimCx4 + 36 * dimCx4 * w->info.dim[0]), CCV_TENSOR_CPU_MEMORY);
656
119
#endif
657
119
  if (!workmem)
658
0
    return CCV_NNC_EXEC_OOM;
659
  // Convert w to a 6x6 matrix, by computing G.w.T(G) // T for transpose.
660
119
  float* const gwtg = workmem;
661
119
  float* const btdb = workmem + 36 * dimCx4 * w->info.dim[0];
662
119
  memset(gwtg, 0, sizeof(float) * 36 * dimCx4 * w->info.dim[0]);
663
119
  _ccv_nnc_winograd_4x4_3x3_gwtg_sse2(w->data.f32, w->info.dim, gwtg);
664
  // kernel weight for one dim.
665
  // Workaround issues of dispatch_apply (cannot reference to on-stack array)
666
119
  const int tile_dim_s[CCV_NNC_MAX_DIM_ALLOC] = {
667
119
    w->info.dim[0], 6, 6, w->info.dim[3]
668
119
  };
669
119
  const int* const tile_dim = tile_dim_s;
670
119
  if (bias)
671
118
  {
672
118
    const float* const biasval = bias->data.f32;
673
    // This block will be cause in each for-loop, therefore, you can use it to generate some temporary variables.
674
1.83k
    
parallel_for118
(i, jump_dim) {
675
1.83k
      const int y = i * 4; // i is unsigned.
676
1.83k
      int j, x, k, c;
677
1.83k
      int n[CCV_NNC_MAX_DIM];
678
1.83k
      int m[CCV_NNC_MAX_DIM];
679
1.83k
      int z[CCV_NNC_MAX_DIM];
680
1.83k
      set_n_m_dim(y, 0, tile_dim, adim);
681
1.83k
      z[0] = ccv_min(y + 4, bdim[0]) - y;
682
1.83k
      const float* ap = a->data.f32 + ccv_max(y - hint.border.begin[0], 0) * astride[1];
683
1.83k
      float* bp = b->data.f32 + y * bstride[1];
684
65.8k
      for (x = 0; x < bdim[1]; 
x += 463.9k
)
685
63.9k
      {
686
63.9k
        set_n_m_dim(x, 1, tile_dim, adim);
687
63.9k
        z[1] = ccv_min(x + 4, bdim[1]) - x;
688
#if FOR_IS_PARALLEL
689
        float* g = btdb + i * 36 * dimCx4;
690
#else
691
63.9k
        float* g = btdb;
692
63.9k
#endif
693
        // zero g such that we can have zero-padding.
694
63.9k
        memset(g, 0, sizeof(float) * 36 * dimCx4);
695
63.9k
        int dx, dy;
696
63.9k
        const float* apz = ap + ccv_max(x - hint.border.begin[1], 0) * astride[2];
697
63.9k
        float* gz = g + (n[0] * 6 + n[1]) * dimCx4;
698
379k
        unroll_for(dy, m[0], 6) {
699
2.24M
          unroll_for(dx, m[1], 6) {
700
2.24M
            float* const gzu = gz + (dy * 6 + dx) * dimCx4;
701
139M
            for (c = 0; c < adim[2]; 
c++137M
)
702
137M
              gzu[c] = apz[dx * astride[2] + c];
703
2.24M
          } unroll_endfor
704
379k
          apz += astride[1];
705
379k
        } unroll_endfor
706
1.08M
        for (c = 0; c < adim[2]; 
c += 41.02M
)
707
1.02M
        {
708
1.02M
          float d[36 * 4]  __attribute__ ((__aligned__(16)));
709
          /* BT.d */
710
6.14M
          unroll_for(j, 6) {
711
            /* row 1 */
712
6.14M
            const float* const gz = g + j * dimCx4;
713
6.14M
            float* dz = d + j * 4;
714
6.14M
            __m128 g0 = _mm_load_ps(gz);
715
6.14M
            __m128 g12 = _mm_load_ps(gz + 12 * dimCx4);
716
6.14M
            __m128 g18 = _mm_load_ps(gz + 18 * dimCx4);
717
6.14M
            __m128 g24 = _mm_load_ps(gz + 24 * dimCx4);
718
6.14M
            g0 = _mm_add_ps(g0, g0);
719
6.14M
            g0 = _mm_add_ps(g0, g0);
720
6.14M
            __m128 g12x2 = _mm_add_ps(g12, g12);
721
6.14M
            g12x2 = _mm_add_ps(g12x2, g12x2);
722
6.14M
            g12x2 = _mm_add_ps(g12x2, g12);
723
6.14M
            _mm_store_ps(dz, _mm_sub_ps(_mm_add_ps(g0, g24), g12x2));
724
            /* row 2 */
725
6.14M
            __m128 g6 = _mm_load_ps(gz + 6 * dimCx4);
726
6.14M
            __m128 g6x12 = _mm_add_ps(g6, g12);
727
6.14M
            g6x12 = _mm_add_ps(g6x12, g6x12);
728
6.14M
            g6x12 = _mm_add_ps(g6x12, g6x12);
729
6.14M
            _mm_store_ps(dz + 24, _mm_sub_ps(_mm_add_ps(g18, g24), g6x12));
730
            /* row 3 */
731
6.14M
            g6x12 = _mm_sub_ps(g6, g12);
732
6.14M
            g6x12 = _mm_add_ps(g6x12, g6x12);
733
6.14M
            g6x12 = _mm_add_ps(g6x12, g6x12);
734
6.14M
            _mm_store_ps(dz + 48, _mm_add_ps(_mm_sub_ps(g24, g18), g6x12));
735
            /* row 4 */
736
6.14M
            __m128 g18x6 = _mm_sub_ps(g18, g6);
737
6.14M
            g18x6 = _mm_add_ps(g18x6, g18x6);
738
6.14M
            _mm_store_ps(dz + 72, _mm_add_ps(_mm_sub_ps(g24, g12), g18x6));
739
            /* row 5 */
740
6.14M
            _mm_store_ps(dz + 96, _mm_sub_ps(_mm_sub_ps(g24, g12), g18x6));
741
            /* row 6 */
742
6.14M
            __m128 g30 = _mm_load_ps(gz + 30 * dimCx4);
743
6.14M
            __m128 g18x2 = _mm_add_ps(g18, g18);
744
6.14M
            g18x2 = _mm_add_ps(g18x2, g18x2);
745
6.14M
            g18x2 = _mm_add_ps(g18, g18x2);
746
6.14M
            g6 = _mm_add_ps(g6, g6);
747
6.14M
            g6 = _mm_add_ps(g6, g6);
748
6.14M
            _mm_store_ps(dz + 120, _mm_sub_ps(_mm_add_ps(g6, g30), g18x2));
749
6.14M
          } unroll_endfor
750
          /* BT.d.B */
751
6.14M
          unroll_for(j, 6) {
752
6.14M
            float* gz = g + j * 6 * dimCx4;
753
6.14M
            const float* const dz = d + j * 24;
754
6.14M
            __m128 d0 = _mm_load_ps(dz);
755
6.14M
            __m128 d1 = _mm_load_ps(dz + 4);
756
6.14M
            __m128 d2 = _mm_load_ps(dz + 8);
757
6.14M
            __m128 d3 = _mm_load_ps(dz + 12);
758
6.14M
            __m128 d4 = _mm_load_ps(dz + 16);
759
6.14M
            __m128 d5 = _mm_load_ps(dz + 20);
760
6.14M
            d0 = _mm_add_ps(d0, d0);
761
6.14M
            d0 = _mm_add_ps(d0, d0);
762
6.14M
            __m128 d2x5 = _mm_add_ps(d2, d2);
763
6.14M
            d2x5 = _mm_add_ps(d2x5, d2x5);
764
6.14M
            d2x5 = _mm_add_ps(d2, d2x5);
765
6.14M
            _mm_store_ps(gz, _mm_sub_ps(_mm_add_ps(d0, d4), d2x5));
766
6.14M
            __m128 d1x2 = _mm_add_ps(d1, d2);
767
6.14M
            d1x2 = _mm_add_ps(d1x2, d1x2);
768
6.14M
            d1x2 = _mm_add_ps(d1x2, d1x2);
769
6.14M
            _mm_store_ps(gz + dimCx4, _mm_sub_ps(_mm_add_ps(d3, d4), d1x2));
770
6.14M
            d1x2 = _mm_sub_ps(d1, d2);
771
6.14M
            d1x2 = _mm_add_ps(d1x2, d1x2);
772
6.14M
            d1x2 = _mm_add_ps(d1x2, d1x2);
773
6.14M
            _mm_store_ps(gz + 2 * dimCx4, _mm_add_ps(_mm_sub_ps(d4, d3), d1x2));
774
6.14M
            __m128 d3x1 = _mm_sub_ps(d3, d1);
775
6.14M
            d3x1 = _mm_add_ps(d3x1, d3x1);
776
6.14M
            _mm_store_ps(gz + 3 * dimCx4, _mm_add_ps(_mm_sub_ps(d4, d2), d3x1));
777
6.14M
            _mm_store_ps(gz + 4 * dimCx4, _mm_sub_ps(_mm_sub_ps(d4, d2), d3x1));
778
6.14M
            d1 = _mm_add_ps(d1, d1);
779
6.14M
            d1 = _mm_add_ps(d1, d1);
780
6.14M
            __m128 d3x5 = _mm_add_ps(d3, d3);
781
6.14M
            d3x5 = _mm_add_ps(d3x5, d3x5);
782
6.14M
            d3x5 = _mm_add_ps(d3, d3x5);
783
6.14M
            _mm_store_ps(gz + 5 * dimCx4, _mm_sub_ps(_mm_add_ps(d1, d5), d3x5));
784
6.14M
          } unroll_endfor
785
          // move to the next channel
786
1.02M
          g += 4;
787
1.02M
        }
788
63.9k
        const float* wpz = gwtg;
789
1.66M
        for (k = 0; k < w->info.dim[0]; 
k += 41.60M
)
790
1.60M
        {
791
1.60M
          float q[36 * 4] __attribute__ ((__aligned__(16)));
792
#if FOR_IS_PARALLEL
793
          g = btdb + i * 36 * dimCx4;
794
#else
795
1.60M
          g = btdb;
796
1.60M
#endif
797
59.3M
          for (j = 0; j < 36; 
j++57.7M
)
798
57.7M
          {
799
57.7M
            __m128 v40 = _mm_setzero_ps();
800
57.7M
            __m128 v41 = _mm_setzero_ps();
801
57.7M
            __m128 v42 = _mm_setzero_ps();
802
57.7M
            __m128 v43 = _mm_setzero_ps();
803
1.80G
            for (c = 0; c < adim[2]; 
c += 41.74G
)
804
1.74G
            {
805
1.74G
              __m128 g4 = _mm_load_ps(g);
806
1.74G
              __m128 w40 = _mm_load_ps(wpz);
807
1.74G
              __m128 w41 = _mm_load_ps(wpz + 4);
808
1.74G
              __m128 w42 = _mm_load_ps(wpz + 8);
809
1.74G
              __m128 w43 = _mm_load_ps(wpz + 12);
810
1.74G
              __m128 g40 = _mm_shuffle_ps(g4, g4, 0x00);
811
1.74G
              __m128 g41 = _mm_shuffle_ps(g4, g4, 0x55);
812
1.74G
              __m128 g42 = _mm_shuffle_ps(g4, g4, 0xAA);
813
1.74G
              __m128 g43 = _mm_shuffle_ps(g4, g4, 0xFF);
814
1.74G
              v40 = _mm_add_ps(_mm_mul_ps(w40, g40), v40);
815
1.74G
              v41 = _mm_add_ps(_mm_mul_ps(w41, g41), v41);
816
1.74G
              v42 = _mm_add_ps(_mm_mul_ps(w42, g42), v42);
817
1.74G
              v43 = _mm_add_ps(_mm_mul_ps(w43, g43), v43);
818
1.74G
              g += 4;
819
1.74G
              wpz += 16;
820
1.74G
            }
821
57.7M
            v40 = _mm_add_ps(v40, v41);
822
57.7M
            v42 = _mm_add_ps(v42, v43);
823
57.7M
            _mm_store_ps(q + j * 4, _mm_add_ps(v40, v42));
824
57.7M
          }
825
1.60M
          float d[24 * 4] __attribute__ ((__aligned__(16)));
826
9.62M
          unroll_for(j, 6) {
827
9.62M
            const float* const qz = q + j * 4;
828
9.62M
            float* const dz = d + j * 4;
829
9.62M
            __m128 q0 = _mm_load_ps(qz);
830
9.62M
            __m128 q6 = _mm_load_ps(qz + 24);
831
9.62M
            __m128 q12 = _mm_load_ps(qz + 48);
832
9.62M
            __m128 q18 = _mm_load_ps(qz + 72);
833
9.62M
            __m128 q24 = _mm_load_ps(qz + 96);
834
9.62M
            __m128 qs6x12 = _mm_add_ps(q6, q12);
835
9.62M
            __m128 qs18x24 = _mm_add_ps(q18, q24);
836
9.62M
            __m128 qss = _mm_add_ps(qs6x12, q0);
837
            /* row 1 */
838
9.62M
            _mm_store_ps(dz, _mm_add_ps(qss, qs18x24));
839
9.62M
            __m128 qn6x12 = _mm_sub_ps(q6, q12);
840
9.62M
            __m128 qn18x24 = _mm_sub_ps(q18, q24);
841
9.62M
            qn18x24 = _mm_add_ps(qn18x24, qn18x24);
842
            /* row 2 */
843
9.62M
            _mm_store_ps(dz + 24, _mm_add_ps(qn6x12, qn18x24));
844
9.62M
            qs18x24 = _mm_add_ps(qs18x24, qs18x24);
845
9.62M
            qs18x24 = _mm_add_ps(qs18x24, qs18x24);
846
            /* row 3 */
847
9.62M
            _mm_store_ps(dz + 48, _mm_add_ps(qs6x12, qs18x24));
848
9.62M
            qn18x24 = _mm_add_ps(qn18x24, qn18x24);
849
9.62M
            qn18x24 = _mm_add_ps(qn18x24, qn18x24);
850
9.62M
            __m128 q30 = _mm_load_ps(qz + 120);
851
            /* row 4 */
852
9.62M
            _mm_store_ps(dz + 72, _mm_add_ps(_mm_add_ps(qn6x12, q30), qn18x24));
853
9.62M
          } unroll_endfor
854
1.60M
          float* bpz = bp + x * bstride[2] + k;
855
1.60M
          __m128 bias4 = _mm_loadu_ps(biasval + k);
856
1.60M
          switch (z[1]) {
857
10.8k
            case 1:
858
35.1k
              unroll_for(dy, z[0], 4) {
859
35.1k
                const float* const dz = d + dy * 24;
860
35.1k
                __m128 d0 = _mm_load_ps(dz);
861
35.1k
                __m128 d1 = _mm_load_ps(dz + 4);
862
35.1k
                __m128 d2 = _mm_load_ps(dz + 8);
863
35.1k
                __m128 d3 = _mm_load_ps(dz + 12);
864
35.1k
                __m128 d4 = _mm_load_ps(dz + 16);
865
35.1k
                __m128 ds1x2 = _mm_add_ps(d1, d2);
866
35.1k
                __m128 ds3x4 = _mm_add_ps(d3, d4);
867
35.1k
                ds1x2 = _mm_add_ps(ds1x2, bias4);
868
35.1k
                _mm_stream_ps(bpz, _mm_add_ps(ds1x2, _mm_add_ps(d0, ds3x4)));
869
35.1k
                bpz += bstride[1];
870
35.1k
              } unroll_endfor
871
10.8k
              break;
872
0
            case 2:
873
0
              unroll_for(dy, z[0], 4) {
874
0
                const float* const dz = d + dy * 24;
875
0
                __m128 d0 = _mm_load_ps(dz);
876
0
                __m128 d1 = _mm_load_ps(dz + 4);
877
0
                __m128 d2 = _mm_load_ps(dz + 8);
878
0
                __m128 d3 = _mm_load_ps(dz + 12);
879
0
                __m128 d4 = _mm_load_ps(dz + 16);
880
0
                __m128 ds1x2 = _mm_add_ps(d1, d2);
881
0
                __m128 ds3x4 = _mm_add_ps(d3, d4);
882
0
                ds1x2 = _mm_add_ps(ds1x2, bias4);
883
0
                _mm_stream_ps(bpz, _mm_add_ps(ds1x2, _mm_add_ps(d0, ds3x4)));
884
0
                __m128 dn1x2 = _mm_sub_ps(d1, d2);
885
0
                __m128 dn3x4 = _mm_sub_ps(d3, d4);
886
0
                dn3x4 = _mm_add_ps(dn3x4, dn3x4);
887
0
                dn1x2 = _mm_add_ps(dn1x2, bias4);
888
0
                _mm_stream_ps(bpz + bstride[2], _mm_add_ps(dn1x2, dn3x4));
889
0
                bpz += bstride[1];
890
0
              } unroll_endfor
891
0
              break;
892
63.1k
            case 3:
893
247k
              unroll_for(dy, z[0], 4) {
894
247k
                const float* const dz = d + dy * 24;
895
247k
                __m128 d0 = _mm_load_ps(dz);
896
247k
                __m128 d1 = _mm_load_ps(dz + 4);
897
247k
                __m128 d2 = _mm_load_ps(dz + 8);
898
247k
                __m128 d3 = _mm_load_ps(dz + 12);
899
247k
                __m128 d4 = _mm_load_ps(dz + 16);
900
247k
                __m128 ds1x2 = _mm_add_ps(d1, d2);
901
247k
                __m128 ds3x4 = _mm_add_ps(d3, d4);
902
247k
                ds1x2 = _mm_add_ps(ds1x2, bias4);
903
247k
                _mm_stream_ps(bpz, _mm_add_ps(ds1x2, _mm_add_ps(d0, ds3x4)));
904
247k
                __m128 dn1x2 = _mm_sub_ps(d1, d2);
905
247k
                __m128 dn3x4 = _mm_sub_ps(d3, d4);
906
247k
                dn3x4 = _mm_add_ps(dn3x4, dn3x4);
907
247k
                dn1x2 = _mm_add_ps(dn1x2, bias4);
908
247k
                _mm_stream_ps(bpz + bstride[2], _mm_add_ps(dn1x2, dn3x4));
909
247k
                ds3x4 = _mm_add_ps(ds3x4, ds3x4);
910
247k
                ds3x4 = _mm_add_ps(ds3x4, ds3x4);
911
247k
                _mm_stream_ps(bpz + 2 * bstride[2], _mm_add_ps(ds1x2, ds3x4));
912
247k
                bpz += bstride[1];
913
247k
              } unroll_endfor
914
63.1k
              break;
915
1.53M
            case 4:
916
6.04M
              unroll_for(dy, z[0], 4) {
917
6.04M
                const float* const dz = d + dy * 24;
918
6.04M
                __m128 d0 = _mm_load_ps(dz);
919
6.04M
                __m128 d1 = _mm_load_ps(dz + 4);
920
6.04M
                __m128 d2 = _mm_load_ps(dz + 8);
921
6.04M
                __m128 d3 = _mm_load_ps(dz + 12);
922
6.04M
                __m128 d4 = _mm_load_ps(dz + 16);
923
6.04M
                __m128 ds1x2 = _mm_add_ps(d1, d2);
924
6.04M
                __m128 ds3x4 = _mm_add_ps(d3, d4);
925
6.04M
                ds1x2 = _mm_add_ps(ds1x2, bias4);
926
6.04M
                _mm_stream_ps(bpz, _mm_add_ps(ds1x2, _mm_add_ps(d0, ds3x4)));
927
6.04M
                __m128 dn1x2 = _mm_sub_ps(d1, d2);
928
6.04M
                __m128 dn3x4 = _mm_sub_ps(d3, d4);
929
6.04M
                dn3x4 = _mm_add_ps(dn3x4, dn3x4);
930
6.04M
                dn1x2 = _mm_add_ps(dn1x2, bias4);
931
6.04M
                _mm_stream_ps(bpz + bstride[2], _mm_add_ps(dn1x2, dn3x4));
932
6.04M
                ds3x4 = _mm_add_ps(ds3x4, ds3x4);
933
6.04M
                ds3x4 = _mm_add_ps(ds3x4, ds3x4);
934
6.04M
                _mm_stream_ps(bpz + 2 * bstride[2], _mm_add_ps(ds1x2, ds3x4));
935
6.04M
                __m128 d5 = _mm_load_ps(dz + 20);
936
6.04M
                dn3x4 = _mm_add_ps(dn3x4, dn3x4);
937
6.04M
                dn3x4 = _mm_add_ps(dn3x4, dn3x4);
938
6.04M
                _mm_stream_ps(bpz + 3 * bstride[2], _mm_add_ps(_mm_add_ps(dn1x2, d5), dn3x4));
939
6.04M
                bpz += bstride[1];
940
6.04M
              } unroll_endfor
941
1.53M
              break;
942
1.60M
          };
943
1.60M
        }
944
63.9k
      }
945
1.83k
    } parallel_endfor
946
118
  } else {
947
    // This block will be cause in each for-loop, therefore, you can use it to generate some temporary variables.
948
14
    
parallel_for1
(i, jump_dim) {
949
14
      const int y = i * 4; // i is unsigned.
950
14
      int j, x, k, c;
951
14
      int n[CCV_NNC_MAX_DIM];
952
14
      int m[CCV_NNC_MAX_DIM];
953
14
      int z[CCV_NNC_MAX_DIM];
954
14
      set_n_m_dim(y, 0, tile_dim, adim);
955
14
      z[0] = ccv_min(y + 4, bdim[0]) - y;
956
14
      const float* ap = a->data.f32 + ccv_max(y - hint.border.begin[0], 0) * astride[1];
957
14
      float* bp = b->data.f32 + y * bstride[1];
958
210
      for (x = 0; x < bdim[1]; 
x += 4196
)
959
196
      {
960
196
        set_n_m_dim(x, 1, tile_dim, adim);
961
196
        z[1] = ccv_min(x + 4, bdim[1]) - x;
962
#if FOR_IS_PARALLEL
963
        float* g = btdb + i * 36 * dimCx4;
964
#else
965
196
        float* g = btdb;
966
196
#endif
967
        // zero g such that we can have zero-padding.
968
196
        memset(g, 0, sizeof(float) * 36 * dimCx4);
969
196
        int dx, dy;
970
196
        const float* apz = ap + ccv_max(x - hint.border.begin[1], 0) * astride[2];
971
196
        float* gz = g + (n[0] * 6 + n[1]) * dimCx4;
972
1.14k
        unroll_for(dy, m[0], 6) {
973
6.72k
          unroll_for(dx, m[1], 6) {
974
6.72k
            float* const gzu = gz + (dy * 6 + dx) * dimCx4;
975
867k
            for (c = 0; c < adim[2]; 
c++860k
)
976
860k
              gzu[c] = apz[dx * astride[2] + c];
977
6.72k
          } unroll_endfor
978
1.14k
          apz += astride[1];
979
1.14k
        } unroll_endfor
980
6.46k
        for (c = 0; c < adim[2]; 
c += 46.27k
)
981
6.27k
        {
982
6.27k
          float d[36 * 4]  __attribute__ ((__aligned__(16)));
983
          /* BT.d */
984
37.6k
          unroll_for(j, 6) {
985
            /* row 1 */
986
37.6k
            const float* const gz = g + j * dimCx4;
987
37.6k
            float* dz = d + j * 4;
988
37.6k
            __m128 g0 = _mm_load_ps(gz);
989
37.6k
            __m128 g12 = _mm_load_ps(gz + 12 * dimCx4);
990
37.6k
            __m128 g18 = _mm_load_ps(gz + 18 * dimCx4);
991
37.6k
            __m128 g24 = _mm_load_ps(gz + 24 * dimCx4);
992
37.6k
            g0 = _mm_add_ps(g0, g0);
993
37.6k
            g0 = _mm_add_ps(g0, g0);
994
37.6k
            __m128 g12x2 = _mm_add_ps(g12, g12);
995
37.6k
            g12x2 = _mm_add_ps(g12x2, g12x2);
996
37.6k
            g12x2 = _mm_add_ps(g12x2, g12);
997
37.6k
            _mm_store_ps(dz, _mm_sub_ps(_mm_add_ps(g0, g24), g12x2));
998
            /* row 2 */
999
37.6k
            __m128 g6 = _mm_load_ps(gz + 6 * dimCx4);
1000
37.6k
            __m128 g6x12 = _mm_add_ps(g6, g12);
1001
37.6k
            g6x12 = _mm_add_ps(g6x12, g6x12);
1002
37.6k
            g6x12 = _mm_add_ps(g6x12, g6x12);
1003
37.6k
            _mm_store_ps(dz + 24, _mm_sub_ps(_mm_add_ps(g18, g24), g6x12));
1004
            /* row 3 */
1005
37.6k
            g6x12 = _mm_sub_ps(g6, g12);
1006
37.6k
            g6x12 = _mm_add_ps(g6x12, g6x12);
1007
37.6k
            g6x12 = _mm_add_ps(g6x12, g6x12);
1008
37.6k
            _mm_store_ps(dz + 48, _mm_add_ps(_mm_sub_ps(g24, g18), g6x12));
1009
            /* row 4 */
1010
37.6k
            __m128 g18x6 = _mm_sub_ps(g18, g6);
1011
37.6k
            g18x6 = _mm_add_ps(g18x6, g18x6);
1012
37.6k
            _mm_store_ps(dz + 72, _mm_add_ps(_mm_sub_ps(g24, g12), g18x6));
1013
            /* row 5 */
1014
37.6k
            _mm_store_ps(dz + 96, _mm_sub_ps(_mm_sub_ps(g24, g12), g18x6));
1015
            /* row 6 */
1016
37.6k
            __m128 g30 = _mm_load_ps(gz + 30 * dimCx4);
1017
37.6k
            __m128 g18x2 = _mm_add_ps(g18, g18);
1018
37.6k
            g18x2 = _mm_add_ps(g18x2, g18x2);
1019
37.6k
            g18x2 = _mm_add_ps(g18, g18x2);
1020
37.6k
            g6 = _mm_add_ps(g6, g6);
1021
37.6k
            g6 = _mm_add_ps(g6, g6);
1022
37.6k
            _mm_store_ps(dz + 120, _mm_sub_ps(_mm_add_ps(g6, g30), g18x2));
1023
37.6k
          } unroll_endfor
1024
          /* BT.d.B */
1025
37.6k
          unroll_for(j, 6) {
1026
37.6k
            float* gz = g + j * 6 * dimCx4;
1027
37.6k
            const float* const dz = d + j * 24;
1028
37.6k
            __m128 d0 = _mm_load_ps(dz);
1029
37.6k
            __m128 d1 = _mm_load_ps(dz + 4);
1030
37.6k
            __m128 d2 = _mm_load_ps(dz + 8);
1031
37.6k
            __m128 d3 = _mm_load_ps(dz + 12);
1032
37.6k
            __m128 d4 = _mm_load_ps(dz + 16);
1033
37.6k
            __m128 d5 = _mm_load_ps(dz + 20);
1034
37.6k
            d0 = _mm_add_ps(d0, d0);
1035
37.6k
            d0 = _mm_add_ps(d0, d0);
1036
37.6k
            __m128 d2x5 = _mm_add_ps(d2, d2);
1037
37.6k
            d2x5 = _mm_add_ps(d2x5, d2x5);
1038
37.6k
            d2x5 = _mm_add_ps(d2, d2x5);
1039
37.6k
            _mm_store_ps(gz, _mm_sub_ps(_mm_add_ps(d0, d4), d2x5));
1040
37.6k
            __m128 d1x2 = _mm_add_ps(d1, d2);
1041
37.6k
            d1x2 = _mm_add_ps(d1x2, d1x2);
1042
37.6k
            d1x2 = _mm_add_ps(d1x2, d1x2);
1043
37.6k
            _mm_store_ps(gz + dimCx4, _mm_sub_ps(_mm_add_ps(d3, d4), d1x2));
1044
37.6k
            d1x2 = _mm_sub_ps(d1, d2);
1045
37.6k
            d1x2 = _mm_add_ps(d1x2, d1x2);
1046
37.6k
            d1x2 = _mm_add_ps(d1x2, d1x2);
1047
37.6k
            _mm_store_ps(gz + 2 * dimCx4, _mm_add_ps(_mm_sub_ps(d4, d3), d1x2));
1048
37.6k
            __m128 d3x1 = _mm_sub_ps(d3, d1);
1049
37.6k
            d3x1 = _mm_add_ps(d3x1, d3x1);
1050
37.6k
            _mm_store_ps(gz + 3 * dimCx4, _mm_add_ps(_mm_sub_ps(d4, d2), d3x1));
1051
37.6k
            _mm_store_ps(gz + 4 * dimCx4, _mm_sub_ps(_mm_sub_ps(d4, d2), d3x1));
1052
37.6k
            d1 = _mm_add_ps(d1, d1);
1053
37.6k
            d1 = _mm_add_ps(d1, d1);
1054
37.6k
            __m128 d3x5 = _mm_add_ps(d3, d3);
1055
37.6k
            d3x5 = _mm_add_ps(d3x5, d3x5);
1056
37.6k
            d3x5 = _mm_add_ps(d3, d3x5);
1057
37.6k
            _mm_store_ps(gz + 5 * dimCx4, _mm_sub_ps(_mm_add_ps(d1, d5), d3x5));
1058
37.6k
          } unroll_endfor
1059
          // move to the next channel
1060
6.27k
          g += 4;
1061
6.27k
        }
1062
196
        const float* wpz = gwtg;
1063
6.46k
        for (k = 0; k < w->info.dim[0]; 
k += 46.27k
)
1064
6.27k
        {
1065
6.27k
          float q[36 * 4] __attribute__ ((__aligned__(16)));
1066
#if FOR_IS_PARALLEL
1067
          g = btdb + i * 36 * dimCx4;
1068
#else
1069
6.27k
          g = btdb;
1070
6.27k
#endif
1071
232k
          for (j = 0; j < 36; 
j++225k
)
1072
225k
          {
1073
225k
            __m128 v40 = _mm_setzero_ps();
1074
225k
            __m128 v41 = _mm_setzero_ps();
1075
225k
            __m128 v42 = _mm_setzero_ps();
1076
225k
            __m128 v43 = _mm_setzero_ps();
1077
7.45M
            for (c = 0; c < adim[2]; 
c += 47.22M
)
1078
7.22M
            {
1079
7.22M
              __m128 g4 = _mm_load_ps(g);
1080
7.22M
              __m128 w40 = _mm_load_ps(wpz);
1081
7.22M
              __m128 w41 = _mm_load_ps(wpz + 4);
1082
7.22M
              __m128 w42 = _mm_load_ps(wpz + 8);
1083
7.22M
              __m128 w43 = _mm_load_ps(wpz + 12);
1084
7.22M
              __m128 g40 = _mm_shuffle_ps(g4, g4, 0x00);
1085
7.22M
              __m128 g41 = _mm_shuffle_ps(g4, g4, 0x55);
1086
7.22M
              __m128 g42 = _mm_shuffle_ps(g4, g4, 0xAA);
1087
7.22M
              __m128 g43 = _mm_shuffle_ps(g4, g4, 0xFF);
1088
7.22M
              v40 = _mm_add_ps(_mm_mul_ps(w40, g40), v40);
1089
7.22M
              v41 = _mm_add_ps(_mm_mul_ps(w41, g41), v41);
1090
7.22M
              v42 = _mm_add_ps(_mm_mul_ps(w42, g42), v42);
1091
7.22M
              v43 = _mm_add_ps(_mm_mul_ps(w43, g43), v43);
1092
7.22M
              g += 4;
1093
7.22M
              wpz += 16;
1094
7.22M
            }
1095
225k
            v40 = _mm_add_ps(v40, v41);
1096
225k
            v42 = _mm_add_ps(v42, v43);
1097
225k
            _mm_store_ps(q + j * 4, _mm_add_ps(v40, v42));
1098
225k
          }
1099
6.27k
          float d[24 * 4] __attribute__ ((__aligned__(16)));
1100
37.6k
          unroll_for(j, 6) {
1101
37.6k
            const float* const qz = q + j * 4;
1102
37.6k
            float* const dz = d + j * 4;
1103
37.6k
            __m128 q0 = _mm_load_ps(qz);
1104
37.6k
            __m128 q6 = _mm_load_ps(qz + 24);
1105
37.6k
            __m128 q12 = _mm_load_ps(qz + 48);
1106
37.6k
            __m128 q18 = _mm_load_ps(qz + 72);
1107
37.6k
            __m128 q24 = _mm_load_ps(qz + 96);
1108
37.6k
            __m128 qs6x12 = _mm_add_ps(q6, q12);
1109
37.6k
            __m128 qs18x24 = _mm_add_ps(q18, q24);
1110
37.6k
            __m128 qss = _mm_add_ps(qs6x12, q0);
1111
            /* row 1 */
1112
37.6k
            _mm_store_ps(dz, _mm_add_ps(qss, qs18x24));
1113
37.6k
            __m128 qn6x12 = _mm_sub_ps(q6, q12);
1114
37.6k
            __m128 qn18x24 = _mm_sub_ps(q18, q24);
1115
37.6k
            qn18x24 = _mm_add_ps(qn18x24, qn18x24);
1116
            /* row 2 */
1117
37.6k
            _mm_store_ps(dz + 24, _mm_add_ps(qn6x12, qn18x24));
1118
37.6k
            qs18x24 = _mm_add_ps(qs18x24, qs18x24);
1119
37.6k
            qs18x24 = _mm_add_ps(qs18x24, qs18x24);
1120
            /* row 3 */
1121
37.6k
            _mm_store_ps(dz + 48, _mm_add_ps(qs6x12, qs18x24));
1122
37.6k
            qn18x24 = _mm_add_ps(qn18x24, qn18x24);
1123
37.6k
            qn18x24 = _mm_add_ps(qn18x24, qn18x24);
1124
37.6k
            __m128 q30 = _mm_load_ps(qz + 120);
1125
            /* row 4 */
1126
37.6k
            _mm_store_ps(dz + 72, _mm_add_ps(_mm_add_ps(qn6x12, q30), qn18x24));
1127
37.6k
          } unroll_endfor
1128
6.27k
          float* bpz = bp + x * bstride[2] + k;
1129
6.27k
          switch (z[1]) {
1130
0
            case 1:
1131
0
              unroll_for(dy, z[0], 4) {
1132
0
                const float* const dz = d + dy * 24;
1133
0
                __m128 d0 = _mm_load_ps(dz);
1134
0
                __m128 d1 = _mm_load_ps(dz + 4);
1135
0
                __m128 d2 = _mm_load_ps(dz + 8);
1136
0
                __m128 d3 = _mm_load_ps(dz + 12);
1137
0
                __m128 d4 = _mm_load_ps(dz + 16);
1138
0
                __m128 ds1x2 = _mm_add_ps(d1, d2);
1139
0
                __m128 ds3x4 = _mm_add_ps(d3, d4);
1140
0
                _mm_stream_ps(bpz, _mm_add_ps(ds1x2, _mm_add_ps(d0, ds3x4)));
1141
0
                bpz += bstride[1];
1142
0
              } unroll_endfor
1143
0
              break;
1144
0
            case 2:
1145
0
              unroll_for(dy, z[0], 4) {
1146
0
                const float* const dz = d + dy * 24;
1147
0
                __m128 d0 = _mm_load_ps(dz);
1148
0
                __m128 d1 = _mm_load_ps(dz + 4);
1149
0
                __m128 d2 = _mm_load_ps(dz + 8);
1150
0
                __m128 d3 = _mm_load_ps(dz + 12);
1151
0
                __m128 d4 = _mm_load_ps(dz + 16);
1152
0
                __m128 ds1x2 = _mm_add_ps(d1, d2);
1153
0
                __m128 ds3x4 = _mm_add_ps(d3, d4);
1154
0
                _mm_stream_ps(bpz, _mm_add_ps(ds1x2, _mm_add_ps(d0, ds3x4)));
1155
0
                __m128 dn1x2 = _mm_sub_ps(d1, d2);
1156
0
                __m128 dn3x4 = _mm_sub_ps(d3, d4);
1157
0
                dn3x4 = _mm_add_ps(dn3x4, dn3x4);
1158
0
                _mm_stream_ps(bpz + bstride[2], _mm_add_ps(dn1x2, dn3x4));
1159
0
                bpz += bstride[1];
1160
0
              } unroll_endfor
1161
0
              break;
1162
0
            case 3:
1163
0
              unroll_for(dy, z[0], 4) {
1164
0
                const float* const dz = d + dy * 24;
1165
0
                __m128 d0 = _mm_load_ps(dz);
1166
0
                __m128 d1 = _mm_load_ps(dz + 4);
1167
0
                __m128 d2 = _mm_load_ps(dz + 8);
1168
0
                __m128 d3 = _mm_load_ps(dz + 12);
1169
0
                __m128 d4 = _mm_load_ps(dz + 16);
1170
0
                __m128 ds1x2 = _mm_add_ps(d1, d2);
1171
0
                __m128 ds3x4 = _mm_add_ps(d3, d4);
1172
0
                _mm_stream_ps(bpz, _mm_add_ps(ds1x2, _mm_add_ps(d0, ds3x4)));
1173
0
                __m128 dn1x2 = _mm_sub_ps(d1, d2);
1174
0
                __m128 dn3x4 = _mm_sub_ps(d3, d4);
1175
0
                dn3x4 = _mm_add_ps(dn3x4, dn3x4);
1176
0
                _mm_stream_ps(bpz + bstride[2], _mm_add_ps(dn1x2, dn3x4));
1177
0
                ds3x4 = _mm_add_ps(ds3x4, ds3x4);
1178
0
                ds3x4 = _mm_add_ps(ds3x4, ds3x4);
1179
0
                _mm_stream_ps(bpz + 2 * bstride[2], _mm_add_ps(ds1x2, ds3x4));
1180
0
                bpz += bstride[1];
1181
0
              } unroll_endfor
1182
0
              break;
1183
6.27k
            case 4:
1184
25.0k
              unroll_for(dy, z[0], 4) {
1185
25.0k
                const float* const dz = d + dy * 24;
1186
25.0k
                __m128 d0 = _mm_load_ps(dz);
1187
25.0k
                __m128 d1 = _mm_load_ps(dz + 4);
1188
25.0k
                __m128 d2 = _mm_load_ps(dz + 8);
1189
25.0k
                __m128 d3 = _mm_load_ps(dz + 12);
1190
25.0k
                __m128 d4 = _mm_load_ps(dz + 16);
1191
25.0k
                __m128 ds1x2 = _mm_add_ps(d1, d2);
1192
25.0k
                __m128 ds3x4 = _mm_add_ps(d3, d4);
1193
25.0k
                _mm_stream_ps(bpz, _mm_add_ps(ds1x2, _mm_add_ps(d0, ds3x4)));
1194
25.0k
                __m128 dn1x2 = _mm_sub_ps(d1, d2);
1195
25.0k
                __m128 dn3x4 = _mm_sub_ps(d3, d4);
1196
25.0k
                dn3x4 = _mm_add_ps(dn3x4, dn3x4);
1197
25.0k
                _mm_stream_ps(bpz + bstride[2], _mm_add_ps(dn1x2, dn3x4));
1198
25.0k
                ds3x4 = _mm_add_ps(ds3x4, ds3x4);
1199
25.0k
                ds3x4 = _mm_add_ps(ds3x4, ds3x4);
1200
25.0k
                _mm_stream_ps(bpz + 2 * bstride[2], _mm_add_ps(ds1x2, ds3x4));
1201
25.0k
                __m128 d5 = _mm_load_ps(dz + 20);
1202
25.0k
                dn3x4 = _mm_add_ps(dn3x4, dn3x4);
1203
25.0k
                dn3x4 = _mm_add_ps(dn3x4, dn3x4);
1204
25.0k
                _mm_stream_ps(bpz + 3 * bstride[2], _mm_add_ps(_mm_add_ps(dn1x2, d5), dn3x4));
1205
25.0k
                bpz += bstride[1];
1206
25.0k
              } unroll_endfor
1207
6.27k
              break;
1208
6.27k
          };
1209
6.27k
        }
1210
196
      }
1211
14
    } parallel_endfor
1212
1
  }
1213
119
  return CCV_NNC_EXEC_SUCCESS;
1214
119
}
1215
#endif
1216
1217
#ifdef HAVE_NEON
1218
inline static void _ccv_nnc_winograd_4x4_3x3_gwtg_neon(const float* const w, const int* const dim, float* const gwtg)
1219
{
1220
  const int jump_dim = dim[0] / 4;
1221
  const int dimCx4 = (dim[3] + 3) & -4;
1222
  parallel_for(k, jump_dim) {
1223
    int i, j;
1224
    float* gwtgz = gwtg + k * 4 * 36 * dimCx4;
1225
    const float* wz[] = {
1226
      w + (k * 4) * 9 * dim[3],
1227
      w + (k * 4 + 1) * 9 * dim[3],
1228
      w + (k * 4 + 2) * 9 * dim[3],
1229
      w + (k * 4 + 3) * 9 * dim[3],
1230
    };
1231
    for (i = 0; i < dim[3]; i++)
1232
    {
1233
      float x9w[9 * 4] __attribute__ ((__aligned__(16)));
1234
      unroll_for(j, 9) {
1235
        x9w[j * 4] = wz[0][j * dim[3] + i];
1236
        x9w[j * 4 + 1] = wz[1][j * dim[3] + i];
1237
        x9w[j * 4 + 2] = wz[2][j * dim[3] + i];
1238
        x9w[j * 4 + 3] = wz[3][j * dim[3] + i];
1239
      } unroll_endfor
1240
      float g[18 * 4] __attribute__ ((__aligned__(16)));
1241
      float32x4_t x9w0 = vld1q_f32(x9w);
1242
      float32x4_t x9w1 = vld1q_f32(x9w + 4);
1243
      float32x4_t x9w2 = vld1q_f32(x9w + 8);
1244
      float32x4_t x9w3 = vld1q_f32(x9w + 12);
1245
      float32x4_t x9w4 = vld1q_f32(x9w + 16);
1246
      float32x4_t x9w5 = vld1q_f32(x9w + 20);
1247
      float32x4_t x9w6 = vld1q_f32(x9w + 24);
1248
      float32x4_t x9w7 = vld1q_f32(x9w + 28);
1249
      float32x4_t x9w8 = vld1q_f32(x9w + 32);
1250
      /* row 1 */
1251
      float32x4_t c1_4 = vdupq_n_f32(1.0 / 4);
1252
      vst1q_f32(g, vmulq_f32(x9w0, c1_4));
1253
      vst1q_f32(g + 4, vmulq_f32(x9w1, c1_4));
1254
      vst1q_f32(g + 8, vmulq_f32(x9w2, c1_4));
1255
      /* row 2 */
1256
      float32x4_t cn1_6 = vdupq_n_f32(-1.0 / 6);
1257
      vst1q_f32(g + 12, vmulq_f32(vaddq_f32(vaddq_f32(x9w0, x9w6), x9w3), cn1_6));
1258
      vst1q_f32(g + 16, vmulq_f32(vaddq_f32(vaddq_f32(x9w1, x9w7), x9w4), cn1_6));
1259
      vst1q_f32(g + 20, vmulq_f32(vaddq_f32(vaddq_f32(x9w2, x9w8), x9w5), cn1_6));
1260
      /* row 3 */
1261
      vst1q_f32(g + 24, vmulq_f32(vsubq_f32(vaddq_f32(x9w0, x9w6), x9w3), cn1_6));
1262
      vst1q_f32(g + 28, vmulq_f32(vsubq_f32(vaddq_f32(x9w1, x9w7), x9w4), cn1_6));
1263
      vst1q_f32(g + 32, vmulq_f32(vsubq_f32(vaddq_f32(x9w2, x9w8), x9w5), cn1_6));
1264
      /* row 6 */
1265
      vst1q_f32(g + 60, x9w6);
1266
      vst1q_f32(g + 64, x9w7);
1267
      vst1q_f32(g + 68, x9w8);
1268
      /* w[x] * 2 */
1269
      x9w3 = vaddq_f32(x9w3, x9w3);
1270
      x9w4 = vaddq_f32(x9w4, x9w4);
1271
      x9w5 = vaddq_f32(x9w5, x9w5);
1272
      /* w[x] * 4 */
1273
      x9w6 = vaddq_f32(x9w6, x9w6);
1274
      x9w6 = vaddq_f32(x9w6, x9w6);
1275
      x9w7 = vaddq_f32(x9w7, x9w7);
1276
      x9w7 = vaddq_f32(x9w7, x9w7);
1277
      x9w8 = vaddq_f32(x9w8, x9w8);
1278
      x9w8 = vaddq_f32(x9w8, x9w8);
1279
      /* row 4 */
1280
      float32x4_t c1_24 = vdupq_n_f32(1.0 / 24);
1281
      vst1q_f32(g + 36, vmulq_f32(vaddq_f32(vaddq_f32(x9w0, x9w6), x9w3), c1_24));
1282
      vst1q_f32(g + 40, vmulq_f32(vaddq_f32(vaddq_f32(x9w1, x9w7), x9w4), c1_24));
1283
      vst1q_f32(g + 44, vmulq_f32(vaddq_f32(vaddq_f32(x9w2, x9w8), x9w5), c1_24));
1284
      /* row 5 */
1285
      vst1q_f32(g + 48, vmulq_f32(vsubq_f32(vaddq_f32(x9w0, x9w6), x9w3), c1_24));
1286
      vst1q_f32(g + 52, vmulq_f32(vsubq_f32(vaddq_f32(x9w1, x9w7), x9w4), c1_24));
1287
      vst1q_f32(g + 56, vmulq_f32(vsubq_f32(vaddq_f32(x9w2, x9w8), x9w5), c1_24));
1288
      unroll_for(j, 6) {
1289
        const float* const gz = g + j * 12;
1290
        float* const gwtgzu = gwtgz + j * 24 * dimCx4;
1291
        float32x4_t g0 = vld1q_f32(gz);
1292
        float32x4_t g1 = vld1q_f32(gz + 4);
1293
        float32x4_t g2 = vld1q_f32(gz + 8);
1294
        vst1q_f32(gwtgzu, vmulq_f32(g0, c1_4));
1295
        vst1q_f32(gwtgzu + 4 * dimCx4, vmulq_f32(vaddq_f32(vaddq_f32(g0, g2), g1), cn1_6));
1296
        vst1q_f32(gwtgzu + 8 * dimCx4, vmulq_f32(vsubq_f32(vaddq_f32(g0, g2), g1), cn1_6));
1297
        vst1q_f32(gwtgzu + 20 * dimCx4, g2);
1298
        /* g[1] * 2 */
1299
        g1 = vaddq_f32(g1, g1);
1300
        /* g[2] * 4 */
1301
        g2 = vaddq_f32(g2, g2);
1302
        g2 = vaddq_f32(g2, g2);
1303
        vst1q_f32(gwtgzu + 12 * dimCx4, vmulq_f32(vaddq_f32(vaddq_f32(g0, g2), g1), c1_24));
1304
        vst1q_f32(gwtgzu + 16 * dimCx4, vmulq_f32(vsubq_f32(vaddq_f32(g0, g2), g1), c1_24));
1305
      } unroll_endfor
1306
      gwtgz += 4;
1307
    }
1308
  } parallel_endfor
1309
}
1310
1311
static int _ccv_nnc_conv_forw_4x4_3x3_winograd_neon(const ccv_nnc_tensor_view_t* const a, const ccv_nnc_tensor_t* const w, const ccv_nnc_tensor_t* const bias, const ccv_nnc_hint_t hint, ccv_nnc_tensor_view_t* const b, ccv_nnc_stream_context_t* const stream_context)
1312
{
1313
  const int a_nd = ccv_nnc_tensor_nd(a->info.dim);
1314
  assert(a_nd == CCV_NNC_MAX_DIM + 1 || a_nd == CCV_NNC_MAX_DIM + 2);
1315
  const int* adim = (a_nd == CCV_NNC_MAX_DIM + 1) ? a->info.dim : a->info.dim + 1;
1316
  const int b_nd = ccv_nnc_tensor_nd(b->info.dim);
1317
  assert(b_nd == CCV_NNC_MAX_DIM + 1 || b_nd == CCV_NNC_MAX_DIM + 2);
1318
  const int* bdim = (b_nd == CCV_NNC_MAX_DIM + 1) ? b->info.dim : b->info.dim + 1;
1319
  int astride[CCV_NNC_MAX_DIM_ALLOC];
1320
  ccv_nnc_tensor_view_get_stride(a, astride);
1321
  int bstride[CCV_NNC_MAX_DIM_ALLOC];
1322
  ccv_nnc_tensor_view_get_stride(b, bstride);
1323
  assert(hint.border.begin[0] <= 1);
1324
  assert(hint.border.begin[1] <= 1);
1325
  assert(w->info.dim[0] % 4 == 0);
1326
  assert(w->info.dim[1] == 3);
1327
  assert(w->info.dim[2] == 3);
1328
  const int jump_dim = (bdim[0] + 3) / 4;
1329
  const int dimCx4 = (adim[2] + 3) & -4;
1330
  // allocating workspace memory for kernel reshaping and input reshaping.
1331
  float* workmem = 0;
1332
#if FOR_IS_PARALLEL
1333
  // If we do parallel for, we need to allocate input reshaping for each block.
1334
  workmem = (float*)ccv_nnc_stream_context_get_workspace(stream_context, sizeof(float) * (36 * dimCx4 * jump_dim + 36 * dimCx4 * w->info.dim[0]), CCV_TENSOR_CPU_MEMORY);
1335
#else
1336
  // Otherwise, just one block.
1337
  workmem = (float*)ccv_nnc_stream_context_get_workspace(stream_context, sizeof(float) * (36 * dimCx4 + 36 * dimCx4 * w->info.dim[0]), CCV_TENSOR_CPU_MEMORY);
1338
#endif
1339
  if (!workmem)
1340
    return CCV_NNC_EXEC_OOM;
1341
  // Convert w to a 6x6 matrix, by computing G.w.T(G) // T for transpose.
1342
  float* const gwtg = workmem;
1343
  float* const btdb = workmem + 36 * dimCx4 * w->info.dim[0];
1344
  memset(gwtg, 0, sizeof(float) * 36 * dimCx4 * w->info.dim[0]);
1345
  _ccv_nnc_winograd_4x4_3x3_gwtg_neon(w->data.f32, w->info.dim, gwtg);
1346
  // kernel weight for one dim.
1347
  // Workaround issues of dispatch_apply (cannot reference to on-stack array)
1348
  const int tile_dim_s[CCV_NNC_MAX_DIM_ALLOC] = {
1349
    w->info.dim[0], 6, 6, w->info.dim[3]
1350
  };
1351
  const int* const tile_dim = tile_dim_s;
1352
  if (bias)
1353
  {
1354
    const float* const biasval = bias->data.f32;
1355
    // This block will be cause in each for-loop, therefore, you can use it to generate some temporary variables.
1356
    parallel_for(i, jump_dim) {
1357
      const int y = i * 4; // i is unsigned.
1358
      int j, x, k, c;
1359
      int n[CCV_NNC_MAX_DIM];
1360
      int m[CCV_NNC_MAX_DIM];
1361
      int z[CCV_NNC_MAX_DIM];
1362
      set_n_m_dim(y, 0, tile_dim, adim);
1363
      z[0] = ccv_min(y + 4, bdim[0]) - y;
1364
      const float* ap = a->data.f32 + ccv_max(y - hint.border.begin[0], 0) * astride[1];
1365
      float* bp = b->data.f32 + y * bstride[1];
1366
      for (x = 0; x < bdim[1]; x += 4)
1367
      {
1368
        set_n_m_dim(x, 1, tile_dim, adim);
1369
        z[1] = ccv_min(x + 4, bdim[1]) - x;
1370
#if FOR_IS_PARALLEL
1371
        float* g = btdb + i * 36 * dimCx4;
1372
#else
1373
        float* g = btdb;
1374
#endif
1375
        // zero g such that we can have zero-padding.
1376
        memset(g, 0, sizeof(float) * 36 * dimCx4);
1377
        int dx, dy;
1378
        const float* apz = ap + ccv_max(x - hint.border.begin[1], 0) * astride[2];
1379
        float* gz = g + (n[0] * 6 + n[1]) * dimCx4;
1380
        unroll_for(dy, m[0], 6) {
1381
          unroll_for(dx, m[1], 6) {
1382
            float* const gzu = gz + (dy * 6 + dx) * dimCx4;
1383
            for (c = 0; c < adim[2]; c++)
1384
              gzu[c] = apz[dx * astride[2] + c];
1385
          } unroll_endfor
1386
          apz += astride[1];
1387
        } unroll_endfor
1388
        for (c = 0; c < adim[2]; c += 4)
1389
        {
1390
          float d[36 * 4]  __attribute__ ((__aligned__(16)));
1391
          /* BT.d */
1392
          unroll_for(j, 6) {
1393
            /* row 1 */
1394
            const float* const gz = g + j * dimCx4;
1395
            float* dz = d + j * 4;
1396
            float32x4_t g0 = vld1q_f32(gz);
1397
            float32x4_t g12 = vld1q_f32(gz + 12 * dimCx4);
1398
            float32x4_t g18 = vld1q_f32(gz + 18 * dimCx4);
1399
            float32x4_t g24 = vld1q_f32(gz + 24 * dimCx4);
1400
            g0 = vaddq_f32(g0, g0);
1401
            g0 = vaddq_f32(g0, g0);
1402
            float32x4_t g12x2 = vaddq_f32(g12, g12);
1403
            g12x2 = vaddq_f32(g12x2, g12x2);
1404
            g12x2 = vaddq_f32(g12x2, g12);
1405
            vst1q_f32(dz, vsubq_f32(vaddq_f32(g0, g24), g12x2));
1406
            /* row 2 */
1407
            float32x4_t g6 = vld1q_f32(gz + 6 * dimCx4);
1408
            float32x4_t g6x12 = vaddq_f32(g6, g12);
1409
            g6x12 = vaddq_f32(g6x12, g6x12);
1410
            g6x12 = vaddq_f32(g6x12, g6x12);
1411
            vst1q_f32(dz + 24, vsubq_f32(vaddq_f32(g18, g24), g6x12));
1412
            /* row 3 */
1413
            g6x12 = vsubq_f32(g6, g12);
1414
            g6x12 = vaddq_f32(g6x12, g6x12);
1415
            g6x12 = vaddq_f32(g6x12, g6x12);
1416
            vst1q_f32(dz + 48, vaddq_f32(vsubq_f32(g24, g18), g6x12));
1417
            /* row 4 */
1418
            float32x4_t g18x6 = vsubq_f32(g18, g6);
1419
            g18x6 = vaddq_f32(g18x6, g18x6);
1420
            vst1q_f32(dz + 72, vaddq_f32(vsubq_f32(g24, g12), g18x6));
1421
            /* row 5 */
1422
            vst1q_f32(dz + 96, vsubq_f32(vsubq_f32(g24, g12), g18x6));
1423
            /* row 6 */
1424
            float32x4_t g30 = vld1q_f32(gz + 30 * dimCx4);
1425
            float32x4_t g18x2 = vaddq_f32(g18, g18);
1426
            g18x2 = vaddq_f32(g18x2, g18x2);
1427
            g18x2 = vaddq_f32(g18, g18x2);
1428
            g6 = vaddq_f32(g6, g6);
1429
            g6 = vaddq_f32(g6, g6);
1430
            vst1q_f32(dz + 120, vsubq_f32(vaddq_f32(g6, g30), g18x2));
1431
          } unroll_endfor
1432
          /* BT.d.B */
1433
          unroll_for(j, 6) {
1434
            float* gz = g + j * 6 * dimCx4;
1435
            const float* const dz = d + j * 24;
1436
            float32x4_t d0 = vld1q_f32(dz);
1437
            float32x4_t d1 = vld1q_f32(dz + 4);
1438
            float32x4_t d2 = vld1q_f32(dz + 8);
1439
            float32x4_t d3 = vld1q_f32(dz + 12);
1440
            float32x4_t d4 = vld1q_f32(dz + 16);
1441
            float32x4_t d5 = vld1q_f32(dz + 20);
1442
            d0 = vaddq_f32(d0, d0);
1443
            d0 = vaddq_f32(d0, d0);
1444
            float32x4_t d2x5 = vaddq_f32(d2, d2);
1445
            d2x5 = vaddq_f32(d2x5, d2x5);
1446
            d2x5 = vaddq_f32(d2, d2x5);
1447
            vst1q_f32(gz, vsubq_f32(vaddq_f32(d0, d4), d2x5));
1448
            float32x4_t d1x2 = vaddq_f32(d1, d2);
1449
            d1x2 = vaddq_f32(d1x2, d1x2);
1450
            d1x2 = vaddq_f32(d1x2, d1x2);
1451
            vst1q_f32(gz + dimCx4, vsubq_f32(vaddq_f32(d3, d4), d1x2));
1452
            d1x2 = vsubq_f32(d1, d2);
1453
            d1x2 = vaddq_f32(d1x2, d1x2);
1454
            d1x2 = vaddq_f32(d1x2, d1x2);
1455
            vst1q_f32(gz + 2 * dimCx4, vaddq_f32(vsubq_f32(d4, d3), d1x2));
1456
            float32x4_t d3x1 = vsubq_f32(d3, d1);
1457
            d3x1 = vaddq_f32(d3x1, d3x1);
1458
            vst1q_f32(gz + 3 * dimCx4, vaddq_f32(vsubq_f32(d4, d2), d3x1));
1459
            vst1q_f32(gz + 4 * dimCx4, vsubq_f32(vsubq_f32(d4, d2), d3x1));
1460
            d1 = vaddq_f32(d1, d1);
1461
            d1 = vaddq_f32(d1, d1);
1462
            float32x4_t d3x5 = vaddq_f32(d3, d3);
1463
            d3x5 = vaddq_f32(d3x5, d3x5);
1464
            d3x5 = vaddq_f32(d3, d3x5);
1465
            vst1q_f32(gz + 5 * dimCx4, vsubq_f32(vaddq_f32(d1, d5), d3x5));
1466
          } unroll_endfor
1467
          // move to the next channel
1468
          g += 4;
1469
        }
1470
        const float* wpz = gwtg;
1471
        for (k = 0; k < w->info.dim[0]; k += 4)
1472
        {
1473
          float q[36 * 4] __attribute__ ((__aligned__(16)));
1474
#if FOR_IS_PARALLEL
1475
          g = btdb + i * 36 * dimCx4;
1476
#else
1477
          g = btdb;
1478
#endif
1479
          for (j = 0; j < 36; j++)
1480
          {
1481
            float32x4_t v40 = vmovq_n_f32(0);
1482
            float32x4_t v41 = vmovq_n_f32(0);
1483
            float32x4_t v42 = vmovq_n_f32(0);
1484
            float32x4_t v43 = vmovq_n_f32(0);
1485
            for (c = 0; c < adim[2]; c += 4)
1486
            {
1487
              float32x2x2_t g4 = vld2_f32(g);
1488
              float32x4_t w40 = vld1q_f32(wpz);
1489
              float32x4_t w41 = vld1q_f32(wpz + 4);
1490
              float32x4_t w42 = vld1q_f32(wpz + 8);
1491
              float32x4_t w43 = vld1q_f32(wpz + 12);
1492
              float32x4_t g40 = vdupq_lane_f32(g4.val[0], 0);
1493
              float32x4_t g41 = vdupq_lane_f32(g4.val[1], 0);
1494
              float32x4_t g42 = vdupq_lane_f32(g4.val[0], 1);
1495
              float32x4_t g43 = vdupq_lane_f32(g4.val[1], 1);
1496
              v40 = vmlaq_f32(v40, w40, g40);
1497
              v41 = vmlaq_f32(v41, w41, g41);
1498
              v42 = vmlaq_f32(v42, w42, g42);
1499
              v43 = vmlaq_f32(v43, w43, g43);
1500
              g += 4;
1501
              wpz += 16;
1502
            }
1503
            v40 = vaddq_f32(v40, v41);
1504
            v42 = vaddq_f32(v42, v43);
1505
            vst1q_f32(q + j * 4, vaddq_f32(v40, v42));
1506
          }
1507
          float d[24 * 4] __attribute__ ((__aligned__(16)));
1508
          unroll_for(j, 6) {
1509
            const float* const qz = q + j * 4;
1510
            float* const dz = d + j * 4;
1511
            float32x4_t q0 = vld1q_f32(qz);
1512
            float32x4_t q6 = vld1q_f32(qz + 24);
1513
            float32x4_t q12 = vld1q_f32(qz + 48);
1514
            float32x4_t q18 = vld1q_f32(qz + 72);
1515
            float32x4_t q24 = vld1q_f32(qz + 96);
1516
            float32x4_t qs6x12 = vaddq_f32(q6, q12);
1517
            float32x4_t qs18x24 = vaddq_f32(q18, q24);
1518
            float32x4_t qss = vaddq_f32(qs6x12, q0);
1519
            /* row 1 */
1520
            vst1q_f32(dz, vaddq_f32(qss, qs18x24));
1521
            float32x4_t qn6x12 = vsubq_f32(q6, q12);
1522
            float32x4_t qn18x24 = vsubq_f32(q18, q24);
1523
            qn18x24 = vaddq_f32(qn18x24, qn18x24);
1524
            /* row 2 */
1525
            vst1q_f32(dz + 24, vaddq_f32(qn6x12, qn18x24));
1526
            qs18x24 = vaddq_f32(qs18x24, qs18x24);
1527
            qs18x24 = vaddq_f32(qs18x24, qs18x24);
1528
            /* row 3 */
1529
            vst1q_f32(dz + 48, vaddq_f32(qs6x12, qs18x24));
1530
            qn18x24 = vaddq_f32(qn18x24, qn18x24);
1531
            qn18x24 = vaddq_f32(qn18x24, qn18x24);
1532
            float32x4_t q30 = vld1q_f32(qz + 120);
1533
            /* row 4 */
1534
            vst1q_f32(dz + 72, vaddq_f32(vaddq_f32(qn6x12, q30), qn18x24));
1535
          } unroll_endfor
1536
          float* bpz = bp + x * bstride[2] + k;
1537
          float32x4_t bias4 = vld1q_f32(biasval + k);
1538
          switch (z[1]) {
1539
            case 1:
1540
              unroll_for(dy, z[0], 4) {
1541
                const float* const dz = d + dy * 24;
1542
                float32x4_t d0 = vld1q_f32(dz);
1543
                float32x4_t d1 = vld1q_f32(dz + 4);
1544
                float32x4_t d2 = vld1q_f32(dz + 8);
1545
                float32x4_t d3 = vld1q_f32(dz + 12);
1546
                float32x4_t d4 = vld1q_f32(dz + 16);
1547
                float32x4_t ds1x2 = vaddq_f32(d1, d2);
1548
                float32x4_t ds3x4 = vaddq_f32(d3, d4);
1549
                ds1x2 = vaddq_f32(ds1x2, bias4);
1550
                vst1q_f32(bpz, vaddq_f32(ds1x2, vaddq_f32(d0, ds3x4)));
1551
                bpz += bstride[1];
1552
              } unroll_endfor
1553
              break;
1554
            case 2:
1555
              unroll_for(dy, z[0], 4) {
1556
                const float* const dz = d + dy * 24;
1557
                float32x4_t d0 = vld1q_f32(dz);
1558
                float32x4_t d1 = vld1q_f32(dz + 4);
1559
                float32x4_t d2 = vld1q_f32(dz + 8);
1560
                float32x4_t d3 = vld1q_f32(dz + 12);
1561
                float32x4_t d4 = vld1q_f32(dz + 16);
1562
                float32x4_t ds1x2 = vaddq_f32(d1, d2);
1563
                float32x4_t ds3x4 = vaddq_f32(d3, d4);
1564
                ds1x2 = vaddq_f32(ds1x2, bias4);
1565
                vst1q_f32(bpz, vaddq_f32(ds1x2, vaddq_f32(d0, ds3x4)));
1566
                float32x4_t dn1x2 = vsubq_f32(d1, d2);
1567
                float32x4_t dn3x4 = vsubq_f32(d3, d4);
1568
                dn3x4 = vaddq_f32(dn3x4, dn3x4);
1569
                dn1x2 = vaddq_f32(dn1x2, bias4);
1570
                vst1q_f32(bpz + bstride[2], vaddq_f32(dn1x2, dn3x4));
1571
                bpz += bstride[1];
1572
              } unroll_endfor
1573
              break;
1574
            case 3:
1575
              unroll_for(dy, z[0], 4) {
1576
                const float* const dz = d + dy * 24;
1577
                float32x4_t d0 = vld1q_f32(dz);
1578
                float32x4_t d1 = vld1q_f32(dz + 4);
1579
                float32x4_t d2 = vld1q_f32(dz + 8);
1580
                float32x4_t d3 = vld1q_f32(dz + 12);
1581
                float32x4_t d4 = vld1q_f32(dz + 16);
1582
                float32x4_t ds1x2 = vaddq_f32(d1, d2);
1583
                float32x4_t ds3x4 = vaddq_f32(d3, d4);
1584
                ds1x2 = vaddq_f32(ds1x2, bias4);
1585
                vst1q_f32(bpz, vaddq_f32(ds1x2, vaddq_f32(d0, ds3x4)));
1586
                float32x4_t dn1x2 = vsubq_f32(d1, d2);
1587
                float32x4_t dn3x4 = vsubq_f32(d3, d4);
1588
                dn3x4 = vaddq_f32(dn3x4, dn3x4);
1589
                dn1x2 = vaddq_f32(dn1x2, bias4);
1590
                vst1q_f32(bpz + bstride[2], vaddq_f32(dn1x2, dn3x4));
1591
                ds3x4 = vaddq_f32(ds3x4, ds3x4);
1592
                ds3x4 = vaddq_f32(ds3x4, ds3x4);
1593
                vst1q_f32(bpz + 2 * bstride[2], vaddq_f32(ds1x2, ds3x4));
1594
                bpz += bstride[1];
1595
              } unroll_endfor
1596
              break;
1597
            case 4:
1598
              unroll_for(dy, z[0], 4) {
1599
                const float* const dz = d + dy * 24;
1600
                float32x4_t d0 = vld1q_f32(dz);
1601
                float32x4_t d1 = vld1q_f32(dz + 4);
1602
                float32x4_t d2 = vld1q_f32(dz + 8);
1603
                float32x4_t d3 = vld1q_f32(dz + 12);
1604
                float32x4_t d4 = vld1q_f32(dz + 16);
1605
                float32x4_t ds1x2 = vaddq_f32(d1, d2);
1606
                float32x4_t ds3x4 = vaddq_f32(d3, d4);
1607
                ds1x2 = vaddq_f32(ds1x2, bias4);
1608
                vst1q_f32(bpz, vaddq_f32(ds1x2, vaddq_f32(d0, ds3x4)));
1609
                float32x4_t dn1x2 = vsubq_f32(d1, d2);
1610
                float32x4_t dn3x4 = vsubq_f32(d3, d4);
1611
                dn3x4 = vaddq_f32(dn3x4, dn3x4);
1612
                dn1x2 = vaddq_f32(dn1x2, bias4);
1613
                vst1q_f32(bpz + bstride[2], vaddq_f32(dn1x2, dn3x4));
1614
                ds3x4 = vaddq_f32(ds3x4, ds3x4);
1615
                ds3x4 = vaddq_f32(ds3x4, ds3x4);
1616
                vst1q_f32(bpz + 2 * bstride[2], vaddq_f32(ds1x2, ds3x4));
1617
                float32x4_t d5 = vld1q_f32(dz + 20);
1618
                dn3x4 = vaddq_f32(dn3x4, dn3x4);
1619
                dn3x4 = vaddq_f32(dn3x4, dn3x4);
1620
                vst1q_f32(bpz + 3 * bstride[2], vaddq_f32(vaddq_f32(dn1x2, d5), dn3x4));
1621
                bpz += bstride[1];
1622
              } unroll_endfor
1623
              break;
1624
          };
1625
        }
1626
      }
1627
    } parallel_endfor
1628
  } else {
1629
    // This block will be cause in each for-loop, therefore, you can use it to generate some temporary variables.
1630
    parallel_for(i, jump_dim) {
1631
      const int y = i * 4; // i is unsigned.
1632
      int j, x, k, c;
1633
      int n[CCV_NNC_MAX_DIM];
1634
      int m[CCV_NNC_MAX_DIM];
1635
      int z[CCV_NNC_MAX_DIM];
1636
      set_n_m_dim(y, 0, tile_dim, adim);
1637
      z[0] = ccv_min(y + 4, bdim[0]) - y;
1638
      const float* ap = a->data.f32 + ccv_max(y - hint.border.begin[0], 0) * astride[1];
1639
      float* bp = b->data.f32 + y * bstride[1];
1640
      for (x = 0; x < bdim[1]; x += 4)
1641
      {
1642
        set_n_m_dim(x, 1, tile_dim, adim);
1643
        z[1] = ccv_min(x + 4, bdim[1]) - x;
1644
#if FOR_IS_PARALLEL
1645
        float* g = btdb + i * 36 * dimCx4;
1646
#else
1647
        float* g = btdb;
1648
#endif
1649
        // zero g such that we can have zero-padding.
1650
        memset(g, 0, sizeof(float) * 36 * dimCx4);
1651
        int dx, dy;
1652
        const float* apz = ap + ccv_max(x - hint.border.begin[1], 0) * astride[2];
1653
        float* gz = g + (n[0] * 6 + n[1]) * dimCx4;
1654
        unroll_for(dy, m[0], 6) {
1655
          unroll_for(dx, m[1], 6) {
1656
            float* const gzu = gz + (dy * 6 + dx) * dimCx4;
1657
            for (c = 0; c < adim[2]; c++)
1658
              gzu[c] = apz[dx * astride[2] + c];
1659
          } unroll_endfor
1660
          apz += astride[1];
1661
        } unroll_endfor
1662
        for (c = 0; c < adim[2]; c += 4)
1663
        {
1664
          float d[36 * 4]  __attribute__ ((__aligned__(16)));
1665
          /* BT.d */
1666
          unroll_for(j, 6) {
1667
            /* row 1 */
1668
            const float* const gz = g + j * dimCx4;
1669
            float* dz = d + j * 4;
1670
            float32x4_t g0 = vld1q_f32(gz);
1671
            float32x4_t g12 = vld1q_f32(gz + 12 * dimCx4);
1672
            float32x4_t g18 = vld1q_f32(gz + 18 * dimCx4);
1673
            float32x4_t g24 = vld1q_f32(gz + 24 * dimCx4);
1674
            g0 = vaddq_f32(g0, g0);
1675
            g0 = vaddq_f32(g0, g0);
1676
            float32x4_t g12x2 = vaddq_f32(g12, g12);
1677
            g12x2 = vaddq_f32(g12x2, g12x2);
1678
            g12x2 = vaddq_f32(g12x2, g12);
1679
            vst1q_f32(dz, vsubq_f32(vaddq_f32(g0, g24), g12x2));
1680
            /* row 2 */
1681
            float32x4_t g6 = vld1q_f32(gz + 6 * dimCx4);
1682
            float32x4_t g6x12 = vaddq_f32(g6, g12);
1683
            g6x12 = vaddq_f32(g6x12, g6x12);
1684
            g6x12 = vaddq_f32(g6x12, g6x12);
1685
            vst1q_f32(dz + 24, vsubq_f32(vaddq_f32(g18, g24), g6x12));
1686
            /* row 3 */
1687
            g6x12 = vsubq_f32(g6, g12);
1688
            g6x12 = vaddq_f32(g6x12, g6x12);
1689
            g6x12 = vaddq_f32(g6x12, g6x12);
1690
            vst1q_f32(dz + 48, vaddq_f32(vsubq_f32(g24, g18), g6x12));
1691
            /* row 4 */
1692
            float32x4_t g18x6 = vsubq_f32(g18, g6);
1693
            g18x6 = vaddq_f32(g18x6, g18x6);
1694
            vst1q_f32(dz + 72, vaddq_f32(vsubq_f32(g24, g12), g18x6));
1695
            /* row 5 */
1696
            vst1q_f32(dz + 96, vsubq_f32(vsubq_f32(g24, g12), g18x6));
1697
            /* row 6 */
1698
            float32x4_t g30 = vld1q_f32(gz + 30 * dimCx4);
1699
            float32x4_t g18x2 = vaddq_f32(g18, g18);
1700
            g18x2 = vaddq_f32(g18x2, g18x2);
1701
            g18x2 = vaddq_f32(g18, g18x2);
1702
            g6 = vaddq_f32(g6, g6);
1703
            g6 = vaddq_f32(g6, g6);
1704
            vst1q_f32(dz + 120, vsubq_f32(vaddq_f32(g6, g30), g18x2));
1705
          } unroll_endfor
1706
          /* BT.d.B */
1707
          unroll_for(j, 6) {
1708
            float* gz = g + j * 6 * dimCx4;
1709
            const float* const dz = d + j * 24;
1710
            float32x4_t d0 = vld1q_f32(dz);
1711
            float32x4_t d1 = vld1q_f32(dz + 4);
1712
            float32x4_t d2 = vld1q_f32(dz + 8);
1713
            float32x4_t d3 = vld1q_f32(dz + 12);
1714
            float32x4_t d4 = vld1q_f32(dz + 16);
1715
            float32x4_t d5 = vld1q_f32(dz + 20);
1716
            d0 = vaddq_f32(d0, d0);
1717
            d0 = vaddq_f32(d0, d0);
1718
            float32x4_t d2x5 = vaddq_f32(d2, d2);
1719
            d2x5 = vaddq_f32(d2x5, d2x5);
1720
            d2x5 = vaddq_f32(d2, d2x5);
1721
            vst1q_f32(gz, vsubq_f32(vaddq_f32(d0, d4), d2x5));
1722
            float32x4_t d1x2 = vaddq_f32(d1, d2);
1723
            d1x2 = vaddq_f32(d1x2, d1x2);
1724
            d1x2 = vaddq_f32(d1x2, d1x2);
1725
            vst1q_f32(gz + dimCx4, vsubq_f32(vaddq_f32(d3, d4), d1x2));
1726
            d1x2 = vsubq_f32(d1, d2);
1727
            d1x2 = vaddq_f32(d1x2, d1x2);
1728
            d1x2 = vaddq_f32(d1x2, d1x2);
1729
            vst1q_f32(gz + 2 * dimCx4, vaddq_f32(vsubq_f32(d4, d3), d1x2));
1730
            float32x4_t d3x1 = vsubq_f32(d3, d1);
1731
            d3x1 = vaddq_f32(d3x1, d3x1);
1732
            vst1q_f32(gz + 3 * dimCx4, vaddq_f32(vsubq_f32(d4, d2), d3x1));
1733
            vst1q_f32(gz + 4 * dimCx4, vsubq_f32(vsubq_f32(d4, d2), d3x1));
1734
            d1 = vaddq_f32(d1, d1);
1735
            d1 = vaddq_f32(d1, d1);
1736
            float32x4_t d3x5 = vaddq_f32(d3, d3);
1737
            d3x5 = vaddq_f32(d3x5, d3x5);
1738
            d3x5 = vaddq_f32(d3, d3x5);
1739
            vst1q_f32(gz + 5 * dimCx4, vsubq_f32(vaddq_f32(d1, d5), d3x5));
1740
          } unroll_endfor
1741
          // move to the next channel
1742
          g += 4;
1743
        }
1744
        const float* wpz = gwtg;
1745
        for (k = 0; k < w->info.dim[0]; k += 4)
1746
        {
1747
          float q[36 * 4] __attribute__ ((__aligned__(16)));
1748
#if FOR_IS_PARALLEL
1749
          g = btdb + i * 36 * dimCx4;
1750
#else
1751
          g = btdb;
1752
#endif
1753
          for (j = 0; j < 36; j++)
1754
          {
1755
            float32x4_t v40 = vmovq_n_f32(0);
1756
            float32x4_t v41 = vmovq_n_f32(0);
1757
            float32x4_t v42 = vmovq_n_f32(0);
1758
            float32x4_t v43 = vmovq_n_f32(0);
1759
            for (c = 0; c < adim[2]; c += 4)
1760
            {
1761
              float32x2x2_t g4 = vld2_f32(g);
1762
              float32x4_t w40 = vld1q_f32(wpz);
1763
              float32x4_t w41 = vld1q_f32(wpz + 4);
1764
              float32x4_t w42 = vld1q_f32(wpz + 8);
1765
              float32x4_t w43 = vld1q_f32(wpz + 12);
1766
              float32x4_t g40 = vdupq_lane_f32(g4.val[0], 0);
1767
              float32x4_t g41 = vdupq_lane_f32(g4.val[1], 0);
1768
              float32x4_t g42 = vdupq_lane_f32(g4.val[0], 1);
1769
              float32x4_t g43 = vdupq_lane_f32(g4.val[1], 1);
1770
              v40 = vmlaq_f32(v40, w40, g40);
1771
              v41 = vmlaq_f32(v41, w41, g41);
1772
              v42 = vmlaq_f32(v42, w42, g42);
1773
              v43 = vmlaq_f32(v43, w43, g43);
1774
              g += 4;
1775
              wpz += 16;
1776
            }
1777
            v40 = vaddq_f32(v40, v41);
1778
            v42 = vaddq_f32(v42, v43);
1779
            vst1q_f32(q + j * 4, vaddq_f32(v40, v42));
1780
          }
1781
          float d[24 * 4] __attribute__ ((__aligned__(16)));
1782
          unroll_for(j, 6) {
1783
            const float* const qz = q + j * 4;
1784
            float* const dz = d + j * 4;
1785
            float32x4_t q0 = vld1q_f32(qz);
1786
            float32x4_t q6 = vld1q_f32(qz + 24);
1787
            float32x4_t q12 = vld1q_f32(qz + 48);
1788
            float32x4_t q18 = vld1q_f32(qz + 72);
1789
            float32x4_t q24 = vld1q_f32(qz + 96);
1790
            float32x4_t qs6x12 = vaddq_f32(q6, q12);
1791
            float32x4_t qs18x24 = vaddq_f32(q18, q24);
1792
            float32x4_t qss = vaddq_f32(qs6x12, q0);
1793
            /* row 1 */
1794
            vst1q_f32(dz, vaddq_f32(qss, qs18x24));
1795
            float32x4_t qn6x12 = vsubq_f32(q6, q12);
1796
            float32x4_t qn18x24 = vsubq_f32(q18, q24);
1797
            qn18x24 = vaddq_f32(qn18x24, qn18x24);
1798
            /* row 2 */
1799
            vst1q_f32(dz + 24, vaddq_f32(qn6x12, qn18x24));
1800
            qs18x24 = vaddq_f32(qs18x24, qs18x24);
1801
            qs18x24 = vaddq_f32(qs18x24, qs18x24);
1802
            /* row 3 */
1803
            vst1q_f32(dz + 48, vaddq_f32(qs6x12, qs18x24));
1804
            qn18x24 = vaddq_f32(qn18x24, qn18x24);
1805
            qn18x24 = vaddq_f32(qn18x24, qn18x24);
1806
            float32x4_t q30 = vld1q_f32(qz + 120);
1807
            /* row 4 */
1808
            vst1q_f32(dz + 72, vaddq_f32(vaddq_f32(qn6x12, q30), qn18x24));
1809
          } unroll_endfor
1810
          float* bpz = bp + x * bstride[2] + k;
1811
          switch (z[1]) {
1812
            case 1:
1813
              unroll_for(dy, z[0], 4) {
1814
                const float* const dz = d + dy * 24;
1815
                float32x4_t d0 = vld1q_f32(dz);
1816
                float32x4_t d1 = vld1q_f32(dz + 4);
1817
                float32x4_t d2 = vld1q_f32(dz + 8);
1818
                float32x4_t d3 = vld1q_f32(dz + 12);
1819
                float32x4_t d4 = vld1q_f32(dz + 16);
1820
                float32x4_t ds1x2 = vaddq_f32(d1, d2);
1821
                float32x4_t ds3x4 = vaddq_f32(d3, d4);
1822
                vst1q_f32(bpz, vaddq_f32(ds1x2, vaddq_f32(d0, ds3x4)));
1823
                bpz += bstride[1];
1824
              } unroll_endfor
1825
              break;
1826
            case 2:
1827
              unroll_for(dy, z[0], 4) {
1828
                const float* const dz = d + dy * 24;
1829
                float32x4_t d0 = vld1q_f32(dz);
1830
                float32x4_t d1 = vld1q_f32(dz + 4);
1831
                float32x4_t d2 = vld1q_f32(dz + 8);
1832
                float32x4_t d3 = vld1q_f32(dz + 12);
1833
                float32x4_t d4 = vld1q_f32(dz + 16);
1834
                float32x4_t ds1x2 = vaddq_f32(d1, d2);
1835
                float32x4_t ds3x4 = vaddq_f32(d3, d4);
1836
                vst1q_f32(bpz, vaddq_f32(ds1x2, vaddq_f32(d0, ds3x4)));
1837
                float32x4_t dn1x2 = vsubq_f32(d1, d2);
1838
                float32x4_t dn3x4 = vsubq_f32(d3, d4);
1839
                dn3x4 = vaddq_f32(dn3x4, dn3x4);
1840
                vst1q_f32(bpz + bstride[2], vaddq_f32(dn1x2, dn3x4));
1841
                bpz += bstride[1];
1842
              } unroll_endfor
1843
              break;
1844
            case 3:
1845
              unroll_for(dy, z[0], 4) {
1846
                const float* const dz = d + dy * 24;
1847
                float32x4_t d0 = vld1q_f32(dz);
1848
                float32x4_t d1 = vld1q_f32(dz + 4);
1849
                float32x4_t d2 = vld1q_f32(dz + 8);
1850
                float32x4_t d3 = vld1q_f32(dz + 12);
1851
                float32x4_t d4 = vld1q_f32(dz + 16);
1852
                float32x4_t ds1x2 = vaddq_f32(d1, d2);
1853
                float32x4_t ds3x4 = vaddq_f32(d3, d4);
1854
                vst1q_f32(bpz, vaddq_f32(ds1x2, vaddq_f32(d0, ds3x4)));
1855
                float32x4_t dn1x2 = vsubq_f32(d1, d2);
1856
                float32x4_t dn3x4 = vsubq_f32(d3, d4);
1857
                dn3x4 = vaddq_f32(dn3x4, dn3x4);
1858
                vst1q_f32(bpz + bstride[2], vaddq_f32(dn1x2, dn3x4));
1859
                ds3x4 = vaddq_f32(ds3x4, ds3x4);
1860
                ds3x4 = vaddq_f32(ds3x4, ds3x4);
1861
                vst1q_f32(bpz + 2 * bstride[2], vaddq_f32(ds1x2, ds3x4));
1862
                bpz += bstride[1];
1863
              } unroll_endfor
1864
              break;
1865
            case 4:
1866
              unroll_for(dy, z[0], 4) {
1867
                const float* const dz = d + dy * 24;
1868
                float32x4_t d0 = vld1q_f32(dz);
1869
                float32x4_t d1 = vld1q_f32(dz + 4);
1870
                float32x4_t d2 = vld1q_f32(dz + 8);
1871
                float32x4_t d3 = vld1q_f32(dz + 12);
1872
                float32x4_t d4 = vld1q_f32(dz + 16);
1873
                float32x4_t ds1x2 = vaddq_f32(d1, d2);
1874
                float32x4_t ds3x4 = vaddq_f32(d3, d4);
1875
                vst1q_f32(bpz, vaddq_f32(ds1x2, vaddq_f32(d0, ds3x4)));
1876
                float32x4_t dn1x2 = vsubq_f32(d1, d2);
1877
                float32x4_t dn3x4 = vsubq_f32(d3, d4);
1878
                dn3x4 = vaddq_f32(dn3x4, dn3x4);
1879
                vst1q_f32(bpz + bstride[2], vaddq_f32(dn1x2, dn3x4));
1880
                ds3x4 = vaddq_f32(ds3x4, ds3x4);
1881
                ds3x4 = vaddq_f32(ds3x4, ds3x4);
1882
                vst1q_f32(bpz + 2 * bstride[2], vaddq_f32(ds1x2, ds3x4));
1883
                float32x4_t d5 = vld1q_f32(dz + 20);
1884
                dn3x4 = vaddq_f32(dn3x4, dn3x4);
1885
                dn3x4 = vaddq_f32(dn3x4, dn3x4);
1886
                vst1q_f32(bpz + 3 * bstride[2], vaddq_f32(vaddq_f32(dn1x2, d5), dn3x4));
1887
                bpz += bstride[1];
1888
              } unroll_endfor
1889
              break;
1890
          };
1891
        }
1892
      }
1893
    } parallel_endfor
1894
  }
1895
  return CCV_NNC_EXEC_SUCCESS;
1896
}
1897
#endif
1898
1899
int _ccv_nnc_conv_forw_4x4_3x3_winograd_cpu_opt(const ccv_nnc_tensor_view_t* const a, const ccv_nnc_tensor_t* const w, const ccv_nnc_tensor_t* const bias, const ccv_nnc_hint_t hint, ccv_nnc_tensor_view_t* const b, ccv_nnc_stream_context_t* const stream_context)
1900
119
{
1901
119
#if defined(HAVE_SSE2)
1902
119
  if (w->info.dim[0] % 4 == 0)
1903
119
    return _ccv_nnc_conv_forw_4x4_3x3_winograd_sse2(a, w, bias, hint, b, stream_context);
1904
#elif defined(HAVE_NEON)
1905
  if (w->info.dim[0] % 4 == 0)
1906
    return _ccv_nnc_conv_forw_4x4_3x3_winograd_neon(a, w, bias, hint, b, stream_context);
1907
#endif
1908
0
  return _ccv_nnc_conv_forw_4x4_3x3_winograd_ref(a, w, bias, hint, b, stream_context);
1909
119
}