Coverage Report

Created: 2021-04-12 03:25

/home/liu/buildslave/linux-x64-runtests/build/lib/nnc/cmd/convolution/cpu_opt/_ccv_nnc_conv_cpu_4x4_3x3_winograd.c
Line
Count
Source (jump to first uncovered line)
1
#include "ccv.h"
2
#include "ccv_internal.h"
3
#include "nnc/ccv_nnc.h"
4
#include "nnc/ccv_nnc_easy.h"
5
#include "nnc/ccv_nnc_internal.h"
6
#if defined(HAVE_SSE2)
7
#include <xmmintrin.h>
8
#elif defined(HAVE_NEON)
9
#include <arm_neon.h>
10
#endif
11
#ifdef USE_OPENMP
12
#include <omp.h>
13
#endif
14
#ifdef USE_DISPATCH
15
#include <dispatch/dispatch.h>
16
#endif
17
#include "../_ccv_nnc_conv_cpu_opt.h"
18
19
#define set_n_m_dim(i, x, wd, ad) \
20
59.3k
  do { \
21
59.3k
    n[x] = ccv_max((i) * hint.stride.dim[x] - hint.border.begin[x], 0) - ((i) * hint.stride.dim[x] - hint.border.begin[x]); \
22
59.3k
    m[x] = wd[x + 1] - n[x] - ((i) * hint.stride.dim[x] - hint.border.begin[x] + wd[x + 1] - ccv_min(ad[x], (i) * hint.stride.dim[x] - hint.border.begin[x] + wd[x + 1])); \
23
59.2k
  } while (0)
