Coverage Report

Created: 2017-11-12 13:27

/home/liu/buildslave/linux-x64-runtests/build/lib/nnc/cmd/convolution/cpu_opt/_ccv_nnc_conv_cpu_4x4_3x3_winograd.c
Line
Count
Source (jump to first uncovered line)
1
#include <ccv.h>
2
#include <ccv_internal.h>
3
#include <nnc/ccv_nnc.h>
4
#include <nnc/ccv_nnc_easy.h>
5
#include <nnc/ccv_nnc_internal.h>
6
#if defined(HAVE_SSE2)
7
#include <xmmintrin.h>
8
#elif defined(HAVE_NEON)
9
#include <arm_neon.h>
10
#endif
11
#ifdef USE_OPENMP
12
#include <omp.h>
13
#endif
14
#ifdef USE_DISPATCH
15
#include <dispatch/dispatch.h>
16
#endif
17
#include "../_ccv_nnc_conv_cpu_opt.h"
18
19
#define set_n_m_dim(i, x, wd, ad) \
20
72.0k
  
do 72.0k
{ \71.9k
21
72.0k
    n[x] = ccv_max((i) * hint.stride.dim[x] - hint.border.begin[x], 0) - ((i) * hint.stride.dim[x] - hint.border.begin[x]); \
22
72.0k
    m[x] = wd[x + 1] - n[x] - ((i) * hint.stride.dim[x] - hint.border.begin[x] + wd[x + 1] - ccv_min(ad[x], (i) * hint.stride.dim[x] - hint.border.begin[x] + wd[x + 1])); \
23
71.9k
  } while (0)
