Coverage Report

Created: 2022-08-03 23:52

/home/liu/buildslave/linux-x64-runtests/build/lib/nnc/cmd/convolution/cpu_opt/_ccv_nnc_conv_cpu_4x4_3x3_winograd.c
Line
Count
Source (jump to first uncovered line)
1
#include "ccv.h"
2
#include "ccv_internal.h"
3
#include "nnc/ccv_nnc.h"
4
#include "nnc/ccv_nnc_easy.h"
5
#include "nnc/ccv_nnc_internal.h"
6
#if defined(HAVE_SSE2)
7
#include <xmmintrin.h>
8
#elif defined(HAVE_NEON)
9
#include <arm_neon.h>
10
#endif
11
#ifdef USE_OPENMP
12
#include <omp.h>
13
#endif
14
#ifdef USE_DISPATCH
15
#include <dispatch/dispatch.h>
16
#endif
17
#include "../_ccv_nnc_conv_cpu_opt.h"
18
19
#define set_n_m_dim(i, x, wd, ad) \
20
56.4k
  do { \
21
56.4k
    n[x] = ccv_max((i) * hint.stride.dim[x] - hint.border.begin[x], 0) - ((i) * hint.stride.dim[x] - hint.border.begin[x]); \
22
56.4k
    m[x] = wd[x + 1] - n[x] - ((i) * hint.stride.dim[x] - hint.border.begin[x] + wd[x + 1] - ccv_min(ad[x], (i) * hint.stride.dim[x] - hint.border.begin[x] + wd[x + 1])); \
23
56.3k
  } while (0)