24
25
inline static void _ccv_nnc_winograd_4x4_3x3_gwtg_ref(const float* const w, const int c, float* gwtg)
26
0
{
27
0
  int i;
28
0
  for (i = 0; i < c; i++)
29
0
  {
30
0
    float g[18];
31
0
    /*
32
0
     * a0, b1, c2
33
0
     * d3, e4, f5
34
0
     * g6, h7, i8
35
0
     * {{a/4, b/4, c/4},
36
0
     * {1/6 (-a - d - g), 1/6 (-b - e - h), 1/6 (-c - f - i)},
37
0
     * {1/6 (-a + d - g), 1/6 (-b + e - h), 1/6 (-c + f - i)},
38
0
     * {1/24 (a + 2 d + 4 g), 1/24 (b + 2 e + 4 h), 1/24 (c + 2 f + 4 i)},
39
0
     * {1/24 (a - 2 d + 4 g), 1/24 (b - 2 e + 4 h), 1/24 (c - 2 f + 4 i)},
40
0
     * {g, h, i}}
41
0
     */
42
0
    /* row 1 */
43
0
    g[0] = w[i] / 4;
44
0
    g[1] = w[c + i] / 4;
45
0
    g[2] = w[2 * c + i] / 4;
46
0
    /* row 2 */
47
0
    g[3] = -(w[i] + w[3 * c + i] + w[6 * c + i]) / 6;
48
0
    g[4] = -(w[c + i] + w[4 * c + i] + w[7 * c + i]) / 6;
49
0
    g[5] = -(w[2 * c + i] + w[5 * c + i] + w[8 * c + i]) / 6;
50
0
    /* row 3 */
51
0
    g[6] = (-w[i] + w[3 * c + i] - w[6 * c + i]) / 6;
52
0
    g[7] = (-w[c + i] + w[4 * c + i] - w[7 * c + i]) / 6;
53
0
    g[8] = (-w[2 * c + i] + w[5 * c + i] - w[8 * c + i]) / 6;
54
0
    /* row 4 */
55
0
    g[9] = (w[i] + 2 * w[3 * c + i] + 4 * w[6 * c + i]) / 24;
56
0
    g[10] = (w[c + i] + 2 * w[4 * c + i] + 4 * w[7 * c + i]) / 24;
57
0
    g[11] = (w[2 * c + i] + 2 * w[5 * c + i] + 4 * w[8 * c + i]) / 24;
58
0
    /* row 5 */
59
0
    g[12] = (w[i] - 2 * w[3 * c + i] + 4 * w[6 * c + i]) / 24;
60
0
    g[13] = (w[c + i] - 2 * w[4 * c + i] + 4 * w[7 * c + i]) / 24;
61
0
    g[14] = (w[2 * c + i] - 2 * w[5 * c + i] + 4 * w[8 * c + i]) / 24;
62
0
    /* row 6 */
63
0
    g[15] = w[6 * c + i];
64
0
    g[16] = w[7 * c + i];
65
0
    g[17] = w[8 * c + i];
66
0
    /*
67
0
     * a0, b1, c2
68
0
     * d3, e4, f5
69
0
     * g6, h7, i8
70
0
     * j9, k10,l11
71
0
     * m12,n13,o14
72
0
     * p15,q16,r17
73
0
     * {{a/4, 1/6 (-a - b - c), 1/6 (-a + b - c), 1/24 (a + 2 b + 4 c), 1/24 (a - 2 b + 4 c), c},
74
0
     * {d/4, 1/6 (-d - e - f), 1/6 (-d + e - f), 1/24 (d + 2 e + 4 f), 1/24 (d - 2 e + 4 f), f},
75
0
     * {g/4, 1/6 (-g - h - i), 1/6 (-g + h - i), 1/24 (g + 2 h + 4 i), 1/24 (g - 2 h + 4 i), i},
76
0
     * {j/4, 1/6 (-j - k - l), 1/6 (-j + k - l), 1/24 (j + 2 k + 4 l), 1/24 (j - 2 k + 4 l), l},
77
0
     * {m/4, 1/6 (-m - n - o), 1/6 (-m + n - o), 1/24 (m + 2 n + 4 o), 1/24 (m - 2 n + 4 o), o},
78
0
     * {p/4, 1/6 (-p - q - r), 1/6 (-p + q - r), 1/24 (p + 2 q + 4 r), 1/24 (p - 2 q + 4 r), r}}
79
0
     */
80
0
    /* row 1 */
81
0
    gwtg[0] = g[0] / 4;
82
0
    gwtg[c] = -(g[0] + g[1] + g[2]) / 6;
83
0
    gwtg[2 * c] = (-g[0] + g[1] - g[2]) / 6;
84
0
    gwtg[3 * c] = (g[0] + 2 * g[1] + 4 * g[2]) / 24;
85
0
    gwtg[4 * c] = (g[0] - 2 * g[1] + 4 * g[2]) / 24;
86
0
    gwtg[5 * c] = g[2];
87
0
    /* row 2 */
88
0
    gwtg[6 * c] = g[3] / 4;
89
0
    gwtg[7 * c] = -(g[3] + g[4] + g[5]) / 6;
90
0
    gwtg[8 * c] = (-g[3] + g[4] - g[5]) / 6;
91
0
    gwtg[9 * c] = (g[3] + 2 * g[4] + 4 * g[5]) / 24;
92
0
    gwtg[10 * c] = (g[3] - 2 * g[4] + 4 * g[5]) / 24;
93
0
    gwtg[11 * c] = g[5];
94
0
    /* row 3 */
95
0
    gwtg[12 * c] = g[6] / 4;
96
0
    gwtg[13 * c] = -(g[6] + g[7] + g[8]) / 6;
97
0
    gwtg[14 * c] = (-g[6] + g[7] - g[8]) / 6;
98
0
    gwtg[15 * c] = (g[6] + 2 * g[7] + 4 * g[8]) / 24;
99
0
    gwtg[16 * c] = (g[6] - 2 * g[7] + 4 * g[8]) / 24;
100
0
    gwtg[17 * c] = g[8];
101
0
    /* row 4 */
102
0
    gwtg[18 * c] = g[9] / 4;
103
0
    gwtg[19 * c] = -(g[9] + g[10] + g[11]) / 6;
104
0
    gwtg[20 * c] = (-g[9] + g[10] - g[11]) / 6;
105
0
    gwtg[21 * c] = (g[9] + 2 * g[10] + 4 * g[11]) / 24;
106
0
    gwtg[22 * c] = (g[9] - 2 * g[10] + 4 * g[11]) / 24;
107
0
    gwtg[23 * c] = g[11];
108
0
    /* row 5 */
109
0
    gwtg[24 * c] = g[12] / 4;
110
0
    gwtg[25 * c] = -(g[12] + g[13] + g[14]) / 6;
111
0
    gwtg[26 * c] = (-g[12] + g[13] - g[14]) / 6;
112
0
    gwtg[27 * c] = (g[12] + 2 * g[13] + 4 * g[14]) / 24;
113
0
    gwtg[28 * c] = (g[12] - 2 * g[13] + 4 * g[14]) / 24;
114
0
    gwtg[29 * c] = g[14];
115
0
    /* row 6 */
116
0
    gwtg[30 * c] = g[15] / 4;
117
0
    gwtg[31 * c] = -(g[15] + g[16] + g[17]) / 6;
118
0
    gwtg[32 * c] = (-g[15] + g[16] - g[17]) / 6;
119
0
    gwtg[33 * c] = (g[15] + 2 * g[16] + 4 * g[17]) / 24;
120
0
    gwtg[34 * c] = (g[15] - 2 * g[16] + 4 * g[17]) / 24;
121
0
    gwtg[35 * c] = g[17];
122
0
    ++gwtg;
123
0
  }
124
0
}
125
126
static int _ccv_nnc_conv_forw_4x4_3x3_winograd_ref(const ccv_nnc_tensor_view_t* const a, const ccv_nnc_tensor_t* const w, const ccv_nnc_tensor_t* const bias, const ccv_nnc_hint_t hint, ccv_nnc_tensor_view_t* const b, ccv_nnc_stream_context_t* const stream_context)
127
0
{
128
0
  const int a_nd = ccv_nnc_tensor_nd(a->info.dim);
129
0
  assert(a_nd == CCV_NNC_MAX_DIM + 1 || a_nd == CCV_NNC_MAX_DIM + 2);
130
0
  const int* adim = (a_nd == CCV_NNC_MAX_DIM + 1) ? a->info.dim : a->info.dim + 1;
131
0
  const int b_nd = ccv_nnc_tensor_nd(b->info.dim);
132
0
  assert(b_nd == CCV_NNC_MAX_DIM + 1 || b_nd == CCV_NNC_MAX_DIM + 2);
133
0
  const int* bdim = (b_nd == CCV_NNC_MAX_DIM + 1) ? b->info.dim : b->info.dim + 1;
134
0
  const int* ainc = CCV_IS_TENSOR_VIEW(a) ? ((a_nd == CCV_NNC_MAX_DIM + 1) ? a->inc : a->inc + 1) : adim;
135
0
  const int* binc = CCV_IS_TENSOR_VIEW(b) ? ((b_nd == CCV_NNC_MAX_DIM + 1) ? b->inc : b->inc + 1) : bdim;
136
0
  assert(hint.border.begin[0] <= 1);
137
0
  assert(hint.border.begin[1] <= 1);
138
0
  assert(w->info.dim[1] == 3);
139
0
  assert(w->info.dim[2] == 3);
140
0
  const int jump_dim = (bdim[0] + 3) / 4;
141
0
  float* workmem;
142
0
  // allocating workspace memory for kernel reshaping and input reshaping.
143
0
#if FOR_IS_PARALLEL
144
0
  // If we do parallel for, we need to allocate input reshaping for each block.
145
0
  workmem = ccv_nnc_stream_context_get_workspace(stream_context, sizeof(float) * (36 * adim[2] * jump_dim + 36 * w->info.dim[0] * w->info.dim[3]), CCV_TENSOR_CPU_MEMORY);
146
#else
147
  // Otherwise, just one block.
148
  workmem = ccv_nnc_stream_context_get_workspace(stream_context, sizeof(float) * (36 * adim[2] + 36 * w->info.dim[0] * w->info.dim[3]), CCV_TENSOR_CPU_MEMORY);
149
#endif
150
0
  if (!workmem)
151
0
    return CCV_NNC_EXEC_OOM;
152
0
  // Convert w to a 6x6 matrix, by computing G.w.T(G) // T for transpose.
153
0
  float* const gwtg = workmem;
154
0
  float* const btdb = workmem + 36 * w->info.dim[0] * w->info.dim[3];
155
0
  parallel_for(k, w->info.dim[0]) {
156
0
    _ccv_nnc_winograd_4x4_3x3_gwtg_ref(w->data.f32 + k * w->info.dim[3] * w->info.dim[2] * w->info.dim[1], w->info.dim[3], gwtg + k * 36 * w->info.dim[3]);
157
0
  } parallel_endfor
158
0
  // kernel weight for one dim.
159
0
  // Workaround issues of dispatch_apply (cannot reference to on-stack array)
160
0
  const int tile_dim_s[CCV_NNC_MAX_DIM_ALLOC] = {
161
0
    w->info.dim[0], 6, 6, w->info.dim[3]
162
0
  };
163
0
  const int* const tile_dim = tile_dim_s;
164
0
  // This block will be cause in each for-loop, therefore, you can use it to generate some temporary variables.
165
0
  if (bias)
166
0
  {
167
0
    const float* const biasval = bias->data.f32;
168
0
    parallel_for(i, jump_dim) {
169
0
      const int y = i * 4; // i is unsigned.
170
0
      int j, x, k, c;
171
0
      int n[CCV_NNC_MAX_DIM];
172
0
      int m[CCV_NNC_MAX_DIM];
173
0
      int z[CCV_NNC_MAX_DIM];
174
0
      set_n_m_dim(y, 0, tile_dim, adim);
175
0
      z[0] = ccv_min(y + 4, bdim[0]) - y;
176
0
      const float* ap = a->data.f32 + ccv_max(y - hint.border.begin[0], 0) * ainc[1] * ainc[2];
177
0
      float* bp = b->data.f32 + y * binc[1] * binc[2];
178
0
      for (x = 0; x < bdim[1]; x += 4)
179
0
      {
180
0
        set_n_m_dim(x, 1, tile_dim, adim);
181
0
        z[1] = ccv_min(x + 4, bdim[1]) - x;
182
0
#if FOR_IS_PARALLEL
183
0
        float* g = btdb + i * 36 * adim[2];
184
#else
185
        float* g = btdb;
186
#endif
187
0
        // zero g such that we can have zero-padding.
188
0
        memset(g, 0, sizeof(float) * 36 * adim[2]);
189
0
        int dx, dy;
190
0
        const float* apz = ap + ccv_max(x - hint.border.begin[1], 0) * ainc[2];
191
0
        float* gz = g + (n[0] * 6 + n[1]) * adim[2];
192
0
        unroll_for(dy, m[0], 6) {
193
0
          unroll_for(dx, m[1], 6) {
194
0
            float* const gzu = gz + (dy * 6 + dx) * adim[2];
195
0
            for (c = 0; c < adim[2]; c++)
196
0
              gzu[c] = apz[dx * ainc[2] + c];
197
0
          } unroll_endfor
198
0
          apz += ainc[1] * ainc[2];
199
0
        } unroll_endfor
200
0
        for (c = 0; c < adim[2]; c++)
201
0
        {
202
0
          /*
203
0
           * a0, a1, a2, a3, a4, a5,
204
0
           * b6, b7, b8, b9, b10,l11,
205
0
           * c12,c13,c14,c15,c16,c17,
206
0
           * d18,d19,d20,d21,d22,d23,
207
0
           * e24,e25,e26,e27,e28,e29,
208
0
           * f30,f31,f32,f33,f34,f35
209
0
           * {{4 a0 - 5 c12 + e24, 4 a1 - 5 c13 + e25, 4 a2 - 5 c14 + e26, 4 a3 - 5 c15 + e27, 4 a4 - 5 c16 + e28, 4 a5 - 5 c17 + e29},
210
0
           * {-4 b6 - 4 c12 + d18 + e24, -4 b7 - 4 c13 + d19 + e25, -4 b8 - 4 c14 + d20 + e26, -4 b9 - 4 c15 + d21 + e27, -4 b10 - 4 c16 + d22 + e28, -4 b11 - 4 c17 + d23 + e29},
211
0
           * {4 b6 - 4 c12 - d18 + e24, 4 b7 - 4 c13 - d19 + e25, 4 b8 - 4 c14 - d20 + e26, 4 b9 - 4 c15 - d21 + e27, 4 b10 - 4 c16 - d22 + e28, 4 b11 - 4 c17 - d23 + e29},
212
0
           * {-2 b6 - c12 + 2 d18 + e24, -2 b7 - c13 + 2 d19 + e25, -2 b8 - c14 + 2 d20 + e26, -2 b9 - c15 + 2 d21 + e27, -2 b10 - c16 + 2 d22 + e28, -2 b11 - c17 + 2 d23 + e29},
213
0
           * {2 b6 - c12 - 2 d18 + e24, 2 b7 - c13 - 2 d19 + e25, 2 b8 - c14 - 2 d20 + e26, 2 b9 - c15 - 2 d21 + e27, 2 b10 - c16 - 2 d22 + e28, 2 b11 - c17 - 2 d23 + e29},
214
0
           * {4 b6 - 5 d18 + f30, 4 b7 - 5 d19 + f31, 4 b8 - 5 d20 + f32, 4 b9 - 5 d21 + f33, 4 b10 - 5 d22 + f34, 4 b11 - 5 d23 + f35}}
215
0
           */
216
0
          float d[36];
217
0
          /* BT.d */
218
0
          unroll_for(j, 6) {
219
0
            float g0 = g[j * adim[2]];
220
0
            float g12 = g[(12 + j) * adim[2]];
221
0
            float g24 = g[(24 + j) * adim[2]];
222
0
            /* row 1 */
223
0
            d[j] = 4 * g0 - 5 * g12 + g24;
224
0
            float g6 = g[(6 + j) * adim[2]];
225
0
            float g18 = g[(18 + j) * adim[2]];
226
0
            /* row 2 */
227
0
            d[6 + j] = -4 * (g6 + g12) + g18 + g24;
228
0
            /* row 3 */
229
0
            d[12 + j] = 4 * (g6 - g12) - g18 + g24;
230
0
            /* row 4 */
231
0
            d[18 + j] = 2 * (g18 - g6) - g12 + g24;
232
0
            /* row 5 */
233
0
            d[24 + j] = 2 * (g6 - g18) - g12 + g24;
234
0
            float g30 = g[(30 + j) * adim[2]];
235
0
            /* row 6 */
236
0
            d[30 + j] = 4 * g6 - 5 * g18 + g30;
237
0
          } unroll_endfor
238
0
          /*
239
0
           * a0, a1, a2, a3, a4, a5,
240
0
           * b6, b7, b8, b9, b10,l11,
241
0
           * c12,c13,c14,c15,c16,c17,
242
0
           * d18,d19,d20,d21,d22,d23,
243
0
           * e24,e25,e26,e27,e28,e29,
244
0
           * f30,f31,f32,f33,f34,f35
245
0
           * {{4 a0 - 5 a2 + a4, -4 a1 - 4 a2 + a3 + a4, 4 a1 - 4 a2 - a3 + a4, -2 a1 - a2 + 2 a3 + a4, 2 a1 - a2 - 2 a3 + a4, 4 a1 - 5 a3 + a5},
246
0
           * {b10 + 4 b6 - 5 b8, b10 - 4 b7 - 4 b8 + b9, b10 + 4 b7 - 4 b8 - b9, b10 - 2 b7 - b8 + 2 b9, b10 + 2 b7 - b8 - 2 b9, b11 + 4 b7 - 5 b9},
247
0
           * {4 c12 - 5 c14 + c16, -4 c13 - 4 c14 + c15 + c16, 4 c13 - 4 c14 - c15 + c16, -2 c13 - c14 + 2 c15 + c16, 2 c13 - c14 - 2 c15 + c16, 4 c13 - 5 c15 + c17},
248
0
           * {4 d18 - 5 d20 + d22, -4 d19 - 4 d20 + d21 + d22, 4 d19 - 4 d20 - d21 + d22, -2 d19 - d20 + 2 d21 + d22, 2 d19 - d20 - 2 d21 + d22, 4 d19 - 5 d21 + d23},
249
0
           * {4 e24 - 5 e26 + e28, -4 e25 - 4 e26 + e27 + e28, 4 e25 - 4 e26 - e27 + e28, -2 e25 - e26 + 2 e27 + e28, 2 e25 - e26 - 2 e27 + e28, 4 e25 - 5 e27 + e29},
250
0
           * {4 f30 - 5 f32 + f34, -4 f31 - 4 f32 + f33 + f34, 4 f31 - 4 f32 - f33 + f34, -2 f31 - f32 + 2 f33 + f34, 2 f31 - f32 - 2 f33 + f34, 4 f31 - 5 f33 + f35}}
251
0
           */
252
0
          /* BT.d.B */
253
0
          unroll_for(j, 6) {
254
0
            /* row 1 - 6 */
255
0
            float* const gz = g + j * 6 * adim[2];
256
0
            float* const dz = d + j * 6;
257
0
            gz[0] = 4 * dz[0] - 5 * dz[2] + dz[4];
258
0
            gz[adim[2]] = -4 * (dz[1] + dz[2]) + dz[3] + dz[4];
259
0
            gz[2 * adim[2]] = 4 * (dz[1] - dz[2]) - dz[3] + dz[4];
260
0
            gz[3 * adim[2]] = 2 * (dz[3] - dz[1]) - dz[2] + dz[4];
261
0
            gz[4 * adim[2]] = 2 * (dz[1] - dz[3]) - dz[2] + dz[4];
262
0
            gz[5 * adim[2]] = 4 * dz[1] - 5 * dz[3] + dz[5];
263
0
          } unroll_endfor
264
0
          // move to the next channel
265
0
          ++g;
266
0
        }
267
0
        const float* wpz = gwtg;
268
0
        for (k = 0; k < w->info.dim[0]; k++)
269
0
        {
270
0
          float q[36];
271
0
#if FOR_IS_PARALLEL
272
0
          g = btdb + i * 36 * adim[2];
273
#else
274
          g = btdb;
275
#endif
276
0
          for (j = 0; j < 36; j++)
277
0
          {
278
0
            float b = 0;
279
0
            for (c = 0; c < adim[2]; c++)
280
0
              b += g[c] * wpz[c];
281
0
            q[j] = b;
282
0
            g += adim[2];
283
0
            wpz += adim[2];
284
0
          }
285
0
          /*
286
0
           * a0, a1, a2, a3, a4, a5,
287
0
           * b6, b7, b8, b9, b10,l11,
288
0
           * c12,c13,c14,c15,c16,c17,
289
0
           * d18,d19,d20,d21,d22,d23,
290
0
           * e24,e25,e26,e27,e28,e29,
291
0
           * f30,f31,f32,f33,f34,f35
292
0
           * {{a0 + b6 + c12 + d18 + e24, a1 + b7 + c13 + d19 + e25, a2 + b8 + c14 + d20 + e26, a3 + b9 + c15 + d21 + e27, a4 + b10 + c16 + d22 + e28, a5 + b11 + c17 + d23 + e29},
293
0
           * {b6 - c12 + 2 d18 - 2 e24, b7 - c13 + 2 d19 - 2 e25, b8 - c14 + 2 d20 - 2 e26, b9 - c15 + 2 d21 - 2 e27, b10 - c16 + 2 d22 - 2 e28, b11 - c17 + 2 d23 - 2 e29},
294
0
           * {b6 + c12 + 4 (d18 + e24), b7 + c13 + 4 (d19 + e25), b8 + c14 + 4 (d20 + e26), b9 + c15 + 4 (d21 + e27), b10 + c16 + 4 (d22 + e28), b11 + c17 + 4 (d23 + e29)},
295
0
           * {b6 - c12 + 8 d18 - 8 e24 + f30, b7 - c13 + 8 d19 - 8 e25 + f31, b8 - c14 + 8 d20 - 8 e26 + f32, b9 - c15 + 8 d21 - 8 e27 + f33, b10 - c16 + 8 d22 - 8 e28 + f34, b11 - c17 + 8 d23 - 8 e29 + f35}}
296
0
           */
297
0
          float d[24];
298
0
          /* row 1 */
299
0
          d[0] = q[0] + q[6] + q[12] + q[18] + q[24];
300
0
          d[1] = q[1] + q[7] + q[13] + q[19] + q[25];
301
0
          d[2] = q[2] + q[8] + q[14] + q[20] + q[26];
302
0
          d[3] = q[3] + q[9] + q[15] + q[21] + q[27];
303
0
          d[4] = q[4] + q[10] + q[16] + q[22] + q[28];
304
0
          d[5] = q[5] + q[11] + q[17] + q[23] + q[29];
305
0
          /* row 2 */
306
0
          d[6] = q[6] - q[12] + 2 * (q[18] - q[24]);
307
0
          d[7] = q[7] - q[13] + 2 * (q[19] - q[25]);
308
0
          d[8] = q[8] - q[14] + 2 * (q[20] - q[26]);
309
0
          d[9] = q[9] - q[15] + 2 * (q[21] - q[27]);
310
0
          d[10] = q[10] - q[16] + 2 * (q[22] - q[28]);
311
0
          d[11] = q[11] - q[17] + 2 * (q[23] - q[29]);
312
0
          /* row 3 */
313
0
          d[12] = q[6] + q[12] + 4 * (q[18] + q[24]);
314
0
          d[13] = q[7] + q[13] + 4 * (q[19] + q[25]);
315
0
          d[14] = q[8] + q[14] + 4 * (q[20] + q[26]);
316
0
          d[15] = q[9] + q[15] + 4 * (q[21] + q[27]);
317
0
          d[16] = q[10] + q[16] + 4 * (q[22] + q[28]);
318
0
          d[17] = q[11] + q[17] + 4 * (q[23] + q[29]);
319
0
          /* row 4 */
320
0
          d[18] = q[6] - q[12] + 8 * (q[18] - q[24]) + q[30];
321
0
          d[19] = q[7] - q[13] + 8 * (q[19] - q[25]) + q[31];
322
0
          d[20] = q[8] - q[14] + 8 * (q[20] - q[26]) + q[32];
323
0
          d[21] = q[9] - q[15] + 8 * (q[21] - q[27]) + q[33];
324
0
          d[22] = q[10] - q[16] + 8 * (q[22] - q[28]) + q[34];
325
0
          d[23] = q[11] - q[17] + 8 * (q[23] - q[29]) + q[35];
326
0
          /*
327
0
           * {{a0 + a1 + a2 + a3 + a4, a1 - a2 + 2 a3 - 2 a4, a1 + a2 + 4 (a3 + a4), a1 - a2 + 8 a3 - 8 a4 + a5},
328
0
           * {b10 + b6 + b7 + b8 + b9, -2 b10 + b7 - b8 + 2 b9, 4 b10 + b7 + b8 + 4 b9, -8 b10 + b11 + b7 - b8 + 8 b9},
329
0
           * {c12 + c13 + c14 + c15 + c16, c13 - c14 + 2 c15 - 2 c16, c13 + c14 + 4 (c15 + c16), c13 - c14 + 8 c15 - 8 c16 + c17},
330
0
           * {d18 + d19 + d20 + d21 + d22, d19 - d20 + 2 d21 - 2 d22, d19 + d20 + 4 (d21 + d22), d19 - d20 + 8 d21 - 8 d22 + d23}}
331
0
           */
332
0
          float* bpz = bp + x * binc[2] + k;
333
0
          unroll_for(dy, z[0], 4) {
334
0
            float r[] = {
335
0
              d[dy * 6 + 0] + d[dy * 6 + 1] + d[dy * 6 + 2] + d[dy * 6 + 3] + d[dy * 6 + 4] + biasval[k],
336
0
              d[dy * 6 + 1] - d[dy * 6 + 2] + 2 * (d[dy * 6 + 3] - d[dy * 6 + 4]) + biasval[k],
337
0
              d[dy * 6 + 1] + d[dy * 6 + 2] + 4 * (d[dy * 6 + 3] + d[dy * 6 + 4]) + biasval[k],
338
0
              d[dy * 6 + 1] - d[dy * 6 + 2] + 8 * (d[dy * 6 + 3] - d[dy * 6 + 4]) + d[dy * 6 + 5] + biasval[k],
339
0
            };
340
0
            unroll_for(dx, z[1], 4) {
341
0
              bpz[dx * binc[2]] = r[dx];
342
0
            } unroll_endfor
343
0
            bpz += binc[1] * binc[2];
344
0
          } unroll_endfor
345
0
        }
346
0
      }
347
0
    } parallel_endfor
348
0
  } else {
349
0
    parallel_for(i, jump_dim) {
350
0
      const int y = i * 4; // i is unsigned.
351
0
      int j, x, k, c;
352
0
      int n[CCV_NNC_MAX_DIM];
353
0
      int m[CCV_NNC_MAX_DIM];
354
0
      int z[CCV_NNC_MAX_DIM];
355
0
      set_n_m_dim(y, 0, tile_dim, adim);
356
0
      z[0] = ccv_min(y + 4, bdim[0]) - y;
357
0
      const float* ap = a->data.f32 + ccv_max(y - hint.border.begin[0], 0) * ainc[1] * ainc[2];
358
0
      float* bp = b->data.f32 + y * binc[1] * binc[2];
359
0
      for (x = 0; x < bdim[1]; x += 4)
360
0
      {
361
0
        set_n_m_dim(x, 1, tile_dim, adim);
362
0
        z[1] = ccv_min(x + 4, bdim[1]) - x;
363
0
#if FOR_IS_PARALLEL
364
0
        float* g = btdb + i * 36 * adim[2];
365
#else
366
        float* g = btdb;
367
#endif
368
0
        // zero g such that we can have zero-padding.
369
0
        memset(g, 0, sizeof(float) * 36 * adim[2]);
370
0
        int dx, dy;
371
0
        const float* apz = ap + ccv_max(x - hint.border.begin[1], 0) * ainc[2];
372
0
        float* gz = g + (n[0] * 6 + n[1]) * adim[2];
373
0
        unroll_for(dy, m[0], 6) {
374
0
          unroll_for(dx, m[1], 6) {
375
0
            float* const gzu = gz + (dy * 6 + dx) * adim[2];
376
0
            for (c = 0; c < adim[2]; c++)
377
0
              gzu[c] = apz[dx * ainc[2] + c];
378
0
          } unroll_endfor
379
0
          apz += ainc[1] * ainc[2];
380
0
        } unroll_endfor
381
0
        for (c = 0; c < adim[2]; c++)
382
0
        {
383
0
          /*
384
0
           * a0, a1, a2, a3, a4, a5,
385
0
           * b6, b7, b8, b9, b10,l11,
386
0
           * c12,c13,c14,c15,c16,c17,
387
0
           * d18,d19,d20,d21,d22,d23,
388
0
           * e24,e25,e26,e27,e28,e29,
389
0
           * f30,f31,f32,f33,f34,f35
390
0
           * {{4 a0 - 5 c12 + e24, 4 a1 - 5 c13 + e25, 4 a2 - 5 c14 + e26, 4 a3 - 5 c15 + e27, 4 a4 - 5 c16 + e28, 4 a5 - 5 c17 + e29},
391
0
           * {-4 b6 - 4 c12 + d18 + e24, -4 b7 - 4 c13 + d19 + e25, -4 b8 - 4 c14 + d20 + e26, -4 b9 - 4 c15 + d21 + e27, -4 b10 - 4 c16 + d22 + e28, -4 b11 - 4 c17 + d23 + e29},
392
0
           * {4 b6 - 4 c12 - d18 + e24, 4 b7 - 4 c13 - d19 + e25, 4 b8 - 4 c14 - d20 + e26, 4 b9 - 4 c15 - d21 + e27, 4 b10 - 4 c16 - d22 + e28, 4 b11 - 4 c17 - d23 + e29},
393
0
           * {-2 b6 - c12 + 2 d18 + e24, -2 b7 - c13 + 2 d19 + e25, -2 b8 - c14 + 2 d20 + e26, -2 b9 - c15 + 2 d21 + e27, -2 b10 - c16 + 2 d22 + e28, -2 b11 - c17 + 2 d23 + e29},
394
0
           * {2 b6 - c12 - 2 d18 + e24, 2 b7 - c13 - 2 d19 + e25, 2 b8 - c14 - 2 d20 + e26, 2 b9 - c15 - 2 d21 + e27, 2 b10 - c16 - 2 d22 + e28, 2 b11 - c17 - 2 d23 + e29},
395
0
           * {4 b6 - 5 d18 + f30, 4 b7 - 5 d19 + f31, 4 b8 - 5 d20 + f32, 4 b9 - 5 d21 + f33, 4 b10 - 5 d22 + f34, 4 b11 - 5 d23 + f35}}
396
0
           */
397
0
          float d[36];
398
0
          /* BT.d */
399
0
          unroll_for(j, 6) {
400
0
            float g0 = g[j * adim[2]];
401
0
            float g12 = g[(12 + j) * adim[2]];
402
0
            float g24 = g[(24 + j) * adim[2]];
403
0
            /* row 1 */
404
0
            d[j] = 4 * g0 - 5 * g12 + g24;
405
0
            float g6 = g[(6 + j) * adim[2]];
406
0
            float g18 = g[(18 + j) * adim[2]];
407
0
            /* row 2 */
408
0
            d[6 + j] = -4 * (g6 + g12) + g18 + g24;
409
0
            /* row 3 */
410
0
            d[12 + j] = 4 * (g6 - g12) - g18 + g24;
411
0
            /* row 4 */
412
0
            d[18 + j] = 2 * (g18 - g6) - g12 + g24;
413
0
            /* row 5 */
414
0
            d[24 + j] = 2 * (g6 - g18) - g12 + g24;
415
0
            float g30 = g[(30 + j) * adim[2]];
416
0
            /* row 6 */
417
0
            d[30 + j] = 4 * g6 - 5 * g18 + g30;
418
0
          } unroll_endfor
419
0
          /*
420
0
           * a0, a1, a2, a3, a4, a5,
421
0
           * b6, b7, b8, b9, b10,l11,
422
0
           * c12,c13,c14,c15,c16,c17,
423
0
           * d18,d19,d20,d21,d22,d23,
424
0
           * e24,e25,e26,e27,e28,e29,
425
0
           * f30,f31,f32,f33,f34,f35
426
0
           * {{4 a0 - 5 a2 + a4, -4 a1 - 4 a2 + a3 + a4, 4 a1 - 4 a2 - a3 + a4, -2 a1 - a2 + 2 a3 + a4, 2 a1 - a2 - 2 a3 + a4, 4 a1 - 5 a3 + a5},
427
0
           * {b10 + 4 b6 - 5 b8, b10 - 4 b7 - 4 b8 + b9, b10 + 4 b7 - 4 b8 - b9, b10 - 2 b7 - b8 + 2 b9, b10 + 2 b7 - b8 - 2 b9, b11 + 4 b7 - 5 b9},
428
0
           * {4 c12 - 5 c14 + c16, -4 c13 - 4 c14 + c15 + c16, 4 c13 - 4 c14 - c15 + c16, -2 c13 - c14 + 2 c15 + c16, 2 c13 - c14 - 2 c15 + c16, 4 c13 - 5 c15 + c17},
429
0
           * {4 d18 - 5 d20 + d22, -4 d19 - 4 d20 + d21 + d22, 4 d19 - 4 d20 - d21 + d22, -2 d19 - d20 + 2 d21 + d22, 2 d19 - d20 - 2 d21 + d22, 4 d19 - 5 d21 + d23},
430
0
           * {4 e24 - 5 e26 + e28, -4 e25 - 4 e26 + e27 + e28, 4 e25 - 4 e26 - e27 + e28, -2 e25 - e26 + 2 e27 + e28, 2 e25 - e26 - 2 e27 + e28, 4 e25 - 5 e27 + e29},
431
0
           * {4 f30 - 5 f32 + f34, -4 f31 - 4 f32 + f33 + f34, 4 f31 - 4 f32 - f33 + f34, -2 f31 - f32 + 2 f33 + f34, 2 f31 - f32 - 2 f33 + f34, 4 f31 - 5 f33 + f35}}
432
0
           */
433
0
          /* BT.d.B */
434
0
          unroll_for(j, 6) {
435
0
            /* row 1 - 6 */
436
0
            float* const gz = g + j * 6 * adim[2];
437
0
            float* const dz = d + j * 6;
438
0
            gz[0] = 4 * dz[0] - 5 * dz[2] + dz[4];
439
0
            gz[adim[2]] = -4 * (dz[1] + dz[2]) + dz[3] + dz[4];
440
0
            gz[2 * adim[2]] = 4 * (dz[1] - dz[2]) - dz[3] + dz[4];
441
0
            gz[3 * adim[2]] = 2 * (dz[3] - dz[1]) - dz[2] + dz[4];
442
0
            gz[4 * adim[2]] = 2 * (dz[1] - dz[3]) - dz[2] + dz[4];
443
0
            gz[5 * adim[2]] = 4 * dz[1] - 5 * dz[3] + dz[5];
444
0
          } unroll_endfor
445
0
          // move to the next channel
446
0
          ++g;
447
0
        }
448
0
        const float* wpz = gwtg;
449
0
        for (k = 0; k < w->info.dim[0]; k++)
450
0
        {
451
0
          float q[36];
452
0
#if FOR_IS_PARALLEL
453
0
          g = btdb + i * 36 * adim[2];
454
#else
455
          g = btdb;
456
#endif
457
0
          for (j = 0; j < 36; j++)
458
0
          {
459
0
            float b = 0;
460
0
            for (c = 0; c < adim[2]; c++)
461
0
              b += g[c] * wpz[c];
462
0
            q[j] = b;
463
0
            g += adim[2];
464
0
            wpz += adim[2];
465
0
          }
466
0
          /*
467
0
           * a0, a1, a2, a3, a4, a5,
468
0
           * b6, b7, b8, b9, b10,l11,
469
0
           * c12,c13,c14,c15,c16,c17,
470
0
           * d18,d19,d20,d21,d22,d23,
471
0
           * e24,e25,e26,e27,e28,e29,
472
0
           * f30,f31,f32,f33,f34,f35
473
0
           * {{a0 + b6 + c12 + d18 + e24, a1 + b7 + c13 + d19 + e25, a2 + b8 + c14 + d20 + e26, a3 + b9 + c15 + d21 + e27, a4 + b10 + c16 + d22 + e28, a5 + b11 + c17 + d23 + e29},
474
0
           * {b6 - c12 + 2 d18 - 2 e24, b7 - c13 + 2 d19 - 2 e25, b8 - c14 + 2 d20 - 2 e26, b9 - c15 + 2 d21 - 2 e27, b10 - c16 + 2 d22 - 2 e28, b11 - c17 + 2 d23 - 2 e29},
475
0
           * {b6 + c12 + 4 (d18 + e24), b7 + c13 + 4 (d19 + e25), b8 + c14 + 4 (d20 + e26), b9 + c15 + 4 (d21 + e27), b10 + c16 + 4 (d22 + e28), b11 + c17 + 4 (d23 + e29)},
476
0
           * {b6 - c12 + 8 d18 - 8 e24 + f30, b7 - c13 + 8 d19 - 8 e25 + f31, b8 - c14 + 8 d20 - 8 e26 + f32, b9 - c15 + 8 d21 - 8 e27 + f33, b10 - c16 + 8 d22 - 8 e28 + f34, b11 - c17 + 8 d23 - 8 e29 + f35}}
477
0
           */
478
0
          float d[24];
479
0
          /* row 1 */
480
0
          d[0] = q[0] + q[6] + q[12] + q[18] + q[24];
481
0
          d[1] = q[1] + q[7] + q[13] + q[19] + q[25];
482
0
          d[2] = q[2] + q[8] + q[14] + q[20] + q[26];
483
0
          d[3] = q[3] + q[9] + q[15] + q[21] + q[27];
484
0
          d[4] = q[4] + q[10] + q[16] + q[22] + q[28];
485
0
          d[5] = q[5] + q[11] + q[17] + q[23] + q[29];
486
0
          /* row 2 */
487
0
          d[6] = q[6] - q[12] + 2 * (q[18] - q[24]);
488
0
          d[7] = q[7] - q[13] + 2 * (q[19] - q[25]);
489
0
          d[8] = q[8] - q[14] + 2 * (q[20] - q[26]);
490
0
          d[9] = q[9] - q[15] + 2 * (q[21] - q[27]);
491
0
          d[10] = q[10] - q[16] + 2 * (q[22] - q[28]);
492
0
          d[11] = q[11] - q[17] + 2 * (q[23] - q[29]);
493
0
          /* row 3 */
494
0
          d[12] = q[6] + q[12] + 4 * (q[18] + q[24]);
495
0
          d[13] = q[7] + q[13] + 4 * (q[19] + q[25]);
496
0
          d[14] = q[8] + q[14] + 4 * (q[20] + q[26]);
497
0
          d[15] = q[9] + q[15] + 4 * (q[21] + q[27]);
498
0
          d[16] = q[10] + q[16] + 4 * (q[22] + q[28]);
499
0
          d[17] = q[11] + q[17] + 4 * (q[23] + q[29]);
500
0
          /* row 4 */
501
0
          d[18] = q[6] - q[12] + 8 * (q[18] - q[24]) + q[30];
502
0
          d[19] = q[7] - q[13] + 8 * (q[19] - q[25]) + q[31];
503
0
          d[20] = q[8] - q[14] + 8 * (q[20] - q[26]) + q[32];
504
0
          d[21] = q[9] - q[15] + 8 * (q[21] - q[27]) + q[33];
505
0
          d[22] = q[10] - q[16] + 8 * (q[22] - q[28]) + q[34];
506
0
          d[23] = q[11] - q[17] + 8 * (q[23] - q[29]) + q[35];
507
0
          /*
508
0
           * {{a0 + a1 + a2 + a3 + a4, a1 - a2 + 2 a3 - 2 a4, a1 + a2 + 4 (a3 + a4), a1 - a2 + 8 a3 - 8 a4 + a5},
509
0
           * {b10 + b6 + b7 + b8 + b9, -2 b10 + b7 - b8 + 2 b9, 4 b10 + b7 + b8 + 4 b9, -8 b10 + b11 + b7 - b8 + 8 b9},
510
0
           * {c12 + c13 + c14 + c15 + c16, c13 - c14 + 2 c15 - 2 c16, c13 + c14 + 4 (c15 + c16), c13 - c14 + 8 c15 - 8 c16 + c17},
511
0
           * {d18 + d19 + d20 + d21 + d22, d19 - d20 + 2 d21 - 2 d22, d19 + d20 + 4 (d21 + d22), d19 - d20 + 8 d21 - 8 d22 + d23}}
512
0
           */
513
0
          float* bpz = bp + x * binc[2] + k;
514
0
          unroll_for(dy, z[0], 4) {
515
0
            float r[] = {
516
0
              d[dy * 6 + 0] + d[dy * 6 + 1] + d[dy * 6 + 2] + d[dy * 6 + 3] + d[dy * 6 + 4],
517
0
              d[dy * 6 + 1] - d[dy * 6 + 2] + 2 * (d[dy * 6 + 3] - d[dy * 6 + 4]),
518
0
              d[dy * 6 + 1] + d[dy * 6 + 2] + 4 * (d[dy * 6 + 3] + d[dy * 6 + 4]),
519
0
              d[dy * 6 + 1] - d[dy * 6 + 2] + 8 * (d[dy * 6 + 3] - d[dy * 6 + 4]) + d[dy * 6 + 5],
520
0
            };
521
0
            unroll_for(dx, z[1], 4) {
522
0
              bpz[dx * binc[2]] = r[dx];
523
0
            } unroll_endfor
524
0
            bpz += binc[1] * binc[2];
525
0
          } unroll_endfor
526
0
        }
527
0
      }
528
0
    } parallel_endfor
529
0
  }
530
0
  return CCV_NNC_EXEC_SUCCESS;
531
0
}
532
533
#ifdef HAVE_SSE2
534
inline static void _ccv_nnc_winograd_4x4_3x3_gwtg_sse2(const float* const w, const int* const dim, float* const gwtg)
535
103
{
536
103
  const int jump_dim = dim[0] / 4;
537
103
  const int dimCx4 = (dim[3] + 3) & -4;
538
103
  parallel_for(k, jump_dim) {
539
0
    int i, j;
540
0
    float* gwtgz = gwtg + k * 4 * 36 * dimCx4;
541
0
    const float* wz[] = {
542
0
      w + (k * 4) * 9 * dim[3],
543
0
      w + (k * 4 + 1) * 9 * dim[3],
544
0
      w + (k * 4 + 2) * 9 * dim[3],
545
0
      w + (k * 4 + 3) * 9 * dim[3],
546
0
    };
547
1.68M
    for (i = 0; i < dim[3]; i++)
548
1.68M
    {
549
1.68M
      float x9w[9 * 4] __attribute__ ((__aligned__(16)));
550
15.0M
      unroll_for(j, 9) {
551
15.0M
        x9w[j * 4] = wz[0][j * dim[3] + i];
552
15.0M
        x9w[j * 4 + 1] = wz[1][j * dim[3] + i];
553
15.0M
        x9w[j * 4 + 2] = wz[2][j * dim[3] + i];
554
15.0M
        x9w[j * 4 + 3] = wz[3][j * dim[3] + i];
555
15.0M
      } unroll_endfor
556
1.68M
      float g[18 * 4] __attribute__ ((__aligned__(16)));
557
1.68M
      __m128 x9w0 = _mm_load_ps(x9w);
558
1.68M
      __m128 x9w1 = _mm_load_ps(x9w + 4);
559
1.68M
      __m128 x9w2 = _mm_load_ps(x9w + 8);
560
1.68M
      __m128 x9w3 = _mm_load_ps(x9w + 12);
561
1.68M
      __m128 x9w4 = _mm_load_ps(x9w + 16);
562
1.68M
      __m128 x9w5 = _mm_load_ps(x9w + 20);
563
1.68M
      __m128 x9w6 = _mm_load_ps(x9w + 24);
564
1.68M
      __m128 x9w7 = _mm_load_ps(x9w + 28);
565
1.68M
      __m128 x9w8 = _mm_load_ps(x9w + 32);
566
1.68M
      /* row 1 */
567
1.68M
      __m128 c1_4 = _mm_set1_ps(1.0 / 4);
568
1.68M
      _mm_store_ps(g, _mm_mul_ps(x9w0, c1_4));
569
1.68M
      _mm_store_ps(g + 4, _mm_mul_ps(x9w1, c1_4));
570
1.68M
      _mm_store_ps(g + 8, _mm_mul_ps(x9w2, c1_4));
571
1.68M
      /* row 2 */
572
1.68M
      __m128 cn1_6 = _mm_set1_ps(-1.0 / 6);
573
1.68M
      _mm_store_ps(g + 12, _mm_mul_ps(_mm_add_ps(_mm_add_ps(x9w0, x9w6), x9w3), cn1_6));
574
1.68M
      _mm_store_ps(g + 16, _mm_mul_ps(_mm_add_ps(_mm_add_ps(x9w1, x9w7), x9w4), cn1_6));
575
1.68M
      _mm_store_ps(g + 20, _mm_mul_ps(_mm_add_ps(_mm_add_ps(x9w2, x9w8), x9w5), cn1_6));
576
1.68M
      /* row 3 */
577
1.68M
      _mm_store_ps(g + 24, _mm_mul_ps(_mm_sub_ps(_mm_add_ps(x9w0, x9w6), x9w3), cn1_6));
578
1.68M
      _mm_store_ps(g + 28, _mm_mul_ps(_mm_sub_ps(_mm_add_ps(x9w1, x9w7), x9w4), cn1_6));
579
1.68M
      _mm_store_ps(g + 32, _mm_mul_ps(_mm_sub_ps(_mm_add_ps(x9w2, x9w8), x9w5), cn1_6));
580
1.68M
      /* row 6 */
581
1.68M
      _mm_store_ps(g + 60, x9w6);
582
1.68M
      _mm_store_ps(g + 64, x9w7);
583
1.68M
      _mm_store_ps(g + 68, x9w8);
584
1.68M
      /* w[x] * 2 */
585
1.68M
      x9w3 = _mm_add_ps(x9w3, x9w3);
586
1.68M
      x9w4 = _mm_add_ps(x9w4, x9w4);
587
1.68M
      x9w5 = _mm_add_ps(x9w5, x9w5);
588
1.68M
      /* w[x] * 4 */
589
1.68M
      x9w6 = _mm_add_ps(x9w6, x9w6);
590
1.68M
      x9w6 = _mm_add_ps(x9w6, x9w6);
591
1.68M
      x9w7 = _mm_add_ps(x9w7, x9w7);
592
1.68M
      x9w7 = _mm_add_ps(x9w7, x9w7);
593
1.68M
      x9w8 = _mm_add_ps(x9w8, x9w8);
594
1.68M
      x9w8 = _mm_add_ps(x9w8, x9w8);
595
1.68M
      /* row 4 */
596
1.68M
      __m128 c1_24 = _mm_set1_ps(1.0 / 24);
597
1.68M
      _mm_store_ps(g + 36, _mm_mul_ps(_mm_add_ps(_mm_add_ps(x9w0, x9w6), x9w3), c1_24));
598
1.68M
      _mm_store_ps(g + 40, _mm_mul_ps(_mm_add_ps(_mm_add_ps(x9w1, x9w7), x9w4), c1_24));
599
1.68M
      _mm_store_ps(g + 44, _mm_mul_ps(_mm_add_ps(_mm_add_ps(x9w2, x9w8), x9w5), c1_24));
600
1.68M
      /* row 5 */
601
1.68M
      _mm_store_ps(g + 48, _mm_mul_ps(_mm_sub_ps(_mm_add_ps(x9w0, x9w6), x9w3), c1_24));
602
1.68M
      _mm_store_ps(g + 52, _mm_mul_ps(_mm_sub_ps(_mm_add_ps(x9w1, x9w7), x9w4), c1_24));
603
1.68M
      _mm_store_ps(g + 56, _mm_mul_ps(_mm_sub_ps(_mm_add_ps(x9w2, x9w8), x9w5), c1_24));
604
3.57M
      unroll_for(j, 6) {
605
3.57M
        const float* const gz = g + j * 12;
606
3.57M
        float* const gwtgzu = gwtgz + j * 24 * dimCx4;
607
3.57M
        __m128 g0 = _mm_load_ps(gz);
608
3.57M
        __m128 g1 = _mm_load_ps(gz + 4);
609
3.57M
        __m128 g2 = _mm_load_ps(gz + 8);
610
3.57M
        _mm_store_ps(gwtgzu, _mm_mul_ps(g0, c1_4));
611
3.57M
        _mm_store_ps(gwtgzu + 4 * dimCx4, _mm_mul_ps(_mm_add_ps(_mm_add_ps(g0, g2), g1), cn1_6));
612
3.57M
        _mm_store_ps(gwtgzu + 8 * dimCx4, _mm_mul_ps(_mm_sub_ps(_mm_add_ps(g0, g2), g1), cn1_6));
613
3.57M
        _mm_store_ps(gwtgzu + 20 * dimCx4, g2);
614
3.57M
        /* g[1] * 2 */
615
3.57M
        g1 = _mm_add_ps(g1, g1);
616
3.57M
        /* g[2] * 4 */
617
3.57M
        g2 = _mm_add_ps(g2, g2);
618
3.57M
        g2 = _mm_add_ps(g2, g2);
619
3.57M
        _mm_store_ps(gwtgzu + 12 * dimCx4, _mm_mul_ps(_mm_add_ps(_mm_add_ps(g0, g2), g1), c1_24));
620
3.57M
        _mm_store_ps(gwtgzu + 16 * dimCx4, _mm_mul_ps(_mm_sub_ps(_mm_add_ps(g0, g2), g1), c1_24));
621
3.57M
      } unroll_endfor
622
1.68M
      gwtgz += 4;
623
1.68M
    }
624
103
  } parallel_endfor
625
103
}
626
627
static int _ccv_nnc_conv_forw_4x4_3x3_winograd_sse2(const ccv_nnc_tensor_view_t* const a, const ccv_nnc_tensor_t* const w, const ccv_nnc_tensor_t* const bias, const ccv_nnc_hint_t hint, ccv_nnc_tensor_view_t* const b, ccv_nnc_stream_context_t* const stream_context)
628
103
{
629
103
  const int a_nd = ccv_nnc_tensor_nd(a->info.dim);
630
103
  assert(a_nd == CCV_NNC_MAX_DIM + 1 || a_nd == CCV_NNC_MAX_DIM + 2);
631
103
  const int* adim = (a_nd == CCV_NNC_MAX_DIM + 1) ? a->info.dim : 
a->info.dim + 10
;
632
103
  const int b_nd = ccv_nnc_tensor_nd(b->info.dim);
633
103
  assert(b_nd == CCV_NNC_MAX_DIM + 1 || b_nd == CCV_NNC_MAX_DIM + 2);
634
103
  const int* bdim = (b_nd == CCV_NNC_MAX_DIM + 1) ? b->info.dim : 
b->info.dim + 10
;
635
103
  const int* ainc = CCV_IS_TENSOR_VIEW(a) ? 
((a_nd == 0
CCV_NNC_MAX_DIM0
+ 1) ?
a->inc0
:
a->inc + 10
) : adim;
636
103
  const int* binc = CCV_IS_TENSOR_VIEW(b) ? 
((b_nd == 0
CCV_NNC_MAX_DIM0
+ 1) ?
b->inc0
:
b->inc + 10
) : bdim;
637
103
  assert(hint.border.begin[0] <= 1);
638
103
  assert(hint.border.begin[1] <= 1);
639
103
  assert(w->info.dim[0] % 4 == 0);
640
103
  assert(w->info.dim[1] == 3);
641
103
  assert(w->info.dim[2] == 3);
642
103
  const int jump_dim = (bdim[0] + 3) / 4;
643
103
  const int dimCx4 = (adim[2] + 3) & -4;
644
103
  // allocating workspace memory for kernel reshaping and input reshaping.
645
103
  float* workmem = 0;
646
103
#if FOR_IS_PARALLEL
647
103
  // If we do parallel for, we need to allocate input reshaping for each block.
648
103
  workmem = ccv_nnc_stream_context_get_workspace(stream_context, sizeof(float) * (36 * dimCx4 * jump_dim + 36 * dimCx4 * w->info.dim[0]), CCV_TENSOR_CPU_MEMORY);
649
#else
650
  // Otherwise, just one block.
651
  workmem = ccv_nnc_stream_context_get_workspace(stream_context, sizeof(float) * (36 * dimCx4 + 36 * dimCx4 * w->info.dim[0]), CCV_TENSOR_CPU_MEMORY);
652
#endif
653
103
  if (!workmem)
654
0
    return CCV_NNC_EXEC_OOM;
655
103
  // Convert w to a 6x6 matrix, by computing G.w.T(G) // T for transpose.
656
103
  float* const gwtg = workmem;
657
103
  float* const btdb = workmem + 36 * dimCx4 * w->info.dim[0];
658
103
  memset(gwtg, 0, sizeof(float) * 36 * dimCx4 * w->info.dim[0]);
659
103
  _ccv_nnc_winograd_4x4_3x3_gwtg_sse2(w->data.f32, w->info.dim, gwtg);
660
103
  // kernel weight for one dim.
661
103
  // Workaround issues of dispatch_apply (cannot reference to on-stack array)
662
103
  const int tile_dim_s[CCV_NNC_MAX_DIM_ALLOC] = {
663
103
    w->info.dim[0], 6, 6, w->info.dim[3]
664
103
  };
665
103
  const int* const tile_dim = tile_dim_s;
666
103
  if (bias)
667
102
  {
668
102
    const float* const biasval = bias->data.f32;
669
102
    // This block will be cause in each for-loop, therefore, you can use it to generate some temporary variables.
670
102
    parallel_for(i, jump_dim) {
671
0
      const int y = i * 4; // i is unsigned.
672
0
      int j, x, k, c;
673
0
      int n[CCV_NNC_MAX_DIM];
674
0
      int m[CCV_NNC_MAX_DIM];
675
0
      int z[CCV_NNC_MAX_DIM];
676
102
      set_n_m_dim(y, 0, tile_dim, adim);
677
102
      z[0] = ccv_min(y + 4, bdim[0]) - y;
678
0
      const float* ap = a->data.f32 + ccv_max(y - hint.border.begin[0], 0) * ainc[1] * ainc[2];
679
0
      float* bp = b->data.f32 + y * binc[1] * binc[2];
680
49.1k
      for (x = 0; x < bdim[1]; x += 4)
681
59.0k
      {
682
59.0k
        set_n_m_dim(x, 1, tile_dim, adim);
683
59.0k
        z[1] = ccv_min(x + 4, bdim[1]) - x;
684
59.0k
#if FOR_IS_PARALLEL
685
59.0k
        float* g = btdb + i * 36 * dimCx4;
686
#else
687
        float* g = btdb;
688
#endif
689
59.0k
        // zero g such that we can have zero-padding.
690
59.0k
        memset(g, 0, sizeof(float) * 36 * dimCx4);
691
59.0k
        int dx, dy;
692
59.0k
        const float* apz = ap + ccv_max(x - hint.border.begin[1], 0) * ainc[2];
693
59.0k
        float* gz = g + (n[0] * 6 + n[1]) * dimCx4;
694
353k
        unroll_for(dy, m[0], 6) {
695
2.08M
          unroll_for(dx, m[1], 6) {
696
2.08M
            float* const gzu = gz + (dy * 6 + dx) * dimCx4;
697
127M
            for (c = 0; c < adim[2]; 
c++125M
)
698
125M
              gzu[c] = apz[dx * ainc[2] + c];
699
2.08M
          } unroll_endfor
700
353k
          apz += ainc[1] * ainc[2];
701
353k
        } unroll_endfor
702
1.02M
        for (c = 0; c < adim[2]; 
c += 4963k
)
703
963k
        {
704
963k
          float d[36 * 4]  __attribute__ ((__aligned__(16)));
705
963k
          /* BT.d */
706
5.68M
          unroll_for(j, 6) {
707
5.68M
            /* row 1 */
708
5.68M
            const float* const gz = g + j * dimCx4;
709
5.68M
            float* dz = d + j * 4;
710
5.68M
            __m128 g0 = _mm_load_ps(gz);
711
5.68M
            __m128 g12 = _mm_load_ps(gz + 12 * dimCx4);
712
5.68M
            __m128 g18 = _mm_load_ps(gz + 18 * dimCx4);
713
5.68M
            __m128 g24 = _mm_load_ps(gz + 24 * dimCx4);
714
5.68M
            g0 = _mm_add_ps(g0, g0);
715
5.68M
            g0 = _mm_add_ps(g0, g0);
716
5.68M
            __m128 g12x2 = _mm_add_ps(g12, g12);
717
5.68M
            g12x2 = _mm_add_ps(g12x2, g12x2);
718
5.68M
            g12x2 = _mm_add_ps(g12x2, g12);
719
5.68M
            _mm_store_ps(dz, _mm_sub_ps(_mm_add_ps(g0, g24), g12x2));
720
5.68M
            /* row 2 */
721
5.68M
            __m128 g6 = _mm_load_ps(gz + 6 * dimCx4);
722
5.68M
            __m128 g6x12 = _mm_add_ps(g6, g12);
723
5.68M
            g6x12 = _mm_add_ps(g6x12, g6x12);
724
5.68M
            g6x12 = _mm_add_ps(g6x12, g6x12);
725
5.68M
            _mm_store_ps(dz + 24, _mm_sub_ps(_mm_add_ps(g18, g24), g6x12));
726
5.68M
            /* row 3 */
727
5.68M
            g6x12 = _mm_sub_ps(g6, g12);
728
5.68M
            g6x12 = _mm_add_ps(g6x12, g6x12);
729
5.68M
            g6x12 = _mm_add_ps(g6x12, g6x12);
730
5.68M
            _mm_store_ps(dz + 48, _mm_add_ps(_mm_sub_ps(g24, g18), g6x12));
731
5.68M
            /* row 4 */
732
5.68M
            __m128 g18x6 = _mm_sub_ps(g18, g6);
733
5.68M
            g18x6 = _mm_add_ps(g18x6, g18x6);
734
5.68M
            _mm_store_ps(dz + 72, _mm_add_ps(_mm_sub_ps(g24, g12), g18x6));
735
5.68M
            /* row 5 */
736
5.68M
            _mm_store_ps(dz + 96, _mm_sub_ps(_mm_sub_ps(g24, g12), g18x6));
737
5.68M
            /* row 6 */
738
5.68M
            __m128 g30 = _mm_load_ps(gz + 30 * dimCx4);
739
5.68M
            __m128 g18x2 = _mm_add_ps(g18, g18);
740
5.68M
            g18x2 = _mm_add_ps(g18x2, g18x2);
741
5.68M
            g18x2 = _mm_add_ps(g18, g18x2);
742
5.68M
            g6 = _mm_add_ps(g6, g6);
743
5.68M
            g6 = _mm_add_ps(g6, g6);
744
5.68M
            _mm_store_ps(dz + 120, _mm_sub_ps(_mm_add_ps(g6, g30), g18x2));
745
5.68M
          } unroll_endfor
746
963k
          /* BT.d.B */
747
5.68M
          unroll_for(j, 6) {
748
5.68M
            float* gz = g + j * 6 * dimCx4;
749
5.68M
            const float* const dz = d + j * 24;
750
5.68M
            __m128 d0 = _mm_load_ps(dz);
751
5.68M
            __m128 d1 = _mm_load_ps(dz + 4);
752
5.68M
            __m128 d2 = _mm_load_ps(dz + 8);
753
5.68M
            __m128 d3 = _mm_load_ps(dz + 12);
754
5.68M
            __m128 d4 = _mm_load_ps(dz + 16);
755
5.68M
            __m128 d5 = _mm_load_ps(dz + 20);
756
5.68M
            d0 = _mm_add_ps(d0, d0);
757
5.68M
            d0 = _mm_add_ps(d0, d0);
758
5.68M
            __m128 d2x5 = _mm_add_ps(d2, d2);
759
5.68M
            d2x5 = _mm_add_ps(d2x5, d2x5);
760
5.68M
            d2x5 = _mm_add_ps(d2, d2x5);
761
5.68M
            _mm_store_ps(gz, _mm_sub_ps(_mm_add_ps(d0, d4), d2x5));
762
5.68M
            __m128 d1x2 = _mm_add_ps(d1, d2);
763
5.68M
            d1x2 = _mm_add_ps(d1x2, d1x2);
764
5.68M
            d1x2 = _mm_add_ps(d1x2, d1x2);
765
5.68M
            _mm_store_ps(gz + dimCx4, _mm_sub_ps(_mm_add_ps(d3, d4), d1x2));
766
5.68M
            d1x2 = _mm_sub_ps(d1, d2);
767
5.68M
            d1x2 = _mm_add_ps(d1x2, d1x2);
768
5.68M
            d1x2 = _mm_add_ps(d1x2, d1x2);
769
5.68M
            _mm_store_ps(gz + 2 * dimCx4, _mm_add_ps(_mm_sub_ps(d4, d3), d1x2));
770
5.68M
            __m128 d3x1 = _mm_sub_ps(d3, d1);
771
5.68M
            d3x1 = _mm_add_ps(d3x1, d3x1);
772
5.68M
            _mm_store_ps(gz + 3 * dimCx4, _mm_add_ps(_mm_sub_ps(d4, d2), d3x1));
773
5.68M
            _mm_store_ps(gz + 4 * dimCx4, _mm_sub_ps(_mm_sub_ps(d4, d2), d3x1));
774
5.68M
            d1 = _mm_add_ps(d1, d1);
775
5.68M
            d1 = _mm_add_ps(d1, d1);
776
5.68M
            __m128 d3x5 = _mm_add_ps(d3, d3);
777
5.68M
            d3x5 = _mm_add_ps(d3x5, d3x5);
778
5.68M
            d3x5 = _mm_add_ps(d3, d3x5);
779
5.68M
            _mm_store_ps(gz + 5 * dimCx4, _mm_sub_ps(_mm_add_ps(d1, d5), d3x5));
780
5.68M
          } unroll_endfor
781
963k
          // move to the next channel
782
963k
          g += 4;
783
963k
        }
784
59.0k
        const float* wpz = gwtg;
785
1.51M
        for (k = 0; k < w->info.dim[0]; 
k += 41.45M
)
786
1.46M
        {
787
1.46M
          float q[36 * 4] __attribute__ ((__aligned__(16)));
788
1.46M
#if FOR_IS_PARALLEL
789
1.46M
          g = btdb + i * 36 * dimCx4;
790
#else
791
          g = btdb;
792
#endif
793
36.3M
          for (j = 0; j < 36; 
j++34.8M
)
794
34.8M
          {
795
34.8M
            __m128 v40 = _mm_setzero_ps();
796
34.8M
            __m128 v41 = _mm_setzero_ps();
797
34.8M
            __m128 v42 = _mm_setzero_ps();
798
34.8M
            __m128 v43 = _mm_setzero_ps();
799
399M
            for (c = 0; c < adim[2]; 
c += 4364M
)
800
364M
            {
801
364M
              __m128 g4 = _mm_load_ps(g);
802
364M
              __m128 w40 = _mm_load_ps(wpz);
803
364M
              __m128 w41 = _mm_load_ps(wpz + 4);
804
364M
              __m128 w42 = _mm_load_ps(wpz + 8);
805
364M
              __m128 w43 = _mm_load_ps(wpz + 12);
806
364M
              __m128 g40 = _mm_shuffle_ps(g4, g4, 0x00);
807
364M
              __m128 g41 = _mm_shuffle_ps(g4, g4, 0x55);
808
364M
              __m128 g42 = _mm_shuffle_ps(g4, g4, 0xAA);
809
364M
              __m128 g43 = _mm_shuffle_ps(g4, g4, 0xFF);
810
364M
              v40 = _mm_add_ps(_mm_mul_ps(w40, g40), v40);
811
364M
              v41 = _mm_add_ps(_mm_mul_ps(w41, g41), v41);
812
364M
              v42 = _mm_add_ps(_mm_mul_ps(w42, g42), v42);
813
364M
              v43 = _mm_add_ps(_mm_mul_ps(w43, g43), v43);
814
364M
              g += 4;
815
364M
              wpz += 16;
816
364M
            }
817
34.8M
            v40 = _mm_add_ps(v40, v41);
818
34.8M
            v42 = _mm_add_ps(v42, v43);
819
34.8M
            _mm_store_ps(q + j * 4, _mm_add_ps(v40, v42));
820
34.8M
          }
821
1.46M
          float d[24 * 4] __attribute__ ((__aligned__(16)));
822
8.36M
          unroll_for(j, 6) {
823
8.36M
            const float* const qz = q + j * 4;
824
8.36M
            float* const dz = d + j * 4;
825
8.36M
            __m128 q0 = _mm_load_ps(qz);
826
8.36M
            __m128 q6 = _mm_load_ps(qz + 24);
827
8.36M
            __m128 q12 = _mm_load_ps(qz + 48);
828
8.36M
            __m128 q18 = _mm_load_ps(qz + 72);
829
8.36M
            __m128 q24 = _mm_load_ps(qz + 96);
830
8.36M
            __m128 qs6x12 = _mm_add_ps(q6, q12);
831
8.36M
            __m128 qs18x24 = _mm_add_ps(q18, q24);
832
8.36M
            __m128 qss = _mm_add_ps(qs6x12, q0);
833
8.36M
            /* row 1 */
834
8.36M
            _mm_store_ps(dz, _mm_add_ps(qss, qs18x24));
835
8.36M
            __m128 qn6x12 = _mm_sub_ps(q6, q12);
836
8.36M
            __m128 qn18x24 = _mm_sub_ps(q18, q24);
837
8.36M
            qn18x24 = _mm_add_ps(qn18x24, qn18x24);
838
8.36M
            /* row 2 */
839
8.36M
            _mm_store_ps(dz + 24, _mm_add_ps(qn6x12, qn18x24));
840
8.36M
            qs18x24 = _mm_add_ps(qs18x24, qs18x24);
841
8.36M
            qs18x24 = _mm_add_ps(qs18x24, qs18x24);
842
8.36M
            /* row 3 */
843
8.36M
            _mm_store_ps(dz + 48, _mm_add_ps(qs6x12, qs18x24));
844
8.36M
            qn18x24 = _mm_add_ps(qn18x24, qn18x24);
845
8.36M
            qn18x24 = _mm_add_ps(qn18x24, qn18x24);
846
8.36M
            __m128 q30 = _mm_load_ps(qz + 120);
847
8.36M
            /* row 4 */
848
8.36M
            _mm_store_ps(dz + 72, _mm_add_ps(_mm_add_ps(qn6x12, q30), qn18x24));
849
8.36M
          } unroll_endfor
850
1.46M
          float* bpz = bp + x * binc[2] + k;
851
1.46M
          __m128 bias4 = _mm_loadu_ps(biasval + k);
852
1.46M
          switch (z[1]) {
853
9.25k
            case 1:
854
30.0k
              unroll_for(dy, z[0], 4) {
855
30.0k
                const float* const dz = d + dy * 24;
856
30.0k
                __m128 d0 = _mm_load_ps(dz);
857
30.0k
                __m128 d1 = _mm_load_ps(dz + 4);
858
30.0k
                __m128 d2 = _mm_load_ps(dz + 8);
859
30.0k
                __m128 d3 = _mm_load_ps(dz + 12);
860
30.0k
                __m128 d4 = _mm_load_ps(dz + 16);
861
30.0k
                __m128 ds1x2 = _mm_add_ps(d1, d2);
862
30.0k
                __m128 ds3x4 = _mm_add_ps(d3, d4);
863
30.0k
                ds1x2 = _mm_add_ps(ds1x2, bias4);
864
30.0k
                _mm_stream_ps(bpz, _mm_add_ps(ds1x2, _mm_add_ps(d0, ds3x4)));
865
30.0k
                bpz += binc[1] * binc[2];
866
30.0k
              } unroll_endfor
867
9.25k
              break;
868
0
            case 2:
869
0
              unroll_for(dy, z[0], 4) {
870
0
                const float* const dz = d + dy * 24;
871
0
                __m128 d0 = _mm_load_ps(dz);
872
0
                __m128 d1 = _mm_load_ps(dz + 4);
873
0
                __m128 d2 = _mm_load_ps(dz + 8);
874
0
                __m128 d3 = _mm_load_ps(dz + 12);
875
0
                __m128 d4 = _mm_load_ps(dz + 16);
876
0
                __m128 ds1x2 = _mm_add_ps(d1, d2);
877
0
                __m128 ds3x4 = _mm_add_ps(d3, d4);
878
0
                ds1x2 = _mm_add_ps(ds1x2, bias4);
879
0
                _mm_stream_ps(bpz, _mm_add_ps(ds1x2, _mm_add_ps(d0, ds3x4)));
880
0
                __m128 dn1x2 = _mm_sub_ps(d1, d2);
881
0
                __m128 dn3x4 = _mm_sub_ps(d3, d4);
882
0
                dn3x4 = _mm_add_ps(dn3x4, dn3x4);
883
0
                dn1x2 = _mm_add_ps(dn1x2, bias4);
884
0
                _mm_stream_ps(bpz + binc[2], _mm_add_ps(dn1x2, dn3x4));
885
0
                bpz += binc[1] * binc[2];
886
0
              } unroll_endfor
887
0
              break;
888
56.8k
            case 3:
889
222k
              unroll_for(dy, z[0], 4) {
890
222k
                const float* const dz = d + dy * 24;
891
222k
                __m128 d0 = _mm_load_ps(dz);
892
222k
                __m128 d1 = _mm_load_ps(dz + 4);
893
222k
                __m128 d2 = _mm_load_ps(dz + 8);
894
222k
                __m128 d3 = _mm_load_ps(dz + 12);
895
222k
                __m128 d4 = _mm_load_ps(dz + 16);
896
222k
                __m128 ds1x2 = _mm_add_ps(d1, d2);
897
222k
                __m128 ds3x4 = _mm_add_ps(d3, d4);
898
222k
                ds1x2 = _mm_add_ps(ds1x2, bias4);
899
222k
                _mm_stream_ps(bpz, _mm_add_ps(ds1x2, _mm_add_ps(d0, ds3x4)));
900
222k
                __m128 dn1x2 = _mm_sub_ps(d1, d2);
901
222k
                __m128 dn3x4 = _mm_sub_ps(d3, d4);
902
222k
                dn3x4 = _mm_add_ps(dn3x4, dn3x4);
903
222k
                dn1x2 = _mm_add_ps(dn1x2, bias4);
904
222k
                _mm_stream_ps(bpz + binc[2], _mm_add_ps(dn1x2, dn3x4));
905
222k
                ds3x4 = _mm_add_ps(ds3x4, ds3x4);
906
222k
                ds3x4 = _mm_add_ps(ds3x4, ds3x4);
907
222k
                _mm_stream_ps(bpz + 2 * binc[2], _mm_add_ps(ds1x2, ds3x4));
908
222k
                bpz += binc[1] * binc[2];
909
222k
              } unroll_endfor
910
56.8k
              break;
911
1.39M
            case 4:
912
5.18M
              unroll_for(dy, z[0], 4) {
913
5.18M
                const float* const dz = d + dy * 24;
914
5.18M
                __m128 d0 = _mm_load_ps(dz);
915
5.18M
                __m128 d1 = _mm_load_ps(dz + 4);
916
5.18M
                __m128 d2 = _mm_load_ps(dz + 8);
917
5.18M
                __m128 d3 = _mm_load_ps(dz + 12);
918
5.18M
                __m128 d4 = _mm_load_ps(dz + 16);
919
5.18M
                __m128 ds1x2 = _mm_add_ps(d1, d2);
920
5.18M
                __m128 ds3x4 = _mm_add_ps(d3, d4);
921
5.18M
                ds1x2 = _mm_add_ps(ds1x2, bias4);
922
5.18M
                _mm_stream_ps(bpz, _mm_add_ps(ds1x2, _mm_add_ps(d0, ds3x4)));
923
5.18M
                __m128 dn1x2 = _mm_sub_ps(d1, d2);
924
5.18M
                __m128 dn3x4 = _mm_sub_ps(d3, d4);
925
5.18M
                dn3x4 = _mm_add_ps(dn3x4, dn3x4);
926
5.18M
                dn1x2 = _mm_add_ps(dn1x2, bias4);
927
5.18M
                _mm_stream_ps(bpz + binc[2], _mm_add_ps(dn1x2, dn3x4));
928
5.18M
                ds3x4 = _mm_add_ps(ds3x4, ds3x4);
929
5.18M
                ds3x4 = _mm_add_ps(ds3x4, ds3x4);
930
5.18M
                _mm_stream_ps(bpz + 2 * binc[2], _mm_add_ps(ds1x2, ds3x4));
931
5.18M
                __m128 d5 = _mm_load_ps(dz + 20);
932
5.18M
                dn3x4 = _mm_add_ps(dn3x4, dn3x4);
933
5.18M
                dn3x4 = _mm_add_ps(dn3x4, dn3x4);
934
5.18M
                _mm_stream_ps(bpz + 3 * binc[2], _mm_add_ps(_mm_add_ps(dn1x2, d5), dn3x4));
935
5.18M
                bpz += binc[1] * binc[2];
936
5.18M
              } unroll_endfor
937
1.39M
              break;
938
1.45M
          };
939
1.45M
        }
940
59.0k
      }
941
102
    } parallel_endfor
942
102
  } else {
943
1
    // This block will be cause in each for-loop, therefore, you can use it to generate some temporary variables.
944
1
    parallel_for(i, jump_dim) {
945
0
      const int y = i * 4; // i is unsigned.
946
0
      int j, x, k, c;
947
0
      int n[CCV_NNC_MAX_DIM];
948
0
      int m[CCV_NNC_MAX_DIM];
949
0
      int z[CCV_NNC_MAX_DIM];
950
1
      set_n_m_dim(y, 0, tile_dim, adim);
951
1
      z[0] = ccv_min(y + 4, bdim[0]) - y;
952
0
      const float* ap = a->data.f32 + ccv_max(y - hint.border.begin[0], 0) * ainc[1] * ainc[2];
953
0
      float* bp = b->data.f32 + y * binc[1] * binc[2];
954
184
      for (x = 0; x < bdim[1]; x += 4)
955
186
      {
956
186
        set_n_m_dim(x, 1, tile_dim, adim);
957
186
        z[1] = ccv_min(x + 4, bdim[1]) - x;
958
186
#if FOR_IS_PARALLEL
959
186
        float* g = btdb + i * 36 * dimCx4;
960
#else
961
        float* g = btdb;
962
#endif
963
186
        // zero g such that we can have zero-padding.
964
186
        memset(g, 0, sizeof(float) * 36 * dimCx4);
965
186
        int dx, dy;
966
186
        const float* apz = ap + ccv_max(x - hint.border.begin[1], 0) * ainc[2];
967
186
        float* gz = g + (n[0] * 6 + n[1]) * dimCx4;
968
1.14k
        unroll_for(dy, m[0], 6) {
969
6.69k
          unroll_for(dx, m[1], 6) {
970
6.69k
            float* const gzu = gz + (dy * 6 + dx) * dimCx4;
971
822k
            for (c = 0; c < adim[2]; 
c++815k
)
972
815k
              gzu[c] = apz[dx * ainc[2] + c];
973
6.69k
          } unroll_endfor
974
1.14k
          apz += ainc[1] * ainc[2];
975
1.14k
        } unroll_endfor
976
6.42k
        for (c = 0; c < adim[2]; 
c += 46.23k
)
977
6.23k
        {
978
6.23k
          float d[36 * 4]  __attribute__ ((__aligned__(16)));
979
6.23k
          /* BT.d */
980
36.6k
          unroll_for(j, 6) {
981
36.6k
            /* row 1 */
982
36.6k
            const float* const gz = g + j * dimCx4;
983
36.6k
            float* dz = d + j * 4;
984
36.6k
            __m128 g0 = _mm_load_ps(gz);
985
36.6k
            __m128 g12 = _mm_load_ps(gz + 12 * dimCx4);
986
36.6k
            __m128 g18 = _mm_load_ps(gz + 18 * dimCx4);
987
36.6k
            __m128 g24 = _mm_load_ps(gz + 24 * dimCx4);
988
36.6k
            g0 = _mm_add_ps(g0, g0);
989
36.6k
            g0 = _mm_add_ps(g0, g0);
990
36.6k
            __m128 g12x2 = _mm_add_ps(g12, g12);
991
36.6k
            g12x2 = _mm_add_ps(g12x2, g12x2);
992
36.6k
            g12x2 = _mm_add_ps(g12x2, g12);
993
36.6k
            _mm_store_ps(dz, _mm_sub_ps(_mm_add_ps(g0, g24), g12x2));
994
36.6k
            /* row 2 */
995
36.6k
            __m128 g6 = _mm_load_ps(gz + 6 * dimCx4);
996
36.6k
            __m128 g6x12 = _mm_add_ps(g6, g12);
997
36.6k
            g6x12 = _mm_add_ps(g6x12, g6x12);
998
36.6k
            g6x12 = _mm_add_ps(g6x12, g6x12);
999
36.6k
            _mm_store_ps(dz + 24, _mm_sub_ps(_mm_add_ps(g18, g24), g6x12));
1000
36.6k
            /* row 3 */
1001
36.6k
            g6x12 = _mm_sub_ps(g6, g12);
1002
36.6k
            g6x12 = _mm_add_ps(g6x12, g6x12);
1003
36.6k
            g6x12 = _mm_add_ps(g6x12, g6x12);
1004
36.6k
            _mm_store_ps(dz + 48, _mm_add_ps(_mm_sub_ps(g24, g18), g6x12));
1005
36.6k
            /* row 4 */
1006
36.6k
            __m128 g18x6 = _mm_sub_ps(g18, g6);
1007
36.6k
            g18x6 = _mm_add_ps(g18x6, g18x6);
1008
36.6k
            _mm_store_ps(dz + 72, _mm_add_ps(_mm_sub_ps(g24, g12), g18x6));
1009
36.6k
            /* row 5 */
1010
36.6k
            _mm_store_ps(dz + 96, _mm_sub_ps(_mm_sub_ps(g24, g12), g18x6));
1011
36.6k
            /* row 6 */
1012
36.6k
            __m128 g30 = _mm_load_ps(gz + 30 * dimCx4);
1013
36.6k
            __m128 g18x2 = _mm_add_ps(g18, g18);
1014
36.6k
            g18x2 = _mm_add_ps(g18x2, g18x2);
1015
36.6k
            g18x2 = _mm_add_ps(g18, g18x2);
1016
36.6k
            g6 = _mm_add_ps(g6, g6);
1017
36.6k
            g6 = _mm_add_ps(g6, g6);
1018
36.6k
            _mm_store_ps(dz + 120, _mm_sub_ps(_mm_add_ps(g6, g30), g18x2));
1019
36.6k
          } unroll_endfor
1020
6.23k
          /* BT.d.B */
1021
36.6k
          unroll_for(j, 6) {
1022
36.6k
            float* gz = g + j * 6 * dimCx4;
1023
36.6k
            const float* const dz = d + j * 24;
1024
36.6k
            __m128 d0 = _mm_load_ps(dz);
1025
36.6k
            __m128 d1 = _mm_load_ps(dz + 4);
1026
36.6k
            __m128 d2 = _mm_load_ps(dz + 8);
1027
36.6k
            __m128 d3 = _mm_load_ps(dz + 12);
1028
36.6k
            __m128 d4 = _mm_load_ps(dz + 16);
1029
36.6k
            __m128 d5 = _mm_load_ps(dz + 20);
1030
36.6k
            d0 = _mm_add_ps(d0, d0);
1031
36.6k
            d0 = _mm_add_ps(d0, d0);
1032
36.6k
            __m128 d2x5 = _mm_add_ps(d2, d2);
1033
36.6k
            d2x5 = _mm_add_ps(d2x5, d2x5);
1034
36.6k
            d2x5 = _mm_add_ps(d2, d2x5);
1035
36.6k
            _mm_store_ps(gz, _mm_sub_ps(_mm_add_ps(d0, d4), d2x5));
1036
36.6k
            __m128 d1x2 = _mm_add_ps(d1, d2);
1037
36.6k
            d1x2 = _mm_add_ps(d1x2, d1x2);
1038
36.6k
            d1x2 = _mm_add_ps(d1x2, d1x2);
1039
36.6k
            _mm_store_ps(gz + dimCx4, _mm_sub_ps(_mm_add_ps(d3, d4), d1x2));
1040
36.6k
            d1x2 = _mm_sub_ps(d1, d2);
1041
36.6k
            d1x2 = _mm_add_ps(d1x2, d1x2);
1042
36.6k
            d1x2 = _mm_add_ps(d1x2, d1x2);
1043
36.6k
            _mm_store_ps(gz + 2 * dimCx4, _mm_add_ps(_mm_sub_ps(d4, d3), d1x2));
1044
36.6k
            __m128 d3x1 = _mm_sub_ps(d3, d1);
1045
36.6k
            d3x1 = _mm_add_ps(d3x1, d3x1);
1046
36.6k
            _mm_store_ps(gz + 3 * dimCx4, _mm_add_ps(_mm_sub_ps(d4, d2), d3x1));
1047
36.6k
            _mm_store_ps(gz + 4 * dimCx4, _mm_sub_ps(_mm_sub_ps(d4, d2), d3x1));
1048
36.6k
            d1 = _mm_add_ps(d1, d1);
1049
36.6k
            d1 = _mm_add_ps(d1, d1);
1050
36.6k
            __m128 d3x5 = _mm_add_ps(d3, d3);
1051
36.6k
            d3x5 = _mm_add_ps(d3x5, d3x5);
1052
36.6k
            d3x5 = _mm_add_ps(d3, d3x5);
1053
36.6k
            _mm_store_ps(gz + 5 * dimCx4, _mm_sub_ps(_mm_add_ps(d1, d5), d3x5));
1054
36.6k
          } unroll_endfor
1055
6.23k
          // move to the next channel
1056
6.23k
          g += 4;
1057
6.23k
        }
1058
186
        const float* wpz = gwtg;
1059
6.44k
        for (k = 0; k < w->info.dim[0]; 
k += 46.26k
)
1060
6.26k
        {
1061
6.26k
          float q[36 * 4] __attribute__ ((__aligned__(16)));
1062
6.26k
#if FOR_IS_PARALLEL
1063
6.26k
          g = btdb + i * 36 * dimCx4;
1064
#else
1065
          g = btdb;
1066
#endif
1067
216k
          for (j = 0; j < 36; 
j++210k
)
1068
210k
          {
1069
210k
            __m128 v40 = _mm_setzero_ps();
1070
210k
            __m128 v41 = _mm_setzero_ps();
1071
210k
            __m128 v42 = _mm_setzero_ps();
1072
210k
            __m128 v43 = _mm_setzero_ps();
1073
2.28M
            for (c = 0; c < adim[2]; 
c += 42.07M
)
1074
2.07M
            {
1075
2.07M
              __m128 g4 = _mm_load_ps(g);
1076
2.07M
              __m128 w40 = _mm_load_ps(wpz);
1077
2.07M
              __m128 w41 = _mm_load_ps(wpz + 4);
1078
2.07M
              __m128 w42 = _mm_load_ps(wpz + 8);
1079
2.07M
              __m128 w43 = _mm_load_ps(wpz + 12);
1080
2.07M
              __m128 g40 = _mm_shuffle_ps(g4, g4, 0x00);
1081
2.07M
              __m128 g41 = _mm_shuffle_ps(g4, g4, 0x55);
1082
2.07M
              __m128 g42 = _mm_shuffle_ps(g4, g4, 0xAA);
1083
2.07M
              __m128 g43 = _mm_shuffle_ps(g4, g4, 0xFF);
1084
2.07M
              v40 = _mm_add_ps(_mm_mul_ps(w40, g40), v40);
1085
2.07M
              v41 = _mm_add_ps(_mm_mul_ps(w41, g41), v41);
1086
2.07M
              v42 = _mm_add_ps(_mm_mul_ps(w42, g42), v42);
1087
2.07M
              v43 = _mm_add_ps(_mm_mul_ps(w43, g43), v43);
1088
2.07M
              g += 4;
1089
2.07M
              wpz += 16;
1090
2.07M
            }
1091
210k
            v40 = _mm_add_ps(v40, v41);
1092
210k
            v42 = _mm_add_ps(v42, v43);
1093
210k
            _mm_store_ps(q + j * 4, _mm_add_ps(v40, v42));
1094
210k
          }
1095
6.26k
          float d[24 * 4] __attribute__ ((__aligned__(16)));
1096
37.4k
          unroll_for(j, 6) {
1097
37.4k
            const float* const qz = q + j * 4;
1098
37.4k
            float* const dz = d + j * 4;
1099
37.4k
            __m128 q0 = _mm_load_ps(qz);
1100
37.4k
            __m128 q6 = _mm_load_ps(qz + 24);
1101
37.4k
            __m128 q12 = _mm_load_ps(qz + 48);
1102
37.4k
            __m128 q18 = _mm_load_ps(qz + 72);
1103
37.4k
            __m128 q24 = _mm_load_ps(qz + 96);
1104
37.4k
            __m128 qs6x12 = _mm_add_ps(q6, q12);
1105
37.4k
            __m128 qs18x24 = _mm_add_ps(q18, q24);
1106
37.4k
            __m128 qss = _mm_add_ps(qs6x12, q0);
1107
37.4k
            /* row 1 */
1108
37.4k
            _mm_store_ps(dz, _mm_add_ps(qss, qs18x24));
1109
37.4k
            __m128 qn6x12 = _mm_sub_ps(q6, q12);
1110
37.4k
            __m128 qn18x24 = _mm_sub_ps(q18, q24);
1111
37.4k
            qn18x24 = _mm_add_ps(qn18x24, qn18x24);
1112
37.4k
            /* row 2 */
1113
37.4k
            _mm_store_ps(dz + 24, _mm_add_ps(qn6x12, qn18x24));
1114
37.4k
            qs18x24 = _mm_add_ps(qs18x24, qs18x24);
1115
37.4k
            qs18x24 = _mm_add_ps(qs18x24, qs18x24);
1116
37.4k
            /* row 3 */
1117
37.4k
            _mm_store_ps(dz + 48, _mm_add_ps(qs6x12, qs18x24));
1118
37.4k
            qn18x24 = _mm_add_ps(qn18x24, qn18x24);
1119
37.4k
            qn18x24 = _mm_add_ps(qn18x24, qn18x24);
1120
37.4k
            __m128 q30 = _mm_load_ps(qz + 120);
1121
37.4k
            /* row 4 */
1122
37.4k
            _mm_store_ps(dz + 72, _mm_add_ps(_mm_add_ps(qn6x12, q30), qn18x24));
1123
37.4k
          } unroll_endfor
1124
6.26k
          float* bpz = bp + x * binc[2] + k;
1125
6.26k
          switch (z[1]) {
1126
0
            case 1:
1127
0
              unroll_for(dy, z[0], 4) {
1128
0
                const float* const dz = d + dy * 24;
1129
0
                __m128 d0 = _mm_load_ps(dz);
1130
0
                __m128 d1 = _mm_load_ps(dz + 4);
1131
0
                __m128 d2 = _mm_load_ps(dz + 8);
1132
0
                __m128 d3 = _mm_load_ps(dz + 12);
1133
0
                __m128 d4 = _mm_load_ps(dz + 16);
1134
0
                __m128 ds1x2 = _mm_add_ps(d1, d2);
1135
0
                __m128 ds3x4 = _mm_add_ps(d3, d4);
1136
0
                _mm_stream_ps(bpz, _mm_add_ps(ds1x2, _mm_add_ps(d0, ds3x4)));
1137
0
                bpz += binc[1] * binc[2];
1138
0
              } unroll_endfor
1139
0
              break;
1140
0
            case 2:
1141
0
              unroll_for(dy, z[0], 4) {
1142
0
                const float* const dz = d + dy * 24;
1143
0
                __m128 d0 = _mm_load_ps(dz);
1144
0
                __m128 d1 = _mm_load_ps(dz + 4);
1145
0
                __m128 d2 = _mm_load_ps(dz + 8);
1146
0
                __m128 d3 = _mm_load_ps(dz + 12);
1147
0
                __m128 d4 = _mm_load_ps(dz + 16);
1148
0
                __m128 ds1x2 = _mm_add_ps(d1, d2);
1149
0
                __m128 ds3x4 = _mm_add_ps(d3, d4);
1150
0
                _mm_stream_ps(bpz, _mm_add_ps(ds1x2, _mm_add_ps(d0, ds3x4)));
1151
0
                __m128 dn1x2 = _mm_sub_ps(d1, d2);
1152
0
                __m128 dn3x4 = _mm_sub_ps(d3, d4);
1153
0
                dn3x4 = _mm_add_ps(dn3x4, dn3x4);
1154
0
                _mm_stream_ps(bpz + binc[2], _mm_add_ps(dn1x2, dn3x4));
1155
0
                bpz += binc[1] * binc[2];
1156
0
              } unroll_endfor
1157
0
              break;
1158
0
            case 3:
1159
0
              unroll_for(dy, z[0], 4) {
1160
0
                const float* const dz = d + dy * 24;
1161
0
                __m128 d0 = _mm_load_ps(dz);
1162
0
                __m128 d1 = _mm_load_ps(dz + 4);
1163
0
                __m128 d2 = _mm_load_ps(dz + 8);
1164
0
                __m128 d3 = _mm_load_ps(dz + 12);
1165
0
                __m128 d4 = _mm_load_ps(dz + 16);
1166
0
                __m128 ds1x2 = _mm_add_ps(d1, d2);
1167
0
                __m128 ds3x4 = _mm_add_ps(d3, d4);
1168
0
                _mm_stream_ps(bpz, _mm_add_ps(ds1x2, _mm_add_ps(d0, ds3x4)));
1169
0
                __m128 dn1x2 = _mm_sub_ps(d1, d2);
1170
0
                __m128 dn3x4 = _mm_sub_ps(d3, d4);
1171
0
                dn3x4 = _mm_add_ps(dn3x4, dn3x4);
1172
0
                _mm_stream_ps(bpz + binc[2], _mm_add_ps(dn1x2, dn3x4));
1173
0
                ds3x4 = _mm_add_ps(ds3x4, ds3x4);
1174
0
                ds3x4 = _mm_add_ps(ds3x4, ds3x4);
1175
0
                _mm_stream_ps(bpz + 2 * binc[2], _mm_add_ps(ds1x2, ds3x4));
1176
0
                bpz += binc[1] * binc[2];
1177
0
              } unroll_endfor
1178
0
              break;
1179
6.27k
            case 4:
1180
25.0k
              unroll_for(dy, z[0], 4) {
1181
25.0k
                const float* const dz = d + dy * 24;
1182
25.0k
                __m128 d0 = _mm_load_ps(dz);
1183
25.0k
                __m128 d1 = _mm_load_ps(dz + 4);
1184
25.0k
                __m128 d2 = _mm_load_ps(dz + 8);
1185
25.0k
                __m128 d3 = _mm_load_ps(dz + 12);
1186
25.0k
                __m128 d4 = _mm_load_ps(dz + 16);
1187
25.0k
                __m128 ds1x2 = _mm_add_ps(d1, d2);
1188
25.0k
                __m128 ds3x4 = _mm_add_ps(d3, d4);
1189
25.0k
                _mm_stream_ps(bpz, _mm_add_ps(ds1x2, _mm_add_ps(d0, ds3x4)));
1190
25.0k
                __m128 dn1x2 = _mm_sub_ps(d1, d2);
1191
25.0k
                __m128 dn3x4 = _mm_sub_ps(d3, d4);
1192
25.0k
                dn3x4 = _mm_add_ps(dn3x4, dn3x4);
1193
25.0k
                _mm_stream_ps(bpz + binc[2], _mm_add_ps(dn1x2, dn3x4));
1194
25.0k
                ds3x4 = _mm_add_ps(ds3x4, ds3x4);
1195
25.0k
                ds3x4 = _mm_add_ps(ds3x4, ds3x4);
1196
25.0k
                _mm_stream_ps(bpz + 2 * binc[2], _mm_add_ps(ds1x2, ds3x4));
1197
25.0k
                __m128 d5 = _mm_load_ps(dz + 20);
1198
25.0k
                dn3x4 = _mm_add_ps(dn3x4, dn3x4);
1199
25.0k
                dn3x4 = _mm_add_ps(dn3x4, dn3x4);
1200
25.0k
                _mm_stream_ps(bpz + 3 * binc[2], _mm_add_ps(_mm_add_ps(dn1x2, d5), dn3x4));
1201
25.0k
                bpz += binc[1] * binc[2];
1202
25.0k
              } unroll_endfor
1203
6.27k
              break;
1204
6.26k
          };
1205
6.26k
        }
1206
186
      }
1207
1
    } parallel_endfor
1208
1
  }
1209
103
  return CCV_NNC_EXEC_SUCCESS;
1210
103
}
1211
#endif
1212
1213
#ifdef HAVE_NEON
1214
inline static void _ccv_nnc_winograd_4x4_3x3_gwtg_neon(const float* const w, const int* const dim, float* const gwtg)
1215
{
1216
  const int jump_dim = dim[0] / 4;
1217
  const int dimCx4 = (dim[3] + 3) & -4;
1218
  parallel_for(k, jump_dim) {
1219
    int i, j;
1220
    float* gwtgz = gwtg + k * 4 * 36 * dimCx4;
1221
    const float* wz[] = {
1222
      w + (k * 4) * 9 * dim[3],
1223
      w + (k * 4 + 1) * 9 * dim[3],
1224
      w + (k * 4 + 2) * 9 * dim[3],
1225
      w + (k * 4 + 3) * 9 * dim[3],
1226
    };
1227
    for (i = 0; i < dim[3]; i++)
1228
    {
1229
      float x9w[9 * 4] __attribute__ ((__aligned__(16)));
1230
      unroll_for(j, 9) {
1231
        x9w[j * 4] = wz[0][j * dim[3] + i];
1232
        x9w[j * 4 + 1] = wz[1][j * dim[3] + i];
1233
        x9w[j * 4 + 2] = wz[2][j * dim[3] + i];
1234
        x9w[j * 4 + 3] = wz[3][j * dim[3] + i];
1235
      } unroll_endfor
1236
      float g[18 * 4] __attribute__ ((__aligned__(16)));
1237
      float32x4_t x9w0 = vld1q_f32(x9w);
1238
      float32x4_t x9w1 = vld1q_f32(x9w + 4);
1239
      float32x4_t x9w2 = vld1q_f32(x9w + 8);
1240
      float32x4_t x9w3 = vld1q_f32(x9w + 12);
1241
      float32x4_t x9w4 = vld1q_f32(x9w + 16);
1242
      float32x4_t x9w5 = vld1q_f32(x9w + 20);
1243
      float32x4_t x9w6 = vld1q_f32(x9w + 24);
1244
      float32x4_t x9w7 = vld1q_f32(x9w + 28);
1245
      float32x4_t x9w8 = vld1q_f32(x9w + 32);
1246
      /* row 1 */
1247
      float32x4_t c1_4 = vdupq_n_f32(1.0 / 4);
1248
      vst1q_f32(g, vmulq_f32(x9w0, c1_4));
1249
      vst1q_f32(g + 4, vmulq_f32(x9w1, c1_4));
1250
      vst1q_f32(g + 8, vmulq_f32(x9w2, c1_4));
1251
      /* row 2 */
1252
      float32x4_t cn1_6 = vdupq_n_f32(-1.0 / 6);
1253
      vst1q_f32(g + 12, vmulq_f32(vaddq_f32(vaddq_f32(x9w0, x9w6), x9w3), cn1_6));
1254
      vst1q_f32(g + 16, vmulq_f32(vaddq_f32(vaddq_f32(x9w1, x9w7), x9w4), cn1_6));
1255
      vst1q_f32(g + 20, vmulq_f32(vaddq_f32(vaddq_f32(x9w2, x9w8), x9w5), cn1_6));
1256
      /* row 3 */
1257
      vst1q_f32(g + 24, vmulq_f32(vsubq_f32(vaddq_f32(x9w0, x9w6), x9w3), cn1_6));
1258
      vst1q_f32(g + 28, vmulq_f32(vsubq_f32(vaddq_f32(x9w1, x9w7), x9w4), cn1_6));
1259
      vst1q_f32(g + 32, vmulq_f32(vsubq_f32(vaddq_f32(x9w2, x9w8), x9w5), cn1_6));
1260
      /* row 6 */
1261
      vst1q_f32(g + 60, x9w6);
1262
      vst1q_f32(g + 64, x9w7);
1263
      vst1q_f32(g + 68, x9w8);
1264
      /* w[x] * 2 */
1265
      x9w3 = vaddq_f32(x9w3, x9w3);
1266
      x9w4 = vaddq_f32(x9w4, x9w4);
1267
      x9w5 = vaddq_f32(x9w5, x9w5);
1268
      /* w[x] * 4 */
1269
      x9w6 = vaddq_f32(x9w6, x9w6);
1270
      x9w6 = vaddq_f32(x9w6, x9w6);
1271
      x9w7 = vaddq_f32(x9w7, x9w7);
1272
      x9w7 = vaddq_f32(x9w7, x9w7);
1273
      x9w8 = vaddq_f32(x9w8, x9w8);
1274
      x9w8 = vaddq_f32(x9w8, x9w8);
1275
      /* row 4 */
1276
      float32x4_t c1_24 = vdupq_n_f32(1.0 / 24);
1277
      vst1q_f32(g + 36, vmulq_f32(vaddq_f32(vaddq_f32(x9w0, x9w6), x9w3), c1_24));
1278
      vst1q_f32(g + 40, vmulq_f32(vaddq_f32(vaddq_f32(x9w1, x9w7), x9w4), c1_24));
1279
      vst1q_f32(g + 44, vmulq_f32(vaddq_f32(vaddq_f32(x9w2, x9w8), x9w5), c1_24));
1280
      /* row 5 */
1281
      vst1q_f32(g + 48, vmulq_f32(vsubq_f32(vaddq_f32(x9w0, x9w6), x9w3), c1_24));
1282
      vst1q_f32(g + 52, vmulq_f32(vsubq_f32(vaddq_f32(x9w1, x9w7), x9w4), c1_24));
1283
      vst1q_f32(g + 56, vmulq_f32(vsubq_f32(vaddq_f32(x9w2, x9w8), x9w5), c1_24));
1284
      unroll_for(j, 6) {
1285
        const float* const gz = g + j * 12;
1286
        float* const gwtgzu = gwtgz + j * 24 * dimCx4;
1287
        float32x4_t g0 = vld1q_f32(gz);
1288
        float32x4_t g1 = vld1q_f32(gz + 4);
1289
        float32x4_t g2 = vld1q_f32(gz + 8);
1290
        vst1q_f32(gwtgzu, vmulq_f32(g0, c1_4));
1291
        vst1q_f32(gwtgzu + 4 * dimCx4, vmulq_f32(vaddq_f32(vaddq_f32(g0, g2), g1), cn1_6));
1292
        vst1q_f32(gwtgzu + 8 * dimCx4, vmulq_f32(vsubq_f32(vaddq_f32(g0, g2), g1), cn1_6));
1293
        vst1q_f32(gwtgzu + 20 * dimCx4, g2);
1294
        /* g[1] * 2 */
1295
        g1 = vaddq_f32(g1, g1);
1296
        /* g[2] * 4 */
1297
        g2 = vaddq_f32(g2, g2);
1298
        g2 = vaddq_f32(g2, g2);
1299
        vst1q_f32(gwtgzu + 12 * dimCx4, vmulq_f32(vaddq_f32(vaddq_f32(g0, g2), g1), c1_24));
1300
        vst1q_f32(gwtgzu + 16 * dimCx4, vmulq_f32(vsubq_f32(vaddq_f32(g0, g2), g1), c1_24));
1301
      } unroll_endfor
1302
      gwtgz += 4;
1303
    }
1304
  } parallel_endfor
1305
}
1306
1307
static int _ccv_nnc_conv_forw_4x4_3x3_winograd_neon(const ccv_nnc_tensor_view_t* const a, const ccv_nnc_tensor_t* const w, const ccv_nnc_tensor_t* const bias, const ccv_nnc_hint_t hint, ccv_nnc_tensor_view_t* const b, ccv_nnc_stream_context_t* const stream_context)
1308
{
1309
  const int a_nd = ccv_nnc_tensor_nd(a->info.dim);
1310
  assert(a_nd == CCV_NNC_MAX_DIM + 1 || a_nd == CCV_NNC_MAX_DIM + 2);
1311
  const int* adim = (a_nd == CCV_NNC_MAX_DIM + 1) ? a->info.dim : a->info.dim + 1;
1312
  const int b_nd = ccv_nnc_tensor_nd(b->info.dim);
1313
  assert(b_nd == CCV_NNC_MAX_DIM + 1 || b_nd == CCV_NNC_MAX_DIM + 2);
1314
  const int* bdim = (b_nd == CCV_NNC_MAX_DIM + 1) ? b->info.dim : b->info.dim + 1;
1315
  const int* ainc = CCV_IS_TENSOR_VIEW(a) ? ((a_nd == CCV_NNC_MAX_DIM + 1) ? a->inc : a->inc + 1) : adim;
1316
  const int* binc = CCV_IS_TENSOR_VIEW(b) ? ((b_nd == CCV_NNC_MAX_DIM + 1) ? b->inc : b->inc + 1) : bdim;
1317
  assert(hint.border.begin[0] <= 1);
1318
  assert(hint.border.begin[1] <= 1);
1319
  assert(w->info.dim[0] % 4 == 0);
1320
  assert(w->info.dim[1] == 3);
1321
  assert(w->info.dim[2] == 3);
1322
  const int jump_dim = (bdim[0] + 3) / 4;
1323
  const int dimCx4 = (adim[2] + 3) & -4;
1324
  // allocating workspace memory for kernel reshaping and input reshaping.
1325
  float* workmem = 0;
1326
#if FOR_IS_PARALLEL
1327
  // If we do parallel for, we need to allocate input reshaping for each block.
1328
  workmem = (float*)ccv_nnc_stream_context_get_workspace(stream_context, sizeof(float) * (36 * dimCx4 * jump_dim + 36 * dimCx4 * w->info.dim[0]), CCV_TENSOR_CPU_MEMORY);
1329
#else
1330
  // Otherwise, just one block.
1331
  workmem = (float*)ccv_nnc_stream_context_get_workspace(stream_context, sizeof(float) * (36 * dimCx4 + 36 * dimCx4 * w->info.dim[0]), CCV_TENSOR_CPU_MEMORY);
1332
#endif
1333
  if (!workmem)
1334
    return CCV_NNC_EXEC_OOM;
1335
  // Convert w to a 6x6 matrix, by computing G.w.T(G) // T for transpose.
1336
  float* const gwtg = workmem;
1337
  float* const btdb = workmem + 36 * dimCx4 * w->info.dim[0];
1338
  memset(gwtg, 0, sizeof(float) * 36 * dimCx4 * w->info.dim[0]);
1339
  _ccv_nnc_winograd_4x4_3x3_gwtg_neon(w->data.f32, w->info.dim, gwtg);
1340
  // kernel weight for one dim.
1341
  // Workaround issues of dispatch_apply (cannot reference to on-stack array)
1342
  const int tile_dim_s[CCV_NNC_MAX_DIM_ALLOC] = {
1343
    w->info.dim[0], 6, 6, w->info.dim[3]
1344
  };
1345
  const int* const tile_dim = tile_dim_s;
1346
  if (bias)
1347
  {
1348
    const float* const biasval = bias->data.f32;
1349
    // This block will be cause in each for-loop, therefore, you can use it to generate some temporary variables.
1350
    parallel_for(i, jump_dim) {
1351
      const int y = i * 4; // i is unsigned.
1352
      int j, x, k, c;
1353
      int n[CCV_NNC_MAX_DIM];
1354
      int m[CCV_NNC_MAX_DIM];
1355
      int z[CCV_NNC_MAX_DIM];
1356
      set_n_m_dim(y, 0, tile_dim, adim);
1357
      z[0] = ccv_min(y + 4, bdim[0]) - y;
1358
      const float* ap = a->data.f32 + ccv_max(y - hint.border.begin[0], 0) * ainc[1] * ainc[2];
1359
      float* bp = b->data.f32 + y * binc[1] * binc[2];
1360
      for (x = 0; x < bdim[1]; x += 4)
1361
      {
1362
        set_n_m_dim(x, 1, tile_dim, adim);
1363
        z[1] = ccv_min(x + 4, bdim[1]) - x;
1364
#if FOR_IS_PARALLEL
1365
        float* g = btdb + i * 36 * dimCx4;
1366
#else
1367
        float* g = btdb;
1368
#endif
1369
        // zero g such that we can have zero-padding.
1370
        memset(g, 0, sizeof(float) * 36 * dimCx4);
1371
        int dx, dy;
1372
        const float* apz = ap + ccv_max(x - hint.border.begin[1], 0) * ainc[2];
1373
        float* gz = g + (n[0] * 6 + n[1]) * dimCx4;
1374
        unroll_for(dy, m[0], 6) {
1375
          unroll_for(dx, m[1], 6) {
1376
            float* const gzu = gz + (dy * 6 + dx) * dimCx4;
1377
            for (c = 0; c < adim[2]; c++)
1378
              gzu[c] = apz[dx * ainc[2] + c];
1379
          } unroll_endfor
1380
          apz += ainc[1] * ainc[2];
1381
        } unroll_endfor
1382
        for (c = 0; c < adim[2]; c += 4)
1383
        {
1384
          float d[36 * 4]  __attribute__ ((__aligned__(16)));
1385
          /* BT.d */
1386
          unroll_for(j, 6) {
1387
            /* row 1 */
1388
            const float* const gz = g + j * dimCx4;
1389
            float* dz = d + j * 4;
1390
            float32x4_t g0 = vld1q_f32(gz);
1391
            float32x4_t g12 = vld1q_f32(gz + 12 * dimCx4);
1392
            float32x4_t g18 = vld1q_f32(gz + 18 * dimCx4);
1393
            float32x4_t g24 = vld1q_f32(gz + 24 * dimCx4);
1394
            g0 = vaddq_f32(g0, g0);
1395
            g0 = vaddq_f32(g0, g0);
1396
            float32x4_t g12x2 = vaddq_f32(g12, g12);
1397
            g12x2 = vaddq_f32(g12x2, g12x2);
1398
            g12x2 = vaddq_f32(g12x2, g12);
1399
            vst1q_f32(dz, vsubq_f32(vaddq_f32(g0, g24), g12x2));
1400
            /* row 2 */
1401
            float32x4_t g6 = vld1q_f32(gz + 6 * dimCx4);
1402
            float32x4_t g6x12 = vaddq_f32(g6, g12);
1403
            g6x12 = vaddq_f32(g6x12, g6x12);
1404
            g6x12 = vaddq_f32(g6x12, g6x12);
1405
            vst1q_f32(dz + 24, vsubq_f32(vaddq_f32(g18, g24), g6x12));
1406
            /* row 3 */
1407
            g6x12 = vsubq_f32(g6, g12);
1408
            g6x12 = vaddq_f32(g6x12, g6x12);
1409
            g6x12 = vaddq_f32(g6x12, g6x12);
1410
            vst1q_f32(dz + 48, vaddq_f32(vsubq_f32(g24, g18), g6x12));
1411
            /* row 4 */
1412
            float32x4_t g18x6 = vsubq_f32(g18, g6);
1413
            g18x6 = vaddq_f32(g18x6, g18x6);
1414
            vst1q_f32(dz + 72, vaddq_f32(vsubq_f32(g24, g12), g18x6));
1415
            /* row 5 */
1416
            vst1q_f32(dz + 96, vsubq_f32(vsubq_f32(g24, g12), g18x6));
1417
            /* row 6 */
1418
            float32x4_t g30 = vld1q_f32(gz + 30 * dimCx4);
1419
            float32x4_t g18x2 = vaddq_f32(g18, g18);
1420
            g18x2 = vaddq_f32(g18x2, g18x2);
1421
            g18x2 = vaddq_f32(g18, g18x2);
1422
            g6 = vaddq_f32(g6, g6);
1423
            g6 = vaddq_f32(g6, g6);
1424
            vst1q_f32(dz + 120, vsubq_f32(vaddq_f32(g6, g30), g18x2));
1425
          } unroll_endfor
1426
          /* BT.d.B */
1427
          unroll_for(j, 6) {
1428
            float* gz = g + j * 6 * dimCx4;
1429
            const float* const dz = d + j * 24;
1430
            float32x4_t d0 = vld1q_f32(dz);
1431
            float32x4_t d1 = vld1q_f32(dz + 4);
1432
            float32x4_t d2 = vld1q_f32(dz + 8);
1433
            float32x4_t d3 = vld1q_f32(dz + 12);
1434
            float32x4_t d4 = vld1q_f32(dz + 16);
1435
            float32x4_t d5 = vld1q_f32(dz + 20);
1436
            d0 = vaddq_f32(d0, d0);
1437
            d0 = vaddq_f32(d0, d0);
1438
            float32x4_t d2x5 = vaddq_f32(d2, d2);
1439
            d2x5 = vaddq_f32(d2x5, d2x5);
1440
            d2x5 = vaddq_f32(d2, d2x5);
1441
            vst1q_f32(gz, vsubq_f32(vaddq_f32(d0, d4), d2x5));
1442
            float32x4_t d1x2 = vaddq_f32(d1, d2);
1443
            d1x2 = vaddq_f32(d1x2, d1x2);
1444
            d1x2 = vaddq_f32(d1x2, d1x2);
1445
            vst1q_f32(gz + dimCx4, vsubq_f32(vaddq_f32(d3, d4), d1x2));
1446
            d1x2 = vsubq_f32(d1, d2);
1447
            d1x2 = vaddq_f32(d1x2, d1x2);
1448
            d1x2 = vaddq_f32(d1x2, d1x2);
1449
            vst1q_f32(gz + 2 * dimCx4, vaddq_f32(vsubq_f32(d4, d3), d1x2));
1450
            float32x4_t d3x1 = vsubq_f32(d3, d1);
1451
            d3x1 = vaddq_f32(d3x1, d3x1);
1452
            vst1q_f32(gz + 3 * dimCx4, vaddq_f32(vsubq_f32(d4, d2), d3x1));
1453
            vst1q_f32(gz + 4 * dimCx4, vsubq_f32(vsubq_f32(d4, d2), d3x1));
1454
            d1 = vaddq_f32(d1, d1);
1455
            d1 = vaddq_f32(d1, d1);
1456
            float32x4_t d3x5 = vaddq_f32(d3, d3);
1457
            d3x5 = vaddq_f32(d3x5, d3x5);
1458
            d3x5 = vaddq_f32(d3, d3x5);
1459
            vst1q_f32(gz + 5 * dimCx4, vsubq_f32(vaddq_f32(d1, d5), d3x5));
1460
          } unroll_endfor
1461
          // move to the next channel
1462
          g += 4;
1463
        }
1464
        const float* wpz = gwtg;
1465
        for (k = 0; k < w->info.dim[0]; k += 4)
1466
        {
1467
          float q[36 * 4] __attribute__ ((__aligned__(16)));
1468
#if FOR_IS_PARALLEL
1469
          g = btdb + i * 36 * dimCx4;
1470
#else
1471
          g = btdb;
1472
#endif
1473
          for (j = 0; j < 36; j++)
1474
          {
1475
            float32x4_t v40 = vmovq_n_f32(0);
1476
            float32x4_t v41 = vmovq_n_f32(0);
1477
            float32x4_t v42 = vmovq_n_f32(0);
1478
            float32x4_t v43 = vmovq_n_f32(0);
1479
            for (c = 0; c < adim[2]; c += 4)
1480
            {
1481
              float32x2x2_t g4 = vld2_f32(g);
1482
              float32x4_t w40 = vld1q_f32(wpz);
1483
              float32x4_t w41 = vld1q_f32(wpz + 4);
1484
              float32x4_t w42 = vld1q_f32(wpz + 8);
1485
              float32x4_t w43 = vld1q_f32(wpz + 12);
1486
              float32x4_t g40 = vdupq_lane_f32(g4.val[0], 0);
1487
              float32x4_t g41 = vdupq_lane_f32(g4.val[1], 0);
1488
              float32x4_t g42 = vdupq_lane_f32(g4.val[0], 1);
1489
              float32x4_t g43 = vdupq_lane_f32(g4.val[1], 1);
1490
              v40 = vmlaq_f32(v40, w40, g40);
1491
              v41 = vmlaq_f32(v41, w41, g41);
1492
              v42 = vmlaq_f32(v42, w42, g42);
1493
              v43 = vmlaq_f32(v43, w43, g43);
1494
              g += 4;
1495
              wpz += 16;
1496
            }
1497
            v40 = vaddq_f32(v40, v41);
1498
            v42 = vaddq_f32(v42, v43);
1499
            vst1q_f32(q + j * 4, vaddq_f32(v40, v42));
1500
          }
1501
          float d[24 * 4] __attribute__ ((__aligned__(16)));
1502
          unroll_for(j, 6) {
1503
            const float* const qz = q + j * 4;
1504
            float* const dz = d + j * 4;
1505
            float32x4_t q0 = vld1q_f32(qz);
1506
            float32x4_t q6 = vld1q_f32(qz + 24);
1507
            float32x4_t q12 = vld1q_f32(qz + 48);
1508
            float32x4_t q18 = vld1q_f32(qz + 72);
1509
            float32x4_t q24 = vld1q_f32(qz + 96);
1510
            float32x4_t qs6x12 = vaddq_f32(q6, q12);
1511
            float32x4_t qs18x24 = vaddq_f32(q18, q24);
1512
            float32x4_t qss = vaddq_f32(qs6x12, q0);
1513
            /* row 1 */
1514
            vst1q_f32(dz, vaddq_f32(qss, qs18x24));
1515
            float32x4_t qn6x12 = vsubq_f32(q6, q12);
1516
            float32x4_t qn18x24 = vsubq_f32(q18, q24);
1517
            qn18x24 = vaddq_f32(qn18x24, qn18x24);
1518
            /* row 2 */
1519
            vst1q_f32(dz + 24, vaddq_f32(qn6x12, qn18x24));
1520
            qs18x24 = vaddq_f32(qs18x24, qs18x24);
1521
            qs18x24 = vaddq_f32(qs18x24, qs18x24);
1522
            /* row 3 */
1523
            vst1q_f32(dz + 48, vaddq_f32(qs6x12, qs18x24));
1524
            qn18x24 = vaddq_f32(qn18x24, qn18x24);
1525
            qn18x24 = vaddq_f32(qn18x24, qn18x24);
1526
            float32x4_t q30 = vld1q_f32(qz + 120);
1527
            /* row 4 */
1528
            vst1q_f32(dz + 72, vaddq_f32(vaddq_f32(qn6x12, q30), qn18x24));
1529
          } unroll_endfor
1530
          float* bpz = bp + x * binc[2] + k;
1531
          float32x4_t bias4 = vld1q_f32(biasval + k);
1532
          switch (z[1]) {
1533
            case 1:
1534
              unroll_for(dy, z[0], 4) {
1535
                const float* const dz = d + dy * 24;
1536
                float32x4_t d0 = vld1q_f32(dz);
1537
                float32x4_t d1 = vld1q_f32(dz + 4);
1538
                float32x4_t d2 = vld1q_f32(dz + 8);
1539
                float32x4_t d3 = vld1q_f32(dz + 12);
1540
                float32x4_t d4 = vld1q_f32(dz + 16);
1541
                float32x4_t ds1x2 = vaddq_f32(d1, d2);
1542
                float32x4_t ds3x4 = vaddq_f32(d3, d4);
1543
                ds1x2 = vaddq_f32(ds1x2, bias4);
1544
                vst1q_f32(bpz, vaddq_f32(ds1x2, vaddq_f32(d0, ds3x4)));
1545
                bpz += binc[1] * binc[2];
1546
              } unroll_endfor
1547
              break;
1548
            case 2:
1549
              unroll_for(dy, z[0], 4) {
1550
                const float* const dz = d + dy * 24;
1551
                float32x4_t d0 = vld1q_f32(dz);
1552
                float32x4_t d1 = vld1q_f32(dz + 4);
1553
                float32x4_t d2 = vld1q_f32(dz + 8);
1554
                float32x4_t d3 = vld1q_f32(dz + 12);
1555
                float32x4_t d4 = vld1q_f32(dz + 16);
1556
                float32x4_t ds1x2 = vaddq_f32(d1, d2);
1557
                float32x4_t ds3x4 = vaddq_f32(d3, d4);
1558
                ds1x2 = vaddq_f32(ds1x2, bias4);
1559
                vst1q_f32(bpz, vaddq_f32(ds1x2, vaddq_f32(d0, ds3x4)));
1560
                float32x4_t dn1x2 = vsubq_f32(d1, d2);
1561
                float32x4_t dn3x4 = vsubq_f32(d3, d4);
1562
                dn3x4 = vaddq_f32(dn3x4, dn3x4);
1563
                dn1x2 = vaddq_f32(dn1x2, bias4);
1564
                vst1q_f32(bpz + binc[2], vaddq_f32(dn1x2, dn3x4));
1565
                bpz += binc[1] * binc[2];
1566
              } unroll_endfor
1567
              break;
1568
            case 3:
1569
              unroll_for(dy, z[0], 4) {
1570
                const float* const dz = d + dy * 24;
1571
                float32x4_t d0 = vld1q_f32(dz);
1572
                float32x4_t d1 = vld1q_f32(dz + 4);
1573
                float32x4_t d2 = vld1q_f32(dz + 8);
1574
                float32x4_t d3 = vld1q_f32(dz + 12);
1575
                float32x4_t d4 = vld1q_f32(dz + 16);
1576
                float32x4_t ds1x2 = vaddq_f32(d1, d2);
1577
                float32x4_t ds3x4 = vaddq_f32(d3, d4);
1578
                ds1x2 = vaddq_f32(ds1x2, bias4);
1579
                vst1q_f32(bpz, vaddq_f32(ds1x2, vaddq_f32(d0, ds3x4)));
1580
                float32x4_t dn1x2 = vsubq_f32(d1, d2);
1581
                float32x4_t dn3x4 = vsubq_f32(d3, d4);
1582
                dn3x4 = vaddq_f32(dn3x4, dn3x4);
1583
                dn1x2 = vaddq_f32(dn1x2, bias4);
1584
                vst1q_f32(bpz + binc[2], vaddq_f32(dn1x2, dn3x4));
1585
                ds3x4 = vaddq_f32(ds3x4, ds3x4);
1586
                ds3x4 = vaddq_f32(ds3x4, ds3x4);
1587
                vst1q_f32(bpz + 2 * binc[2], vaddq_f32(ds1x2, ds3x4));
1588
                bpz += binc[1] * binc[2];
1589
              } unroll_endfor
1590
              break;
1591
            case 4:
1592
              unroll_for(dy, z[0], 4) {
1593
                const float* const dz = d + dy * 24;
1594
                float32x4_t d0 = vld1q_f32(dz);
1595
                float32x4_t d1 = vld1q_f32(dz + 4);
1596
                float32x4_t d2 = vld1q_f32(dz + 8);
1597
                float32x4_t d3 = vld1q_f32(dz + 12);
1598
                float32x4_t d4 = vld1q_f32(dz + 16);
1599
                float32x4_t ds1x2 = vaddq_f32(d1, d2);
1600
                float32x4_t ds3x4 = vaddq_f32(d3, d4);
1601
                ds1x2 = vaddq_f32(ds1x2, bias4);
1602
                vst1q_f32(bpz, vaddq_f32(ds1x2, vaddq_f32(d0, ds3x4)));
1603
                float32x4_t dn1x2 = vsubq_f32(d1, d2);
1604
                float32x4_t dn3x4 = vsubq_f32(d3, d4);
1605
                dn3x4 = vaddq_f32(dn3x4, dn3x4);
1606
                dn1x2 = vaddq_f32(dn1x2, bias4);
1607
                vst1q_f32(bpz + binc[2], vaddq_f32(dn1x2, dn3x4));
1608
                ds3x4 = vaddq_f32(ds3x4, ds3x4);
1609
                ds3x4 = vaddq_f32(ds3x4, ds3x4);
1610
                vst1q_f32(bpz + 2 * binc[2], vaddq_f32(ds1x2, ds3x4));
1611
                float32x4_t d5 = vld1q_f32(dz + 20);
1612
                dn3x4 = vaddq_f32(dn3x4, dn3x4);
1613
                dn3x4 = vaddq_f32(dn3x4, dn3x4);
1614
                vst1q_f32(bpz + 3 * binc[2], vaddq_f32(vaddq_f32(dn1x2, d5), dn3x4));
1615
                bpz += binc[1] * binc[2];
1616
              } unroll_endfor
1617
              break;
1618
          };
1619
        }
1620
      }
1621
    } parallel_endfor
1622
  } else {
1623
    // This block will be cause in each for-loop, therefore, you can use it to generate some temporary variables.
1624
    parallel_for(i, jump_dim) {
1625
      const int y = i * 4; // i is unsigned.
1626
      int j, x, k, c;
1627
      int n[CCV_NNC_MAX_DIM];
1628
      int m[CCV_NNC_MAX_DIM];
1629
      int z[CCV_NNC_MAX_DIM];
1630
      set_n_m_dim(y, 0, tile_dim, adim);
1631
      z[0] = ccv_min(y + 4, bdim[0]) - y;
1632
      const float* ap = a->data.f32 + ccv_max(y - hint.border.begin[0], 0) * ainc[1] * ainc[2];
1633
      float* bp = b->data.f32 + y * binc[1] * binc[2];
1634
      for (x = 0; x < bdim[1]; x += 4)
1635
      {
1636
        set_n_m_dim(x, 1, tile_dim, adim);
1637
        z[1] = ccv_min(x + 4, bdim[1]) - x;
1638
#if FOR_IS_PARALLEL
1639
        float* g = btdb + i * 36 * dimCx4;
1640
#else
1641
        float* g = btdb;
1642
#endif
1643
        // zero g such that we can have zero-padding.
1644
        memset(g, 0, sizeof(float) * 36 * dimCx4);
1645
        int dx, dy;
1646
        const float* apz = ap + ccv_max(x - hint.border.begin[1], 0) * ainc[2];
1647
        float* gz = g + (n[0] * 6 + n[1]) * dimCx4;
1648
        unroll_for(dy, m[0], 6) {
1649
          unroll_for(dx, m[1], 6) {
1650
            float* const gzu = gz + (dy * 6 + dx) * dimCx4;
1651
            for (c = 0; c < adim[2]; c++)
1652
              gzu[c] = apz[dx * ainc[2] + c];
1653
          } unroll_endfor
1654
          apz += ainc[1] * ainc[2];
1655
        } unroll_endfor
1656
        for (c = 0; c < adim[2]; c += 4)
1657
        {
1658
          float d[36 * 4]  __attribute__ ((__aligned__(16)));
1659
          /* BT.d */
1660
          unroll_for(j, 6) {
1661
            /* row 1 */
1662
            const float* const gz = g + j * dimCx4;
1663
            float* dz = d + j * 4;
1664
            float32x4_t g0 = vld1q_f32(gz);
1665
            float32x4_t g12 = vld1q_f32(gz + 12 * dimCx4);
1666
            float32x4_t g18 = vld1q_f32(gz + 18 * dimCx4);
1667
            float32x4_t g24 = vld1q_f32(gz + 24 * dimCx4);
1668
            g0 = vaddq_f32(g0, g0);
1669
            g0 = vaddq_f32(g0, g0);
1670
            float32x4_t g12x2 = vaddq_f32(g12, g12);
1671
            g12x2 = vaddq_f32(g12x2, g12x2);
1672
            g12x2 = vaddq_f32(g12x2, g12);
1673
            vst1q_f32(dz, vsubq_f32(vaddq_f32(g0, g24), g12x2));
1674
            /* row 2 */
1675
            float32x4_t g6 = vld1q_f32(gz + 6 * dimCx4);
1676
            float32x4_t g6x12 = vaddq_f32(g6, g12);
1677
            g6x12 = vaddq_f32(g6x12, g6x12);
1678
            g6x12 = vaddq_f32(g6x12, g6x12);
1679
            vst1q_f32(dz + 24, vsubq_f32(vaddq_f32(g18, g24), g6x12));
1680
            /* row 3 */
1681
            g6x12 = vsubq_f32(g6, g12);
1682
            g6x12 = vaddq_f32(g6x12, g6x12);
1683
            g6x12 = vaddq_f32(g6x12, g6x12);
1684
            vst1q_f32(dz + 48, vaddq_f32(vsubq_f32(g24, g18), g6x12));
1685
            /* row 4 */
1686
            float32x4_t g18x6 = vsubq_f32(g18, g6);
1687
            g18x6 = vaddq_f32(g18x6, g18x6);
1688
            vst1q_f32(dz + 72, vaddq_f32(vsubq_f32(g24, g12), g18x6));
1689
            /* row 5 */
1690
            vst1q_f32(dz + 96, vsubq_f32(vsubq_f32(g24, g12), g18x6));
1691
            /* row 6 */
1692
            float32x4_t g30 = vld1q_f32(gz + 30 * dimCx4);
1693
            float32x4_t g18x2 = vaddq_f32(g18, g18);
1694
            g18x2 = vaddq_f32(g18x2, g18x2);
1695
            g18x2 = vaddq_f32(g18, g18x2);
1696
            g6 = vaddq_f32(g6, g6);
1697
            g6 = vaddq_f32(g6, g6);
1698
            vst1q_f32(dz + 120, vsubq_f32(vaddq_f32(g6, g30), g18x2));
1699
          } unroll_endfor
1700
          /* BT.d.B */
1701
          unroll_for(j, 6) {
1702
            float* gz = g + j * 6 * dimCx4;
1703
            const float* const dz = d + j * 24;
1704
            float32x4_t d0 = vld1q_f32(dz);
1705
            float32x4_t d1 = vld1q_f32(dz + 4);
1706
            float32x4_t d2 = vld1q_f32(dz + 8);
1707
            float32x4_t d3 = vld1q_f32(dz + 12);
1708
            float32x4_t d4 = vld1q_f32(dz + 16);
1709
            float32x4_t d5 = vld1q_f32(dz + 20);
1710
            d0 = vaddq_f32(d0, d0);
1711
            d0 = vaddq_f32(d0, d0);
1712
            float32x4_t d2x5 = vaddq_f32(d2, d2);
1713
            d2x5 = vaddq_f32(d2x5, d2x5);
1714
            d2x5 = vaddq_f32(d2, d2x5);
1715
            vst1q_f32(gz, vsubq_f32(vaddq_f32(d0, d4), d2x5));
1716
            float32x4_t d1x2 = vaddq_f32(d1, d2);
1717
            d1x2 = vaddq_f32(d1x2, d1x2);
1718
            d1x2 = vaddq_f32(d1x2, d1x2);
1719
            vst1q_f32(gz + dimCx4, vsubq_f32(vaddq_f32(d3, d4), d1x2));
1720
            d1x2 = vsubq_f32(d1, d2);
1721
            d1x2 = vaddq_f32(d1x2, d1x2);
1722
            d1x2 = vaddq_f32(d1x2, d1x2);
1723
            vst1q_f32(gz + 2 * dimCx4, vaddq_f32(vsubq_f32(d4, d3), d1x2));
1724
            float32x4_t d3x1 = vsubq_f32(d3, d1);
1725
            d3x1 = vaddq_f32(d3x1, d3x1);
1726
            vst1q_f32(gz + 3 * dimCx4, vaddq_f32(vsubq_f32(d4, d2), d3x1));
1727
            vst1q_f32(gz + 4 * dimCx4, vsubq_f32(vsubq_f32(d4, d2), d3x1));
1728
            d1 = vaddq_f32(d1, d1);
1729
            d1 = vaddq_f32(d1, d1);
1730
            float32x4_t d3x5 = vaddq_f32(d3, d3);
1731
            d3x5 = vaddq_f32(d3x5, d3x5);
1732
            d3x5 = vaddq_f32(d3, d3x5);
1733
            vst1q_f32(gz + 5 * dimCx4, vsubq_f32(vaddq_f32(d1, d5), d3x5));
1734
          } unroll_endfor
1735
          // move to the next channel
1736
          g += 4;
1737
        }
1738
        const float* wpz = gwtg;
1739
        for (k = 0; k < w->info.dim[0]; k += 4)
1740
        {
1741
          float q[36 * 4] __attribute__ ((__aligned__(16)));
1742
#if FOR_IS_PARALLEL
1743
          g = btdb + i * 36 * dimCx4;
1744
#else
1745
          g = btdb;
1746
#endif
1747
          for (j = 0; j < 36; j++)
1748
          {
1749
            float32x4_t v40 = vmovq_n_f32(0);
1750
            float32x4_t v41 = vmovq_n_f32(0);
1751
            float32x4_t v42 = vmovq_n_f32(0);
1752
            float32x4_t v43 = vmovq_n_f32(0);
1753
            for (c = 0; c < adim[2]; c += 4)
1754
            {
1755
              float32x2x2_t g4 = vld2_f32(g);
1756
              float32x4_t w40 = vld1q_f32(wpz);
1757
              float32x4_t w41 = vld1q_f32(wpz + 4);
1758
              float32x4_t w42 = vld1q_f32(wpz + 8);
1759
              float32x4_t w43 = vld1q_f32(wpz + 12);
1760
              float32x4_t g40 = vdupq_lane_f32(g4.val[0], 0);
1761
              float32x4_t g41 = vdupq_lane_f32(g4.val[1], 0);
1762
              float32x4_t g42 = vdupq_lane_f32(g4.val[0], 1);
1763
              float32x4_t g43 = vdupq_lane_f32(g4.val[1], 1);
1764
              v40 = vmlaq_f32(v40, w40, g40);
1765
              v41 = vmlaq_f32(v41, w41, g41);
1766
              v42 = vmlaq_f32(v42, w42, g42);
1767
              v43 = vmlaq_f32(v43, w43, g43);
1768
              g += 4;
1769
              wpz += 16;
1770
            }
1771
            v40 = vaddq_f32(v40, v41);
1772
            v42 = vaddq_f32(v42, v43);
1773
            vst1q_f32(q + j * 4, vaddq_f32(v40, v42));
1774
          }
1775
          float d[24 * 4] __attribute__ ((__aligned__(16)));
1776
          unroll_for(j, 6) {
1777
            const float* const qz = q + j * 4;
1778
            float* const dz = d + j * 4;
1779
            float32x4_t q0 = vld1q_f32(qz);
1780
            float32x4_t q6 = vld1q_f32(qz + 24);
1781
            float32x4_t q12 = vld1q_f32(qz + 48);
1782
            float32x4_t q18 = vld1q_f32(qz + 72);
1783
            float32x4_t q24 = vld1q_f32(qz + 96);
1784
            float32x4_t qs6x12 = vaddq_f32(q6, q12);
1785
            float32x4_t qs18x24 = vaddq_f32(q18, q24);
1786
            float32x4_t qss = vaddq_f32(qs6x12, q0);
1787
            /* row 1 */
1788
            vst1q_f32(dz, vaddq_f32(qss, qs18x24));
1789
            float32x4_t qn6x12 = vsubq_f32(q6, q12);
1790
            float32x4_t qn18x24 = vsubq_f32(q18, q24);
1791
            qn18x24 = vaddq_f32(qn18x24, qn18x24);
1792
            /* row 2 */
1793
            vst1q_f32(dz + 24, vaddq_f32(qn6x12, qn18x24));
1794
            qs18x24 = vaddq_f32(qs18x24, qs18x24);
1795
            qs18x24 = vaddq_f32(qs18x24, qs18x24);
1796
            /* row 3 */
1797
            vst1q_f32(dz + 48, vaddq_f32(qs6x12, qs18x24));
1798
            qn18x24 = vaddq_f32(qn18x24, qn18x24);
1799
            qn18x24 = vaddq_f32(qn18x24, qn18x24);
1800
            float32x4_t q30 = vld1q_f32(qz + 120);
1801
            /* row 4 */
1802
            vst1q_f32(dz + 72, vaddq_f32(vaddq_f32(qn6x12, q30), qn18x24));
1803
          } unroll_endfor
1804
          float* bpz = bp + x * binc[2] + k;
1805
          switch (z[1]) {
1806
            case 1:
1807
              unroll_for(dy, z[0], 4) {
1808
                const float* const dz = d + dy * 24;
1809
                float32x4_t d0 = vld1q_f32(dz);
1810
                float32x4_t d1 = vld1q_f32(dz + 4);
1811
                float32x4_t d2 = vld1q_f32(dz + 8);
1812
                float32x4_t d3 = vld1q_f32(dz + 12);
1813
                float32x4_t d4 = vld1q_f32(dz + 16);
1814
                float32x4_t ds1x2 = vaddq_f32(d1, d2);
1815
                float32x4_t ds3x4 = vaddq_f32(d3, d4);
1816
                vst1q_f32(bpz, vaddq_f32(ds1x2, vaddq_f32(d0, ds3x4)));
1817
                bpz += binc[1] * binc[2];
1818
              } unroll_endfor
1819
              break;
1820
            case 2:
1821
              unroll_for(dy, z[0], 4) {
1822
                const float* const dz = d + dy * 24;
1823
                float32x4_t d0 = vld1q_f32(dz);
1824
                float32x4_t d1 = vld1q_f32(dz + 4);
1825
                float32x4_t d2 = vld1q_f32(dz + 8);
1826
                float32x4_t d3 = vld1q_f32(dz + 12);
1827
                float32x4_t d4 = vld1q_f32(dz + 16);
1828
                float32x4_t ds1x2 = vaddq_f32(d1, d2);
1829
                float32x4_t ds3x4 = vaddq_f32(d3, d4);
1830
                vst1q_f32(bpz, vaddq_f32(ds1x2, vaddq_f32(d0, ds3x4)));
1831
                float32x4_t dn1x2 = vsubq_f32(d1, d2);
1832
                float32x4_t dn3x4 = vsubq_f32(d3, d4);
1833
                dn3x4 = vaddq_f32(dn3x4, dn3x4);
1834
                vst1q_f32(bpz + binc[2], vaddq_f32(dn1x2, dn3x4));
1835
                bpz += binc[1] * binc[2];
1836
              } unroll_endfor
1837
              break;
1838
            case 3:
1839
              unroll_for(dy, z[0], 4) {
1840
                const float* const dz = d + dy * 24;
1841
                float32x4_t d0 = vld1q_f32(dz);
1842
                float32x4_t d1 = vld1q_f32(dz + 4);
1843
                float32x4_t d2 = vld1q_f32(dz + 8);
1844
                float32x4_t d3 = vld1q_f32(dz + 12);
1845
                float32x4_t d4 = vld1q_f32(dz + 16);
1846
                float32x4_t ds1x2 = vaddq_f32(d1, d2);
1847
                float32x4_t ds3x4 = vaddq_f32(d3, d4);
1848
                vst1q_f32(bpz, vaddq_f32(ds1x2, vaddq_f32(d0, ds3x4)));
1849
                float32x4_t dn1x2 = vsubq_f32(d1, d2);
1850
                float32x4_t dn3x4 = vsubq_f32(d3, d4);
1851
                dn3x4 = vaddq_f32(dn3x4, dn3x4);
1852
                vst1q_f32(bpz + binc[2], vaddq_f32(dn1x2, dn3x4));
1853
                ds3x4 = vaddq_f32(ds3x4, ds3x4);
1854
                ds3x4 = vaddq_f32(ds3x4, ds3x4);
1855
                vst1q_f32(bpz + 2 * binc[2], vaddq_f32(ds1x2, ds3x4));
1856
                bpz += binc[1] * binc[2];
1857
              } unroll_endfor
1858
              break;
1859
            case 4:
1860
              unroll_for(dy, z[0], 4) {
1861
                const float* const dz = d + dy * 24;
1862
                float32x4_t d0 = vld1q_f32(dz);
1863
                float32x4_t d1 = vld1q_f32(dz + 4);
1864
                float32x4_t d2 = vld1q_f32(dz + 8);
1865
                float32x4_t d3 = vld1q_f32(dz + 12);
1866
                float32x4_t d4 = vld1q_f32(dz + 16);
1867
                float32x4_t ds1x2 = vaddq_f32(d1, d2);
1868
                float32x4_t ds3x4 = vaddq_f32(d3, d4);
1869
                vst1q_f32(bpz, vaddq_f32(ds1x2, vaddq_f32(d0, ds3x4)));
1870
                float32x4_t dn1x2 = vsubq_f32(d1, d2);
1871
                float32x4_t dn3x4 = vsubq_f32(d3, d4);
1872
                dn3x4 = vaddq_f32(dn3x4, dn3x4);
1873
                vst1q_f32(bpz + binc[2], vaddq_f32(dn1x2, dn3x4));
1874
                ds3x4 = vaddq_f32(ds3x4, ds3x4);
1875
                ds3x4 = vaddq_f32(ds3x4, ds3x4);
1876
                vst1q_f32(bpz + 2 * binc[2], vaddq_f32(ds1x2, ds3x4));
1877
                float32x4_t d5 = vld1q_f32(dz + 20);
1878
                dn3x4 = vaddq_f32(dn3x4, dn3x4);
1879
                dn3x4 = vaddq_f32(dn3x4, dn3x4);
1880
                vst1q_f32(bpz + 3 * binc[2], vaddq_f32(vaddq_f32(dn1x2, d5), dn3x4));
1881
                bpz += binc[1] * binc[2];
1882
              } unroll_endfor
1883
              break;
1884
          };
1885
        }
1886
      }
1887
    } parallel_endfor
1888
  }
1889
  return CCV_NNC_EXEC_SUCCESS;
1890
}
1891
#endif
1892
1893
int _ccv_nnc_conv_forw_4x4_3x3_winograd_cpu_opt(const ccv_nnc_tensor_view_t* const a, const ccv_nnc_tensor_t* const w, const ccv_nnc_tensor_t* const bias, const ccv_nnc_hint_t hint, ccv_nnc_tensor_view_t* const b, ccv_nnc_stream_context_t* const stream_context)
1894
103
{
1895
103
#if defined(HAVE_SSE2)
1896
103
  if (w->info.dim[0] % 4 == 0)
1897
103
    return _ccv_nnc_conv_forw_4x4_3x3_winograd_sse2(a, w, bias, hint, b, stream_context);
1898
#elif defined(HAVE_NEON)
1899
  if (w->info.dim[0] % 4 == 0)
1900
    return _ccv_nnc_conv_forw_4x4_3x3_winograd_neon(a, w, bias, hint, b, stream_context);
1901
#endif
1902
0
  return _ccv_nnc_conv_forw_4x4_3x3_winograd_ref(a, w, bias, hint, b, stream_context);
1903
0
}