24
25
inline static void _ccv_nnc_winograd_4x4_3x3_gwtg_ref(const float* const w, const int c, float* gwtg)
26
0
{
27
0
  int i;
28
0
  for (i = 0; 
i < c0
;
i++0
)
29
0
  {
30
0
    float g[18];
31
0
    /*
32
0
     * a0, b1, c2
33
0
     * d3, e4, f5
34
0
     * g6, h7, i8
35
0
     * {{a/4, b/4, c/4},
36
0
     * {1/6 (-a - d - g), 1/6 (-b - e - h), 1/6 (-c - f - i)},
37
0
     * {1/6 (-a + d - g), 1/6 (-b + e - h), 1/6 (-c + f - i)},
38
0
     * {1/24 (a + 2 d + 4 g), 1/24 (b + 2 e + 4 h), 1/24 (c + 2 f + 4 i)},
39
0
     * {1/24 (a - 2 d + 4 g), 1/24 (b - 2 e + 4 h), 1/24 (c - 2 f + 4 i)},
40
0
     * {g, h, i}}
41
0
     */
42
0
    /* row 1 */
43
0
    g[0] = w[i] / 4;
44
0
    g[1] = w[c + i] / 4;
45
0
    g[2] = w[2 * c + i] / 4;
46
0
    /* row 2 */
47
0
    g[3] = -(w[i] + w[3 * c + i] + w[6 * c + i]) / 6;
48
0
    g[4] = -(w[c + i] + w[4 * c + i] + w[7 * c + i]) / 6;
49
0
    g[5] = -(w[2 * c + i] + w[5 * c + i] + w[8 * c + i]) / 6;
50
0
    /* row 3 */
51
0
    g[6] = (-w[i] + w[3 * c + i] - w[6 * c + i]) / 6;
52
0
    g[7] = (-w[c + i] + w[4 * c + i] - w[7 * c + i]) / 6;
53
0
    g[8] = (-w[2 * c + i] + w[5 * c + i] - w[8 * c + i]) / 6;
54
0
    /* row 4 */
55
0
    g[9] = (w[i] + 2 * w[3 * c + i] + 4 * w[6 * c + i]) / 24;
56
0
    g[10] = (w[c + i] + 2 * w[4 * c + i] + 4 * w[7 * c + i]) / 24;
57
0
    g[11] = (w[2 * c + i] + 2 * w[5 * c + i] + 4 * w[8 * c + i]) / 24;
58
0
    /* row 5 */
59
0
    g[12] = (w[i] - 2 * w[3 * c + i] + 4 * w[6 * c + i]) / 24;
60
0
    g[13] = (w[c + i] - 2 * w[4 * c + i] + 4 * w[7 * c + i]) / 24;
61
0
    g[14] = (w[2 * c + i] - 2 * w[5 * c + i] + 4 * w[8 * c + i]) / 24;
62
0
    /* row 6 */
63
0
    g[15] = w[6 * c + i];
64
0
    g[16] = w[7 * c + i];
65
0
    g[17] = w[8 * c + i];
66
0
    /*
67
0
     * a0, b1, c2
68
0
     * d3, e4, f5
69
0
     * g6, h7, i8
70
0
     * j9, k10,l11
71
0
     * m12,n13,o14
72
0
     * p15,q16,r17
73
0
     * {{a/4, 1/6 (-a - b - c), 1/6 (-a + b - c), 1/24 (a + 2 b + 4 c), 1/24 (a - 2 b + 4 c), c},
74
0
     * {d/4, 1/6 (-d - e - f), 1/6 (-d + e - f), 1/24 (d + 2 e + 4 f), 1/24 (d - 2 e + 4 f), f},
75
0
     * {g/4, 1/6 (-g - h - i), 1/6 (-g + h - i), 1/24 (g + 2 h + 4 i), 1/24 (g - 2 h + 4 i), i},
76
0
     * {j/4, 1/6 (-j - k - l), 1/6 (-j + k - l), 1/24 (j + 2 k + 4 l), 1/24 (j - 2 k + 4 l), l},
77
0
     * {m/4, 1/6 (-m - n - o), 1/6 (-m + n - o), 1/24 (m + 2 n + 4 o), 1/24 (m - 2 n + 4 o), o},
78
0
     * {p/4, 1/6 (-p - q - r), 1/6 (-p + q - r), 1/24 (p + 2 q + 4 r), 1/24 (p - 2 q + 4 r), r}}
79
0
     */
80
0
    /* row 1 */
81
0
    gwtg[0] = g[0] / 4;
82
0
    gwtg[c] = -(g[0] + g[1] + g[2]) / 6;
83
0
    gwtg[2 * c] = (-g[0] + g[1] - g[2]) / 6;
84
0
    gwtg[3 * c] = (g[0] + 2 * g[1] + 4 * g[2]) / 24;
85
0
    gwtg[4 * c] = (g[0] - 2 * g[1] + 4 * g[2]) / 24;
86
0
    gwtg[5 * c] = g[2];
87
0
    /* row 2 */
88
0
    gwtg[6 * c] = g[3] / 4;
89
0
    gwtg[7 * c] = -(g[3] + g[4] + g[5]) / 6;
90
0
    gwtg[8 * c] = (-g[3] + g[4] - g[5]) / 6;
91
0
    gwtg[9 * c] = (g[3] + 2 * g[4] + 4 * g[5]) / 24;
92
0
    gwtg[10 * c] = (g[3] - 2 * g[4] + 4 * g[5]) / 24;
93
0
    gwtg[11 * c] = g[5];
94
0
    /* row 3 */
95
0
    gwtg[12 * c] = g[6] / 4;
96
0
    gwtg[13 * c] = -(g[6] + g[7] + g[8]) / 6;
97
0
    gwtg[14 * c] = (-g[6] + g[7] - g[8]) / 6;
98
0
    gwtg[15 * c] = (g[6] + 2 * g[7] + 4 * g[8]) / 24;
99
0
    gwtg[16 * c] = (g[6] - 2 * g[7] + 4 * g[8]) / 24;
100
0
    gwtg[17 * c] = g[8];
101
0
    /* row 4 */
102
0
    gwtg[18 * c] = g[9] / 4;
103
0
    gwtg[19 * c] = -(g[9] + g[10] + g[11]) / 6;
104
0
    gwtg[20 * c] = (-g[9] + g[10] - g[11]) / 6;
105
0
    gwtg[21 * c] = (g[9] + 2 * g[10] + 4 * g[11]) / 24;
106
0
    gwtg[22 * c] = (g[9] - 2 * g[10] + 4 * g[11]) / 24;
107
0
    gwtg[23 * c] = g[11];
108
0
    /* row 5 */
109
0
    gwtg[24 * c] = g[12] / 4;
110
0
    gwtg[25 * c] = -(g[12] + g[13] + g[14]) / 6;
111
0
    gwtg[26 * c] = (-g[12] + g[13] - g[14]) / 6;
112
0
    gwtg[27 * c] = (g[12] + 2 * g[13] + 4 * g[14]) / 24;
113
0
    gwtg[28 * c] = (g[12] - 2 * g[13] + 4 * g[14]) / 24;
114
0
    gwtg[29 * c] = g[14];
115
0
    /* row 6 */
116
0
    gwtg[30 * c] = g[15] / 4;
117
0
    gwtg[31 * c] = -(g[15] + g[16] + g[17]) / 6;
118
0
    gwtg[32 * c] = (-g[15] + g[16] - g[17]) / 6;
119
0
    gwtg[33 * c] = (g[15] + 2 * g[16] + 4 * g[17]) / 24;
120
0
    gwtg[34 * c] = (g[15] - 2 * g[16] + 4 * g[17]) / 24;
121
0
    gwtg[35 * c] = g[17];
122
0
    ++gwtg;
123
0
  }
124
0
}
125
126
static int _ccv_nnc_conv_forw_4x4_3x3_winograd_ref(const ccv_nnc_tensor_view_t* const a, const ccv_nnc_tensor_t* const w, const ccv_nnc_tensor_t* const bias, const ccv_nnc_hint_t hint, ccv_nnc_tensor_view_t* const b)
127
0
{
128
0
  const int a_nd = ccv_nnc_tensor_nd(a->info.dim);
129
0
  assert(a_nd == CCV_NNC_MAX_DIM + 1 || a_nd == CCV_NNC_MAX_DIM + 2);
130
0
  const int* adim = (a_nd == 
CCV_NNC_MAX_DIM0
+ 1) ?
a->info.dim0
:
a->info.dim + 10
;
131
0
  const int b_nd = ccv_nnc_tensor_nd(b->info.dim);
132
0
  assert(b_nd == CCV_NNC_MAX_DIM + 1 || b_nd == CCV_NNC_MAX_DIM + 2);
133
0
  const int* bdim = (b_nd == 
CCV_NNC_MAX_DIM0
+ 1) ?
b->info.dim0
:
b->info.dim + 10
;
134
0
  const int* ainc = 
CCV_IS_TENSOR_VIEW0
(a) ?
((a_nd == 0
CCV_NNC_MAX_DIM0
+ 1) ?
a->inc0
:
a->inc + 10
) :
adim0
;
135
0
  const int* binc = 
CCV_IS_TENSOR_VIEW0
(b) ?
((b_nd == 0
CCV_NNC_MAX_DIM0
+ 1) ?
b->inc0
:
b->inc + 10
) :
bdim0
;
136
0
  assert(hint.border.begin[0] <= 1);
137
0
  assert(hint.border.begin[1] <= 1);
138
0
  assert(w->info.dim[1] == 3);
139
0
  assert(w->info.dim[2] == 3);
140
0
  const int jump_dim = (bdim[0] + 3) / 4;
141
0
  // allocating workspace memory for kernel reshaping and input reshaping.
142
0
#if FOR_IS_PARALLEL
143
0
  // If we do parallel for, we need to allocate input reshaping for each block.
144
0
  float* const workmem = (float*)ccmalloc(sizeof(float) * (36 * adim[2] * jump_dim + 36 * w->info.dim[0] * w->info.dim[3]));
145
0
#else
146
  // Otherwise, just one block.
147
  float* const workmem = (float*)ccmalloc(sizeof(float) * (36 * adim[2] + 36 * w->info.dim[0] * w->info.dim[3]));
148
#endif
149
0
  if (!workmem)
150
0
    return CCV_NNC_EXEC_OOM;
151
0
  // Convert w to a 6x6 matrix, by computing G.w.T(G) // T for transpose.
152
0
  float* const gwtg = workmem;
153
0
  float* const btdb = workmem + 36 * w->info.dim[0] * w->info.dim[3];
154
0
  
parallel_for0
(k, w->info.dim[0]) {0
155
0
    _ccv_nnc_winograd_4x4_3x3_gwtg_ref(w->data.f32 + k * w->info.dim[3] * w->info.dim[2] * w->info.dim[1], w->info.dim[3], gwtg + k * 36 * w->info.dim[3]);
156
0
  } parallel_endfor
157
0
  // kernel weight for one dim.
158
0
  const float* const biasval = bias->data.f32;
159
0
  // Workaround issues of dispatch_apply (cannot reference to on-stack array)
160
0
  const int tile_dim_s[CCV_NNC_MAX_DIM_ALLOC] = {
161
0
    w->info.dim[0], 6, 6, w->info.dim[3]
162
0
  };
163
0
  const int* const tile_dim = tile_dim_s;
164
0
  // This block will be cause in each for-loop, therefore, you can use it to generate some temporary variables.
165
0
  
parallel_for0
(i, jump_dim) {0
166
0
    const int y = i * 4; // i is unsigned.
167
0
    int j, x, k, c;
168
0
    int n[CCV_NNC_MAX_DIM];
169
0
    int m[CCV_NNC_MAX_DIM];
170
0
    int z[CCV_NNC_MAX_DIM];
171
0
    set_n_m_dim(y, 0, tile_dim, adim);
172
0
    z[0] = ccv_min(y + 4, bdim[0]) - y;
173
0
    const float* ap = a->data.f32 + ccv_max(y - hint.border.begin[0], 0) * ainc[1] * ainc[2];
174
0
    float* bp = b->data.f32 + y * binc[1] * binc[2];
175
0
    for (x = 0; 
x < bdim[1]0
;
x += 40
)
176
0
    {
177
0
      set_n_m_dim(x, 1, tile_dim, adim);
178
0
      z[1] = ccv_min(x + 4, bdim[1]) - x;
179
0
#if FOR_IS_PARALLEL
180
0
      float* g = btdb + i * 36 * adim[2];
181
0
#else
182
      float* g = btdb;
183
#endif
184
0
      // zero g such that we can have zero-padding.
185
0
      memset(g, 0, sizeof(float) * 36 * adim[2]);
186
0
      int dx, dy;
187
0
      const float* apz = ap + ccv_max(x - hint.border.begin[1], 0) * ainc[2];
188
0
      float* gz = g + (n[0] * 6 + n[1]) * adim[2];
189
0
      
unroll_for0
(dy, m[0], 6) {0
190
0
        
unroll_for0
(dx, m[1], 6) {0
191
0
          float* const gzu = gz + (dy * 6 + dx) * adim[2];
192
0
          for (c = 0; 
c < adim[2]0
;
c++0
)
193
0
            gzu[c] = apz[dx * ainc[2] + c];
194
0
        } unroll_endfor
195
0
        apz += ainc[1] * ainc[2];
196
0
      } unroll_endfor
197
0
      for (c = 0; 
c < adim[2]0
;
c++0
)
198
0
      {
199
0
        /*
200
0
         * a0, a1, a2, a3, a4, a5,
201
0
         * b6, b7, b8, b9, b10,l11,
202
0
         * c12,c13,c14,c15,c16,c17,
203
0
         * d18,d19,d20,d21,d22,d23,
204
0
         * e24,e25,e26,e27,e28,e29,
205
0
         * f30,f31,f32,f33,f34,f35
206
0
         * {{4 a0 - 5 c12 + e24, 4 a1 - 5 c13 + e25, 4 a2 - 5 c14 + e26, 4 a3 - 5 c15 + e27, 4 a4 - 5 c16 + e28, 4 a5 - 5 c17 + e29},
207
0
         * {-4 b6 - 4 c12 + d18 + e24, -4 b7 - 4 c13 + d19 + e25, -4 b8 - 4 c14 + d20 + e26, -4 b9 - 4 c15 + d21 + e27, -4 b10 - 4 c16 + d22 + e28, -4 b11 - 4 c17 + d23 + e29},
208
0
         * {4 b6 - 4 c12 - d18 + e24, 4 b7 - 4 c13 - d19 + e25, 4 b8 - 4 c14 - d20 + e26, 4 b9 - 4 c15 - d21 + e27, 4 b10 - 4 c16 - d22 + e28, 4 b11 - 4 c17 - d23 + e29},
209
0
         * {-2 b6 - c12 + 2 d18 + e24, -2 b7 - c13 + 2 d19 + e25, -2 b8 - c14 + 2 d20 + e26, -2 b9 - c15 + 2 d21 + e27, -2 b10 - c16 + 2 d22 + e28, -2 b11 - c17 + 2 d23 + e29},
210
0
         * {2 b6 - c12 - 2 d18 + e24, 2 b7 - c13 - 2 d19 + e25, 2 b8 - c14 - 2 d20 + e26, 2 b9 - c15 - 2 d21 + e27, 2 b10 - c16 - 2 d22 + e28, 2 b11 - c17 - 2 d23 + e29},
211
0
         * {4 b6 - 5 d18 + f30, 4 b7 - 5 d19 + f31, 4 b8 - 5 d20 + f32, 4 b9 - 5 d21 + f33, 4 b10 - 5 d22 + f34, 4 b11 - 5 d23 + f35}}
212
0
         */
213
0
        float d[36];
214
0
        /* BT.d */
215
0
        
unroll_for0
(j, 6) {0
216
0
          float g0 = g[j * adim[2]];
217
0
          float g12 = g[(12 + j) * adim[2]];
218
0
          float g24 = g[(24 + j) * adim[2]];
219
0
          /* row 1 */
220
0
          d[j] = 4 * g0 - 5 * g12 + g24;
221
0
          float g6 = g[(6 + j) * adim[2]];
222
0
          float g18 = g[(18 + j) * adim[2]];
223
0
          /* row 2 */
224
0
          d[6 + j] = -4 * (g6 + g12) + g18 + g24;
225
0
          /* row 3 */
226
0
          d[12 + j] = 4 * (g6 - g12) - g18 + g24;
227
0
          /* row 4 */
228
0
          d[18 + j] = 2 * (g18 - g6) - g12 + g24;
229
0
          /* row 5 */
230
0
          d[24 + j] = 2 * (g6 - g18) - g12 + g24;
231
0
          float g30 = g[(30 + j) * adim[2]];
232
0
          /* row 6 */
233
0
          d[30 + j] = 4 * g6 - 5 * g18 + g30;
234
0
        } unroll_endfor
235
0
        /*
236
0
         * a0, a1, a2, a3, a4, a5,
237
0
         * b6, b7, b8, b9, b10,l11,
238
0
         * c12,c13,c14,c15,c16,c17,
239
0
         * d18,d19,d20,d21,d22,d23,
240
0
         * e24,e25,e26,e27,e28,e29,
241
0
         * f30,f31,f32,f33,f34,f35
242
0
         * {{4 a0 - 5 a2 + a4, -4 a1 - 4 a2 + a3 + a4, 4 a1 - 4 a2 - a3 + a4, -2 a1 - a2 + 2 a3 + a4, 2 a1 - a2 - 2 a3 + a4, 4 a1 - 5 a3 + a5},
243
0
         * {b10 + 4 b6 - 5 b8, b10 - 4 b7 - 4 b8 + b9, b10 + 4 b7 - 4 b8 - b9, b10 - 2 b7 - b8 + 2 b9, b10 + 2 b7 - b8 - 2 b9, b11 + 4 b7 - 5 b9},
244
0
         * {4 c12 - 5 c14 + c16, -4 c13 - 4 c14 + c15 + c16, 4 c13 - 4 c14 - c15 + c16, -2 c13 - c14 + 2 c15 + c16, 2 c13 - c14 - 2 c15 + c16, 4 c13 - 5 c15 + c17},
245
0
         * {4 d18 - 5 d20 + d22, -4 d19 - 4 d20 + d21 + d22, 4 d19 - 4 d20 - d21 + d22, -2 d19 - d20 + 2 d21 + d22, 2 d19 - d20 - 2 d21 + d22, 4 d19 - 5 d21 + d23},
246
0
         * {4 e24 - 5 e26 + e28, -4 e25 - 4 e26 + e27 + e28, 4 e25 - 4 e26 - e27 + e28, -2 e25 - e26 + 2 e27 + e28, 2 e25 - e26 - 2 e27 + e28, 4 e25 - 5 e27 + e29},
247
0
         * {4 f30 - 5 f32 + f34, -4 f31 - 4 f32 + f33 + f34, 4 f31 - 4 f32 - f33 + f34, -2 f31 - f32 + 2 f33 + f34, 2 f31 - f32 - 2 f33 + f34, 4 f31 - 5 f33 + f35}}
248
0
         */
249
0
        /* BT.d.B */
250
0
        
unroll_for0
(j, 6) {0
251
0
          /* row 1 - 6 */
252
0
          float* const gz = g + j * 6 * adim[2];
253
0
          float* const dz = d + j * 6;
254
0
          gz[0] = 4 * dz[0] - 5 * dz[2] + dz[4];
255
0
          gz[adim[2]] = -4 * (dz[1] + dz[2]) + dz[3] + dz[4];
256
0
          gz[2 * adim[2]] = 4 * (dz[1] - dz[2]) - dz[3] + dz[4];
257
0
          gz[3 * adim[2]] = 2 * (dz[3] - dz[1]) - dz[2] + dz[4];
258
0
          gz[4 * adim[2]] = 2 * (dz[1] - dz[3]) - dz[2] + dz[4];
259
0
          gz[5 * adim[2]] = 4 * dz[1] - 5 * dz[3] + dz[5];
260
0
        } unroll_endfor
261
0
        // move to the next channel
262
0
        ++g;
263
0
      }
264
0
      const float* wpz = gwtg;
265
0
      for (k = 0; 
k < w->info.dim[0]0
;
k++0
)
266
0
      {
267
0
        float q[36];
268
0
#if FOR_IS_PARALLEL
269
0
        g = btdb + i * 36 * adim[2];
270
0
#else
271
        g = btdb;
272
#endif
273
0
        for (j = 0; 
j < 360
;
j++0
)
274
0
        {
275
0
          float b = 0;
276
0
          for (c = 0; 
c < adim[2]0
;
c++0
)
277
0
            b += g[c] * wpz[c];
278
0
          q[j] = b;
279
0
          g += adim[2];
280
0
          wpz += adim[2];
281
0
        }
282
0
        /*
283
0
         * a0, a1, a2, a3, a4, a5,
284
0
         * b6, b7, b8, b9, b10,l11,
285
0
         * c12,c13,c14,c15,c16,c17,
286
0
         * d18,d19,d20,d21,d22,d23,
287
0
         * e24,e25,e26,e27,e28,e29,
288
0
         * f30,f31,f32,f33,f34,f35
289
0
         * {{a0 + b6 + c12 + d18 + e24, a1 + b7 + c13 + d19 + e25, a2 + b8 + c14 + d20 + e26, a3 + b9 + c15 + d21 + e27, a4 + b10 + c16 + d22 + e28, a5 + b11 + c17 + d23 + e29},
290
0
         * {b6 - c12 + 2 d18 - 2 e24, b7 - c13 + 2 d19 - 2 e25, b8 - c14 + 2 d20 - 2 e26, b9 - c15 + 2 d21 - 2 e27, b10 - c16 + 2 d22 - 2 e28, b11 - c17 + 2 d23 - 2 e29},
291
0
         * {b6 + c12 + 4 (d18 + e24), b7 + c13 + 4 (d19 + e25), b8 + c14 + 4 (d20 + e26), b9 + c15 + 4 (d21 + e27), b10 + c16 + 4 (d22 + e28), b11 + c17 + 4 (d23 + e29)},
292
0
         * {b6 - c12 + 8 d18 - 8 e24 + f30, b7 - c13 + 8 d19 - 8 e25 + f31, b8 - c14 + 8 d20 - 8 e26 + f32, b9 - c15 + 8 d21 - 8 e27 + f33, b10 - c16 + 8 d22 - 8 e28 + f34, b11 - c17 + 8 d23 - 8 e29 + f35}}
293
0
         */
294
0
        float d[24];
295
0
        /* row 1 */
296
0
        d[0] = q[0] + q[6] + q[12] + q[18] + q[24];
297
0
        d[1] = q[1] + q[7] + q[13] + q[19] + q[25];
298
0
        d[2] = q[2] + q[8] + q[14] + q[20] + q[26];
299
0
        d[3] = q[3] + q[9] + q[15] + q[21] + q[27];
300
0
        d[4] = q[4] + q[10] + q[16] + q[22] + q[28];
301
0
        d[5] = q[5] + q[11] + q[17] + q[23] + q[29];
302
0
        /* row 2 */
303
0
        d[6] = q[6] - q[12] + 2 * (q[18] - q[24]);
304
0
        d[7] = q[7] - q[13] + 2 * (q[19] - q[25]);
305
0
        d[8] = q[8] - q[14] + 2 * (q[20] - q[26]);
306
0
        d[9] = q[9] - q[15] + 2 * (q[21] - q[27]);
307
0
        d[10] = q[10] - q[16] + 2 * (q[22] - q[28]);
308
0
        d[11] = q[11] - q[17] + 2 * (q[23] - q[29]);
309
0
        /* row 3 */
310
0
        d[12] = q[6] + q[12] + 4 * (q[18] + q[24]);
311
0
        d[13] = q[7] + q[13] + 4 * (q[19] + q[25]);
312
0
        d[14] = q[8] + q[14] + 4 * (q[20] + q[26]);
313
0
        d[15] = q[9] + q[15] + 4 * (q[21] + q[27]);
314
0
        d[16] = q[10] + q[16] + 4 * (q[22] + q[28]);
315
0
        d[17] = q[11] + q[17] + 4 * (q[23] + q[29]);
316
0
        /* row 4 */
317
0
        d[18] = q[6] - q[12] + 8 * (q[18] - q[24]) + q[30];
318
0
        d[19] = q[7] - q[13] + 8 * (q[19] - q[25]) + q[31];
319
0
        d[20] = q[8] - q[14] + 8 * (q[20] - q[26]) + q[32];
320
0
        d[21] = q[9] - q[15] + 8 * (q[21] - q[27]) + q[33];
321
0
        d[22] = q[10] - q[16] + 8 * (q[22] - q[28]) + q[34];
322
0
        d[23] = q[11] - q[17] + 8 * (q[23] - q[29]) + q[35];
323
0
        /*
324
0
         * {{a0 + a1 + a2 + a3 + a4, a1 - a2 + 2 a3 - 2 a4, a1 + a2 + 4 (a3 + a4), a1 - a2 + 8 a3 - 8 a4 + a5},
325
0
         * {b10 + b6 + b7 + b8 + b9, -2 b10 + b7 - b8 + 2 b9, 4 b10 + b7 + b8 + 4 b9, -8 b10 + b11 + b7 - b8 + 8 b9},
326
0
         * {c12 + c13 + c14 + c15 + c16, c13 - c14 + 2 c15 - 2 c16, c13 + c14 + 4 (c15 + c16), c13 - c14 + 8 c15 - 8 c16 + c17},
327
0
         * {d18 + d19 + d20 + d21 + d22, d19 - d20 + 2 d21 - 2 d22, d19 + d20 + 4 (d21 + d22), d19 - d20 + 8 d21 - 8 d22 + d23}}
328
0
         */
329
0
        float* bpz = bp + x * binc[2] + k;
330
0
        
unroll_for0
(dy, z[0], 4) {0
331
0
          float r[] = {
332
0
            d[dy * 6 + 0] + d[dy * 6 + 1] + d[dy * 6 + 2] + d[dy * 6 + 3] + d[dy * 6 + 4] + biasval[k],
333
0
            d[dy * 6 + 1] - d[dy * 6 + 2] + 2 * (d[dy * 6 + 3] - d[dy * 6 + 4]) + biasval[k],
334
0
            d[dy * 6 + 1] + d[dy * 6 + 2] + 4 * (d[dy * 6 + 3] + d[dy * 6 + 4]) + biasval[k],
335
0
            d[dy * 6 + 1] - d[dy * 6 + 2] + 8 * (d[dy * 6 + 3] - d[dy * 6 + 4]) + d[dy * 6 + 5] + biasval[k],
336
0
          };
337
0
          
unroll_for0
(dx, z[1], 4) {0
338
0
            bpz[dx * binc[2]] = r[dx];
339
0
          } unroll_endfor
340
0
          bpz += binc[1] * binc[2];
341
0
        } unroll_endfor
342
0
      }
343
0
    }
344
0
  } parallel_endfor
345
0
  ccfree(workmem);
346
0
  return CCV_NNC_EXEC_SUCCESS;
347
0
}
348
349
#ifdef HAVE_SSE2
350
inline static void _ccv_nnc_winograd_4x4_3x3_gwtg_sse2(const float* const w, const int* const dim, float* const gwtg)
351
106
{
352
106
  const int jump_dim = dim[0] / 4;
353
106
  const int dimCx4 = (dim[3] + 3) & -4;
354
106
  
parallel_for106
(k, jump_dim) {0
355
0
    int i, j;
356
0
    float* gwtgz = gwtg + k * 4 * 36 * dimCx4;
357
0
    const float* wz[] = {
358
0
      w + (k * 4) * 9 * dim[3],
359
0
      w + (k * 4 + 1) * 9 * dim[3],
360
0
      w + (k * 4 + 2) * 9 * dim[3],
361
0
      w + (k * 4 + 3) * 9 * dim[3],
362
0
    };
363
2.79M
    for (i = 0; 
i < dim[3]2.79M
;
i++2.79M
)
364
2.79M
    {
365
2.79M
      float x9w[9 * 4] __attribute__ ((__aligned__(16)));
366
24.6M
      
unroll_for24.6M
(j, 9) {24.6M
367
24.6M
        x9w[j * 4] = wz[0][j * dim[3] + i];
368
24.6M
        x9w[j * 4 + 1] = wz[1][j * dim[3] + i];
369
24.6M
        x9w[j * 4 + 2] = wz[2][j * dim[3] + i];
370
24.6M
        x9w[j * 4 + 3] = wz[3][j * dim[3] + i];
371
24.6M
      } unroll_endfor
372
2.79M
      float g[18 * 4] __attribute__ ((__aligned__(16)));
373
2.79M
      __m128 x9w0 = _mm_load_ps(x9w);
374
2.79M
      __m128 x9w1 = _mm_load_ps(x9w + 4);
375
2.79M
      __m128 x9w2 = _mm_load_ps(x9w + 8);
376
2.79M
      __m128 x9w3 = _mm_load_ps(x9w + 12);
377
2.79M
      __m128 x9w4 = _mm_load_ps(x9w + 16);
378
2.79M
      __m128 x9w5 = _mm_load_ps(x9w + 20);
379
2.79M
      __m128 x9w6 = _mm_load_ps(x9w + 24);
380
2.79M
      __m128 x9w7 = _mm_load_ps(x9w + 28);
381
2.79M
      __m128 x9w8 = _mm_load_ps(x9w + 32);
382
2.79M
      /* row 1 */
383
2.79M
      __m128 c1_4 = _mm_set1_ps(1.0 / 4);
384
2.79M
      _mm_store_ps(g, _mm_mul_ps(x9w0, c1_4));
385
2.79M
      _mm_store_ps(g + 4, _mm_mul_ps(x9w1, c1_4));
386
2.79M
      _mm_store_ps(g + 8, _mm_mul_ps(x9w2, c1_4));
387
2.79M
      /* row 2 */
388
2.79M
      __m128 cn1_6 = _mm_set1_ps(-1.0 / 6);
389
2.79M
      _mm_store_ps(g + 12, _mm_mul_ps(_mm_add_ps(_mm_add_ps(x9w0, x9w6), x9w3), cn1_6));
390
2.79M
      _mm_store_ps(g + 16, _mm_mul_ps(_mm_add_ps(_mm_add_ps(x9w1, x9w7), x9w4), cn1_6));
391
2.79M
      _mm_store_ps(g + 20, _mm_mul_ps(_mm_add_ps(_mm_add_ps(x9w2, x9w8), x9w5), cn1_6));
392
2.79M
      /* row 3 */
393
2.79M
      _mm_store_ps(g + 24, _mm_mul_ps(_mm_sub_ps(_mm_add_ps(x9w0, x9w6), x9w3), cn1_6));
394
2.79M
      _mm_store_ps(g + 28, _mm_mul_ps(_mm_sub_ps(_mm_add_ps(x9w1, x9w7), x9w4), cn1_6));
395
2.79M
      _mm_store_ps(g + 32, _mm_mul_ps(_mm_sub_ps(_mm_add_ps(x9w2, x9w8), x9w5), cn1_6));
396
2.79M
      /* row 6 */
397
2.79M
      _mm_store_ps(g + 60, x9w6);
398
2.79M
      _mm_store_ps(g + 64, x9w7);
399
2.79M
      _mm_store_ps(g + 68, x9w8);
400
2.79M
      /* w[x] * 2 */
401
2.79M
      x9w3 = _mm_add_ps(x9w3, x9w3);
402
2.79M
      x9w4 = _mm_add_ps(x9w4, x9w4);
403
2.79M
      x9w5 = _mm_add_ps(x9w5, x9w5);
404
2.79M
      /* w[x] * 4 */
405
2.79M
      x9w6 = _mm_add_ps(x9w6, x9w6);
406
2.79M
      x9w6 = _mm_add_ps(x9w6, x9w6);
407
2.79M
      x9w7 = _mm_add_ps(x9w7, x9w7);
408
2.79M
      x9w7 = _mm_add_ps(x9w7, x9w7);
409
2.79M
      x9w8 = _mm_add_ps(x9w8, x9w8);
410
2.79M
      x9w8 = _mm_add_ps(x9w8, x9w8);
411
2.79M
      /* row 4 */
412
2.79M
      __m128 c1_24 = _mm_set1_ps(1.0 / 24);
413
2.79M
      _mm_store_ps(g + 36, _mm_mul_ps(_mm_add_ps(_mm_add_ps(x9w0, x9w6), x9w3), c1_24));
414
2.79M
      _mm_store_ps(g + 40, _mm_mul_ps(_mm_add_ps(_mm_add_ps(x9w1, x9w7), x9w4), c1_24));
415
2.79M
      _mm_store_ps(g + 44, _mm_mul_ps(_mm_add_ps(_mm_add_ps(x9w2, x9w8), x9w5), c1_24));
416
2.79M
      /* row 5 */
417
2.79M
      _mm_store_ps(g + 48, _mm_mul_ps(_mm_sub_ps(_mm_add_ps(x9w0, x9w6), x9w3), c1_24));
418
2.79M
      _mm_store_ps(g + 52, _mm_mul_ps(_mm_sub_ps(_mm_add_ps(x9w1, x9w7), x9w4), c1_24));
419
2.79M
      _mm_store_ps(g + 56, _mm_mul_ps(_mm_sub_ps(_mm_add_ps(x9w2, x9w8), x9w5), c1_24));
420
9.05M
      
unroll_for9.05M
(j, 6) {9.05M
421
9.05M
        const float* const gz = g + j * 12;
422
9.05M
        float* const gwtgzu = gwtgz + j * 24 * dimCx4;
423
9.05M
        __m128 g0 = _mm_load_ps(gz);
424
9.05M
        __m128 g1 = _mm_load_ps(gz + 4);
425
9.05M
        __m128 g2 = _mm_load_ps(gz + 8);
426
9.05M
        _mm_store_ps(gwtgzu, _mm_mul_ps(g0, c1_4));
427
9.05M
        _mm_store_ps(gwtgzu + 4 * dimCx4, _mm_mul_ps(_mm_add_ps(_mm_add_ps(g0, g2), g1), cn1_6));
428
9.05M
        _mm_store_ps(gwtgzu + 8 * dimCx4, _mm_mul_ps(_mm_sub_ps(_mm_add_ps(g0, g2), g1), cn1_6));
429
9.05M
        _mm_store_ps(gwtgzu + 20 * dimCx4, g2);
430
9.05M
        /* g[1] * 2 */
431
9.05M
        g1 = _mm_add_ps(g1, g1);
432
9.05M
        /* g[2] * 4 */
433
9.05M
        g2 = _mm_add_ps(g2, g2);
434
9.05M
        g2 = _mm_add_ps(g2, g2);
435
9.05M
        _mm_store_ps(gwtgzu + 12 * dimCx4, _mm_mul_ps(_mm_add_ps(_mm_add_ps(g0, g2), g1), c1_24));
436
9.05M
        _mm_store_ps(gwtgzu + 16 * dimCx4, _mm_mul_ps(_mm_sub_ps(_mm_add_ps(g0, g2), g1), c1_24));
437
9.05M
      } unroll_endfor
438
2.79M
      gwtgz += 4;
439
2.79M
    }
440
106
  } parallel_endfor
441
106
}
442
443
static int _ccv_nnc_conv_forw_4x4_3x3_winograd_sse2(const ccv_nnc_tensor_view_t* const a, const ccv_nnc_tensor_t* const w, const ccv_nnc_tensor_t* const bias, const ccv_nnc_hint_t hint, ccv_nnc_tensor_view_t* const b)
444
106
{
445
106
  const int a_nd = ccv_nnc_tensor_nd(a->info.dim);
446
106
  assert(a_nd == CCV_NNC_MAX_DIM + 1 || a_nd == CCV_NNC_MAX_DIM + 2);
447
106
  const int* adim = (a_nd == 
CCV_NNC_MAX_DIM106
+ 1) ?
a->info.dim106
:
a->info.dim + 10
;
448
106
  const int b_nd = ccv_nnc_tensor_nd(b->info.dim);
449
106
  assert(b_nd == CCV_NNC_MAX_DIM + 1 || b_nd == CCV_NNC_MAX_DIM + 2);
450
106
  const int* bdim = (b_nd == 
CCV_NNC_MAX_DIM106
+ 1) ?
b->info.dim106
:
b->info.dim + 10
;
451
106
  const int* ainc = 
CCV_IS_TENSOR_VIEW106
(a) ?
((a_nd == 0
CCV_NNC_MAX_DIM0
+ 1) ?
a->inc0
:
a->inc + 10
) :
adim106
;
452
106
  const int* binc = 
CCV_IS_TENSOR_VIEW106
(b) ?
((b_nd == 0
CCV_NNC_MAX_DIM0
+ 1) ?
b->inc0
:
b->inc + 10
) :
bdim106
;
453
106
  assert(hint.border.begin[0] <= 1);
454
106
  assert(hint.border.begin[1] <= 1);
455
106
  assert(w->info.dim[0] % 4 == 0);
456
106
  assert(w->info.dim[1] == 3);
457
106
  assert(w->info.dim[2] == 3);
458
106
  const int jump_dim = (bdim[0] + 3) / 4;
459
106
  const int dimCx4 = (adim[2] + 3) & -4;
460
106
  // allocating workspace memory for kernel reshaping and input reshaping.
461
106
  float* workmem = 0;
462
106
#if FOR_IS_PARALLEL
463
106
  // If we do parallel for, we need to allocate input reshaping for each block.
464
106
  ccmemalign((void **)&workmem, 16, sizeof(float) * (36 * dimCx4 * jump_dim + 36 * dimCx4 * w->info.dim[0]));
465
106
#else
466
  // Otherwise, just one block.
467
  ccmemalign((void **)&workmem, 16, sizeof(float) * (36 * dimCx4 + 36 * dimCx4 * w->info.dim[0]));
468
#endif
469
106
  if (!workmem)
470
0
    return CCV_NNC_EXEC_OOM;
471
106
  // Convert w to a 6x6 matrix, by computing G.w.T(G) // T for transpose.
472
106
  float* const gwtg = workmem;
473
106
  float* const btdb = workmem + 36 * dimCx4 * w->info.dim[0];
474
106
  memset(gwtg, 0, sizeof(float) * 36 * dimCx4 * w->info.dim[0]);
475
106
  _ccv_nnc_winograd_4x4_3x3_gwtg_sse2(w->data.f32, w->info.dim, gwtg);
476
106
  // kernel weight for one dim.
477
106
  const float* const biasval = bias->data.f32;
478
106
  // Workaround issues of dispatch_apply (cannot reference to on-stack array)
479
106
  const int tile_dim_s[CCV_NNC_MAX_DIM_ALLOC] = {
480
106
    w->info.dim[0], 6, 6, w->info.dim[3]
481
106
  };
482
106
  const int* const tile_dim = tile_dim_s;
483
106
  // This block will be cause in each for-loop, therefore, you can use it to generate some temporary variables.
484
106
  
parallel_for106
(i, jump_dim) {0
485
0
    const int y = i * 4; // i is unsigned.
486
0
    int j, x, k, c;
487
0
    int n[CCV_NNC_MAX_DIM];
488
0
    int m[CCV_NNC_MAX_DIM];
489
0
    int z[CCV_NNC_MAX_DIM];
490
106
    set_n_m_dim(y, 0, tile_dim, adim);
491
106
    z[0] = ccv_min(y + 4, bdim[0]) - y;
492
0
    const float* ap = a->data.f32 + ccv_max(y - hint.border.begin[0], 0) * ainc[1] * ainc[2];
493
0
    float* bp = b->data.f32 + y * binc[1] * binc[2];
494
70.2k
    for (x = 0; 
x < bdim[1]70.2k
;
x += 470.2k
)
495
71.9k
    {
496
71.9k
      set_n_m_dim(x, 1, tile_dim, adim);
497
71.9k
      z[1] = ccv_min(x + 4, bdim[1]) - x;
498
71.9k
#if FOR_IS_PARALLEL
499
71.9k
      float* g = btdb + i * 36 * dimCx4;
500
71.9k
#else
501
      float* g = btdb;
502
#endif
503
71.9k
      // zero g such that we can have zero-padding.
504
71.9k
      memset(g, 0, sizeof(float) * 36 * dimCx4);
505
71.9k
      int dx, dy;
506
71.9k
      const float* apz = ap + ccv_max(x - hint.border.begin[1], 0) * ainc[2];
507
71.9k
      float* gz = g + (n[0] * 6 + n[1]) * dimCx4;
508
428k
      
unroll_for428k
(dy, m[0], 6) {428k
509
2.53M
        
unroll_for2.53M
(dx, m[1], 6) {2.53M
510
2.53M
          float* const gzu = gz + (dy * 6 + dx) * dimCx4;
511
154M
          for (c = 0; 
c < adim[2]154M
;
c++151M
)
512
151M
            gzu[c] = apz[dx * ainc[2] + c];
513
2.53M
        } unroll_endfor
514
428k
        apz += ainc[1] * ainc[2];
515
428k
      } unroll_endfor
516
1.23M
      for (c = 0; 
c < adim[2]1.23M
;
c += 41.16M
)
517
1.16M
      {
518
1.16M
        float d[36 * 4]  __attribute__ ((__aligned__(16)));
519
1.16M
        /* BT.d */
520
6.91M
        
unroll_for6.91M
(j, 6) {6.91M
521
6.91M
          /* row 1 */
522
6.91M
          const float* const gz = g + j * dimCx4;
523
6.91M
          float* dz = d + j * 4;
524
6.91M
          __m128 g0 = _mm_load_ps(gz);
525
6.91M
          __m128 g12 = _mm_load_ps(gz + 12 * dimCx4);
526
6.91M
          __m128 g18 = _mm_load_ps(gz + 18 * dimCx4);
527
6.91M
          __m128 g24 = _mm_load_ps(gz + 24 * dimCx4);
528
6.91M
          g0 = _mm_add_ps(g0, g0);
529
6.91M
          g0 = _mm_add_ps(g0, g0);
530
6.91M
          __m128 g12x2 = _mm_add_ps(g12, g12);
531
6.91M
          g12x2 = _mm_add_ps(g12x2, g12x2);
532
6.91M
          g12x2 = _mm_add_ps(g12x2, g12);
533
6.91M
          _mm_store_ps(dz, _mm_sub_ps(_mm_add_ps(g0, g24), g12x2));
534
6.91M
          /* row 2 */
535
6.91M
          __m128 g6 = _mm_load_ps(gz + 6 * dimCx4);
536
6.91M
          __m128 g6x12 = _mm_add_ps(g6, g12);
537
6.91M
          g6x12 = _mm_add_ps(g6x12, g6x12);
538
6.91M
          g6x12 = _mm_add_ps(g6x12, g6x12);
539
6.91M
          _mm_store_ps(dz + 24, _mm_sub_ps(_mm_add_ps(g18, g24), g6x12));
540
6.91M
          /* row 3 */
541
6.91M
          g6x12 = _mm_sub_ps(g6, g12);
542
6.91M
          g6x12 = _mm_add_ps(g6x12, g6x12);
543
6.91M
          g6x12 = _mm_add_ps(g6x12, g6x12);
544
6.91M
          _mm_store_ps(dz + 48, _mm_add_ps(_mm_sub_ps(g24, g18), g6x12));
545
6.91M
          /* row 4 */
546
6.91M
          __m128 g18x6 = _mm_sub_ps(g18, g6);
547
6.91M
          g18x6 = _mm_add_ps(g18x6, g18x6);
548
6.91M
          _mm_store_ps(dz + 72, _mm_add_ps(_mm_sub_ps(g24, g12), g18x6));
549
6.91M
          /* row 5 */
550
6.91M
          _mm_store_ps(dz + 96, _mm_sub_ps(_mm_sub_ps(g24, g12), g18x6));
551
6.91M
          /* row 6 */
552
6.91M
          __m128 g30 = _mm_load_ps(gz + 30 * dimCx4);
553
6.91M
          __m128 g18x2 = _mm_add_ps(g18, g18);
554
6.91M
          g18x2 = _mm_add_ps(g18x2, g18x2);
555
6.91M
          g18x2 = _mm_add_ps(g18, g18x2);
556
6.91M
          g6 = _mm_add_ps(g6, g6);
557
6.91M
          g6 = _mm_add_ps(g6, g6);
558
6.91M
          _mm_store_ps(dz + 120, _mm_sub_ps(_mm_add_ps(g6, g30), g18x2));
559
6.91M
        } unroll_endfor
560
1.16M
        /* BT.d.B */
561
6.92M
        
unroll_for6.92M
(j, 6) {6.92M
562
6.92M
          float* gz = g + j * 6 * dimCx4;
563
6.92M
          const float* const dz = d + j * 24;
564
6.92M
          __m128 d0 = _mm_load_ps(dz);
565
6.92M
          __m128 d1 = _mm_load_ps(dz + 4);
566
6.92M
          __m128 d2 = _mm_load_ps(dz + 8);
567
6.92M
          __m128 d3 = _mm_load_ps(dz + 12);
568
6.92M
          __m128 d4 = _mm_load_ps(dz + 16);
569
6.92M
          __m128 d5 = _mm_load_ps(dz + 20);
570
6.92M
          d0 = _mm_add_ps(d0, d0);
571
6.92M
          d0 = _mm_add_ps(d0, d0);
572
6.92M
          __m128 d2x5 = _mm_add_ps(d2, d2);
573
6.92M
          d2x5 = _mm_add_ps(d2x5, d2x5);
574
6.92M
          d2x5 = _mm_add_ps(d2, d2x5);
575
6.92M
          _mm_store_ps(gz, _mm_sub_ps(_mm_add_ps(d0, d4), d2x5));
576
6.92M
          __m128 d1x2 = _mm_add_ps(d1, d2);
577
6.92M
          d1x2 = _mm_add_ps(d1x2, d1x2);
578
6.92M
          d1x2 = _mm_add_ps(d1x2, d1x2);
579
6.92M
          _mm_store_ps(gz + dimCx4, _mm_sub_ps(_mm_add_ps(d3, d4), d1x2));
580
6.92M
          d1x2 = _mm_sub_ps(d1, d2);
581
6.92M
          d1x2 = _mm_add_ps(d1x2, d1x2);
582
6.92M
          d1x2 = _mm_add_ps(d1x2, d1x2);
583
6.92M
          _mm_store_ps(gz + 2 * dimCx4, _mm_add_ps(_mm_sub_ps(d4, d3), d1x2));
584
6.92M
          __m128 d3x1 = _mm_sub_ps(d3, d1);
585
6.92M
          d3x1 = _mm_add_ps(d3x1, d3x1);
586
6.92M
          _mm_store_ps(gz + 3 * dimCx4, _mm_add_ps(_mm_sub_ps(d4, d2), d3x1));
587
6.92M
          _mm_store_ps(gz + 4 * dimCx4, _mm_sub_ps(_mm_sub_ps(d4, d2), d3x1));
588
6.92M
          d1 = _mm_add_ps(d1, d1);
589
6.92M
          d1 = _mm_add_ps(d1, d1);
590
6.92M
          __m128 d3x5 = _mm_add_ps(d3, d3);
591
6.92M
          d3x5 = _mm_add_ps(d3x5, d3x5);
592
6.92M
          d3x5 = _mm_add_ps(d3, d3x5);
593
6.92M
          _mm_store_ps(gz + 5 * dimCx4, _mm_sub_ps(_mm_add_ps(d1, d5), d3x5));
594
6.92M
        } unroll_endfor
595
1.16M
        // move to the next channel
596
1.16M
        g += 4;
597
1.16M
      }
598
71.9k
      const float* wpz = gwtg;
599
1.87M
      for (k = 0; 
k < w->info.dim[0]1.87M
;
k += 41.80M
)
600
1.80M
      {
601
1.80M
        float q[36 * 4] __attribute__ ((__aligned__(16)));
602
1.80M
#if FOR_IS_PARALLEL
603
1.80M
        g = btdb + i * 36 * dimCx4;
604
1.80M
#else
605
        g = btdb;
606
#endif
607
53.1M
        for (j = 0; 
j < 3653.1M
;
j++51.3M
)
608
51.3M
        {
609
51.3M
          __m128 v40 = _mm_setzero_ps();
610
51.3M
          __m128 v41 = _mm_setzero_ps();
611
51.3M
          __m128 v42 = _mm_setzero_ps();
612
51.3M
          __m128 v43 = _mm_setzero_ps();
613
539M
          for (c = 0; 
c < adim[2]539M
;
c += 4488M
)
614
488M
          {
615
488M
            __m128 g4 = _mm_load_ps(g);
616
488M
            __m128 w40 = _mm_load_ps(wpz);
617
488M
            __m128 w41 = _mm_load_ps(wpz + 4);
618
488M
            __m128 w42 = _mm_load_ps(wpz + 8);
619
488M
            __m128 w43 = _mm_load_ps(wpz + 12);
620
488M
            __m128 g40 = _mm_shuffle_ps(g4, g4, 0x00);
621
488M
            __m128 g41 = _mm_shuffle_ps(g4, g4, 0x55);
622
488M
            __m128 g42 = _mm_shuffle_ps(g4, g4, 0xAA);
623
488M
            __m128 g43 = _mm_shuffle_ps(g4, g4, 0xFF);
624
488M
            v40 = _mm_add_ps(_mm_mul_ps(w40, g40), v40);
625
488M
            v41 = _mm_add_ps(_mm_mul_ps(w41, g41), v41);
626
488M
            v42 = _mm_add_ps(_mm_mul_ps(w42, g42), v42);
627
488M
            v43 = _mm_add_ps(_mm_mul_ps(w43, g43), v43);
628
488M
            g += 4;
629
488M
            wpz += 16;
630
488M
          }
631
51.3M
          v40 = _mm_add_ps(v40, v41);
632
51.3M
          v42 = _mm_add_ps(v42, v43);
633
51.3M
          _mm_store_ps(q + j * 4, _mm_add_ps(v40, v42));
634
51.3M
        }
635
1.80M
        float d[24 * 4] __attribute__ ((__aligned__(16)));
636
10.3M
        
unroll_for10.3M
(j, 6) {10.3M
637
10.3M
          const float* const qz = q + j * 4;
638
10.3M
          float* const dz = d + j * 4;
639
10.3M
          __m128 q0 = _mm_load_ps(qz);
640
10.3M
          __m128 q6 = _mm_load_ps(qz + 24);
641
10.3M
          __m128 q12 = _mm_load_ps(qz + 48);
642
10.3M
          __m128 q18 = _mm_load_ps(qz + 72);
643
10.3M
          __m128 q24 = _mm_load_ps(qz + 96);
644
10.3M
          __m128 qs6x12 = _mm_add_ps(q6, q12);
645
10.3M
          __m128 qs18x24 = _mm_add_ps(q18, q24);
646
10.3M
          __m128 qss = _mm_add_ps(qs6x12, q0);
647
10.3M
          /* row 1 */
648
10.3M
          _mm_store_ps(dz, _mm_add_ps(qss, qs18x24));
649
10.3M
          __m128 qn6x12 = _mm_sub_ps(q6, q12);
650
10.3M
          __m128 qn18x24 = _mm_sub_ps(q18, q24);
651
10.3M
          qn18x24 = _mm_add_ps(qn18x24, qn18x24);
652
10.3M
          /* row 2 */
653
10.3M
          _mm_store_ps(dz + 24, _mm_add_ps(qn6x12, qn18x24));
654
10.3M
          qs18x24 = _mm_add_ps(qs18x24, qs18x24);
655
10.3M
          qs18x24 = _mm_add_ps(qs18x24, qs18x24);
656
10.3M
          /* row 3 */
657
10.3M
          _mm_store_ps(dz + 48, _mm_add_ps(qs6x12, qs18x24));
658
10.3M
          qn18x24 = _mm_add_ps(qn18x24, qn18x24);
659
10.3M
          qn18x24 = _mm_add_ps(qn18x24, qn18x24);
660
10.3M
          __m128 q30 = _mm_load_ps(qz + 120);
661
10.3M
          /* row 4 */
662
10.3M
          _mm_store_ps(dz + 72, _mm_add_ps(_mm_add_ps(qn6x12, q30), qn18x24));
663
10.3M
        } unroll_endfor
664
1.80M
        float* bpz = bp + x * binc[2] + k;
665
1.80M
        __m128 bias4 = _mm_loadu_ps(biasval + k);
666
1.80M
        switch (z[1]) {
667
11.7k
          case 1:
668
38.2k
            
unroll_for38.2k
(dy, z[0], 4) {38.2k
669
38.2k
              const float* const dz = d + dy * 24;
670
38.2k
              __m128 d0 = _mm_load_ps(dz);
671
38.2k
              __m128 d1 = _mm_load_ps(dz + 4);
672
38.2k
              __m128 d2 = _mm_load_ps(dz + 8);
673
38.2k
              __m128 d3 = _mm_load_ps(dz + 12);
674
38.2k
              __m128 d4 = _mm_load_ps(dz + 16);
675
38.2k
              __m128 ds1x2 = _mm_add_ps(d1, d2);
676
38.2k
              __m128 ds3x4 = _mm_add_ps(d3, d4);
677
38.2k
              ds1x2 = _mm_add_ps(ds1x2, bias4);
678
38.2k
              _mm_stream_ps(bpz, _mm_add_ps(ds1x2, _mm_add_ps(d0, ds3x4)));
679
38.2k
              bpz += binc[1] * binc[2];
680
38.2k
            } unroll_endfor
681
11.7k
            break;
682
0
          case 2:
683
0
            
unroll_for0
(dy, z[0], 4) {0
684
0
              const float* const dz = d + dy * 24;
685
0
              __m128 d0 = _mm_load_ps(dz);
686
0
              __m128 d1 = _mm_load_ps(dz + 4);
687
0
              __m128 d2 = _mm_load_ps(dz + 8);
688
0
              __m128 d3 = _mm_load_ps(dz + 12);
689
0
              __m128 d4 = _mm_load_ps(dz + 16);
690
0
              __m128 ds1x2 = _mm_add_ps(d1, d2);
691
0
              __m128 ds3x4 = _mm_add_ps(d3, d4);
692
0
              ds1x2 = _mm_add_ps(ds1x2, bias4);
693
0
              _mm_stream_ps(bpz, _mm_add_ps(ds1x2, _mm_add_ps(d0, ds3x4)));
694
0
              __m128 dn1x2 = _mm_sub_ps(d1, d2);
695
0
              __m128 dn3x4 = _mm_sub_ps(d3, d4);
696
0
              dn3x4 = _mm_add_ps(dn3x4, dn3x4);
697
0
              dn1x2 = _mm_add_ps(dn1x2, bias4);
698
0
              _mm_stream_ps(bpz + binc[2], _mm_add_ps(dn1x2, dn3x4));
699
0
              bpz += binc[1] * binc[2];
700
0
            } unroll_endfor
701
0
            break;
702
72.0k
          case 3:
703
281k
            
unroll_for281k
(dy, z[0], 4) {281k
704
281k
              const float* const dz = d + dy * 24;
705
281k
              __m128 d0 = _mm_load_ps(dz);
706
281k
              __m128 d1 = _mm_load_ps(dz + 4);
707
281k
              __m128 d2 = _mm_load_ps(dz + 8);
708
281k
              __m128 d3 = _mm_load_ps(dz + 12);
709
281k
              __m128 d4 = _mm_load_ps(dz + 16);
710
281k
              __m128 ds1x2 = _mm_add_ps(d1, d2);
711
281k
              __m128 ds3x4 = _mm_add_ps(d3, d4);
712
281k
              ds1x2 = _mm_add_ps(ds1x2, bias4);
713
281k
              _mm_stream_ps(bpz, _mm_add_ps(ds1x2, _mm_add_ps(d0, ds3x4)));
714
281k
              __m128 dn1x2 = _mm_sub_ps(d1, d2);
715
281k
              __m128 dn3x4 = _mm_sub_ps(d3, d4);
716
281k
              dn3x4 = _mm_add_ps(dn3x4, dn3x4);
717
281k
              dn1x2 = _mm_add_ps(dn1x2, bias4);
718
281k
              _mm_stream_ps(bpz + binc[2], _mm_add_ps(dn1x2, dn3x4));
719
281k
              ds3x4 = _mm_add_ps(ds3x4, ds3x4);
720
281k
              ds3x4 = _mm_add_ps(ds3x4, ds3x4);
721
281k
              _mm_stream_ps(bpz + 2 * binc[2], _mm_add_ps(ds1x2, ds3x4));
722
281k
              bpz += binc[1] * binc[2];
723
281k
            } unroll_endfor
724
72.0k
            break;
725
1.72M
          case 4:
726
6.59M
            
unroll_for6.59M
(dy, z[0], 4) {6.59M
727
6.59M
              const float* const dz = d + dy * 24;
728
6.59M
              __m128 d0 = _mm_load_ps(dz);
729
6.59M
              __m128 d1 = _mm_load_ps(dz + 4);
730
6.59M
              __m128 d2 = _mm_load_ps(dz + 8);
731
6.59M
              __m128 d3 = _mm_load_ps(dz + 12);
732
6.59M
              __m128 d4 = _mm_load_ps(dz + 16);
733
6.59M
              __m128 ds1x2 = _mm_add_ps(d1, d2);
734
6.59M
              __m128 ds3x4 = _mm_add_ps(d3, d4);
735
6.59M
              ds1x2 = _mm_add_ps(ds1x2, bias4);
736
6.59M
              _mm_stream_ps(bpz, _mm_add_ps(ds1x2, _mm_add_ps(d0, ds3x4)));
737
6.59M
              __m128 dn1x2 = _mm_sub_ps(d1, d2);
738
6.59M
              __m128 dn3x4 = _mm_sub_ps(d3, d4);
739
6.59M
              dn3x4 = _mm_add_ps(dn3x4, dn3x4);
740
6.59M
              dn1x2 = _mm_add_ps(dn1x2, bias4);
741
6.59M
              _mm_stream_ps(bpz + binc[2], _mm_add_ps(dn1x2, dn3x4));
742
6.59M
              ds3x4 = _mm_add_ps(ds3x4, ds3x4);
743
6.59M
              ds3x4 = _mm_add_ps(ds3x4, ds3x4);
744
6.59M
              _mm_stream_ps(bpz + 2 * binc[2], _mm_add_ps(ds1x2, ds3x4));
745
6.59M
              __m128 d5 = _mm_load_ps(dz + 20);
746
6.59M
              dn3x4 = _mm_add_ps(dn3x4, dn3x4);
747
6.59M
              dn3x4 = _mm_add_ps(dn3x4, dn3x4);
748
6.59M
              _mm_stream_ps(bpz + 3 * binc[2], _mm_add_ps(_mm_add_ps(dn1x2, d5), dn3x4));
749
6.59M
              bpz += binc[1] * binc[2];
750
6.59M
            } unroll_endfor
751
1.72M
            break;
752
1.80M
        };
753
1.80M
      }
754
71.9k
    }
755
106
  } parallel_endfor
756
106
  ccfree(workmem);
757
106
  return CCV_NNC_EXEC_SUCCESS;
758
106
}
759
#endif
760
761
#ifdef HAVE_NEON
762
inline static void _ccv_nnc_winograd_4x4_3x3_gwtg_neon(const float* const w, const int* const dim, float* const gwtg)
763
{
764
  const int jump_dim = dim[0] / 4;
765
  const int dimCx4 = (dim[3] + 3) & -4;
766
  parallel_for(k, jump_dim) {
767
    int i, j;
768
    float* gwtgz = gwtg + k * 4 * 36 * dimCx4;
769
    const float* wz[] = {
770
      w + (k * 4) * 9 * dim[3],
771
      w + (k * 4 + 1) * 9 * dim[3],
772
      w + (k * 4 + 2) * 9 * dim[3],
773
      w + (k * 4 + 3) * 9 * dim[3],
774
    };
775
    for (i = 0; i < dim[3]; i++)
776
    {
777
      float x9w[9 * 4] __attribute__ ((__aligned__(16)));
778
      unroll_for(j, 9) {
779
        x9w[j * 4] = wz[0][j * dim[3] + i];
780
        x9w[j * 4 + 1] = wz[1][j * dim[3] + i];
781
        x9w[j * 4 + 2] = wz[2][j * dim[3] + i];
782
        x9w[j * 4 + 3] = wz[3][j * dim[3] + i];
783
      } unroll_endfor
784
      float g[18 * 4] __attribute__ ((__aligned__(16)));
785
      float32x4_t x9w0 = vld1q_f32(x9w);
786
      float32x4_t x9w1 = vld1q_f32(x9w + 4);
787
      float32x4_t x9w2 = vld1q_f32(x9w + 8);
788
      float32x4_t x9w3 = vld1q_f32(x9w + 12);
789
      float32x4_t x9w4 = vld1q_f32(x9w + 16);
790
      float32x4_t x9w5 = vld1q_f32(x9w + 20);
791
      float32x4_t x9w6 = vld1q_f32(x9w + 24);
792
      float32x4_t x9w7 = vld1q_f32(x9w + 28);
793
      float32x4_t x9w8 = vld1q_f32(x9w + 32);
794
      /* row 1 */
795
      float32x4_t c1_4 = vdupq_n_f32(1.0 / 4);
796
      vst1q_f32(g, vmulq_f32(x9w0, c1_4));
797
      vst1q_f32(g + 4, vmulq_f32(x9w1, c1_4));
798
      vst1q_f32(g + 8, vmulq_f32(x9w2, c1_4));
799
      /* row 2 */
800
      float32x4_t cn1_6 = vdupq_n_f32(-1.0 / 6);
801
      vst1q_f32(g + 12, vmulq_f32(vaddq_f32(vaddq_f32(x9w0, x9w6), x9w3), cn1_6));
802
      vst1q_f32(g + 16, vmulq_f32(vaddq_f32(vaddq_f32(x9w1, x9w7), x9w4), cn1_6));
803
      vst1q_f32(g + 20, vmulq_f32(vaddq_f32(vaddq_f32(x9w2, x9w8), x9w5), cn1_6));
804
      /* row 3 */
805
      vst1q_f32(g + 24, vmulq_f32(vsubq_f32(vaddq_f32(x9w0, x9w6), x9w3), cn1_6));
806
      vst1q_f32(g + 28, vmulq_f32(vsubq_f32(vaddq_f32(x9w1, x9w7), x9w4), cn1_6));
807
      vst1q_f32(g + 32, vmulq_f32(vsubq_f32(vaddq_f32(x9w2, x9w8), x9w5), cn1_6));
808
      /* row 6 */
809
      vst1q_f32(g + 60, x9w6);
810
      vst1q_f32(g + 64, x9w7);
811
      vst1q_f32(g + 68, x9w8);
812
      /* w[x] * 2 */
813
      x9w3 = vaddq_f32(x9w3, x9w3);
814
      x9w4 = vaddq_f32(x9w4, x9w4);
815
      x9w5 = vaddq_f32(x9w5, x9w5);
816
      /* w[x] * 4 */
817
      x9w6 = vaddq_f32(x9w6, x9w6);
818
      x9w6 = vaddq_f32(x9w6, x9w6);
819
      x9w7 = vaddq_f32(x9w7, x9w7);
820
      x9w7 = vaddq_f32(x9w7, x9w7);
821
      x9w8 = vaddq_f32(x9w8, x9w8);
822
      x9w8 = vaddq_f32(x9w8, x9w8);
823
      /* row 4 */
824
      float32x4_t c1_24 = vdupq_n_f32(1.0 / 24);
825
      vst1q_f32(g + 36, vmulq_f32(vaddq_f32(vaddq_f32(x9w0, x9w6), x9w3), c1_24));
826
      vst1q_f32(g + 40, vmulq_f32(vaddq_f32(vaddq_f32(x9w1, x9w7), x9w4), c1_24));
827
      vst1q_f32(g + 44, vmulq_f32(vaddq_f32(vaddq_f32(x9w2, x9w8), x9w5), c1_24));
828
      /* row 5 */
829
      vst1q_f32(g + 48, vmulq_f32(vsubq_f32(vaddq_f32(x9w0, x9w6), x9w3), c1_24));
830
      vst1q_f32(g + 52, vmulq_f32(vsubq_f32(vaddq_f32(x9w1, x9w7), x9w4), c1_24));
831
      vst1q_f32(g + 56, vmulq_f32(vsubq_f32(vaddq_f32(x9w2, x9w8), x9w5), c1_24));
832
      unroll_for(j, 6) {
833
        const float* const gz = g + j * 12;
834
        float* const gwtgzu = gwtgz + j * 24 * dimCx4;
835
        float32x4_t g0 = vld1q_f32(gz);
836
        float32x4_t g1 = vld1q_f32(gz + 4);
837
        float32x4_t g2 = vld1q_f32(gz + 8);
838
        vst1q_f32(gwtgzu, vmulq_f32(g0, c1_4));
839
        vst1q_f32(gwtgzu + 4 * dimCx4, vmulq_f32(vaddq_f32(vaddq_f32(g0, g2), g1), cn1_6));
840
        vst1q_f32(gwtgzu + 8 * dimCx4, vmulq_f32(vsubq_f32(vaddq_f32(g0, g2), g1), cn1_6));
841
        vst1q_f32(gwtgzu + 20 * dimCx4, g2);
842
        /* g[1] * 2 */
843
        g1 = vaddq_f32(g1, g1);
844
        /* g[2] * 4 */
845
        g2 = vaddq_f32(g2, g2);
846
        g2 = vaddq_f32(g2, g2);
847
        vst1q_f32(gwtgzu + 12 * dimCx4, vmulq_f32(vaddq_f32(vaddq_f32(g0, g2), g1), c1_24));
848
        vst1q_f32(gwtgzu + 16 * dimCx4, vmulq_f32(vsubq_f32(vaddq_f32(g0, g2), g1), c1_24));
849
      } unroll_endfor
850
      gwtgz += 4;
851
    }
852
  } parallel_endfor
853
}
854
855
static int _ccv_nnc_conv_forw_4x4_3x3_winograd_neon(const ccv_nnc_tensor_view_t* const a, const ccv_nnc_tensor_t* const w, const ccv_nnc_tensor_t* const bias, const ccv_nnc_hint_t hint, ccv_nnc_tensor_view_t* const b)
856
{
857
  const int a_nd = ccv_nnc_tensor_nd(a->info.dim);
858
  assert(a_nd == CCV_NNC_MAX_DIM + 1 || a_nd == CCV_NNC_MAX_DIM + 2);
859
  const int* adim = (a_nd == CCV_NNC_MAX_DIM + 1) ? a->info.dim : a->info.dim + 1;
860
  const int b_nd = ccv_nnc_tensor_nd(b->info.dim);
861
  assert(b_nd == CCV_NNC_MAX_DIM + 1 || b_nd == CCV_NNC_MAX_DIM + 2);
862
  const int* bdim = (b_nd == CCV_NNC_MAX_DIM + 1) ? b->info.dim : b->info.dim + 1;
863
  const int* ainc = CCV_IS_TENSOR_VIEW(a) ? ((a_nd == CCV_NNC_MAX_DIM + 1) ? a->inc : a->inc + 1) : adim;
864
  const int* binc = CCV_IS_TENSOR_VIEW(b) ? ((b_nd == CCV_NNC_MAX_DIM + 1) ? b->inc : b->inc + 1) : bdim;
865
  assert(hint.border.begin[0] <= 1);
866
  assert(hint.border.begin[1] <= 1);
867
  assert(w->info.dim[0] % 4 == 0);
868
  assert(w->info.dim[1] == 3);
869
  assert(w->info.dim[2] == 3);
870
  const int jump_dim = (bdim[0] + 3) / 4;
871
  const int dimCx4 = (adim[2] + 3) & -4;
872
  // allocating workspace memory for kernel reshaping and input reshaping.
873
  float* workmem = 0;
874
#if FOR_IS_PARALLEL
875
  // If we do parallel for, we need to allocate input reshaping for each block.
876
  ccmemalign((void **)&workmem, 16, sizeof(float) * (36 * dimCx4 * jump_dim + 36 * dimCx4 * w->info.dim[0]));
877
#else
878
  // Otherwise, just one block.
879
  ccmemalign((void **)&workmem, 16, sizeof(float) * (36 * dimCx4 + 36 * dimCx4 * w->info.dim[0]));
880
#endif
881
  if (!workmem)
882
    return CCV_NNC_EXEC_OOM;
883
  // Convert w to a 6x6 matrix, by computing G.w.T(G) // T for transpose.
884
  float* const gwtg = workmem;
885
  float* const btdb = workmem + 36 * dimCx4 * w->info.dim[0];
886
  memset(gwtg, 0, sizeof(float) * 36 * dimCx4 * w->info.dim[0]);
887
  _ccv_nnc_winograd_4x4_3x3_gwtg_neon(w->data.f32, w->info.dim, gwtg);
888
  // kernel weight for one dim.
889
  const float* const biasval = bias->data.f32;
890
  // Workaround issues of dispatch_apply (cannot reference to on-stack array)
891
  const int tile_dim_s[CCV_NNC_MAX_DIM_ALLOC] = {
892
    w->info.dim[0], 6, 6, w->info.dim[3]
893
  };
894
  const int* const tile_dim = tile_dim_s;
895
  // This block will be cause in each for-loop, therefore, you can use it to generate some temporary variables.
896
  parallel_for(i, jump_dim) {
897
    const int y = i * 4; // i is unsigned.
898
    int j, x, k, c;
899
    int n[CCV_NNC_MAX_DIM];
900
    int m[CCV_NNC_MAX_DIM];
901
    int z[CCV_NNC_MAX_DIM];
902
    set_n_m_dim(y, 0, tile_dim, adim);
903
    z[0] = ccv_min(y + 4, bdim[0]) - y;
904
    const float* ap = a->data.f32 + ccv_max(y - hint.border.begin[0], 0) * ainc[1] * ainc[2];
905
    float* bp = b->data.f32 + y * binc[1] * binc[2];
906
    for (x = 0; x < bdim[1]; x += 4)
907
    {
908
      set_n_m_dim(x, 1, tile_dim, adim);
909
      z[1] = ccv_min(x + 4, bdim[1]) - x;
910
#if FOR_IS_PARALLEL
911
      float* g = btdb + i * 36 * dimCx4;
912
#else
913
      float* g = btdb;
914
#endif
915
      // zero g such that we can have zero-padding.
916
      memset(g, 0, sizeof(float) * 36 * dimCx4);
917
      int dx, dy;
918
      const float* apz = ap + ccv_max(x - hint.border.begin[1], 0) * ainc[2];
919
      float* gz = g + (n[0] * 6 + n[1]) * dimCx4;
920
      unroll_for(dy, m[0], 6) {
921
        unroll_for(dx, m[1], 6) {
922
          float* const gzu = gz + (dy * 6 + dx) * dimCx4;
923
          for (c = 0; c < adim[2]; c++)
924
            gzu[c] = apz[dx * ainc[2] + c];
925
        } unroll_endfor
926
        apz += ainc[1] * ainc[2];
927
      } unroll_endfor
928
      for (c = 0; c < adim[2]; c += 4)
929
      {
930
        float d[36 * 4]  __attribute__ ((__aligned__(16)));
931
        /* BT.d */
932
        unroll_for(j, 6) {
933
          /* row 1 */
934
          const float* const gz = g + j * dimCx4;
935
          float* dz = d + j * 4;
936
          float32x4_t g0 = vld1q_f32(gz);
937
          float32x4_t g12 = vld1q_f32(gz + 12 * dimCx4);
938
          float32x4_t g18 = vld1q_f32(gz + 18 * dimCx4);
939
          float32x4_t g24 = vld1q_f32(gz + 24 * dimCx4);
940
          g0 = vaddq_f32(g0, g0);
941
          g0 = vaddq_f32(g0, g0);
942
          float32x4_t g12x2 = vaddq_f32(g12, g12);
943
          g12x2 = vaddq_f32(g12x2, g12x2);
944
          g12x2 = vaddq_f32(g12x2, g12);
945
          vst1q_f32(dz, vsubq_f32(vaddq_f32(g0, g24), g12x2));
946
          /* row 2 */
947
          float32x4_t g6 = vld1q_f32(gz + 6 * dimCx4);
948
          float32x4_t g6x12 = vaddq_f32(g6, g12);
949
          g6x12 = vaddq_f32(g6x12, g6x12);
950
          g6x12 = vaddq_f32(g6x12, g6x12);
951
          vst1q_f32(dz + 24, vsubq_f32(vaddq_f32(g18, g24), g6x12));
952
          /* row 3 */
953
          g6x12 = vsubq_f32(g6, g12);
954
          g6x12 = vaddq_f32(g6x12, g6x12);
955
          g6x12 = vaddq_f32(g6x12, g6x12);
956
          vst1q_f32(dz + 48, vaddq_f32(vsubq_f32(g24, g18), g6x12));
957
          /* row 4 */
958
          float32x4_t g18x6 = vsubq_f32(g18, g6);
959
          g18x6 = vaddq_f32(g18x6, g18x6);
960
          vst1q_f32(dz + 72, vaddq_f32(vsubq_f32(g24, g12), g18x6));
961
          /* row 5 */
962
          vst1q_f32(dz + 96, vsubq_f32(vsubq_f32(g24, g12), g18x6));
963
          /* row 6 */
964
          float32x4_t g30 = vld1q_f32(gz + 30 * dimCx4);
965
          float32x4_t g18x2 = vaddq_f32(g18, g18);
966
          g18x2 = vaddq_f32(g18x2, g18x2);
967
          g18x2 = vaddq_f32(g18, g18x2);
968
          g6 = vaddq_f32(g6, g6);
969
          g6 = vaddq_f32(g6, g6);
970
          vst1q_f32(dz + 120, vsubq_f32(vaddq_f32(g6, g30), g18x2));
971
        } unroll_endfor
972
        /* BT.d.B */
973
        unroll_for(j, 6) {
974
          float* gz = g + j * 6 * dimCx4;
975
          const float* const dz = d + j * 24;
976
          float32x4_t d0 = vld1q_f32(dz);
977
          float32x4_t d1 = vld1q_f32(dz + 4);
978
          float32x4_t d2 = vld1q_f32(dz + 8);
979
          float32x4_t d3 = vld1q_f32(dz + 12);
980
          float32x4_t d4 = vld1q_f32(dz + 16);
981
          float32x4_t d5 = vld1q_f32(dz + 20);
982
          d0 = vaddq_f32(d0, d0);
983
          d0 = vaddq_f32(d0, d0);
984
          float32x4_t d2x5 = vaddq_f32(d2, d2);
985
          d2x5 = vaddq_f32(d2x5, d2x5);
986
          d2x5 = vaddq_f32(d2, d2x5);
987
          vst1q_f32(gz, vsubq_f32(vaddq_f32(d0, d4), d2x5));
988
          float32x4_t d1x2 = vaddq_f32(d1, d2);
989
          d1x2 = vaddq_f32(d1x2, d1x2);
990
          d1x2 = vaddq_f32(d1x2, d1x2);
991
          vst1q_f32(gz + dimCx4, vsubq_f32(vaddq_f32(d3, d4), d1x2));
992
          d1x2 = vsubq_f32(d1, d2);
993
          d1x2 = vaddq_f32(d1x2, d1x2);
994
          d1x2 = vaddq_f32(d1x2, d1x2);
995
          vst1q_f32(gz + 2 * dimCx4, vaddq_f32(vsubq_f32(d4, d3), d1x2));
996
          float32x4_t d3x1 = vsubq_f32(d3, d1);
997
          d3x1 = vaddq_f32(d3x1, d3x1);
998
          vst1q_f32(gz + 3 * dimCx4, vaddq_f32(vsubq_f32(d4, d2), d3x1));
999
          vst1q_f32(gz + 4 * dimCx4, vsubq_f32(vsubq_f32(d4, d2), d3x1));
1000
          d1 = vaddq_f32(d1, d1);
1001
          d1 = vaddq_f32(d1, d1);
1002
          float32x4_t d3x5 = vaddq_f32(d3, d3);
1003
          d3x5 = vaddq_f32(d3x5, d3x5);
1004
          d3x5 = vaddq_f32(d3, d3x5);
1005
          vst1q_f32(gz + 5 * dimCx4, vsubq_f32(vaddq_f32(d1, d5), d3x5));
1006
        } unroll_endfor
1007
        // move to the next channel
1008
        g += 4;
1009
      }
1010
      const float* wpz = gwtg;
1011
      for (k = 0; k < w->info.dim[0]; k += 4)
1012
      {
1013
        float q[36 * 4] __attribute__ ((__aligned__(16)));
1014
#if FOR_IS_PARALLEL
1015
        g = btdb + i * 36 * dimCx4;
1016
#else
1017
        g = btdb;
1018
#endif
1019
        for (j = 0; j < 36; j++)
1020
        {
1021
          float32x4_t v40 = vmovq_n_f32(0);
1022
          float32x4_t v41 = vmovq_n_f32(0);
1023
          float32x4_t v42 = vmovq_n_f32(0);
1024
          float32x4_t v43 = vmovq_n_f32(0);
1025
          for (c = 0; c < adim[2]; c += 4)
1026
          {
1027
            float32x2x2_t g4 = vld2_f32(g);
1028
            float32x4_t w40 = vld1q_f32(wpz);
1029
            float32x4_t w41 = vld1q_f32(wpz + 4);
1030
            float32x4_t w42 = vld1q_f32(wpz + 8);
1031
            float32x4_t w43 = vld1q_f32(wpz + 12);
1032
            float32x4_t g40 = vdupq_lane_f32(g4.val[0], 0);
1033
            float32x4_t g41 = vdupq_lane_f32(g4.val[1], 0);
1034
            float32x4_t g42 = vdupq_lane_f32(g4.val[0], 1);
1035
            float32x4_t g43 = vdupq_lane_f32(g4.val[1], 1);
1036
            v40 = vmlaq_f32(v40, w40, g40);
1037
            v41 = vmlaq_f32(v41, w41, g41);
1038
            v42 = vmlaq_f32(v42, w42, g42);
1039
            v43 = vmlaq_f32(v43, w43, g43);
1040
            g += 4;
1041
            wpz += 16;
1042
          }
1043
          v40 = vaddq_f32(v40, v41);
1044
          v42 = vaddq_f32(v42, v43);
1045
          vst1q_f32(q + j * 4, vaddq_f32(v40, v42));
1046
        }
1047
        float d[24 * 4] __attribute__ ((__aligned__(16)));
1048
        unroll_for(j, 6) {
1049
          const float* const qz = q + j * 4;
1050
          float* const dz = d + j * 4;
1051
          float32x4_t q0 = vld1q_f32(qz);
1052
          float32x4_t q6 = vld1q_f32(qz + 24);
1053
          float32x4_t q12 = vld1q_f32(qz + 48);
1054
          float32x4_t q18 = vld1q_f32(qz + 72);
1055
          float32x4_t q24 = vld1q_f32(qz + 96);
1056
          float32x4_t qs6x12 = vaddq_f32(q6, q12);
1057
          float32x4_t qs18x24 = vaddq_f32(q18, q24);
1058
          float32x4_t qss = vaddq_f32(qs6x12, q0);
1059
          /* row 1 */
1060
          vst1q_f32(dz, vaddq_f32(qss, qs18x24));
1061
          float32x4_t qn6x12 = vsubq_f32(q6, q12);
1062
          float32x4_t qn18x24 = vsubq_f32(q18, q24);
1063
          qn18x24 = vaddq_f32(qn18x24, qn18x24);
1064
          /* row 2 */
1065
          vst1q_f32(dz + 24, vaddq_f32(qn6x12, qn18x24));
1066
          qs18x24 = vaddq_f32(qs18x24, qs18x24);
1067
          qs18x24 = vaddq_f32(qs18x24, qs18x24);
1068
          /* row 3 */
1069
          vst1q_f32(dz + 48, vaddq_f32(qs6x12, qs18x24));
1070
          qn18x24 = vaddq_f32(qn18x24, qn18x24);
1071
          qn18x24 = vaddq_f32(qn18x24, qn18x24);
1072
          float32x4_t q30 = vld1q_f32(qz + 120);
1073
          /* row 4 */
1074
          vst1q_f32(dz + 72, vaddq_f32(vaddq_f32(qn6x12, q30), qn18x24));
1075
        } unroll_endfor
1076
        float* bpz = bp + x * binc[2] + k;
1077
        float32x4_t bias4 = vld1q_f32(biasval + k);
1078
        switch (z[1]) {
1079
          case 1:
1080
            unroll_for(dy, z[0], 4) {
1081
              const float* const dz = d + dy * 24;
1082
              float32x4_t d0 = vld1q_f32(dz);
1083
              float32x4_t d1 = vld1q_f32(dz + 4);
1084
              float32x4_t d2 = vld1q_f32(dz + 8);
1085
              float32x4_t d3 = vld1q_f32(dz + 12);
1086
              float32x4_t d4 = vld1q_f32(dz + 16);
1087
              float32x4_t ds1x2 = vaddq_f32(d1, d2);
1088
              float32x4_t ds3x4 = vaddq_f32(d3, d4);
1089
              ds1x2 = vaddq_f32(ds1x2, bias4);
1090
              vst1q_f32(bpz, vaddq_f32(ds1x2, vaddq_f32(d0, ds3x4)));
1091
              bpz += binc[1] * binc[2];
1092
            } unroll_endfor
1093
            break;
1094
          case 2:
1095
            unroll_for(dy, z[0], 4) {
1096
              const float* const dz = d + dy * 24;
1097
              float32x4_t d0 = vld1q_f32(dz);
1098
              float32x4_t d1 = vld1q_f32(dz + 4);
1099
              float32x4_t d2 = vld1q_f32(dz + 8);
1100
              float32x4_t d3 = vld1q_f32(dz + 12);
1101
              float32x4_t d4 = vld1q_f32(dz + 16);
1102
              float32x4_t ds1x2 = vaddq_f32(d1, d2);
1103
              float32x4_t ds3x4 = vaddq_f32(d3, d4);
1104
              ds1x2 = vaddq_f32(ds1x2, bias4);
1105
              vst1q_f32(bpz, vaddq_f32(ds1x2, vaddq_f32(d0, ds3x4)));
1106
              float32x4_t dn1x2 = vsubq_f32(d1, d2);
1107
              float32x4_t dn3x4 = vsubq_f32(d3, d4);
1108
              dn3x4 = vaddq_f32(dn3x4, dn3x4);
1109
              dn1x2 = vaddq_f32(dn1x2, bias4);
1110
              vst1q_f32(bpz + binc[2], vaddq_f32(dn1x2, dn3x4));
1111
              bpz += binc[1] * binc[2];
1112
            } unroll_endfor
1113
            break;
1114
          case 3:
1115
            unroll_for(dy, z[0], 4) {
1116
              const float* const dz = d + dy * 24;
1117
              float32x4_t d0 = vld1q_f32(dz);
1118
              float32x4_t d1 = vld1q_f32(dz + 4);
1119
              float32x4_t d2 = vld1q_f32(dz + 8);
1120
              float32x4_t d3 = vld1q_f32(dz + 12);
1121
              float32x4_t d4 = vld1q_f32(dz + 16);
1122
              float32x4_t ds1x2 = vaddq_f32(d1, d2);
1123
              float32x4_t ds3x4 = vaddq_f32(d3, d4);
1124
              ds1x2 = vaddq_f32(ds1x2, bias4);
1125
              vst1q_f32(bpz, vaddq_f32(ds1x2, vaddq_f32(d0, ds3x4)));
1126
              float32x4_t dn1x2 = vsubq_f32(d1, d2);
1127
              float32x4_t dn3x4 = vsubq_f32(d3, d4);
1128
              dn3x4 = vaddq_f32(dn3x4, dn3x4);
1129
              dn1x2 = vaddq_f32(dn1x2, bias4);
1130
              vst1q_f32(bpz + binc[2], vaddq_f32(dn1x2, dn3x4));
1131
              ds3x4 = vaddq_f32(ds3x4, ds3x4);
1132
              ds3x4 = vaddq_f32(ds3x4, ds3x4);
1133
              vst1q_f32(bpz + 2 * binc[2], vaddq_f32(ds1x2, ds3x4));
1134
              bpz += binc[1] * binc[2];
1135
            } unroll_endfor
1136
            break;
1137
          case 4:
1138
            unroll_for(dy, z[0], 4) {
1139
              const float* const dz = d + dy * 24;
1140
              float32x4_t d0 = vld1q_f32(dz);
1141
              float32x4_t d1 = vld1q_f32(dz + 4);
1142
              float32x4_t d2 = vld1q_f32(dz + 8);
1143
              float32x4_t d3 = vld1q_f32(dz + 12);
1144
              float32x4_t d4 = vld1q_f32(dz + 16);
1145
              float32x4_t ds1x2 = vaddq_f32(d1, d2);
1146
              float32x4_t ds3x4 = vaddq_f32(d3, d4);
1147
              ds1x2 = vaddq_f32(ds1x2, bias4);
1148
              vst1q_f32(bpz, vaddq_f32(ds1x2, vaddq_f32(d0, ds3x4)));
1149
              float32x4_t dn1x2 = vsubq_f32(d1, d2);
1150
              float32x4_t dn3x4 = vsubq_f32(d3, d4);
1151
              dn3x4 = vaddq_f32(dn3x4, dn3x4);
1152
              dn1x2 = vaddq_f32(dn1x2, bias4);
1153
              vst1q_f32(bpz + binc[2], vaddq_f32(dn1x2, dn3x4));
1154
              ds3x4 = vaddq_f32(ds3x4, ds3x4);
1155
              ds3x4 = vaddq_f32(ds3x4, ds3x4);
1156
              vst1q_f32(bpz + 2 * binc[2], vaddq_f32(ds1x2, ds3x4));
1157
              float32x4_t d5 = vld1q_f32(dz + 20);
1158
              dn3x4 = vaddq_f32(dn3x4, dn3x4);
1159
              dn3x4 = vaddq_f32(dn3x4, dn3x4);
1160
              vst1q_f32(bpz + 3 * binc[2], vaddq_f32(vaddq_f32(dn1x2, d5), dn3x4));
1161
              bpz += binc[1] * binc[2];
1162
            } unroll_endfor
1163
            break;
1164
        };
1165
      }
1166
    }
1167
  } parallel_endfor
1168
  ccfree(workmem);
1169
  return CCV_NNC_EXEC_SUCCESS;
1170
}
1171
#endif
1172
1173
int _ccv_nnc_conv_forw_4x4_3x3_winograd_cpu_opt(const ccv_nnc_tensor_view_t* const a, const ccv_nnc_tensor_t* const w, const ccv_nnc_tensor_t* const bias, const ccv_nnc_hint_t hint, ccv_nnc_tensor_view_t* const b)
1174
106
{
1175
106
#if defined(HAVE_SSE2)
1176
106
  if (w->info.dim[0] % 4 == 0)
1177
106
    return _ccv_nnc_conv_forw_4x4_3x3_winograd_sse2(a, w, bias, hint, b);
1178
106
#elif defined(HAVE_NEON)
1179
  if (w->info.dim[0] % 4 == 0)
1180
    return _ccv_nnc_conv_forw_4x4_3x3_winograd_neon(a, w, bias, hint, b);
1181
#endif
1182
0
  return _ccv_nnc_conv_forw_4x4_3x3_winograd_ref(a, w, bias, hint, b);
1183
106
}