24
25
inline static void _ccv_nnc_winograd_4x4_3x3_gwtg_ref(const float* const w, const int c, float* gwtg)
26
0
{
27
0
  int i;
28
0
  for (i = 0; i < c; i++)
29
0
  {
30
0
    float g[18];
31
    /*
32
     * a0, b1, c2
33
     * d3, e4, f5
34
     * g6, h7, i8
35
     * {{a/4, b/4, c/4},
36
     * {1/6 (-a - d - g), 1/6 (-b - e - h), 1/6 (-c - f - i)},
37
     * {1/6 (-a + d - g), 1/6 (-b + e - h), 1/6 (-c + f - i)},
38
     * {1/24 (a + 2 d + 4 g), 1/24 (b + 2 e + 4 h), 1/24 (c + 2 f + 4 i)},
39
     * {1/24 (a - 2 d + 4 g), 1/24 (b - 2 e + 4 h), 1/24 (c - 2 f + 4 i)},
40
     * {g, h, i}}
41
     */
42
    /* row 1 */
43
0
    g[0] = w[i] / 4;
44
0
    g[1] = w[c + i] / 4;
45
0
    g[2] = w[2 * c + i] / 4;
46
    /* row 2 */
47
0
    g[3] = -(w[i] + w[3 * c + i] + w[6 * c + i]) / 6;
48
0
    g[4] = -(w[c + i] + w[4 * c + i] + w[7 * c + i]) / 6;
49
0
    g[5] = -(w[2 * c + i] + w[5 * c + i] + w[8 * c + i]) / 6;
50
    /* row 3 */
51
0
    g[6] = (-w[i] + w[3 * c + i] - w[6 * c + i]) / 6;
52
0
    g[7] = (-w[c + i] + w[4 * c + i] - w[7 * c + i]) / 6;
53
0
    g[8] = (-w[2 * c + i] + w[5 * c + i] - w[8 * c + i]) / 6;
54
    /* row 4 */
55
0
    g[9] = (w[i] + 2 * w[3 * c + i] + 4 * w[6 * c + i]) / 24;
56
0
    g[10] = (w[c + i] + 2 * w[4 * c + i] + 4 * w[7 * c + i]) / 24;
57
0
    g[11] = (w[2 * c + i] + 2 * w[5 * c + i] + 4 * w[8 * c + i]) / 24;
58
    /* row 5 */
59
0
    g[12] = (w[i] - 2 * w[3 * c + i] + 4 * w[6 * c + i]) / 24;
60
0
    g[13] = (w[c + i] - 2 * w[4 * c + i] + 4 * w[7 * c + i]) / 24;
61
0
    g[14] = (w[2 * c + i] - 2 * w[5 * c + i] + 4 * w[8 * c + i]) / 24;
62
    /* row 6 */
63
0
    g[15] = w[6 * c + i];
64
0
    g[16] = w[7 * c + i];
65
0
    g[17] = w[8 * c + i];
66
    /*
67
     * a0, b1, c2
68
     * d3, e4, f5
69
     * g6, h7, i8
70
     * j9, k10,l11
71
     * m12,n13,o14
72
     * p15,q16,r17
73
     * {{a/4, 1/6 (-a - b - c), 1/6 (-a + b - c), 1/24 (a + 2 b + 4 c), 1/24 (a - 2 b + 4 c), c},
74
     * {d/4, 1/6 (-d - e - f), 1/6 (-d + e - f), 1/24 (d + 2 e + 4 f), 1/24 (d - 2 e + 4 f), f},
75
     * {g/4, 1/6 (-g - h - i), 1/6 (-g + h - i), 1/24 (g + 2 h + 4 i), 1/24 (g - 2 h + 4 i), i},
76
     * {j/4, 1/6 (-j - k - l), 1/6 (-j + k - l), 1/24 (j + 2 k + 4 l), 1/24 (j - 2 k + 4 l), l},
77
     * {m/4, 1/6 (-m - n - o), 1/6 (-m + n - o), 1/24 (m + 2 n + 4 o), 1/24 (m - 2 n + 4 o), o},
78
     * {p/4, 1/6 (-p - q - r), 1/6 (-p + q - r), 1/24 (p + 2 q + 4 r), 1/24 (p - 2 q + 4 r), r}}
79
     */
80
    /* row 1 */
81
0
    gwtg[0] = g[0] / 4;
82
0
    gwtg[c] = -(g[0] + g[1] + g[2]) / 6;
83
0
    gwtg[2 * c] = (-g[0] + g[1] - g[2]) / 6;
84
0
    gwtg[3 * c] = (g[0] + 2 * g[1] + 4 * g[2]) / 24;
85
0
    gwtg[4 * c] = (g[0] - 2 * g[1] + 4 * g[2]) / 24;
86
0
    gwtg[5 * c] = g[2];
87
    /* row 2 */
88
0
    gwtg[6 * c] = g[3] / 4;
89
0
    gwtg[7 * c] = -(g[3] + g[4] + g[5]) / 6;
90
0
    gwtg[8 * c] = (-g[3] + g[4] - g[5]) / 6;
91
0
    gwtg[9 * c] = (g[3] + 2 * g[4] + 4 * g[5]) / 24;
92
0
    gwtg[10 * c] = (g[3] - 2 * g[4] + 4 * g[5]) / 24;
93
0
    gwtg[11 * c] = g[5];
94
    /* row 3 */
95
0
    gwtg[12 * c] = g[6] / 4;
96
0
    gwtg[13 * c] = -(g[6] + g[7] + g[8]) / 6;
97
0
    gwtg[14 * c] = (-g[6] + g[7] - g[8]) / 6;
98
0
    gwtg[15 * c] = (g[6] + 2 * g[7] + 4 * g[8]) / 24;
99
0
    gwtg[16 * c] = (g[6] - 2 * g[7] + 4 * g[8]) / 24;
100
0
    gwtg[17 * c] = g[8];
101
    /* row 4 */
102
0
    gwtg[18 * c] = g[9] / 4;
103
0
    gwtg[19 * c] = -(g[9] + g[10] + g[11]) / 6;
104
0
    gwtg[20 * c] = (-g[9] + g[10] - g[11]) / 6;
105
0
    gwtg[21 * c] = (g[9] + 2 * g[10] + 4 * g[11]) / 24;
106
0
    gwtg[22 * c] = (g[9] - 2 * g[10] + 4 * g[11]) / 24;
107
0
    gwtg[23 * c] = g[11];
108
    /* row 5 */
109
0
    gwtg[24 * c] = g[12] / 4;
110
0
    gwtg[25 * c] = -(g[12] + g[13] + g[14]) / 6;
111
0
    gwtg[26 * c] = (-g[12] + g[13] - g[14]) / 6;
112
0
    gwtg[27 * c] = (g[12] + 2 * g[13] + 4 * g[14]) / 24;
113
0
    gwtg[28 * c] = (g[12] - 2 * g[13] + 4 * g[14]) / 24;
114
0
    gwtg[29 * c] = g[14];
115
    /* row 6 */
116
0
    gwtg[30 * c] = g[15] / 4;
117
0
    gwtg[31 * c] = -(g[15] + g[16] + g[17]) / 6;
118
0
    gwtg[32 * c] = (-g[15] + g[16] - g[17]) / 6;
119
0
    gwtg[33 * c] = (g[15] + 2 * g[16] + 4 * g[17]) / 24;
120
0
    gwtg[34 * c] = (g[15] - 2 * g[16] + 4 * g[17]) / 24;
121
0
    gwtg[35 * c] = g[17];
122
0
    ++gwtg;
123
0
  }
124
0
}
125
126
static int _ccv_nnc_conv_forw_4x4_3x3_winograd_ref(const ccv_nnc_tensor_view_t* const a, const ccv_nnc_tensor_t* const w, const ccv_nnc_tensor_t* const bias, const ccv_nnc_hint_t hint, ccv_nnc_tensor_view_t* const b, ccv_nnc_stream_context_t* const stream_context)
127
0
{
128
0
  const int a_nd = ccv_nnc_tensor_nd(a->info.dim);
129
0
  assert(a_nd == CCV_NNC_MAX_DIM + 1 || a_nd == CCV_NNC_MAX_DIM + 2);
130
0
  const int* adim = (a_nd == CCV_NNC_MAX_DIM + 1) ? a->info.dim : a->info.dim + 1;
131
0
  const int b_nd = ccv_nnc_tensor_nd(b->info.dim);
132
0
  assert(b_nd == CCV_NNC_MAX_DIM + 1 || b_nd == CCV_NNC_MAX_DIM + 2);
133
0
  const int* bdim = (b_nd == CCV_NNC_MAX_DIM + 1) ? b->info.dim : b->info.dim + 1;
134
0
  const int* ainc = CCV_IS_TENSOR_VIEW(a) ? ((a_nd == CCV_NNC_MAX_DIM + 1) ? a->inc : a->inc + 1) : adim;
135
0
  const int* binc = CCV_IS_TENSOR_VIEW(b) ? ((b_nd == CCV_NNC_MAX_DIM + 1) ? b->inc : b->inc + 1) : bdim;
136
0
  assert(hint.border.begin[0] <= 1);
137
0
  assert(hint.border.begin[1] <= 1);
138
0
  assert(w->info.dim[1] == 3);
139
0
  assert(w->info.dim[2] == 3);
140
0
  const int jump_dim = (bdim[0] + 3) / 4;
141
0
  float* workmem;
142
  // allocating workspace memory for kernel reshaping and input reshaping.
143
0
#if FOR_IS_PARALLEL
144
  // If we do parallel for, we need to allocate input reshaping for each block.
145
0
  workmem = ccv_nnc_stream_context_get_workspace(stream_context, sizeof(float) * (36 * adim[2] * jump_dim + 36 * w->info.dim[0] * w->info.dim[3]), CCV_TENSOR_CPU_MEMORY);
146
#else
147
  // Otherwise, just one block.
148
  workmem = ccv_nnc_stream_context_get_workspace(stream_context, sizeof(float) * (36 * adim[2] + 36 * w->info.dim[0] * w->info.dim[3]), CCV_TENSOR_CPU_MEMORY);
149
#endif
150
0
  if (!workmem)
151
0
    return CCV_NNC_EXEC_OOM;
152
  // Convert w to a 6x6 matrix, by computing G.w.T(G) // T for transpose.
153
0
  float* const gwtg = workmem;
154
0
  float* const btdb = workmem + 36 * w->info.dim[0] * w->info.dim[3];
155
0
  parallel_for(k, w->info.dim[0]) {
156
0
    _ccv_nnc_winograd_4x4_3x3_gwtg_ref(w->data.f32 + k * w->info.dim[3] * w->info.dim[2] * w->info.dim[1], w->info.dim[3], gwtg + k * 36 * w->info.dim[3]);
157
0
  } parallel_endfor
158
  // kernel weight for one dim.
159
  // Workaround issues of dispatch_apply (cannot reference to on-stack array)
160
0
  const int tile_dim_s[CCV_NNC_MAX_DIM_ALLOC] = {
161
0
    w->info.dim[0], 6, 6, w->info.dim[3]
162
0
  };
163
0
  const int* const tile_dim = tile_dim_s;
164
  // This block will be cause in each for-loop, therefore, you can use it to generate some temporary variables.
165
0
  if (bias)
166
0
  {
167
0
    const float* const biasval = bias->data.f32;
168
0
    parallel_for(i, jump_dim) {
169
0
      const int y = i * 4; // i is unsigned.
170
0
      int j, x, k, c;
171
0
      int n[CCV_NNC_MAX_DIM];
172
0
      int m[CCV_NNC_MAX_DIM];
173
0
      int z[CCV_NNC_MAX_DIM];
174
0
      set_n_m_dim(y, 0, tile_dim, adim);
175
0
      z[0] = ccv_min(y + 4, bdim[0]) - y;
176
0
      const float* ap = a->data.f32 + ccv_max(y - hint.border.begin[0], 0) * ainc[1] * ainc[2];
177
0
      float* bp = b->data.f32 + y * binc[1] * binc[2];
178
0
      for (x = 0; x < bdim[1]; x += 4)
179
0
      {
180
0
        set_n_m_dim(x, 1, tile_dim, adim);
181
0
        z[1] = ccv_min(x + 4, bdim[1]) - x;
182
0
#if FOR_IS_PARALLEL
183
0
        float* g = btdb + i * 36 * adim[2];
184
#else
185
        float* g = btdb;
186
#endif
187
        // zero g such that we can have zero-padding.
188
0
        memset(g, 0, sizeof(float) * 36 * adim[2]);
189
0
        int dx, dy;
190
0
        const float* apz = ap + ccv_max(x - hint.border.begin[1], 0) * ainc[2];
191
0
        float* gz = g + (n[0] * 6 + n[1]) * adim[2];
192
0
        unroll_for(dy, m[0], 6) {
193
0
          unroll_for(dx, m[1], 6) {
194
0
            float* const gzu = gz + (dy * 6 + dx) * adim[2];
195
0
            for (c = 0; c < adim[2]; c++)
196
0
              gzu[c] = apz[dx * ainc[2] + c];
197
0
          } unroll_endfor
198
0
          apz += ainc[1] * ainc[2];
199
0
        } unroll_endfor
200
0
        for (c = 0; c < adim[2]; c++)
201
0
        {
202
          /*
203
           * a0, a1, a2, a3, a4, a5,
204
           * b6, b7, b8, b9, b10,l11,
205
           * c12,c13,c14,c15,c16,c17,
206
           * d18,d19,d20,d21,d22,d23,
207
           * e24,e25,e26,e27,e28,e29,
208
           * f30,f31,f32,f33,f34,f35
209
           * {{4 a0 - 5 c12 + e24, 4 a1 - 5 c13 + e25, 4 a2 - 5 c14 + e26, 4 a3 - 5 c15 + e27, 4 a4 - 5 c16 + e28, 4 a5 - 5 c17 + e29},
210
           * {-4 b6 - 4 c12 + d18 + e24, -4 b7 - 4 c13 + d19 + e25, -4 b8 - 4 c14 + d20 + e26, -4 b9 - 4 c15 + d21 + e27, -4 b10 - 4 c16 + d22 + e28, -4 b11 - 4 c17 + d23 + e29},
211
           * {4 b6 - 4 c12 - d18 + e24, 4 b7 - 4 c13 - d19 + e25, 4 b8 - 4 c14 - d20 + e26, 4 b9 - 4 c15 - d21 + e27, 4 b10 - 4 c16 - d22 + e28, 4 b11 - 4 c17 - d23 + e29},
212
           * {-2 b6 - c12 + 2 d18 + e24, -2 b7 - c13 + 2 d19 + e25, -2 b8 - c14 + 2 d20 + e26, -2 b9 - c15 + 2 d21 + e27, -2 b10 - c16 + 2 d22 + e28, -2 b11 - c17 + 2 d23 + e29},
213
           * {2 b6 - c12 - 2 d18 + e24, 2 b7 - c13 - 2 d19 + e25, 2 b8 - c14 - 2 d20 + e26, 2 b9 - c15 - 2 d21 + e27, 2 b10 - c16 - 2 d22 + e28, 2 b11 - c17 - 2 d23 + e29},
214
           * {4 b6 - 5 d18 + f30, 4 b7 - 5 d19 + f31, 4 b8 - 5 d20 + f32, 4 b9 - 5 d21 + f33, 4 b10 - 5 d22 + f34, 4 b11 - 5 d23 + f35}}
215
           */
216
0
          float d[36];
217
          /* BT.d */
218
0
          unroll_for(j, 6) {
219
0
            float g0 = g[j * adim[2]];
220
0
            float g12 = g[(12 + j) * adim[2]];
221
0
            float g24 = g[(24 + j) * adim[2]];
222
            /* row 1 */
223
0
            d[j] = 4 * g0 - 5 * g12 + g24;
224
0
            float g6 = g[(6 + j) * adim[2]];
225
0
            float g18 = g[(18 + j) * adim[2]];
226
            /* row 2 */
227
0
            d[6 + j] = -4 * (g6 + g12) + g18 + g24;
228
            /* row 3 */
229
0
            d[12 + j] = 4 * (g6 - g12) - g18 + g24;
230
            /* row 4 */
231
0
            d[18 + j] = 2 * (g18 - g6) - g12 + g24;
232
            /* row 5 */
233
0
            d[24 + j] = 2 * (g6 - g18) - g12 + g24;
234
0
            float g30 = g[(30 + j) * adim[2]];
235
            /* row 6 */
236
0
            d[30 + j] = 4 * g6 - 5 * g18 + g30;
237
0
          } unroll_endfor
238
          /*
239
           * a0, a1, a2, a3, a4, a5,
240
           * b6, b7, b8, b9, b10,l11,
241
           * c12,c13,c14,c15,c16,c17,
242
           * d18,d19,d20,d21,d22,d23,
243
           * e24,e25,e26,e27,e28,e29,
244
           * f30,f31,f32,f33,f34,f35
245
           * {{4 a0 - 5 a2 + a4, -4 a1 - 4 a2 + a3 + a4, 4 a1 - 4 a2 - a3 + a4, -2 a1 - a2 + 2 a3 + a4, 2 a1 - a2 - 2 a3 + a4, 4 a1 - 5 a3 + a5},
246
           * {b10 + 4 b6 - 5 b8, b10 - 4 b7 - 4 b8 + b9, b10 + 4 b7 - 4 b8 - b9, b10 - 2 b7 - b8 + 2 b9, b10 + 2 b7 - b8 - 2 b9, b11 + 4 b7 - 5 b9},
247
           * {4 c12 - 5 c14 + c16, -4 c13 - 4 c14 + c15 + c16, 4 c13 - 4 c14 - c15 + c16, -2 c13 - c14 + 2 c15 + c16, 2 c13 - c14 - 2 c15 + c16, 4 c13 - 5 c15 + c17},
248
           * {4 d18 - 5 d20 + d22, -4 d19 - 4 d20 + d21 + d22, 4 d19 - 4 d20 - d21 + d22, -2 d19 - d20 + 2 d21 + d22, 2 d19 - d20 - 2 d21 + d22, 4 d19 - 5 d21 + d23},
249
           * {4 e24 - 5 e26 + e28, -4 e25 - 4 e26 + e27 + e28, 4 e25 - 4 e26 - e27 + e28, -2 e25 - e26 + 2 e27 + e28, 2 e25 - e26 - 2 e27 + e28, 4 e25 - 5 e27 + e29},
250
           * {4 f30 - 5 f32 + f34, -4 f31 - 4 f32 + f33 + f34, 4 f31 - 4 f32 - f33 + f34, -2 f31 - f32 + 2 f33 + f34, 2 f31 - f32 - 2 f33 + f34, 4 f31 - 5 f33 + f35}}
251
           */
252
          /* BT.d.B */
253
0
          unroll_for(j, 6) {
254
            /* row 1 - 6 */
255
0
            float* const gz = g + j * 6 * adim[2];
256
0
            float* const dz = d + j * 6;
257
0
            gz[0] = 4 * dz[0] - 5 * dz[2] + dz[4];
258
0
            gz[adim[2]] = -4 * (dz[1] + dz[2]) + dz[3] + dz[4];
259
0
            gz[2 * adim[2]] = 4 * (dz[1] - dz[2]) - dz[3] + dz[4];
260
0
            gz[3 * adim[2]] = 2 * (dz[3] - dz[1]) - dz[2] + dz[4];
261
0
            gz[4 * adim[2]] = 2 * (dz[1] - dz[3]) - dz[2] + dz[4];
262
0
            gz[5 * adim[2]] = 4 * dz[1] - 5 * dz[3] + dz[5];
263
0
          } unroll_endfor
264
          // move to the next channel
265
0
          ++g;
266
0
        }
267
0
        const float* wpz = gwtg;
268
0
        for (k = 0; k < w->info.dim[0]; k++)
269
0
        {
270
0
          float q[36];
271
0
#if FOR_IS_PARALLEL
272
0
          g = btdb + i * 36 * adim[2];
273
#else
274
          g = btdb;
275
#endif
276
0
          for (j = 0; j < 36; j++)
277
0
          {
278
0
            float b = 0;
279
0
            for (c = 0; c < adim[2]; c++)
280
0
              b += g[c] * wpz[c];
281
0
            q[j] = b;
282
0
            g += adim[2];
283
0
            wpz += adim[2];
284
0
          }
285
          /*
286
           * a0, a1, a2, a3, a4, a5,
287
           * b6, b7, b8, b9, b10,l11,
288
           * c12,c13,c14,c15,c16,c17,
289
           * d18,d19,d20,d21,d22,d23,
290
           * e24,e25,e26,e27,e28,e29,
291
           * f30,f31,f32,f33,f34,f35
292
           * {{a0 + b6 + c12 + d18 + e24, a1 + b7 + c13 + d19 + e25, a2 + b8 + c14 + d20 + e26, a3 + b9 + c15 + d21 + e27, a4 + b10 + c16 + d22 + e28, a5 + b11 + c17 + d23 + e29},
293
           * {b6 - c12 + 2 d18 - 2 e24, b7 - c13 + 2 d19 - 2 e25, b8 - c14 + 2 d20 - 2 e26, b9 - c15 + 2 d21 - 2 e27, b10 - c16 + 2 d22 - 2 e28, b11 - c17 + 2 d23 - 2 e29},
294
           * {b6 + c12 + 4 (d18 + e24), b7 + c13 + 4 (d19 + e25), b8 + c14 + 4 (d20 + e26), b9 + c15 + 4 (d21 + e27), b10 + c16 + 4 (d22 + e28), b11 + c17 + 4 (d23 + e29)},
295
           * {b6 - c12 + 8 d18 - 8 e24 + f30, b7 - c13 + 8 d19 - 8 e25 + f31, b8 - c14 + 8 d20 - 8 e26 + f32, b9 - c15 + 8 d21 - 8 e27 + f33, b10 - c16 + 8 d22 - 8 e28 + f34, b11 - c17 + 8 d23 - 8 e29 + f35}}
296
           */
297
0
          float d[24];
298
          /* row 1 */
299
0
          d[0] = q[0] + q[6] + q[12] + q[18] + q[24];
300
0
          d[1] = q[1] + q[7] + q[13] + q[19] + q[25];
301
0
          d[2] = q[2] + q[8] + q[14] + q[20] + q[26];
302
0
          d[3] = q[3] + q[9] + q[15] + q[21] + q[27];
303
0
          d[4] = q[4] + q[10] + q[16] + q[22] + q[28];
304
0
          d[5] = q[5] + q[11] + q[17] + q[23] + q[29];
305
          /* row 2 */
306
0
          d[6] = q[6] - q[12] + 2 * (q[18] - q[24]);
307
0
          d[7] = q[7] - q[13] + 2 * (q[19] - q[25]);
308
0
          d[8] = q[8] - q[14] + 2 * (q[20] - q[26]);
309
0
          d[9] = q[9] - q[15] + 2 * (q[21] - q[27]);
310
0
          d[10] = q[10] - q[16] + 2 * (q[22] - q[28]);
311
0
          d[11] = q[11] - q[17] + 2 * (q[23] - q[29]);
312
          /* row 3 */
313
0
          d[12] = q[6] + q[12] + 4 * (q[18] + q[24]);
314
0
          d[13] = q[7] + q[13] + 4 * (q[19] + q[25]);
315
0
          d[14] = q[8] + q[14] + 4 * (q[20] + q[26]);
316
0
          d[15] = q[9] + q[15] + 4 * (q[21] + q[27]);
317
0
          d[16] = q[10] + q[16] + 4 * (q[22] + q[28]);
318
0
          d[17] = q[11] + q[17] + 4 * (q[23] + q[29]);
319
          /* row 4 */
320
0
          d[18] = q[6] - q[12] + 8 * (q[18] - q[24]) + q[30];
321
0
          d[19] = q[7] - q[13] + 8 * (q[19] - q[25]) + q[31];
322
0
          d[20] = q[8] - q[14] + 8 * (q[20] - q[26]) + q[32];
323
0
          d[21] = q[9] - q[15] + 8 * (q[21] - q[27]) + q[33];
324
0
          d[22] = q[10] - q[16] + 8 * (q[22] - q[28]) + q[34];
325
0
          d[23] = q[11] - q[17] + 8 * (q[23] - q[29]) + q[35];
326
          /*
327
           * {{a0 + a1 + a2 + a3 + a4, a1 - a2 + 2 a3 - 2 a4, a1 + a2 + 4 (a3 + a4), a1 - a2 + 8 a3 - 8 a4 + a5},
328
           * {b10 + b6 + b7 + b8 + b9, -2 b10 + b7 - b8 + 2 b9, 4 b10 + b7 + b8 + 4 b9, -8 b10 + b11 + b7 - b8 + 8 b9},
329
           * {c12 + c13 + c14 + c15 + c16, c13 - c14 + 2 c15 - 2 c16, c13 + c14 + 4 (c15 + c16), c13 - c14 + 8 c15 - 8 c16 + c17},
330
           * {d18 + d19 + d20 + d21 + d22, d19 - d20 + 2 d21 - 2 d22, d19 + d20 + 4 (d21 + d22), d19 - d20 + 8 d21 - 8 d22 + d23}}
331
           */
332
0
          float* bpz = bp + x * binc[2] + k;
333
0
          unroll_for(dy, z[0], 4) {
334
0
            float r[] = {
335
0
              d[dy * 6 + 0] + d[dy * 6 + 1] + d[dy * 6 + 2] + d[dy * 6 + 3] + d[dy * 6 + 4] + biasval[k],
336
0
              d[dy * 6 + 1] - d[dy * 6 + 2] + 2 * (d[dy * 6 + 3] - d[dy * 6 + 4]) + biasval[k],
337
0
              d[dy * 6 + 1] + d[dy * 6 + 2] + 4 * (d[dy * 6 + 3] + d[dy * 6 + 4]) + biasval[k],
338
0
              d[dy * 6 + 1] - d[dy * 6 + 2] + 8 * (d[dy * 6 + 3] - d[dy * 6 + 4]) + d[dy * 6 + 5] + biasval[k],
339
0
            };
340
0
            unroll_for(dx, z[1], 4) {
341
0
              bpz[dx * binc[2]] = r[dx];
342
0
            } unroll_endfor
343
0
            bpz += binc[1] * binc[2];
344
0
          } unroll_endfor
345
0
        }
346
0
      }
347
0
    } parallel_endfor
348
0
  } else {
349
0
    parallel_for(i, jump_dim) {
350
0
      const int y = i * 4; // i is unsigned.
351
0
      int j, x, k, c;
352
0
      int n[CCV_NNC_MAX_DIM];
353
0
      int m[CCV_NNC_MAX_DIM];
354
0
      int z[CCV_NNC_MAX_DIM];
355
0
      set_n_m_dim(y, 0, tile_dim, adim);
356
0
      z[0] = ccv_min(y + 4, bdim[0]) - y;
357
0
      const float* ap = a->data.f32 + ccv_max(y - hint.border.begin[0], 0) * ainc[1] * ainc[2];
358
0
      float* bp = b->data.f32 + y * binc[1] * binc[2];
359
0
      for (x = 0; x < bdim[1]; x += 4)
360
0
      {
361
0
        set_n_m_dim(x, 1, tile_dim, adim);
362
0
        z[1] = ccv_min(x + 4, bdim[1]) - x;
363
0
#if FOR_IS_PARALLEL
364
0
        float* g = btdb + i * 36 * adim[2];
365
#else
366
        float* g = btdb;
367
#endif
368
        // zero g such that we can have zero-padding.
369
0
        memset(g, 0, sizeof(float) * 36 * adim[2]);
370
0
        int dx, dy;
371
0
        const float* apz = ap + ccv_max(x - hint.border.begin[1], 0) * ainc[2];
372
0
        float* gz = g + (n[0] * 6 + n[1]) * adim[2];
373
0
        unroll_for(dy, m[0], 6) {
374
0
          unroll_for(dx, m[1], 6) {
375
0
            float* const gzu = gz + (dy * 6 + dx) * adim[2];
376
0
            for (c = 0; c < adim[2]; c++)
377
0
              gzu[c] = apz[dx * ainc[2] + c];
378
0
          } unroll_endfor
379
0
          apz += ainc[1] * ainc[2];
380
0
        } unroll_endfor
381
0
        for (c = 0; c < adim[2]; c++)
382
0
        {
383
          /*
384
           * a0, a1, a2, a3, a4, a5,
385
           * b6, b7, b8, b9, b10,l11,
386
           * c12,c13,c14,c15,c16,c17,
387
           * d18,d19,d20,d21,d22,d23,
388
           * e24,e25,e26,e27,e28,e29,
389
           * f30,f31,f32,f33,f34,f35
390
           * {{4 a0 - 5 c12 + e24, 4 a1 - 5 c13 + e25, 4 a2 - 5 c14 + e26, 4 a3 - 5 c15 + e27, 4 a4 - 5 c16 + e28, 4 a5 - 5 c17 + e29},
391
           * {-4 b6 - 4 c12 + d18 + e24, -4 b7 - 4 c13 + d19 + e25, -4 b8 - 4 c14 + d20 + e26, -4 b9 - 4 c15 + d21 + e27, -4 b10 - 4 c16 + d22 + e28, -4 b11 - 4 c17 + d23 + e29},
392
           * {4 b6 - 4 c12 - d18 + e24, 4 b7 - 4 c13 - d19 + e25, 4 b8 - 4 c14 - d20 + e26, 4 b9 - 4 c15 - d21 + e27, 4 b10 - 4 c16 - d22 + e28, 4 b11 - 4 c17 - d23 + e29},
393
           * {-2 b6 - c12 + 2 d18 + e24, -2 b7 - c13 + 2 d19 + e25, -2 b8 - c14 + 2 d20 + e26, -2 b9 - c15 + 2 d21 + e27, -2 b10 - c16 + 2 d22 + e28, -2 b11 - c17 + 2 d23 + e29},
394
           * {2 b6 - c12 - 2 d18 + e24, 2 b7 - c13 - 2 d19 + e25, 2 b8 - c14 - 2 d20 + e26, 2 b9 - c15 - 2 d21 + e27, 2 b10 - c16 - 2 d22 + e28, 2 b11 - c17 - 2 d23 + e29},
395
           * {4 b6 - 5 d18 + f30, 4 b7 - 5 d19 + f31, 4 b8 - 5 d20 + f32, 4 b9 - 5 d21 + f33, 4 b10 - 5 d22 + f34, 4 b11 - 5 d23 + f35}}
396
           */
397
0
          float d[36];
398
          /* BT.d */
399
0
          unroll_for(j, 6) {
400
0
            float g0 = g[j * adim[2]];
401
0
            float g12 = g[(12 + j) * adim[2]];
402
0
            float g24 = g[(24 + j) * adim[2]];
403
            /* row 1 */
404
0
            d[j] = 4 * g0 - 5 * g12 + g24;
405
0
            float g6 = g[(6 + j) * adim[2]];
406
0
            float g18 = g[(18 + j) * adim[2]];
407
            /* row 2 */
408
0
            d[6 + j] = -4 * (g6 + g12) + g18 + g24;
409
            /* row 3 */
410
0
            d[12 + j] = 4 * (g6 - g12) - g18 + g24;
411
            /* row 4 */
412
0
            d[18 + j] = 2 * (g18 - g6) - g12 + g24;
413
            /* row 5 */
414
0
            d[24 + j] = 2 * (g6 - g18) - g12 + g24;
415
0
            float g30 = g[(30 + j) * adim[2]];
416
            /* row 6 */
417
0
            d[30 + j] = 4 * g6 - 5 * g18 + g30;
418
0
          } unroll_endfor
419
          /*
420
           * a0, a1, a2, a3, a4, a5,
421
           * b6, b7, b8, b9, b10,l11,
422
           * c12,c13,c14,c15,c16,c17,
423
           * d18,d19,d20,d21,d22,d23,
424
           * e24,e25,e26,e27,e28,e29,
425
           * f30,f31,f32,f33,f34,f35
426
           * {{4 a0 - 5 a2 + a4, -4 a1 - 4 a2 + a3 + a4, 4 a1 - 4 a2 - a3 + a4, -2 a1 - a2 + 2 a3 + a4, 2 a1 - a2 - 2 a3 + a4, 4 a1 - 5 a3 + a5},
427
           * {b10 + 4 b6 - 5 b8, b10 - 4 b7 - 4 b8 + b9, b10 + 4 b7 - 4 b8 - b9, b10 - 2 b7 - b8 + 2 b9, b10 + 2 b7 - b8 - 2 b9, b11 + 4 b7 - 5 b9},
428
           * {4 c12 - 5 c14 + c16, -4 c13 - 4 c14 + c15 + c16, 4 c13 - 4 c14 - c15 + c16, -2 c13 - c14 + 2 c15 + c16, 2 c13 - c14 - 2 c15 + c16, 4 c13 - 5 c15 + c17},
429
           * {4 d18 - 5 d20 + d22, -4 d19 - 4 d20 + d21 + d22, 4 d19 - 4 d20 - d21 + d22, -2 d19 - d20 + 2 d21 + d22, 2 d19 - d20 - 2 d21 + d22, 4 d19 - 5 d21 + d23},
430
           * {4 e24 - 5 e26 + e28, -4 e25 - 4 e26 + e27 + e28, 4 e25 - 4 e26 - e27 + e28, -2 e25 - e26 + 2 e27 + e28, 2 e25 - e26 - 2 e27 + e28, 4 e25 - 5 e27 + e29},
431
           * {4 f30 - 5 f32 + f34, -4 f31 - 4 f32 + f33 + f34, 4 f31 - 4 f32 - f33 + f34, -2 f31 - f32 + 2 f33 + f34, 2 f31 - f32 - 2 f33 + f34, 4 f31 - 5 f33 + f35}}
432
           */
433
          /* BT.d.B */
434
0
          unroll_for(j, 6) {
435
            /* row 1 - 6 */
436
0
            float* const gz = g + j * 6 * adim[2];
437
0
            float* const dz = d + j * 6;
438
0
            gz[0] = 4 * dz[0] - 5 * dz[2] + dz[4];
439
0
            gz[adim[2]] = -4 * (dz[1] + dz[2]) + dz[3] + dz[4];
440
0
            gz[2 * adim[2]] = 4 * (dz[1] - dz[2]) - dz[3] + dz[4];
441
0
            gz[3 * adim[2]] = 2 * (dz[3] - dz[1]) - dz[2] + dz[4];
442
0
            gz[4 * adim[2]] = 2 * (dz[1] - dz[3]) - dz[2] + dz[4];
443
0
            gz[5 * adim[2]] = 4 * dz[1] - 5 * dz[3] + dz[5];
444
0
          } unroll_endfor
445
          // move to the next channel
446
0
          ++g;
447
0
        }
448
0
        const float* wpz = gwtg;
449
0
        for (k = 0; k < w->info.dim[0]; k++)
450
0
        {
451
0
          float q[36];
452
0
#if FOR_IS_PARALLEL
453
0
          g = btdb + i * 36 * adim[2];
454
#else
455
          g = btdb;
456
#endif
457
0
          for (j = 0; j < 36; j++)
458
0
          {
459
0
            float b = 0;
460
0
            for (c = 0; c < adim[2]; c++)
461
0
              b += g[c] * wpz[c];
462
0
            q[j] = b;
463
0
            g += adim[2];
464
0
            wpz += adim[2];
465
0
          }
466
          /*
467
           * a0, a1, a2, a3, a4, a5,
468
           * b6, b7, b8, b9, b10,l11,
469
           * c12,c13,c14,c15,c16,c17,
470
           * d18,d19,d20,d21,d22,d23,
471
           * e24,e25,e26,e27,e28,e29,
472
           * f30,f31,f32,f33,f34,f35
473
           * {{a0 + b6 + c12 + d18 + e24, a1 + b7 + c13 + d19 + e25, a2 + b8 + c14 + d20 + e26, a3 + b9 + c15 + d21 + e27, a4 + b10 + c16 + d22 + e28, a5 + b11 + c17 + d23 + e29},
474
           * {b6 - c12 + 2 d18 - 2 e24, b7 - c13 + 2 d19 - 2 e25, b8 - c14 + 2 d20 - 2 e26, b9 - c15 + 2 d21 - 2 e27, b10 - c16 + 2 d22 - 2 e28, b11 - c17 + 2 d23 - 2 e29},
475
           * {b6 + c12 + 4 (d18 + e24), b7 + c13 + 4 (d19 + e25), b8 + c14 + 4 (d20 + e26), b9 + c15 + 4 (d21 + e27), b10 + c16 + 4 (d22 + e28), b11 + c17 + 4 (d23 + e29)},
476
           * {b6 - c12 + 8 d18 - 8 e24 + f30, b7 - c13 + 8 d19 - 8 e25 + f31, b8 - c14 + 8 d20 - 8 e26 + f32, b9 - c15 + 8 d21 - 8 e27 + f33, b10 - c16 + 8 d22 - 8 e28 + f34, b11 - c17 + 8 d23 - 8 e29 + f35}}
477
           */
478
0
          float d[24];
479
          /* row 1 */
480
0
          d[0] = q[0] + q[6] + q[12] + q[18] + q[24];
481
0
          d[1] = q[1] + q[7] + q[13] + q[19] + q[25];
482
0
          d[2] = q[2] + q[8] + q[14] + q[20] + q[26];
483
0
          d[3] = q[3] + q[9] + q[15] + q[21] + q[27];
484
0
          d[4] = q[4] + q[10] + q[16] + q[22] + q[28];
485
0
          d[5] = q[5] + q[11] + q[17] + q[23] + q[29];
486
          /* row 2 */
487
0
          d[6] = q[6] - q[12] + 2 * (q[18] - q[24]);
488
0
          d[7] = q[7] - q[13] + 2 * (q[19] - q[25]);
489
0
          d[8] = q[8] - q[14] + 2 * (q[20] - q[26]);
490
0
          d[9] = q[9] - q[15] + 2 * (q[21] - q[27]);
491
0
          d[10] = q[10] - q[16] + 2 * (q[22] - q[28]);
492
0
          d[11] = q[11] - q[17] + 2 * (q[23] - q[29]);
493
          /* row 3 */
494
0
          d[12] = q[6] + q[12] + 4 * (q[18] + q[24]);
495
0
          d[13] = q[7] + q[13] + 4 * (q[19] + q[25]);
496
0
          d[14] = q[8] + q[14] + 4 * (q[20] + q[26]);
497
0
          d[15] = q[9] + q[15] + 4 * (q[21] + q[27]);
498
0
          d[16] = q[10] + q[16] + 4 * (q[22] + q[28]);
499
0
          d[17] = q[11] + q[17] + 4 * (q[23] + q[29]);
500
          /* row 4 */
501
0
          d[18] = q[6] - q[12] + 8 * (q[18] - q[24]) + q[30];
502
0
          d[19] = q[7] - q[13] + 8 * (q[19] - q[25]) + q[31];
503
0
          d[20] = q[8] - q[14] + 8 * (q[20] - q[26]) + q[32];
504
0
          d[21] = q[9] - q[15] + 8 * (q[21] - q[27]) + q[33];
505
0
          d[22] = q[10] - q[16] + 8 * (q[22] - q[28]) + q[34];
506
0
          d[23] = q[11] - q[17] + 8 * (q[23] - q[29]) + q[35];
507
          /*
508
           * {{a0 + a1 + a2 + a3 + a4, a1 - a2 + 2 a3 - 2 a4, a1 + a2 + 4 (a3 + a4), a1 - a2 + 8 a3 - 8 a4 + a5},
509
           * {b10 + b6 + b7 + b8 + b9, -2 b10 + b7 - b8 + 2 b9, 4 b10 + b7 + b8 + 4 b9, -8 b10 + b11 + b7 - b8 + 8 b9},
510
           * {c12 + c13 + c14 + c15 + c16, c13 - c14 + 2 c15 - 2 c16, c13 + c14 + 4 (c15 + c16), c13 - c14 + 8 c15 - 8 c16 + c17},
511
           * {d18 + d19 + d20 + d21 + d22, d19 - d20 + 2 d21 - 2 d22, d19 + d20 + 4 (d21 + d22), d19 - d20 + 8 d21 - 8 d22 + d23}}
512
           */
513
0
          float* bpz = bp + x * binc[2] + k;
514
0
          unroll_for(dy, z[0], 4) {
515
0
            float r[] = {
516
0
              d[dy * 6 + 0] + d[dy * 6 + 1] + d[dy * 6 + 2] + d[dy * 6 + 3] + d[dy * 6 + 4],
517
0
              d[dy * 6 + 1] - d[dy * 6 + 2] + 2 * (d[dy * 6 + 3] - d[dy * 6 + 4]),
518
0
              d[dy * 6 + 1] + d[dy * 6 + 2] + 4 * (d[dy * 6 + 3] + d[dy * 6 + 4]),
519
0
              d[dy * 6 + 1] - d[dy * 6 + 2] + 8 * (d[dy * 6 + 3] - d[dy * 6 + 4]) + d[dy * 6 + 5],
520
0
            };
521
0
            unroll_for(dx, z[1], 4) {
522
0
              bpz[dx * binc[2]] = r[dx];
523
0
            } unroll_endfor
524
0
            bpz += binc[1] * binc[2];
525
0
          } unroll_endfor
526
0
        }
527
0
      }
528
0
    } parallel_endfor
529
0
  }
530
0
  return CCV_NNC_EXEC_SUCCESS;
531
0
}
532
533
#ifdef HAVE_SSE2
534
inline static void _ccv_nnc_winograd_4x4_3x3_gwtg_sse2(const float* const w, const int* const dim, float* const gwtg)
535
100
{
536
100
  const int jump_dim = dim[0] / 4;
537
100
  const int dimCx4 = (dim[3] + 3) & -4;
538
8.23k
  
parallel_for4.21k
(k, jump_dim) {
539
8.23k
    int i, j;
540
8.23k
    float* gwtgz = gwtg + k * 4 * 36 * dimCx4;
541
8.23k
    const float* wz[] = {
542
8.23k
      w + (k * 4) * 9 * dim[3],
543
8.23k
      w + (k * 4 + 1) * 9 * dim[3],
544
8.23k
      w + (k * 4 + 2) * 9 * dim[3],
545
8.23k
      w + (k * 4 + 3) * 9 * dim[3],
546
8.23k
    };
547
1.70M
    for (i = 0; i < dim[3]; 
i++1.69M
)
548
1.69M
    {
549
1.69M
      float x9w[9 * 4] __attribute__ ((__aligned__(16)));
550
14.4M
      unroll_for(j, 9) {
551
14.4M
        x9w[j * 4] = wz[0][j * dim[3] + i];
552
14.4M
        x9w[j * 4 + 1] = wz[1][j * dim[3] + i];
553
14.4M
        x9w[j * 4 + 2] = wz[2][j * dim[3] + i];
554
14.4M
        x9w[j * 4 + 3] = wz[3][j * dim[3] + i];
555
14.4M
      } unroll_endfor
556
1.69M
      float g[18 * 4] __attribute__ ((__aligned__(16)));
557
1.69M
      __m128 x9w0 = _mm_load_ps(x9w);
558
1.69M
      __m128 x9w1 = _mm_load_ps(x9w + 4);
559
1.69M
      __m128 x9w2 = _mm_load_ps(x9w + 8);
560
1.69M
      __m128 x9w3 = _mm_load_ps(x9w + 12);
561
1.69M
      __m128 x9w4 = _mm_load_ps(x9w + 16);
562
1.69M
      __m128 x9w5 = _mm_load_ps(x9w + 20);
563
1.69M
      __m128 x9w6 = _mm_load_ps(x9w + 24);
564
1.69M
      __m128 x9w7 = _mm_load_ps(x9w + 28);
565
1.69M
      __m128 x9w8 = _mm_load_ps(x9w + 32);
566
      /* row 1 */
567
1.69M
      __m128 c1_4 = _mm_set1_ps(1.0 / 4);
568
1.69M
      _mm_store_ps(g, _mm_mul_ps(x9w0, c1_4));
569
1.69M
      _mm_store_ps(g + 4, _mm_mul_ps(x9w1, c1_4));
570
1.69M
      _mm_store_ps(g + 8, _mm_mul_ps(x9w2, c1_4));
571
      /* row 2 */
572
1.69M
      __m128 cn1_6 = _mm_set1_ps(-1.0 / 6);
573
1.69M
      _mm_store_ps(g + 12, _mm_mul_ps(_mm_add_ps(_mm_add_ps(x9w0, x9w6), x9w3), cn1_6));
574
1.69M
      _mm_store_ps(g + 16, _mm_mul_ps(_mm_add_ps(_mm_add_ps(x9w1, x9w7), x9w4), cn1_6));
575
1.69M
      _mm_store_ps(g + 20, _mm_mul_ps(_mm_add_ps(_mm_add_ps(x9w2, x9w8), x9w5), cn1_6));
576
      /* row 3 */
577
1.69M
      _mm_store_ps(g + 24, _mm_mul_ps(_mm_sub_ps(_mm_add_ps(x9w0, x9w6), x9w3), cn1_6));
578
1.69M
      _mm_store_ps(g + 28, _mm_mul_ps(_mm_sub_ps(_mm_add_ps(x9w1, x9w7), x9w4), cn1_6));
579
1.69M
      _mm_store_ps(g + 32, _mm_mul_ps(_mm_sub_ps(_mm_add_ps(x9w2, x9w8), x9w5), cn1_6));
580
      /* row 6 */
581
1.69M
      _mm_store_ps(g + 60, x9w6);
582
1.69M
      _mm_store_ps(g + 64, x9w7);
583
1.69M
      _mm_store_ps(g + 68, x9w8);
584
      /* w[x] * 2 */
585
1.69M
      x9w3 = _mm_add_ps(x9w3, x9w3);
586
1.69M
      x9w4 = _mm_add_ps(x9w4, x9w4);
587
1.69M
      x9w5 = _mm_add_ps(x9w5, x9w5);
588
      /* w[x] * 4 */
589
1.69M
      x9w6 = _mm_add_ps(x9w6, x9w6);
590
1.69M
      x9w6 = _mm_add_ps(x9w6, x9w6);
591
1.69M
      x9w7 = _mm_add_ps(x9w7, x9w7);
592
1.69M
      x9w7 = _mm_add_ps(x9w7, x9w7);
593
1.69M
      x9w8 = _mm_add_ps(x9w8, x9w8);
594
1.69M
      x9w8 = _mm_add_ps(x9w8, x9w8);
595
      /* row 4 */
596
1.69M
      __m128 c1_24 = _mm_set1_ps(1.0 / 24);
597
1.69M
      _mm_store_ps(g + 36, _mm_mul_ps(_mm_add_ps(_mm_add_ps(x9w0, x9w6), x9w3), c1_24));
598
1.69M
      _mm_store_ps(g + 40, _mm_mul_ps(_mm_add_ps(_mm_add_ps(x9w1, x9w7), x9w4), c1_24));
599
1.69M
      _mm_store_ps(g + 44, _mm_mul_ps(_mm_add_ps(_mm_add_ps(x9w2, x9w8), x9w5), c1_24));
600
      /* row 5 */
601
1.69M
      _mm_store_ps(g + 48, _mm_mul_ps(_mm_sub_ps(_mm_add_ps(x9w0, x9w6), x9w3), c1_24));
602
1.69M
      _mm_store_ps(g + 52, _mm_mul_ps(_mm_sub_ps(_mm_add_ps(x9w1, x9w7), x9w4), c1_24));
603
1.69M
      _mm_store_ps(g + 56, _mm_mul_ps(_mm_sub_ps(_mm_add_ps(x9w2, x9w8), x9w5), c1_24));
604
4.08M
      unroll_for(j, 6) {
605
4.08M
        const float* const gz = g + j * 12;
606
4.08M
        float* const gwtgzu = gwtgz + j * 24 * dimCx4;
607
4.08M
        __m128 g0 = _mm_load_ps(gz);
608
4.08M
        __m128 g1 = _mm_load_ps(gz + 4);
609
4.08M
        __m128 g2 = _mm_load_ps(gz + 8);
610
4.08M
        _mm_store_ps(gwtgzu, _mm_mul_ps(g0, c1_4));
611
4.08M
        _mm_store_ps(gwtgzu + 4 * dimCx4, _mm_mul_ps(_mm_add_ps(_mm_add_ps(g0, g2), g1), cn1_6));
612
4.08M
        _mm_store_ps(gwtgzu + 8 * dimCx4, _mm_mul_ps(_mm_sub_ps(_mm_add_ps(g0, g2), g1), cn1_6));
613
4.08M
        _mm_store_ps(gwtgzu + 20 * dimCx4, g2);
614
        /* g[1] * 2 */
615
4.08M
        g1 = _mm_add_ps(g1, g1);
616
        /* g[2] * 4 */
617
4.08M
        g2 = _mm_add_ps(g2, g2);
618
4.08M
        g2 = _mm_add_ps(g2, g2);
619
4.08M
        _mm_store_ps(gwtgzu + 12 * dimCx4, _mm_mul_ps(_mm_add_ps(_mm_add_ps(g0, g2), g1), c1_24));
620
4.08M
        _mm_store_ps(gwtgzu + 16 * dimCx4, _mm_mul_ps(_mm_sub_ps(_mm_add_ps(g0, g2), g1), c1_24));
621
4.08M
      } unroll_endfor
622
1.69M
      gwtgz += 4;
623
1.69M
    }
624
8.23k
  } 
parallel_endfor4.21k
625
100
}
626
627
static int _ccv_nnc_conv_forw_4x4_3x3_winograd_sse2(const ccv_nnc_tensor_view_t* const a, const ccv_nnc_tensor_t* const w, const ccv_nnc_tensor_t* const bias, const ccv_nnc_hint_t hint, ccv_nnc_tensor_view_t* const b, ccv_nnc_stream_context_t* const stream_context)
628
100
{
629
100
  const int a_nd = ccv_nnc_tensor_nd(a->info.dim);
630
100
  assert(a_nd == CCV_NNC_MAX_DIM + 1 || a_nd == CCV_NNC_MAX_DIM + 2);
631
100
  const int* adim = (a_nd == CCV_NNC_MAX_DIM + 1) ? a->info.dim : 
a->info.dim + 10
;
632
100
  const int b_nd = ccv_nnc_tensor_nd(b->info.dim);
633
100
  assert(b_nd == CCV_NNC_MAX_DIM + 1 || b_nd == CCV_NNC_MAX_DIM + 2);
634
100
  const int* bdim = (b_nd == CCV_NNC_MAX_DIM + 1) ? b->info.dim : 
b->info.dim + 10
;
635
100
  const int* ainc = CCV_IS_TENSOR_VIEW(a) ? 
(0
(a_nd == 0
CCV_NNC_MAX_DIM0
+ 1) ?
a->inc0
:
a->inc + 10
) : adim;
636
100
  const int* binc = CCV_IS_TENSOR_VIEW(b) ? 
(0
(b_nd == 0
CCV_NNC_MAX_DIM0
+ 1) ?
b->inc0
:
b->inc + 10
) : bdim;
637
100
  assert(hint.border.begin[0] <= 1);
638
100
  assert(hint.border.begin[1] <= 1);
639
100
  assert(w->info.dim[0] % 4 == 0);
640
100
  assert(w->info.dim[1] == 3);
641
100
  assert(w->info.dim[2] == 3);
642
100
  const int jump_dim = (bdim[0] + 3) / 4;
643
100
  const int dimCx4 = (adim[2] + 3) & -4;
644
  // allocating workspace memory for kernel reshaping and input reshaping.
645
100
  float* workmem = 0;
646
100
#if FOR_IS_PARALLEL
647
  // If we do parallel for, we need to allocate input reshaping for each block.
648
100
  workmem = ccv_nnc_stream_context_get_workspace(stream_context, sizeof(float) * (36 * dimCx4 * jump_dim + 36 * dimCx4 * w->info.dim[0]), CCV_TENSOR_CPU_MEMORY);
649
#else
650
  // Otherwise, just one block.
651
  workmem = ccv_nnc_stream_context_get_workspace(stream_context, sizeof(float) * (36 * dimCx4 + 36 * dimCx4 * w->info.dim[0]), CCV_TENSOR_CPU_MEMORY);
652
#endif
653
100
  if (!workmem)
654
0
    return CCV_NNC_EXEC_OOM;
655
  // Convert w to a 6x6 matrix, by computing G.w.T(G) // T for transpose.
656
100
  float* const gwtg = workmem;
657
100
  float* const btdb = workmem + 36 * dimCx4 * w->info.dim[0];
658
100
  memset(gwtg, 0, sizeof(float) * 36 * dimCx4 * w->info.dim[0]);
659
100
  _ccv_nnc_winograd_4x4_3x3_gwtg_sse2(w->data.f32, w->info.dim, gwtg);
660
  // kernel weight for one dim.
661
  // Workaround issues of dispatch_apply (cannot reference to on-stack array)
662
100
  const int tile_dim_s[CCV_NNC_MAX_DIM_ALLOC] = {
663
100
    w->info.dim[0], 6, 6, w->info.dim[3]
664
100
  };
665
100
  const int* const tile_dim = tile_dim_s;
666
100
  if (bias)
667
99
  {
668
99
    const float* const biasval = bias->data.f32;
669
    // This block will be cause in each for-loop, therefore, you can use it to generate some temporary variables.
670
3.00k
    
parallel_for1.60k
(i, jump_dim) {
671
3.00k
      const int y = i * 4; // i is unsigned.
672
3.00k
      int j, x, k, c;
673
3.00k
      int n[CCV_NNC_MAX_DIM];
674
3.00k
      int m[CCV_NNC_MAX_DIM];
675
3.00k
      int z[CCV_NNC_MAX_DIM];
676
3.00k
      
set_n_m_dim1.60k
(y, 0, tile_dim, adim);
677
3.00k
      z[0] = 
ccv_min1.60k
(y + 4, bdim[0]) - y;
678
3.00k
      const float* ap = a->data.f32 + 
ccv_max1.50k
(y - hint.border.begin[0], 0) * ainc[1] * ainc[2];
679
3.00k
      float* bp = b->data.f32 + y * binc[1] * binc[2];
680
45.1k
      for (x = 0; x < bdim[1]; 
x += 443.6k
)
681
54.6k
      {
682
54.6k
        set_n_m_dim(x, 1, tile_dim, adim);
683
54.6k
        z[1] = ccv_min(x + 4, bdim[1]) - x;
684
54.6k
#if FOR_IS_PARALLEL
685
54.6k
        float* g = btdb + i * 36 * dimCx4;
686
#else
687
        float* g = btdb;
688
#endif
689
        // zero g such that we can have zero-padding.
690
54.6k
        memset(g, 0, sizeof(float) * 36 * dimCx4);
691
54.6k
        int dx, dy;
692
54.6k
        const float* apz = ap + ccv_max(x - hint.border.begin[1], 0) * ainc[2];
693
54.6k
        float* gz = g + (n[0] * 6 + n[1]) * dimCx4;
694
325k
        unroll_for(dy, m[0], 6) {
695
1.92M
          unroll_for(dx, m[1], 6) {
696
1.92M
            float* const gzu = gz + (dy * 6 + dx) * dimCx4;
697
115M
            for (c = 0; c < adim[2]; 
c++113M
)
698
113M
              gzu[c] = apz[dx * ainc[2] + c];
699
1.92M
          } unroll_endfor
700
325k
          apz += ainc[1] * ainc[2];
701
325k
        } unroll_endfor
702
929k
        for (c = 0; c < adim[2]; 
c += 4874k
)
703
874k
        {
704
874k
          float d[36 * 4]  __attribute__ ((__aligned__(16)));
705
          /* BT.d */
706
5.10M
          unroll_for(j, 6) {
707
            /* row 1 */
708
5.10M
            const float* const gz = g + j * dimCx4;
709
5.10M
            float* dz = d + j * 4;
710
5.10M
            __m128 g0 = _mm_load_ps(gz);
711
5.10M
            __m128 g12 = _mm_load_ps(gz + 12 * dimCx4);
712
5.10M
            __m128 g18 = _mm_load_ps(gz + 18 * dimCx4);
713
5.10M
            __m128 g24 = _mm_load_ps(gz + 24 * dimCx4);
714
5.10M
            g0 = _mm_add_ps(g0, g0);
715
5.10M
            g0 = _mm_add_ps(g0, g0);
716
5.10M
            __m128 g12x2 = _mm_add_ps(g12, g12);
717
5.10M
            g12x2 = _mm_add_ps(g12x2, g12x2);
718
5.10M
            g12x2 = _mm_add_ps(g12x2, g12);
719
5.10M
            _mm_store_ps(dz, _mm_sub_ps(_mm_add_ps(g0, g24), g12x2));
720
            /* row 2 */
721
5.10M
            __m128 g6 = _mm_load_ps(gz + 6 * dimCx4);
722
5.10M
            __m128 g6x12 = _mm_add_ps(g6, g12);
723
5.10M
            g6x12 = _mm_add_ps(g6x12, g6x12);
724
5.10M
            g6x12 = _mm_add_ps(g6x12, g6x12);
725
5.10M
            _mm_store_ps(dz + 24, _mm_sub_ps(_mm_add_ps(g18, g24), g6x12));
726
            /* row 3 */
727
5.10M
            g6x12 = _mm_sub_ps(g6, g12);
728
5.10M
            g6x12 = _mm_add_ps(g6x12, g6x12);
729
5.10M
            g6x12 = _mm_add_ps(g6x12, g6x12);
730
5.10M
            _mm_store_ps(dz + 48, _mm_add_ps(_mm_sub_ps(g24, g18), g6x12));
731
            /* row 4 */
732
5.10M
            __m128 g18x6 = _mm_sub_ps(g18, g6);
733
5.10M
            g18x6 = _mm_add_ps(g18x6, g18x6);
734
5.10M
            _mm_store_ps(dz + 72, _mm_add_ps(_mm_sub_ps(g24, g12), g18x6));
735
            /* row 5 */
736
5.10M
            _mm_store_ps(dz + 96, _mm_sub_ps(_mm_sub_ps(g24, g12), g18x6));
737
            /* row 6 */
738
5.10M
            __m128 g30 = _mm_load_ps(gz + 30 * dimCx4);
739
5.10M
            __m128 g18x2 = _mm_add_ps(g18, g18);
740
5.10M
            g18x2 = _mm_add_ps(g18x2, g18x2);
741
5.10M
            g18x2 = _mm_add_ps(g18, g18x2);
742
5.10M
            g6 = _mm_add_ps(g6, g6);
743
5.10M
            g6 = _mm_add_ps(g6, g6);
744
5.10M
            _mm_store_ps(dz + 120, _mm_sub_ps(_mm_add_ps(g6, g30), g18x2));
745
5.10M
          } unroll_endfor
746
          /* BT.d.B */
747
5.09M
          unroll_for(j, 6) {
748
5.09M
            float* gz = g + j * 6 * dimCx4;
749
5.09M
            const float* const dz = d + j * 24;
750
5.09M
            __m128 d0 = _mm_load_ps(dz);
751
5.09M
            __m128 d1 = _mm_load_ps(dz + 4);
752
5.09M
            __m128 d2 = _mm_load_ps(dz + 8);
753
5.09M
            __m128 d3 = _mm_load_ps(dz + 12);
754
5.09M
            __m128 d4 = _mm_load_ps(dz + 16);
755
5.09M
            __m128 d5 = _mm_load_ps(dz + 20);
756
5.09M
            d0 = _mm_add_ps(d0, d0);
757
5.09M
            d0 = _mm_add_ps(d0, d0);
758
5.09M
            __m128 d2x5 = _mm_add_ps(d2, d2);
759
5.09M
            d2x5 = _mm_add_ps(d2x5, d2x5);
760
5.09M
            d2x5 = _mm_add_ps(d2, d2x5);
761
5.09M
            _mm_store_ps(gz, _mm_sub_ps(_mm_add_ps(d0, d4), d2x5));
762
5.09M
            __m128 d1x2 = _mm_add_ps(d1, d2);
763
5.09M
            d1x2 = _mm_add_ps(d1x2, d1x2);
764
5.09M
            d1x2 = _mm_add_ps(d1x2, d1x2);
765
5.09M
            _mm_store_ps(gz + dimCx4, _mm_sub_ps(_mm_add_ps(d3, d4), d1x2));
766
5.09M
            d1x2 = _mm_sub_ps(d1, d2);
767
5.09M
            d1x2 = _mm_add_ps(d1x2, d1x2);
768
5.09M
            d1x2 = _mm_add_ps(d1x2, d1x2);
769
5.09M
            _mm_store_ps(gz + 2 * dimCx4, _mm_add_ps(_mm_sub_ps(d4, d3), d1x2));
770
5.09M
            __m128 d3x1 = _mm_sub_ps(d3, d1);
771
5.09M
            d3x1 = _mm_add_ps(d3x1, d3x1);
772
5.09M
            _mm_store_ps(gz + 3 * dimCx4, _mm_add_ps(_mm_sub_ps(d4, d2), d3x1));
773
5.09M
            _mm_store_ps(gz + 4 * dimCx4, _mm_sub_ps(_mm_sub_ps(d4, d2), d3x1));
774
5.09M
            d1 = _mm_add_ps(d1, d1);
775
5.09M
            d1 = _mm_add_ps(d1, d1);
776
5.09M
            __m128 d3x5 = _mm_add_ps(d3, d3);
777
5.09M
            d3x5 = _mm_add_ps(d3x5, d3x5);
778
5.09M
            d3x5 = _mm_add_ps(d3, d3x5);
779
5.09M
            _mm_store_ps(gz + 5 * dimCx4, _mm_sub_ps(_mm_add_ps(d1, d5), d3x5));
780
5.09M
          } unroll_endfor
781
          // move to the next channel
782
874k
          g += 4;
783
874k
        }
784
54.6k
        const float* wpz = gwtg;
785
1.41M
        for (k = 0; k < w->info.dim[0]; 
k += 41.35M
)
786
1.36M
        {
787
1.36M
          float q[36 * 4] __attribute__ ((__aligned__(16)));
788
1.36M
#if FOR_IS_PARALLEL
789
1.36M
          g = btdb + i * 36 * dimCx4;
790
#else
791
          g = btdb;
792
#endif
793
34.0M
          for (j = 0; j < 36; 
j++32.6M
)
794
32.6M
          {
795
32.6M
            __m128 v40 = _mm_setzero_ps();
796
32.6M
            __m128 v41 = _mm_setzero_ps();
797
32.6M
            __m128 v42 = _mm_setzero_ps();
798
32.6M
            __m128 v43 = _mm_setzero_ps();
799
407M
            for (c = 0; c < adim[2]; 
c += 4375M
)
800
375M
            {
801
375M
              __m128 g4 = _mm_load_ps(g);
802
375M
              __m128 w40 = _mm_load_ps(wpz);
803
375M
              __m128 w41 = _mm_load_ps(wpz + 4);
804
375M
              __m128 w42 = _mm_load_ps(wpz + 8);
805
375M
              __m128 w43 = _mm_load_ps(wpz + 12);
806
375M
              __m128 g40 = _mm_shuffle_ps(g4, g4, 0x00);
807
375M
              __m128 g41 = _mm_shuffle_ps(g4, g4, 0x55);
808
375M
              __m128 g42 = _mm_shuffle_ps(g4, g4, 0xAA);
809
375M
              __m128 g43 = _mm_shuffle_ps(g4, g4, 0xFF);
810
375M
              v40 = _mm_add_ps(_mm_mul_ps(w40, g40), v40);
811
375M
              v41 = _mm_add_ps(_mm_mul_ps(w41, g41), v41);
812
375M
              v42 = _mm_add_ps(_mm_mul_ps(w42, g42), v42);
813
375M
              v43 = _mm_add_ps(_mm_mul_ps(w43, g43), v43);
814
375M
              g += 4;
815
375M
              wpz += 16;
816
375M
            }
817
32.6M
            v40 = _mm_add_ps(v40, v41);
818
32.6M
            v42 = _mm_add_ps(v42, v43);
819
32.6M
            _mm_store_ps(q + j * 4, _mm_add_ps(v40, v42));
820
32.6M
          }
821
1.36M
          float d[24 * 4] __attribute__ ((__aligned__(16)));
822
7.75M
          unroll_for(j, 6) {
823
7.75M
            const float* const qz = q + j * 4;
824
7.75M
            float* const dz = d + j * 4;
825
7.75M
            __m128 q0 = _mm_load_ps(qz);
826
7.75M
            __m128 q6 = _mm_load_ps(qz + 24);
827
7.75M
            __m128 q12 = _mm_load_ps(qz + 48);
828
7.75M
            __m128 q18 = _mm_load_ps(qz + 72);
829
7.75M
            __m128 q24 = _mm_load_ps(qz + 96);
830
7.75M
            __m128 qs6x12 = _mm_add_ps(q6, q12);
831
7.75M
            __m128 qs18x24 = _mm_add_ps(q18, q24);
832
7.75M
            __m128 qss = _mm_add_ps(qs6x12, q0);
833
            /* row 1 */
834
7.75M
            _mm_store_ps(dz, _mm_add_ps(qss, qs18x24));
835
7.75M
            __m128 qn6x12 = _mm_sub_ps(q6, q12);
836
7.75M
            __m128 qn18x24 = _mm_sub_ps(q18, q24);
837
7.75M
            qn18x24 = _mm_add_ps(qn18x24, qn18x24);
838
            /* row 2 */
839
7.75M
            _mm_store_ps(dz + 24, _mm_add_ps(qn6x12, qn18x24));
840
7.75M
            qs18x24 = _mm_add_ps(qs18x24, qs18x24);
841
7.75M
            qs18x24 = _mm_add_ps(qs18x24, qs18x24);
842
            /* row 3 */
843
7.75M
            _mm_store_ps(dz + 48, _mm_add_ps(qs6x12, qs18x24));
844
7.75M
            qn18x24 = _mm_add_ps(qn18x24, qn18x24);
845
7.75M
            qn18x24 = _mm_add_ps(qn18x24, qn18x24);
846
7.75M
            __m128 q30 = _mm_load_ps(qz + 120);
847
            /* row 4 */
848
7.75M
            _mm_store_ps(dz + 72, _mm_add_ps(_mm_add_ps(qn6x12, q30), qn18x24));
849
7.75M
          } unroll_endfor
850
1.36M
          float* bpz = bp + x * binc[2] + k;
851
1.36M
          __m128 bias4 = _mm_loadu_ps(biasval + k);
852
1.36M
          switch (z[1]) {
853
9.25k
            case 1:
854
30.0k
              unroll_for(dy, z[0], 4) {
855
30.0k
                const float* const dz = d + dy * 24;
856
30.0k
                __m128 d0 = _mm_load_ps(dz);
857
30.0k
                __m128 d1 = _mm_load_ps(dz + 4);
858
30.0k
                __m128 d2 = _mm_load_ps(dz + 8);
859
30.0k
                __m128 d3 = _mm_load_ps(dz + 12);
860
30.0k
                __m128 d4 = _mm_load_ps(dz + 16);
861
30.0k
                __m128 ds1x2 = _mm_add_ps(d1, d2);
862
30.0k
                __m128 ds3x4 = _mm_add_ps(d3, d4);
863
30.0k
                ds1x2 = _mm_add_ps(ds1x2, bias4);
864
30.0k
                _mm_stream_ps(bpz, _mm_add_ps(ds1x2, _mm_add_ps(d0, ds3x4)));
865
30.0k
                bpz += binc[1] * binc[2];
866
30.0k
              } unroll_endfor
867
9.25k
              break;
868
0
            case 2:
869
0
              unroll_for(dy, z[0], 4) {
870
0
                const float* const dz = d + dy * 24;
871
0
                __m128 d0 = _mm_load_ps(dz);
872
0
                __m128 d1 = _mm_load_ps(dz + 4);
873
0
                __m128 d2 = _mm_load_ps(dz + 8);
874
0
                __m128 d3 = _mm_load_ps(dz + 12);
875
0
                __m128 d4 = _mm_load_ps(dz + 16);
876
0
                __m128 ds1x2 = _mm_add_ps(d1, d2);
877
0
                __m128 ds3x4 = _mm_add_ps(d3, d4);
878
0
                ds1x2 = _mm_add_ps(ds1x2, bias4);
879
0
                _mm_stream_ps(bpz, _mm_add_ps(ds1x2, _mm_add_ps(d0, ds3x4)));
880
0
                __m128 dn1x2 = _mm_sub_ps(d1, d2);
881
0
                __m128 dn3x4 = _mm_sub_ps(d3, d4);
882
0
                dn3x4 = _mm_add_ps(dn3x4, dn3x4);
883
0
                dn1x2 = _mm_add_ps(dn1x2, bias4);
884
0
                _mm_stream_ps(bpz + binc[2], _mm_add_ps(dn1x2, dn3x4));
885
0
                bpz += binc[1] * binc[2];
886
0
              } unroll_endfor
887
0
              break;
888
54.1k
            case 3:
889
211k
              unroll_for(dy, z[0], 4) {
890
211k
                const float* const dz = d + dy * 24;
891
211k
                __m128 d0 = _mm_load_ps(dz);
892
211k
                __m128 d1 = _mm_load_ps(dz + 4);
893
211k
                __m128 d2 = _mm_load_ps(dz + 8);
894
211k
                __m128 d3 = _mm_load_ps(dz + 12);
895
211k
                __m128 d4 = _mm_load_ps(dz + 16);
896
211k
                __m128 ds1x2 = _mm_add_ps(d1, d2);
897
211k
                __m128 ds3x4 = _mm_add_ps(d3, d4);
898
211k
                ds1x2 = _mm_add_ps(ds1x2, bias4);
899
211k
                _mm_stream_ps(bpz, _mm_add_ps(ds1x2, _mm_add_ps(d0, ds3x4)));
900
211k
                __m128 dn1x2 = _mm_sub_ps(d1, d2);
901
211k
                __m128 dn3x4 = _mm_sub_ps(d3, d4);
902
211k
                dn3x4 = _mm_add_ps(dn3x4, dn3x4);
903
211k
                dn1x2 = _mm_add_ps(dn1x2, bias4);
904
211k
                _mm_stream_ps(bpz + binc[2], _mm_add_ps(dn1x2, dn3x4));
905
211k
                ds3x4 = _mm_add_ps(ds3x4, ds3x4);
906
211k
                ds3x4 = _mm_add_ps(ds3x4, ds3x4);
907
211k
                _mm_stream_ps(bpz + 2 * binc[2], _mm_add_ps(ds1x2, ds3x4));
908
211k
                bpz += binc[1] * binc[2];
909
211k
              } unroll_endfor
910
54.1k
              break;
911
1.30M
            case 4:
912
4.79M
              unroll_for(dy, z[0], 4) {
913
4.79M
                const float* const dz = d + dy * 24;
914
4.79M
                __m128 d0 = _mm_load_ps(dz);
915
4.79M
                __m128 d1 = _mm_load_ps(dz + 4);
916
4.79M
                __m128 d2 = _mm_load_ps(dz + 8);
917
4.79M
                __m128 d3 = _mm_load_ps(dz + 12);
918
4.79M
                __m128 d4 = _mm_load_ps(dz + 16);
919
4.79M
                __m128 ds1x2 = _mm_add_ps(d1, d2);
920
4.79M
                __m128 ds3x4 = _mm_add_ps(d3, d4);
921
4.79M
                ds1x2 = _mm_add_ps(ds1x2, bias4);
922
4.79M
                _mm_stream_ps(bpz, _mm_add_ps(ds1x2, _mm_add_ps(d0, ds3x4)));
923
4.79M
                __m128 dn1x2 = _mm_sub_ps(d1, d2);
924
4.79M
                __m128 dn3x4 = _mm_sub_ps(d3, d4);
925
4.79M
                dn3x4 = _mm_add_ps(dn3x4, dn3x4);
926
4.79M
                dn1x2 = _mm_add_ps(dn1x2, bias4);
927
4.79M
                _mm_stream_ps(bpz + binc[2], _mm_add_ps(dn1x2, dn3x4));
928
4.79M
                ds3x4 = _mm_add_ps(ds3x4, ds3x4);
929
4.79M
                ds3x4 = _mm_add_ps(ds3x4, ds3x4);
930
4.79M
                _mm_stream_ps(bpz + 2 * binc[2], _mm_add_ps(ds1x2, ds3x4));
931
4.79M
                __m128 d5 = _mm_load_ps(dz + 20);
932
4.79M
                dn3x4 = _mm_add_ps(dn3x4, dn3x4);
933
4.79M
                dn3x4 = _mm_add_ps(dn3x4, dn3x4);
934
4.79M
                _mm_stream_ps(bpz + 3 * binc[2], _mm_add_ps(_mm_add_ps(dn1x2, d5), dn3x4));
935
4.79M
                bpz += binc[1] * binc[2];
936
4.79M
              } unroll_endfor
937
1.30M
              break;
938
1.36M
          };
939
1.35M
        }
940
54.6k
      }
941
3.00k
    } 
parallel_endfor1.60k
942
99
  } else {
943
    // This block will be cause in each for-loop, therefore, you can use it to generate some temporary variables.
944
24
    
parallel_for13
(i, jump_dim) {
945
24
      const int y = i * 4; // i is unsigned.
946
24
      int j, x, k, c;
947
24
      int n[CCV_NNC_MAX_DIM];
948
24
      int m[CCV_NNC_MAX_DIM];
949
24
      int z[CCV_NNC_MAX_DIM];
950
24
      
set_n_m_dim13
(y, 0, tile_dim, adim);
951
24
      z[0] = 
ccv_min13
(y + 4, bdim[0]) - y;
952
24
      const float* ap = a->data.f32 + 
ccv_max12
(y - hint.border.begin[0], 0) * ainc[1] * ainc[2];
953
24
      float* bp = b->data.f32 + y * binc[1] * binc[2];
954
204
      for (x = 0; x < bdim[1]; 
x += 4192
)
955
190
      {
956
190
        set_n_m_dim(x, 1, tile_dim, adim);
957
190
        z[1] = ccv_min(x + 4, bdim[1]) - x;
958
190
#if FOR_IS_PARALLEL
959
190
        float* g = btdb + i * 36 * dimCx4;
960
#else
961
        float* g = btdb;
962
#endif
963
        // zero g such that we can have zero-padding.
964
190
        memset(g, 0, sizeof(float) * 36 * dimCx4);
965
190
        int dx, dy;
966
190
        const float* apz = ap + ccv_max(x - hint.border.begin[1], 0) * ainc[2];
967
190
        float* gz = g + (n[0] * 6 + n[1]) * dimCx4;
968
1.13k
        unroll_for(dy, m[0], 6) {
969
6.66k
          unroll_for(dx, m[1], 6) {
970
6.66k
            float* const gzu = gz + (dy * 6 + dx) * dimCx4;
971
824k
            for (c = 0; c < adim[2]; 
c++818k
)
972
818k
              gzu[c] = apz[dx * ainc[2] + c];
973
6.66k
          } unroll_endfor
974
1.13k
          apz += ainc[1] * ainc[2];
975
1.13k
        } unroll_endfor
976
6.42k
        for (c = 0; c < adim[2]; 
c += 46.23k
)
977
6.23k
        {
978
6.23k
          float d[36 * 4]  __attribute__ ((__aligned__(16)));
979
          /* BT.d */
980
36.6k
          unroll_for(j, 6) {
981
            /* row 1 */
982
36.6k
            const float* const gz = g + j * dimCx4;
983
36.6k
            float* dz = d + j * 4;
984
36.6k
            __m128 g0 = _mm_load_ps(gz);
985
36.6k
            __m128 g12 = _mm_load_ps(gz + 12 * dimCx4);
986
36.6k
            __m128 g18 = _mm_load_ps(gz + 18 * dimCx4);
987
36.6k
            __m128 g24 = _mm_load_ps(gz + 24 * dimCx4);
988
36.6k
            g0 = _mm_add_ps(g0, g0);
989
36.6k
            g0 = _mm_add_ps(g0, g0);
990
36.6k
            __m128 g12x2 = _mm_add_ps(g12, g12);
991
36.6k
            g12x2 = _mm_add_ps(g12x2, g12x2);
992
36.6k
            g12x2 = _mm_add_ps(g12x2, g12);
993
36.6k
            _mm_store_ps(dz, _mm_sub_ps(_mm_add_ps(g0, g24), g12x2));
994
            /* row 2 */
995
36.6k
            __m128 g6 = _mm_load_ps(gz + 6 * dimCx4);
996
36.6k
            __m128 g6x12 = _mm_add_ps(g6, g12);
997
36.6k
            g6x12 = _mm_add_ps(g6x12, g6x12);
998
36.6k
            g6x12 = _mm_add_ps(g6x12, g6x12);
999
36.6k
            _mm_store_ps(dz + 24, _mm_sub_ps(_mm_add_ps(g18, g24), g6x12));
1000
            /* row 3 */
1001
36.6k
            g6x12 = _mm_sub_ps(g6, g12);
1002
36.6k
            g6x12 = _mm_add_ps(g6x12, g6x12);
1003
36.6k
            g6x12 = _mm_add_ps(g6x12, g6x12);
1004
36.6k
            _mm_store_ps(dz + 48, _mm_add_ps(_mm_sub_ps(g24, g18), g6x12));
1005
            /* row 4 */
1006
36.6k
            __m128 g18x6 = _mm_sub_ps(g18, g6);
1007
36.6k
            g18x6 = _mm_add_ps(g18x6, g18x6);
1008
36.6k
            _mm_store_ps(dz + 72, _mm_add_ps(_mm_sub_ps(g24, g12), g18x6));
1009
            /* row 5 */
1010
36.6k
            _mm_store_ps(dz + 96, _mm_sub_ps(_mm_sub_ps(g24, g12), g18x6));
1011
            /* row 6 */
1012
36.6k
            __m128 g30 = _mm_load_ps(gz + 30 * dimCx4);
1013
36.6k
            __m128 g18x2 = _mm_add_ps(g18, g18);
1014
36.6k
            g18x2 = _mm_add_ps(g18x2, g18x2);
1015
36.6k
            g18x2 = _mm_add_ps(g18, g18x2);
1016
36.6k
            g6 = _mm_add_ps(g6, g6);
1017
36.6k
            g6 = _mm_add_ps(g6, g6);
1018
36.6k
            _mm_store_ps(dz + 120, _mm_sub_ps(_mm_add_ps(g6, g30), g18x2));
1019
36.6k
          } unroll_endfor
1020
          /* BT.d.B */
1021
36.6k
          unroll_for(j, 6) {
1022
36.6k
            float* gz = g + j * 6 * dimCx4;
1023
36.6k
            const float* const dz = d + j * 24;
1024
36.6k
            __m128 d0 = _mm_load_ps(dz);
1025
36.6k
            __m128 d1 = _mm_load_ps(dz + 4);
1026
36.6k
            __m128 d2 = _mm_load_ps(dz + 8);
1027
36.6k
            __m128 d3 = _mm_load_ps(dz + 12);
1028
36.6k
            __m128 d4 = _mm_load_ps(dz + 16);
1029
36.6k
            __m128 d5 = _mm_load_ps(dz + 20);
1030
36.6k
            d0 = _mm_add_ps(d0, d0);
1031
36.6k
            d0 = _mm_add_ps(d0, d0);
1032
36.6k
            __m128 d2x5 = _mm_add_ps(d2, d2);
1033
36.6k
            d2x5 = _mm_add_ps(d2x5, d2x5);
1034
36.6k
            d2x5 = _mm_add_ps(d2, d2x5);
1035
36.6k
            _mm_store_ps(gz, _mm_sub_ps(_mm_add_ps(d0, d4), d2x5));
1036
36.6k
            __m128 d1x2 = _mm_add_ps(d1, d2);
1037
36.6k
            d1x2 = _mm_add_ps(d1x2, d1x2);
1038
36.6k
            d1x2 = _mm_add_ps(d1x2, d1x2);
1039
36.6k
            _mm_store_ps(gz + dimCx4, _mm_sub_ps(_mm_add_ps(d3, d4), d1x2));
1040
36.6k
            d1x2 = _mm_sub_ps(d1, d2);
1041
36.6k
            d1x2 = _mm_add_ps(d1x2, d1x2);
1042
36.6k
            d1x2 = _mm_add_ps(d1x2, d1x2);
1043
36.6k
            _mm_store_ps(gz + 2 * dimCx4, _mm_add_ps(_mm_sub_ps(d4, d3), d1x2));
1044
36.6k
            __m128 d3x1 = _mm_sub_ps(d3, d1);
1045
36.6k
            d3x1 = _mm_add_ps(d3x1, d3x1);
1046
36.6k
            _mm_store_ps(gz + 3 * dimCx4, _mm_add_ps(_mm_sub_ps(d4, d2), d3x1));
1047
36.6k
            _mm_store_ps(gz + 4 * dimCx4, _mm_sub_ps(_mm_sub_ps(d4, d2), d3x1));
1048
36.6k
            d1 = _mm_add_ps(d1, d1);
1049
36.6k
            d1 = _mm_add_ps(d1, d1);
1050
36.6k
            __m128 d3x5 = _mm_add_ps(d3, d3);
1051
36.6k
            d3x5 = _mm_add_ps(d3x5, d3x5);
1052
36.6k
            d3x5 = _mm_add_ps(d3, d3x5);
1053
36.6k
            _mm_store_ps(gz + 5 * dimCx4, _mm_sub_ps(_mm_add_ps(d1, d5), d3x5));
1054
36.6k
          } unroll_endfor
1055
          // move to the next channel
1056
6.23k
          g += 4;
1057
6.23k
        }
1058
190
        const float* wpz = gwtg;
1059
6.45k
        for (k = 0; k < w->info.dim[0]; 
k += 46.26k
)
1060
6.26k
        {
1061
6.26k
          float q[36 * 4] __attribute__ ((__aligned__(16)));
1062
6.26k
#if FOR_IS_PARALLEL
1063
6.26k
          g = btdb + i * 36 * dimCx4;
1064
#else
1065
          g = btdb;
1066
#endif
1067
212k
          for (j = 0; j < 36; 
j++206k
)
1068
206k
          {
1069
206k
            __m128 v40 = _mm_setzero_ps();
1070
206k
            __m128 v41 = _mm_setzero_ps();
1071
206k
            __m128 v42 = _mm_setzero_ps();
1072
206k
            __m128 v43 = _mm_setzero_ps();
1073
1.89M
            for (c = 0; c < adim[2]; 
c += 41.68M
)
1074
1.68M
            {
1075
1.68M
              __m128 g4 = _mm_load_ps(g);
1076
1.68M
              __m128 w40 = _mm_load_ps(wpz);
1077
1.68M
              __m128 w41 = _mm_load_ps(wpz + 4);
1078
1.68M
              __m128 w42 = _mm_load_ps(wpz + 8);
1079
1.68M
              __m128 w43 = _mm_load_ps(wpz + 12);
1080
1.68M
              __m128 g40 = _mm_shuffle_ps(g4, g4, 0x00);
1081
1.68M
              __m128 g41 = _mm_shuffle_ps(g4, g4, 0x55);
1082
1.68M
              __m128 g42 = _mm_shuffle_ps(g4, g4, 0xAA);
1083
1.68M
              __m128 g43 = _mm_shuffle_ps(g4, g4, 0xFF);
1084
1.68M
              v40 = _mm_add_ps(_mm_mul_ps(w40, g40), v40);
1085
1.68M
              v41 = _mm_add_ps(_mm_mul_ps(w41, g41), v41);
1086
1.68M
              v42 = _mm_add_ps(_mm_mul_ps(w42, g42), v42);
1087
1.68M
              v43 = _mm_add_ps(_mm_mul_ps(w43, g43), v43);
1088
1.68M
              g += 4;
1089
1.68M
              wpz += 16;
1090
1.68M
            }
1091
206k
            v40 = _mm_add_ps(v40, v41);
1092
206k
            v42 = _mm_add_ps(v42, v43);
1093
206k
            _mm_store_ps(q + j * 4, _mm_add_ps(v40, v42));
1094
206k
          }
1095
6.26k
          float d[24 * 4] __attribute__ ((__aligned__(16)));
1096
37.3k
          unroll_for(j, 6) {
1097
37.3k
            const float* const qz = q + j * 4;
1098
37.3k
            float* const dz = d + j * 4;
1099
37.3k
            __m128 q0 = _mm_load_ps(qz);
1100
37.3k
            __m128 q6 = _mm_load_ps(qz + 24);
1101
37.3k
            __m128 q12 = _mm_load_ps(qz + 48);
1102
37.3k
            __m128 q18 = _mm_load_ps(qz + 72);
1103
37.3k
            __m128 q24 = _mm_load_ps(qz + 96);
1104
37.3k
            __m128 qs6x12 = _mm_add_ps(q6, q12);
1105
37.3k
            __m128 qs18x24 = _mm_add_ps(q18, q24);
1106
37.3k
            __m128 qss = _mm_add_ps(qs6x12, q0);
1107
            /* row 1 */
1108
37.3k
            _mm_store_ps(dz, _mm_add_ps(qss, qs18x24));
1109
37.3k
            __m128 qn6x12 = _mm_sub_ps(q6, q12);
1110
37.3k
            __m128 qn18x24 = _mm_sub_ps(q18, q24);
1111
37.3k
            qn18x24 = _mm_add_ps(qn18x24, qn18x24);
1112
            /* row 2 */
1113
37.3k
            _mm_store_ps(dz + 24, _mm_add_ps(qn6x12, qn18x24));
1114
37.3k
            qs18x24 = _mm_add_ps(qs18x24, qs18x24);
1115
37.3k
            qs18x24 = _mm_add_ps(qs18x24, qs18x24);
1116
            /* row 3 */
1117
37.3k
            _mm_store_ps(dz + 48, _mm_add_ps(qs6x12, qs18x24));
1118
37.3k
            qn18x24 = _mm_add_ps(qn18x24, qn18x24);
1119
37.3k
            qn18x24 = _mm_add_ps(qn18x24, qn18x24);
1120
37.3k
            __m128 q30 = _mm_load_ps(qz + 120);
1121
            /* row 4 */
1122
37.3k
            _mm_store_ps(dz + 72, _mm_add_ps(_mm_add_ps(qn6x12, q30), qn18x24));
1123
37.3k
          } unroll_endfor
1124
6.26k
          float* bpz = bp + x * binc[2] + k;
1125
6.26k
          switch (z[1]) {
1126
0
            case 1:
1127
0
              unroll_for(dy, z[0], 4) {
1128
0
                const float* const dz = d + dy * 24;
1129
0
                __m128 d0 = _mm_load_ps(dz);
1130
0
                __m128 d1 = _mm_load_ps(dz + 4);
1131
0
                __m128 d2 = _mm_load_ps(dz + 8);
1132
0
                __m128 d3 = _mm_load_ps(dz + 12);
1133
0
                __m128 d4 = _mm_load_ps(dz + 16);
1134
0
                __m128 ds1x2 = _mm_add_ps(d1, d2);
1135
0
                __m128 ds3x4 = _mm_add_ps(d3, d4);
1136
0
                _mm_stream_ps(bpz, _mm_add_ps(ds1x2, _mm_add_ps(d0, ds3x4)));
1137
0
                bpz += binc[1] * binc[2];
1138
0
              } unroll_endfor
1139
0
              break;
1140
0
            case 2:
1141
0
              unroll_for(dy, z[0], 4) {
1142
0
                const float* const dz = d + dy * 24;
1143
0
                __m128 d0 = _mm_load_ps(dz);
1144
0
                __m128 d1 = _mm_load_ps(dz + 4);
1145
0
                __m128 d2 = _mm_load_ps(dz + 8);
1146
0
                __m128 d3 = _mm_load_ps(dz + 12);
1147
0
                __m128 d4 = _mm_load_ps(dz + 16);
1148
0
                __m128 ds1x2 = _mm_add_ps(d1, d2);
1149
0
                __m128 ds3x4 = _mm_add_ps(d3, d4);
1150
0
                _mm_stream_ps(bpz, _mm_add_ps(ds1x2, _mm_add_ps(d0, ds3x4)));
1151
0
                __m128 dn1x2 = _mm_sub_ps(d1, d2);
1152
0
                __m128 dn3x4 = _mm_sub_ps(d3, d4);
1153
0
                dn3x4 = _mm_add_ps(dn3x4, dn3x4);
1154
0
                _mm_stream_ps(bpz + binc[2], _mm_add_ps(dn1x2, dn3x4));
1155
0
                bpz += binc[1] * binc[2];
1156
0
              } unroll_endfor
1157
0
              break;
1158
0
            case 3:
1159
0
              unroll_for(dy, z[0], 4) {
1160
0
                const float* const dz = d + dy * 24;
1161
0
                __m128 d0 = _mm_load_ps(dz);
1162
0
                __m128 d1 = _mm_load_ps(dz + 4);
1163
0
                __m128 d2 = _mm_load_ps(dz + 8);
1164
0
                __m128 d3 = _mm_load_ps(dz + 12);
1165
0
                __m128 d4 = _mm_load_ps(dz + 16);
1166
0
                __m128 ds1x2 = _mm_add_ps(d1, d2);
1167
0
                __m128 ds3x4 = _mm_add_ps(d3, d4);
1168
0
                _mm_stream_ps(bpz, _mm_add_ps(ds1x2, _mm_add_ps(d0, ds3x4)));
1169
0
                __m128 dn1x2 = _mm_sub_ps(d1, d2);
1170
0
                __m128 dn3x4 = _mm_sub_ps(d3, d4);
1171
0
                dn3x4 = _mm_add_ps(dn3x4, dn3x4);
1172
0
                _mm_stream_ps(bpz + binc[2], _mm_add_ps(dn1x2, dn3x4));
1173
0
                ds3x4 = _mm_add_ps(ds3x4, ds3x4);
1174
0
                ds3x4 = _mm_add_ps(ds3x4, ds3x4);
1175
0
                _mm_stream_ps(bpz + 2 * binc[2], _mm_add_ps(ds1x2, ds3x4));
1176
0
                bpz += binc[1] * binc[2];
1177
0
              } unroll_endfor
1178
0
              break;
1179
6.27k
            case 4:
1180
25.0k
              unroll_for(dy, z[0], 4) {
1181
25.0k
                const float* const dz = d + dy * 24;
1182
25.0k
                __m128 d0 = _mm_load_ps(dz);
1183
25.0k
                __m128 d1 = _mm_load_ps(dz + 4);
1184
25.0k
                __m128 d2 = _mm_load_ps(dz + 8);
1185
25.0k
                __m128 d3 = _mm_load_ps(dz + 12);
1186
25.0k
                __m128 d4 = _mm_load_ps(dz + 16);
1187
25.0k
                __m128 ds1x2 = _mm_add_ps(d1, d2);
1188
25.0k
                __m128 ds3x4 = _mm_add_ps(d3, d4);
1189
25.0k
                _mm_stream_ps(bpz, _mm_add_ps(ds1x2, _mm_add_ps(d0, ds3x4)));
1190
25.0k
                __m128 dn1x2 = _mm_sub_ps(d1, d2);
1191
25.0k
                __m128 dn3x4 = _mm_sub_ps(d3, d4);
1192
25.0k
                dn3x4 = _mm_add_ps(dn3x4, dn3x4);
1193
25.0k
                _mm_stream_ps(bpz + binc[2], _mm_add_ps(dn1x2, dn3x4));
1194
25.0k
                ds3x4 = _mm_add_ps(ds3x4, ds3x4);
1195
25.0k
                ds3x4 = _mm_add_ps(ds3x4, ds3x4);
1196
25.0k
                _mm_stream_ps(bpz + 2 * binc[2], _mm_add_ps(ds1x2, ds3x4));
1197
25.0k
                __m128 d5 = _mm_load_ps(dz + 20);
1198
25.0k
                dn3x4 = _mm_add_ps(dn3x4, dn3x4);
1199
25.0k
                dn3x4 = _mm_add_ps(dn3x4, dn3x4);
1200
25.0k
                _mm_stream_ps(bpz + 3 * binc[2], _mm_add_ps(_mm_add_ps(dn1x2, d5), dn3x4));
1201
25.0k
                bpz += binc[1] * binc[2];
1202
25.0k
              } unroll_endfor
1203
6.27k
              break;
1204
6.26k
          };
1205
6.26k
        }
1206
190
      }
1207
24
    } 
parallel_endfor13
1208
1
  }
1209
100
  return CCV_NNC_EXEC_SUCCESS;
1210
100
}
1211
#endif
1212
1213
#ifdef HAVE_NEON
1214
inline static void _ccv_nnc_winograd_4x4_3x3_gwtg_neon(const float* const w, const int* const dim, float* const gwtg)
1215
{
1216
  const int jump_dim = dim[0] / 4;
1217
  const int dimCx4 = (dim[3] + 3) & -4;
1218
  parallel_for(k, jump_dim) {
1219
    int i, j;
1220
    float* gwtgz = gwtg + k * 4 * 36 * dimCx4;
1221
    const float* wz[] = {
1222
      w + (k * 4) * 9 * dim[3],
1223
      w + (k * 4 + 1) * 9 * dim[3],
1224
      w + (k * 4 + 2) * 9 * dim[3],
1225
      w + (k * 4 + 3) * 9 * dim[3],
1226
    };
1227
    for (i = 0; i < dim[3]; i++)
1228
    {
1229
      float x9w[9 * 4] __attribute__ ((__aligned__(16)));
1230
      unroll_for(j, 9) {
1231
        x9w[j * 4] = wz[0][j * dim[3] + i];
1232
        x9w[j * 4 + 1] = wz[1][j * dim[3] + i];
1233
        x9w[j * 4 + 2] = wz[2][j * dim[3] + i];
1234
        x9w[j * 4 + 3] = wz[3][j * dim[3] + i];
1235
      } unroll_endfor
1236
      float g[18 * 4] __attribute__ ((__aligned__(16)));
1237
      float32x4_t x9w0 = vld1q_f32(x9w);
1238
      float32x4_t x9w1 = vld1q_f32(x9w + 4);
1239
      float32x4_t x9w2 = vld1q_f32(x9w + 8);
1240
      float32x4_t x9w3 = vld1q_f32(x9w + 12);
1241
      float32x4_t x9w4 = vld1q_f32(x9w + 16);
1242
      float32x4_t x9w5 = vld1q_f32(x9w + 20);
1243
      float32x4_t x9w6 = vld1q_f32(x9w + 24);
1244
      float32x4_t x9w7 = vld1q_f32(x9w + 28);
1245
      float32x4_t x9w8 = vld1q_f32(x9w + 32);
1246
      /* row 1 */
1247
      float32x4_t c1_4 = vdupq_n_f32(1.0 / 4);
1248
      vst1q_f32(g, vmulq_f32(x9w0, c1_4));
1249
      vst1q_f32(g + 4, vmulq_f32(x9w1, c1_4));
1250
      vst1q_f32(g + 8, vmulq_f32(x9w2, c1_4));
1251
      /* row 2 */
1252
      float32x4_t cn1_6 = vdupq_n_f32(-1.0 / 6);
1253
      vst1q_f32(g + 12, vmulq_f32(vaddq_f32(vaddq_f32(x9w0, x9w6), x9w3), cn1_6));
1254
      vst1q_f32(g + 16, vmulq_f32(vaddq_f32(vaddq_f32(x9w1, x9w7), x9w4), cn1_6));
1255
      vst1q_f32(g + 20, vmulq_f32(vaddq_f32(vaddq_f32(x9w2, x9w8), x9w5), cn1_6));
1256
      /* row 3 */
1257
      vst1q_f32(g + 24, vmulq_f32(vsubq_f32(vaddq_f32(x9w0, x9w6), x9w3), cn1_6));
1258
      vst1q_f32(g + 28, vmulq_f32(vsubq_f32(vaddq_f32(x9w1, x9w7), x9w4), cn1_6));
1259
      vst1q_f32(g + 32, vmulq_f32(vsubq_f32(vaddq_f32(x9w2, x9w8), x9w5), cn1_6));
1260
      /* row 6 */
1261
      vst1q_f32(g + 60, x9w6);
1262
      vst1q_f32(g + 64, x9w7);
1263
      vst1q_f32(g + 68, x9w8);
1264
      /* w[x] * 2 */
1265
      x9w3 = vaddq_f32(x9w3, x9w3);
1266
      x9w4 = vaddq_f32(x9w4, x9w4);
1267
      x9w5 = vaddq_f32(x9w5, x9w5);
1268
      /* w[x] * 4 */
1269
      x9w6 = vaddq_f32(x9w6, x9w6);
1270
      x9w6 = vaddq_f32(x9w6, x9w6);
1271
      x9w7 = vaddq_f32(x9w7, x9w7);
1272
      x9w7 = vaddq_f32(x9w7, x9w7);
1273
      x9w8 = vaddq_f32(x9w8, x9w8);
1274
      x9w8 = vaddq_f32(x9w8, x9w8);
1275
      /* row 4 */
1276
      float32x4_t c1_24 = vdupq_n_f32(1.0 / 24);
1277
      vst1q_f32(g + 36, vmulq_f32(vaddq_f32(vaddq_f32(x9w0, x9w6), x9w3), c1_24));
1278
      vst1q_f32(g + 40, vmulq_f32(vaddq_f32(vaddq_f32(x9w1, x9w7), x9w4), c1_24));
1279
      vst1q_f32(g + 44, vmulq_f32(vaddq_f32(vaddq_f32(x9w2, x9w8), x9w5), c1_24));
1280
      /* row 5 */
1281
      vst1q_f32(g + 48, vmulq_f32(vsubq_f32(vaddq_f32(x9w0, x9w6), x9w3), c1_24));
1282
      vst1q_f32(g + 52, vmulq_f32(vsubq_f32(vaddq_f32(x9w1, x9w7), x9w4), c1_24));
1283
      vst1q_f32(g + 56, vmulq_f32(vsubq_f32(vaddq_f32(x9w2, x9w8), x9w5), c1_24));
1284
      unroll_for(j, 6) {
1285
        const float* const gz = g + j * 12;
1286
        float* const gwtgzu = gwtgz + j * 24 * dimCx4;
1287
        float32x4_t g0 = vld1q_f32(gz);
1288
        float32x4_t g1 = vld1q_f32(gz + 4);
1289
        float32x4_t g2 = vld1q_f32(gz + 8);
1290
        vst1q_f32(gwtgzu, vmulq_f32(g0, c1_4));
1291
        vst1q_f32(gwtgzu + 4 * dimCx4, vmulq_f32(vaddq_f32(vaddq_f32(g0, g2), g1), cn1_6));
1292
        vst1q_f32(gwtgzu + 8 * dimCx4, vmulq_f32(vsubq_f32(vaddq_f32(g0, g2), g1), cn1_6));
1293
        vst1q_f32(gwtgzu + 20 * dimCx4, g2);
1294
        /* g[1] * 2 */
1295
        g1 = vaddq_f32(g1, g1);
1296
        /* g[2] * 4 */
1297
        g2 = vaddq_f32(g2, g2);
1298
        g2 = vaddq_f32(g2, g2);
1299
        vst1q_f32(gwtgzu + 12 * dimCx4, vmulq_f32(vaddq_f32(vaddq_f32(g0, g2), g1), c1_24));
1300
        vst1q_f32(gwtgzu + 16 * dimCx4, vmulq_f32(vsubq_f32(vaddq_f32(g0, g2), g1), c1_24));
1301
      } unroll_endfor
1302
      gwtgz += 4;
1303
    }
1304
  } parallel_endfor
1305
}
1306
1307
static int _ccv_nnc_conv_forw_4x4_3x3_winograd_neon(const ccv_nnc_tensor_view_t* const a, const ccv_nnc_tensor_t* const w, const ccv_nnc_tensor_t* const bias, const ccv_nnc_hint_t hint, ccv_nnc_tensor_view_t* const b, ccv_nnc_stream_context_t* const stream_context)
1308
{
1309
  const int a_nd = ccv_nnc_tensor_nd(a->info.dim);
1310
  assert(a_nd == CCV_NNC_MAX_DIM + 1 || a_nd == CCV_NNC_MAX_DIM + 2);
1311
  const int* adim = (a_nd == CCV_NNC_MAX_DIM + 1) ? a->info.dim : a->info.dim + 1;
1312
  const int b_nd = ccv_nnc_tensor_nd(b->info.dim);
1313
  assert(b_nd == CCV_NNC_MAX_DIM + 1 || b_nd == CCV_NNC_MAX_DIM + 2);
1314
  const int* bdim = (b_nd == CCV_NNC_MAX_DIM + 1) ? b->info.dim : b->info.dim + 1;
1315
  const int* ainc = CCV_IS_TENSOR_VIEW(a) ? ((a_nd == CCV_NNC_MAX_DIM + 1) ? a->inc : a->inc + 1) : adim;
1316
  const int* binc = CCV_IS_TENSOR_VIEW(b) ? ((b_nd == CCV_NNC_MAX_DIM + 1) ? b->inc : b->inc + 1) : bdim;
1317
  assert(hint.border.begin[0] <= 1);
1318
  assert(hint.border.begin[1] <= 1);
1319
  assert(w->info.dim[0] % 4 == 0);
1320
  assert(w->info.dim[1] == 3);
1321
  assert(w->info.dim[2] == 3);
1322
  const int jump_dim = (bdim[0] + 3) / 4;
1323
  const int dimCx4 = (adim[2] + 3) & -4;
1324
  // allocating workspace memory for kernel reshaping and input reshaping.
1325
  float* workmem = 0;
1326
#if FOR_IS_PARALLEL
1327
  // If we do parallel for, we need to allocate input reshaping for each block.
1328
  workmem = (float*)ccv_nnc_stream_context_get_workspace(stream_context, sizeof(float) * (36 * dimCx4 * jump_dim + 36 * dimCx4 * w->info.dim[0]), CCV_TENSOR_CPU_MEMORY);
1329
#else
1330
  // Otherwise, just one block.
1331
  workmem = (float*)ccv_nnc_stream_context_get_workspace(stream_context, sizeof(float) * (36 * dimCx4 + 36 * dimCx4 * w->info.dim[0]), CCV_TENSOR_CPU_MEMORY);
1332
#endif
1333
  if (!workmem)
1334
    return CCV_NNC_EXEC_OOM;
1335
  // Convert w to a 6x6 matrix, by computing G.w.T(G) // T for transpose.
1336
  float* const gwtg = workmem;
1337
  float* const btdb = workmem + 36 * dimCx4 * w->info.dim[0];
1338
  memset(gwtg, 0, sizeof(float) * 36 * dimCx4 * w->info.dim[0]);
1339
  _ccv_nnc_winograd_4x4_3x3_gwtg_neon(w->data.f32, w->info.dim, gwtg);
1340
  // kernel weight for one dim.
1341
  // Workaround issues of dispatch_apply (cannot reference to on-stack array)
1342
  const int tile_dim_s[CCV_NNC_MAX_DIM_ALLOC] = {
1343
    w->info.dim[0], 6, 6, w->info.dim[3]
1344
  };
1345
  const int* const tile_dim = tile_dim_s;
1346
  if (bias)
1347
  {
1348
    const float* const biasval = bias->data.f32;
1349
    // This block will be cause in each for-loop, therefore, you can use it to generate some temporary variables.
1350
    parallel_for(i, jump_dim) {
1351
      const int y = i * 4; // i is unsigned.
1352
      int j, x, k, c;
1353
      int n[CCV_NNC_MAX_DIM];
1354
      int m[CCV_NNC_MAX_DIM];
1355
      int z[CCV_NNC_MAX_DIM];
1356
      set_n_m_dim(y, 0, tile_dim, adim);
1357
      z[0] = ccv_min(y + 4, bdim[0]) - y;
1358
      const float* ap = a->data.f32 + ccv_max(y - hint.border.begin[0], 0) * ainc[1] * ainc[2];
1359
      float* bp = b->data.f32 + y * binc[1] * binc[2];
1360
      for (x = 0; x < bdim[1]; x += 4)
1361
      {
1362
        set_n_m_dim(x, 1, tile_dim, adim);
1363
        z[1] = ccv_min(x + 4, bdim[1]) - x;
1364
#if FOR_IS_PARALLEL
1365
        float* g = btdb + i * 36 * dimCx4;
1366
#else
1367
        float* g = btdb;
1368
#endif
1369
        // zero g such that we can have zero-padding.
1370
        memset(g, 0, sizeof(float) * 36 * dimCx4);
1371
        int dx, dy;
1372
        const float* apz = ap + ccv_max(x - hint.border.begin[1], 0) * ainc[2];
1373
        float* gz = g + (n[0] * 6 + n[1]) * dimCx4;
1374
        unroll_for(dy, m[0], 6) {
1375
          unroll_for(dx, m[1], 6) {
1376
            float* const gzu = gz + (dy * 6 + dx) * dimCx4;
1377
            for (c = 0; c < adim[2]; c++)
1378
              gzu[c] = apz[dx * ainc[2] + c];
1379
          } unroll_endfor
1380
          apz += ainc[1] * ainc[2];
1381
        } unroll_endfor
1382
        for (c = 0; c < adim[2]; c += 4)
1383
        {
1384
          float d[36 * 4]  __attribute__ ((__aligned__(16)));
1385
          /* BT.d */
1386
          unroll_for(j, 6) {
1387
            /* row 1 */
1388
            const float* const gz = g + j * dimCx4;
1389
            float* dz = d + j * 4;
1390
            float32x4_t g0 = vld1q_f32(gz);
1391
            float32x4_t g12 = vld1q_f32(gz + 12 * dimCx4);
1392
            float32x4_t g18 = vld1q_f32(gz + 18 * dimCx4);
1393
            float32x4_t g24 = vld1q_f32(gz + 24 * dimCx4);
1394
            g0 = vaddq_f32(g0, g0);
1395
            g0 = vaddq_f32(g0, g0);
1396
            float32x4_t g12x2 = vaddq_f32(g12, g12);
1397
            g12x2 = vaddq_f32(g12x2, g12x2);
1398
            g12x2 = vaddq_f32(g12x2, g12);
1399
            vst1q_f32(dz, vsubq_f32(vaddq_f32(g0, g24), g12x2));
1400
            /* row 2 */
1401
            float32x4_t g6 = vld1q_f32(gz + 6 * dimCx4);
1402
            float32x4_t g6x12 = vaddq_f32(g6, g12);
1403
            g6x12 = vaddq_f32(g6x12, g6x12);
1404
            g6x12 = vaddq_f32(g6x12, g6x12);
1405
            vst1q_f32(dz + 24, vsubq_f32(vaddq_f32(g18, g24), g6x12));
1406
            /* row 3 */
1407
            g6x12 = vsubq_f32(g6, g12);
1408
            g6x12 = vaddq_f32(g6x12, g6x12);
1409
            g6x12 = vaddq_f32(g6x12, g6x12);
1410
            vst1q_f32(dz + 48, vaddq_f32(vsubq_f32(g24, g18), g6x12));
1411
            /* row 4 */
1412
            float32x4_t g18x6 = vsubq_f32(g18, g6);
1413
            g18x6 = vaddq_f32(g18x6, g18x6);
1414
            vst1q_f32(dz + 72, vaddq_f32(vsubq_f32(g24, g12), g18x6));
1415
            /* row 5 */
1416
            vst1q_f32(dz + 96, vsubq_f32(vsubq_f32(g24, g12), g18x6));
1417
            /* row 6 */
1418
            float32x4_t g30 = vld1q_f32(gz + 30 * dimCx4);
1419
            float32x4_t g18x2 = vaddq_f32(g18, g18);
1420
            g18x2 = vaddq_f32(g18x2, g18x2);
1421
            g18x2 = vaddq_f32(g18, g18x2);
1422
            g6 = vaddq_f32(g6, g6);
1423
            g6 = vaddq_f32(g6, g6);
1424
            vst1q_f32(dz + 120, vsubq_f32(vaddq_f32(g6, g30), g18x2));
1425
          } unroll_endfor
1426
          /* BT.d.B */
1427
          unroll_for(j, 6) {
1428
            float* gz = g + j * 6 * dimCx4;
1429
            const float* const dz = d + j * 24;
1430
            float32x4_t d0 = vld1q_f32(dz);
1431
            float32x4_t d1 = vld1q_f32(dz + 4);
1432
            float32x4_t d2 = vld1q_f32(dz + 8);
1433
            float32x4_t d3 = vld1q_f32(dz + 12);
1434
            float32x4_t d4 = vld1q_f32(dz + 16);
1435
            float32x4_t d5 = vld1q_f32(dz + 20);
1436
            d0 = vaddq_f32(d0, d0);
1437
            d0 = vaddq_f32(d0, d0);
1438
            float32x4_t d2x5 = vaddq_f32(d2, d2);
1439
            d2x5 = vaddq_f32(d2x5, d2x5);
1440
            d2x5 = vaddq_f32(d2, d2x5);
1441
            vst1q_f32(gz, vsubq_f32(vaddq_f32(d0, d4), d2x5));
1442
            float32x4_t d1x2 = vaddq_f32(d1, d2);
1443
            d1x2 = vaddq_f32(d1x2, d1x2);
1444
            d1x2 = vaddq_f32(d1x2, d1x2);
1445
            vst1q_f32(gz + dimCx4, vsubq_f32(vaddq_f32(d3, d4), d1x2));
1446
            d1x2 = vsubq_f32(d1, d2);
1447
            d1x2 = vaddq_f32(d1x2, d1x2);
1448
            d1x2 = vaddq_f32(d1x2, d1x2);
1449
            vst1q_f32(gz + 2 * dimCx4, vaddq_f32(vsubq_f32(d4, d3), d1x2));
1450
            float32x4_t d3x1 = vsubq_f32(d3, d1);
1451
            d3x1 = vaddq_f32(d3x1, d3x1);
1452
            vst1q_f32(gz + 3 * dimCx4, vaddq_f32(vsubq_f32(d4, d2), d3x1));
1453
            vst1q_f32(gz + 4 * dimCx4, vsubq_f32(vsubq_f32(d4, d2), d3x1));
1454
            d1 = vaddq_f32(d1, d1);
1455
            d1 = vaddq_f32(d1, d1);
1456
            float32x4_t d3x5 = vaddq_f32(d3, d3);
1457
            d3x5 = vaddq_f32(d3x5, d3x5);
1458
            d3x5 = vaddq_f32(d3, d3x5);
1459
            vst1q_f32(gz + 5 * dimCx4, vsubq_f32(vaddq_f32(d1, d5), d3x5));
1460
          } unroll_endfor
1461
          // move to the next channel
1462
          g += 4;
1463
        }
1464
        const float* wpz = gwtg;
1465
        for (k = 0; k < w->info.dim[0]; k += 4)
1466
        {
1467
          float q[36 * 4] __attribute__ ((__aligned__(16)));
1468
#if FOR_IS_PARALLEL
1469
          g = btdb + i * 36 * dimCx4;
1470
#else
1471
          g = btdb;
1472
#endif
1473
          for (j = 0; j < 36; j++)
1474
          {
1475
            float32x4_t v40 = vmovq_n_f32(0);
1476
            float32x4_t v41 = vmovq_n_f32(0);
1477
            float32x4_t v42 = vmovq_n_f32(0);
1478
            float32x4_t v43 = vmovq_n_f32(0);
1479
            for (c = 0; c < adim[2]; c += 4)
1480
            {
1481
              float32x2x2_t g4 = vld2_f32(g);
1482
              float32x4_t w40 = vld1q_f32(wpz);
1483
              float32x4_t w41 = vld1q_f32(wpz + 4);
1484
              float32x4_t w42 = vld1q_f32(wpz + 8);
1485
              float32x4_t w43 = vld1q_f32(wpz + 12);
1486
              float32x4_t g40 = vdupq_lane_f32(g4.val[0], 0);
1487
              float32x4_t g41 = vdupq_lane_f32(g4.val[1], 0);
1488
              float32x4_t g42 = vdupq_lane_f32(g4.val[0], 1);
1489
              float32x4_t g43 = vdupq_lane_f32(g4.val[1], 1);
1490
              v40 = vmlaq_f32(v40, w40, g40);
1491
              v41 = vmlaq_f32(v41, w41, g41);
1492
              v42 = vmlaq_f32(v42, w42, g42);
1493
              v43 = vmlaq_f32(v43, w43, g43);
1494
              g += 4;
1495
              wpz += 16;
1496
            }
1497
            v40 = vaddq_f32(v40, v41);
1498
            v42 = vaddq_f32(v42, v43);
1499
            vst1q_f32(q + j * 4, vaddq_f32(v40, v42));
1500
          }
1501
          float d[24 * 4] __attribute__ ((__aligned__(16)));
1502
          unroll_for(j, 6) {
1503
            const float* const qz = q + j * 4;
1504
            float* const dz = d + j * 4;
1505
            float32x4_t q0 = vld1q_f32(qz);
1506
            float32x4_t q6 = vld1q_f32(qz + 24);
1507
            float32x4_t q12 = vld1q_f32(qz + 48);
1508
            float32x4_t q18 = vld1q_f32(qz + 72);
1509
            float32x4_t q24 = vld1q_f32(qz + 96);
1510
            float32x4_t qs6x12 = vaddq_f32(q6, q12);
1511
            float32x4_t qs18x24 = vaddq_f32(q18, q24);
1512
            float32x4_t qss = vaddq_f32(qs6x12, q0);
1513
            /* row 1 */
1514
            vst1q_f32(dz, vaddq_f32(qss, qs18x24));
1515
            float32x4_t qn6x12 = vsubq_f32(q6, q12);
1516
            float32x4_t qn18x24 = vsubq_f32(q18, q24);
1517
            qn18x24 = vaddq_f32(qn18x24, qn18x24);
1518
            /* row 2 */
1519
            vst1q_f32(dz + 24, vaddq_f32(qn6x12, qn18x24));
1520
            qs18x24 = vaddq_f32(qs18x24, qs18x24);
1521
            qs18x24 = vaddq_f32(qs18x24, qs18x24);
1522
            /* row 3 */
1523
            vst1q_f32(dz + 48, vaddq_f32(qs6x12, qs18x24));
1524
            qn18x24 = vaddq_f32(qn18x24, qn18x24);
1525
            qn18x24 = vaddq_f32(qn18x24, qn18x24);
1526
            float32x4_t q30 = vld1q_f32(qz + 120);
1527
            /* row 4 */
1528
            vst1q_f32(dz + 72, vaddq_f32(vaddq_f32(qn6x12, q30), qn18x24));
1529
          } unroll_endfor
1530
          float* bpz = bp + x * binc[2] + k;
1531
          float32x4_t bias4 = vld1q_f32(biasval + k);
1532
          switch (z[1]) {
1533
            case 1:
1534
              unroll_for(dy, z[0], 4) {
1535
                const float* const dz = d + dy * 24;
1536
                float32x4_t d0 = vld1q_f32(dz);
1537
                float32x4_t d1 = vld1q_f32(dz + 4);
1538
                float32x4_t d2 = vld1q_f32(dz + 8);
1539
                float32x4_t d3 = vld1q_f32(dz + 12);
1540
                float32x4_t d4 = vld1q_f32(dz + 16);
1541
                float32x4_t ds1x2 = vaddq_f32(d1, d2);
1542
                float32x4_t ds3x4 = vaddq_f32(d3, d4);
1543
                ds1x2 = vaddq_f32(ds1x2, bias4);
1544
                vst1q_f32(bpz, vaddq_f32(ds1x2, vaddq_f32(d0, ds3x4)));
1545
                bpz += binc[1] * binc[2];
1546
              } unroll_endfor
1547
              break;
1548
            case 2:
1549
              unroll_for(dy, z[0], 4) {
1550
                const float* const dz = d + dy * 24;
1551
                float32x4_t d0 = vld1q_f32(dz);
1552
                float32x4_t d1 = vld1q_f32(dz + 4);
1553
                float32x4_t d2 = vld1q_f32(dz + 8);
1554
                float32x4_t d3 = vld1q_f32(dz + 12);
1555
                float32x4_t d4 = vld1q_f32(dz + 16);
1556
                float32x4_t ds1x2 = vaddq_f32(d1, d2);
1557
                float32x4_t ds3x4 = vaddq_f32(d3, d4);
1558
                ds1x2 = vaddq_f32(ds1x2, bias4);
1559
                vst1q_f32(bpz, vaddq_f32(ds1x2, vaddq_f32(d0, ds3x4)));
1560
                float32x4_t dn1x2 = vsubq_f32(d1, d2);
1561
                float32x4_t dn3x4 = vsubq_f32(d3, d4);
1562
                dn3x4 = vaddq_f32(dn3x4, dn3x4);
1563
                dn1x2 = vaddq_f32(dn1x2, bias4);
1564
                vst1q_f32(bpz + binc[2], vaddq_f32(dn1x2, dn3x4));
1565
                bpz += binc[1] * binc[2];
1566
              } unroll_endfor
1567
              break;
1568
            case 3:
1569
              unroll_for(dy, z[0], 4) {
1570
                const float* const dz = d + dy * 24;
1571
                float32x4_t d0 = vld1q_f32(dz);
1572
                float32x4_t d1 = vld1q_f32(dz + 4);
1573
                float32x4_t d2 = vld1q_f32(dz + 8);
1574
                float32x4_t d3 = vld1q_f32(dz + 12);
1575
                float32x4_t d4 = vld1q_f32(dz + 16);
1576
                float32x4_t ds1x2 = vaddq_f32(d1, d2);
1577
                float32x4_t ds3x4 = vaddq_f32(d3, d4);
1578
                ds1x2 = vaddq_f32(ds1x2, bias4);
1579
                vst1q_f32(bpz, vaddq_f32(ds1x2, vaddq_f32(d0, ds3x4)));
1580
                float32x4_t dn1x2 = vsubq_f32(d1, d2);
1581
                float32x4_t dn3x4 = vsubq_f32(d3, d4);
1582
                dn3x4 = vaddq_f32(dn3x4, dn3x4);
1583
                dn1x2 = vaddq_f32(dn1x2, bias4);
1584
                vst1q_f32(bpz + binc[2], vaddq_f32(dn1x2, dn3x4));
1585
                ds3x4 = vaddq_f32(ds3x4, ds3x4);
1586
                ds3x4 = vaddq_f32(ds3x4, ds3x4);
1587
                vst1q_f32(bpz + 2 * binc[2], vaddq_f32(ds1x2, ds3x4));
1588
                bpz += binc[1] * binc[2];
1589
              } unroll_endfor
1590
              break;
1591
            case 4:
1592
              unroll_for(dy, z[0], 4) {
1593
                const float* const dz = d + dy * 24;
1594
                float32x4_t d0 = vld1q_f32(dz);
1595
                float32x4_t d1 = vld1q_f32(dz + 4);
1596
                float32x4_t d2 = vld1q_f32(dz + 8);
1597
                float32x4_t d3 = vld1q_f32(dz + 12);
1598
                float32x4_t d4 = vld1q_f32(dz + 16);
1599
                float32x4_t ds1x2 = vaddq_f32(d1, d2);
1600
                float32x4_t ds3x4 = vaddq_f32(d3, d4);
1601
                ds1x2 = vaddq_f32(ds1x2, bias4);
1602
                vst1q_f32(bpz, vaddq_f32(ds1x2, vaddq_f32(d0, ds3x4)));
1603
                float32x4_t dn1x2 = vsubq_f32(d1, d2);
1604
                float32x4_t dn3x4 = vsubq_f32(d3, d4);
1605
                dn3x4 = vaddq_f32(dn3x4, dn3x4);
1606
                dn1x2 = vaddq_f32(dn1x2, bias4);
1607
                vst1q_f32(bpz + binc[2], vaddq_f32(dn1x2, dn3x4));
1608
                ds3x4 = vaddq_f32(ds3x4, ds3x4);
1609
                ds3x4 = vaddq_f32(ds3x4, ds3x4);
1610
                vst1q_f32(bpz + 2 * binc[2], vaddq_f32(ds1x2, ds3x4));
1611
                float32x4_t d5 = vld1q_f32(dz + 20);
1612
                dn3x4 = vaddq_f32(dn3x4, dn3x4);
1613
                dn3x4 = vaddq_f32(dn3x4, dn3x4);
1614
                vst1q_f32(bpz + 3 * binc[2], vaddq_f32(vaddq_f32(dn1x2, d5), dn3x4));
1615
                bpz += binc[1] * binc[2];
1616
              } unroll_endfor
1617
              break;
1618
          };
1619
        }
1620
      }
1621
    } parallel_endfor
1622
  } else {
1623
    // This block will be cause in each for-loop, therefore, you can use it to generate some temporary variables.
1624
    parallel_for(i, jump_dim) {
1625
      const int y = i * 4; // i is unsigned.
1626
      int j, x, k, c;
1627
      int n[CCV_NNC_MAX_DIM];
1628
      int m[CCV_NNC_MAX_DIM];
1629
      int z[CCV_NNC_MAX_DIM];
1630
      set_n_m_dim(y, 0, tile_dim, adim);
1631
      z[0] = ccv_min(y + 4, bdim[0]) - y;
1632
      const float* ap = a->data.f32 + ccv_max(y - hint.border.begin[0], 0) * ainc[1] * ainc[2];
1633
      float* bp = b->data.f32 + y * binc[1] * binc[2];
1634
      for (x = 0; x < bdim[1]; x += 4)
1635
      {
1636
        set_n_m_dim(x, 1, tile_dim, adim);
1637
        z[1] = ccv_min(x + 4, bdim[1]) - x;
1638
#if FOR_IS_PARALLEL
1639
        float* g = btdb + i * 36 * dimCx4;
1640
#else
1641
        float* g = btdb;
1642
#endif
1643
        // zero g such that we can have zero-padding.
1644
        memset(g, 0, sizeof(float) * 36 * dimCx4);
1645
        int dx, dy;
1646
        const float* apz = ap + ccv_max(x - hint.border.begin[1], 0) * ainc[2];
1647
        float* gz = g + (n[0] * 6 + n[1]) * dimCx4;
1648
        unroll_for(dy, m[0], 6) {
1649
          unroll_for(dx, m[1], 6) {
1650
            float* const gzu = gz + (dy * 6 + dx) * dimCx4;
1651
            for (c = 0; c < adim[2]; c++)
1652
              gzu[c] = apz[dx * ainc[2] + c];
1653
          } unroll_endfor
1654
          apz += ainc[1] * ainc[2];
1655
        } unroll_endfor
1656
        for (c = 0; c < adim[2]; c += 4)
1657
        {
1658
          float d[36 * 4]  __attribute__ ((__aligned__(16)));
1659
          /* BT.d */
1660
          unroll_for(j, 6) {
1661
            /* row 1 */
1662
            const float* const gz = g + j * dimCx4;
1663
            float* dz = d + j * 4;
1664
            float32x4_t g0 = vld1q_f32(gz);
1665
            float32x4_t g12 = vld1q_f32(gz + 12 * dimCx4);
1666
            float32x4_t g18 = vld1q_f32(gz + 18 * dimCx4);
1667
            float32x4_t g24 = vld1q_f32(gz + 24 * dimCx4);
1668
            g0 = vaddq_f32(g0, g0);
1669
            g0 = vaddq_f32(g0, g0);
1670
            float32x4_t g12x2 = vaddq_f32(g12, g12);
1671
            g12x2 = vaddq_f32(g12x2, g12x2);
1672
            g12x2 = vaddq_f32(g12x2, g12);
1673
            vst1q_f32(dz, vsubq_f32(vaddq_f32(g0, g24), g12x2));
1674
            /* row 2 */
1675
            float32x4_t g6 = vld1q_f32(gz + 6 * dimCx4);
1676
            float32x4_t g6x12 = vaddq_f32(g6, g12);
1677
            g6x12 = vaddq_f32(g6x12, g6x12);
1678
            g6x12 = vaddq_f32(g6x12, g6x12);
1679
            vst1q_f32(dz + 24, vsubq_f32(vaddq_f32(g18, g24), g6x12));
1680
            /* row 3 */
1681
            g6x12 = vsubq_f32(g6, g12);
1682
            g6x12 = vaddq_f32(g6x12, g6x12);
1683
            g6x12 = vaddq_f32(g6x12, g6x12);
1684
            vst1q_f32(dz + 48, vaddq_f32(vsubq_f32(g24, g18), g6x12));
1685
            /* row 4 */
1686
            float32x4_t g18x6 = vsubq_f32(g18, g6);
1687
            g18x6 = vaddq_f32(g18x6, g18x6);
1688
            vst1q_f32(dz + 72, vaddq_f32(vsubq_f32(g24, g12), g18x6));
1689
            /* row 5 */
1690
            vst1q_f32(dz + 96, vsubq_f32(vsubq_f32(g24, g12), g18x6));
1691
            /* row 6 */
1692
            float32x4_t g30 = vld1q_f32(gz + 30 * dimCx4);
1693
            float32x4_t g18x2 = vaddq_f32(g18, g18);
1694
            g18x2 = vaddq_f32(g18x2, g18x2);
1695
            g18x2 = vaddq_f32(g18, g18x2);
1696
            g6 = vaddq_f32(g6, g6);
1697
            g6 = vaddq_f32(g6, g6);
1698
            vst1q_f32(dz + 120, vsubq_f32(vaddq_f32(g6, g30), g18x2));
1699
          } unroll_endfor
1700
          /* BT.d.B */
1701
          unroll_for(j, 6) {
1702
            float* gz = g + j * 6 * dimCx4;
1703
            const float* const dz = d + j * 24;
1704
            float32x4_t d0 = vld1q_f32(dz);
1705
            float32x4_t d1 = vld1q_f32(dz + 4);
1706
            float32x4_t d2 = vld1q_f32(dz + 8);
1707
            float32x4_t d3 = vld1q_f32(dz + 12);
1708
            float32x4_t d4 = vld1q_f32(dz + 16);
1709
            float32x4_t d5 = vld1q_f32(dz + 20);
1710
            d0 = vaddq_f32(d0, d0);
1711
            d0 = vaddq_f32(d0, d0);
1712
            float32x4_t d2x5 = vaddq_f32(d2, d2);
1713
            d2x5 = vaddq_f32(d2x5, d2x5);
1714
            d2x5 = vaddq_f32(d2, d2x5);
1715
            vst1q_f32(gz, vsubq_f32(vaddq_f32(d0, d4), d2x5));
1716
            float32x4_t d1x2 = vaddq_f32(d1, d2);
1717
            d1x2 = vaddq_f32(d1x2, d1x2);
1718
            d1x2 = vaddq_f32(d1x2, d1x2);
1719
            vst1q_f32(gz + dimCx4, vsubq_f32(vaddq_f32(d3, d4), d1x2));
1720
            d1x2 = vsubq_f32(d1, d2);
1721
            d1x2 = vaddq_f32(d1x2, d1x2);
1722
            d1x2 = vaddq_f32(d1x2, d1x2);
1723
            vst1q_f32(gz + 2 * dimCx4, vaddq_f32(vsubq_f32(d4, d3), d1x2));
1724
            float32x4_t d3x1 = vsubq_f32(d3, d1);
1725
            d3x1 = vaddq_f32(d3x1, d3x1);
1726
            vst1q_f32(gz + 3 * dimCx4, vaddq_f32(vsubq_f32(d4, d2), d3x1));
1727
            vst1q_f32(gz + 4 * dimCx4, vsubq_f32(vsubq_f32(d4, d2), d3x1));
1728
            d1 = vaddq_f32(d1, d1);
1729
            d1 = vaddq_f32(d1, d1);
1730
            float32x4_t d3x5 = vaddq_f32(d3, d3);
1731
            d3x5 = vaddq_f32(d3x5, d3x5);
1732
            d3x5 = vaddq_f32(d3, d3x5);
1733
            vst1q_f32(gz + 5 * dimCx4, vsubq_f32(vaddq_f32(d1, d5), d3x5));
1734
          } unroll_endfor
1735
          // move to the next channel
1736
          g += 4;
1737
        }
1738
        const float* wpz = gwtg;
1739
        for (k = 0; k < w->info.dim[0]; k += 4)
1740
        {
1741
          float q[36 * 4] __attribute__ ((__aligned__(16)));
1742
#if FOR_IS_PARALLEL
1743
          g = btdb + i * 36 * dimCx4;
1744
#else
1745
          g = btdb;
1746
#endif
1747
          for (j = 0; j < 36; j++)
1748
          {
1749
            float32x4_t v40 = vmovq_n_f32(0);
1750
            float32x4_t v41 = vmovq_n_f32(0);
1751
            float32x4_t v42 = vmovq_n_f32(0);
1752
            float32x4_t v43 = vmovq_n_f32(0);
1753
            for (c = 0; c < adim[2]; c += 4)
1754
            {
1755
              float32x2x2_t g4 = vld2_f32(g);
1756
              float32x4_t w40 = vld1q_f32(wpz);
1757
              float32x4_t w41 = vld1q_f32(wpz + 4);
1758
              float32x4_t w42 = vld1q_f32(wpz + 8);
1759
              float32x4_t w43 = vld1q_f32(wpz + 12);
1760
              float32x4_t g40 = vdupq_lane_f32(g4.val[0], 0);
1761
              float32x4_t g41 = vdupq_lane_f32(g4.val[1], 0);
1762
              float32x4_t g42 = vdupq_lane_f32(g4.val[0], 1);
1763
              float32x4_t g43 = vdupq_lane_f32(g4.val[1], 1);
1764
              v40 = vmlaq_f32(v40, w40, g40);
1765
              v41 = vmlaq_f32(v41, w41, g41);
1766
              v42 = vmlaq_f32(v42, w42, g42);
1767
              v43 = vmlaq_f32(v43, w43, g43);
1768
              g += 4;
1769
              wpz += 16;
1770
            }
1771
            v40 = vaddq_f32(v40, v41);
1772
            v42 = vaddq_f32(v42, v43);
1773
            vst1q_f32(q + j * 4, vaddq_f32(v40, v42));
1774
          }
1775
          float d[24 * 4] __attribute__ ((__aligned__(16)));
1776
          unroll_for(j, 6) {
1777
            const float* const qz = q + j * 4;
1778
            float* const dz = d + j * 4;
1779
            float32x4_t q0 = vld1q_f32(qz);
1780
            float32x4_t q6 = vld1q_f32(qz + 24);
1781
            float32x4_t q12 = vld1q_f32(qz + 48);
1782
            float32x4_t q18 = vld1q_f32(qz + 72);
1783
            float32x4_t q24 = vld1q_f32(qz + 96);
1784
            float32x4_t qs6x12 = vaddq_f32(q6, q12);
1785
            float32x4_t qs18x24 = vaddq_f32(q18, q24);
1786
            float32x4_t qss = vaddq_f32(qs6x12, q0);
1787
            /* row 1 */
1788
            vst1q_f32(dz, vaddq_f32(qss, qs18x24));
1789
            float32x4_t qn6x12 = vsubq_f32(q6, q12);
1790
            float32x4_t qn18x24 = vsubq_f32(q18, q24);
1791
            qn18x24 = vaddq_f32(qn18x24, qn18x24);
1792
            /* row 2 */
1793
            vst1q_f32(dz + 24, vaddq_f32(qn6x12, qn18x24));
1794
            qs18x24 = vaddq_f32(qs18x24, qs18x24);
1795
            qs18x24 = vaddq_f32(qs18x24, qs18x24);
1796
            /* row 3 */
1797
            vst1q_f32(dz + 48, vaddq_f32(qs6x12, qs18x24));
1798
            qn18x24 = vaddq_f32(qn18x24, qn18x24);
1799
            qn18x24 = vaddq_f32(qn18x24, qn18x24);
1800
            float32x4_t q30 = vld1q_f32(qz + 120);
1801
            /* row 4 */
1802
            vst1q_f32(dz + 72, vaddq_f32(vaddq_f32(qn6x12, q30), qn18x24));
1803
          } unroll_endfor
1804
          float* bpz = bp + x * binc[2] + k;
1805
          switch (z[1]) {
1806
            case 1:
1807
              unroll_for(dy, z[0], 4) {
1808
                const float* const dz = d + dy * 24;
1809
                float32x4_t d0 = vld1q_f32(dz);
1810
                float32x4_t d1 = vld1q_f32(dz + 4);
1811
                float32x4_t d2 = vld1q_f32(dz + 8);
1812
                float32x4_t d3 = vld1q_f32(dz + 12);
1813
                float32x4_t d4 = vld1q_f32(dz + 16);
1814
                float32x4_t ds1x2 = vaddq_f32(d1, d2);
1815
                float32x4_t ds3x4 = vaddq_f32(d3, d4);
1816
                vst1q_f32(bpz, vaddq_f32(ds1x2, vaddq_f32(d0, ds3x4)));
1817
                bpz += binc[1] * binc[2];
1818
              } unroll_endfor
1819
              break;
1820
            case 2:
1821
              unroll_for(dy, z[0], 4) {
1822
                const float* const dz = d + dy * 24;
1823
                float32x4_t d0 = vld1q_f32(dz);
1824
                float32x4_t d1 = vld1q_f32(dz + 4);
1825
                float32x4_t d2 = vld1q_f32(dz + 8);
1826
                float32x4_t d3 = vld1q_f32(dz + 12);
1827
                float32x4_t d4 = vld1q_f32(dz + 16);
1828
                float32x4_t ds1x2 = vaddq_f32(d1, d2);
1829
                float32x4_t ds3x4 = vaddq_f32(d3, d4);
1830
                vst1q_f32(bpz, vaddq_f32(ds1x2, vaddq_f32(d0, ds3x4)));
1831
                float32x4_t dn1x2 = vsubq_f32(d1, d2);
1832
                float32x4_t dn3x4 = vsubq_f32(d3, d4);
1833
                dn3x4 = vaddq_f32(dn3x4, dn3x4);
1834
                vst1q_f32(bpz + binc[2], vaddq_f32(dn1x2, dn3x4));
1835
                bpz += binc[1] * binc[2];
1836
              } unroll_endfor
1837
              break;
1838
            case 3:
1839
              unroll_for(dy, z[0], 4) {
1840
                const float* const dz = d + dy * 24;
1841
                float32x4_t d0 = vld1q_f32(dz);
1842
                float32x4_t d1 = vld1q_f32(dz + 4);
1843
                float32x4_t d2 = vld1q_f32(dz + 8);
1844
                float32x4_t d3 = vld1q_f32(dz + 12);
1845
                float32x4_t d4 = vld1q_f32(dz + 16);
1846
                float32x4_t ds1x2 = vaddq_f32(d1, d2);
1847
                float32x4_t ds3x4 = vaddq_f32(d3, d4);
1848
                vst1q_f32(bpz, vaddq_f32(ds1x2, vaddq_f32(d0, ds3x4)));
1849
                float32x4_t dn1x2 = vsubq_f32(d1, d2);
1850
                float32x4_t dn3x4 = vsubq_f32(d3, d4);
1851
                dn3x4 = vaddq_f32(dn3x4, dn3x4);
1852
                vst1q_f32(bpz + binc[2], vaddq_f32(dn1x2, dn3x4));
1853
                ds3x4 = vaddq_f32(ds3x4, ds3x4);
1854
                ds3x4 = vaddq_f32(ds3x4, ds3x4);
1855
                vst1q_f32(bpz + 2 * binc[2], vaddq_f32(ds1x2, ds3x4));
1856
                bpz += binc[1] * binc[2];
1857
              } unroll_endfor
1858
              break;
1859
            case 4:
1860
              unroll_for(dy, z[0], 4) {
1861
                const float* const dz = d + dy * 24;
1862
                float32x4_t d0 = vld1q_f32(dz);
1863
                float32x4_t d1 = vld1q_f32(dz + 4);
1864
                float32x4_t d2 = vld1q_f32(dz + 8);
1865
                float32x4_t d3 = vld1q_f32(dz + 12);
1866
                float32x4_t d4 = vld1q_f32(dz + 16);
1867
                float32x4_t ds1x2 = vaddq_f32(d1, d2);
1868
                float32x4_t ds3x4 = vaddq_f32(d3, d4);
1869
                vst1q_f32(bpz, vaddq_f32(ds1x2, vaddq_f32(d0, ds3x4)));
1870
                float32x4_t dn1x2 = vsubq_f32(d1, d2);
1871
                float32x4_t dn3x4 = vsubq_f32(d3, d4);
1872
                dn3x4 = vaddq_f32(dn3x4, dn3x4);
1873
                vst1q_f32(bpz + binc[2], vaddq_f32(dn1x2, dn3x4));
1874
                ds3x4 = vaddq_f32(ds3x4, ds3x4);
1875
                ds3x4 = vaddq_f32(ds3x4, ds3x4);
1876
                vst1q_f32(bpz + 2 * binc[2], vaddq_f32(ds1x2, ds3x4));
1877
                float32x4_t d5 = vld1q_f32(dz + 20);
1878
                dn3x4 = vaddq_f32(dn3x4, dn3x4);
1879
                dn3x4 = vaddq_f32(dn3x4, dn3x4);
1880
                vst1q_f32(bpz + 3 * binc[2], vaddq_f32(vaddq_f32(dn1x2, d5), dn3x4));
1881
                bpz += binc[1] * binc[2];
1882
              } unroll_endfor
1883
              break;
1884
          };
1885
        }
1886
      }
1887
    } parallel_endfor
1888
  }
1889
  return CCV_NNC_EXEC_SUCCESS;
1890
}
1891
#endif
1892
1893
int _ccv_nnc_conv_forw_4x4_3x3_winograd_cpu_opt(const ccv_nnc_tensor_view_t* const a, const ccv_nnc_tensor_t* const w, const ccv_nnc_tensor_t* const bias, const ccv_nnc_hint_t hint, ccv_nnc_tensor_view_t* const b, ccv_nnc_stream_context_t* const stream_context)
1894
100
{
1895
100
#if defined(HAVE_SSE2)
1896
100
  if (w->info.dim[0] % 4 == 0)
1897
100
    return _ccv_nnc_conv_forw_4x4_3x3_winograd_sse2(a, w, bias, hint, b, stream_context);
1898
#elif defined(HAVE_NEON)
1899
  if (w->info.dim[0] % 4 == 0)
1900
    return _ccv_nnc_conv_forw_4x4_3x3_winograd_neon(a, w, bias, hint, b, stream_context);
1901
#endif
1902
0
  return _ccv_nnc_conv_forw_4x4_3x3_winograd_ref(a, w, bias, hint, b, stream_context);
1903
100
}