Coverage Report

Created: 2025-02-24 17:43

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/home/liu/actions-runner/_work/ccv/ccv/lib/nnc/cmd/norm/ccv_nnc_group_norm_cpu_ref.c
Line
Count
Source
1
#include "ccv.h"
2
#include "ccv_internal.h"
3
#include "nnc/ccv_nnc.h"
4
#include "nnc/ccv_nnc_easy.h"
5
#include "nnc/ccv_nnc_internal.h"
6
#ifdef USE_OPENMP
7
#include <omp.h>
8
#endif
9
#ifdef USE_DISPATCH
10
#include <dispatch/dispatch.h>
11
#endif
12
13
// Shared methods.
14
#include "../_ccv_nnc_cpu_ref.h"
15
16
static int _ccv_nnc_group_norm_forw(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint, const int flags, ccv_nnc_tensor_t* const* const inputs, const int input_size, ccv_nnc_tensor_t* const* const outputs, const int output_size, ccv_nnc_stream_context_t* const stream_context)
17
16
{
18
16
  assert(input_size == 3 || input_size == 1);
19
16
  ccv_nnc_tensor_view_t* const a = (ccv_nnc_tensor_view_t*)inputs[0];
20
16
  ccv_nnc_tensor_view_t* const scale = input_size >= 2 ? 
(ccv_nnc_tensor_view_t*)inputs[1]9
:
07
;
21
16
  ccv_nnc_tensor_view_t* const bias = input_size >= 3 ? 
(ccv_nnc_tensor_view_t*)inputs[2]9
:
07
;
22
16
  ccv_nnc_tensor_view_t* const b = (ccv_nnc_tensor_view_t*)outputs[0];
23
16
  ccv_nnc_tensor_view_t* const saved_mean = (ccv_nnc_tensor_view_t*)outputs[1];
24
16
  ccv_nnc_tensor_view_t* const saved_inv_std = (ccv_nnc_tensor_view_t*)outputs[2];
25
16
  assert(ccv_nnc_tensor_nd(a->info.dim) <= CCV_NNC_MAX_DIM + 2);
26
16
  assert(ccv_nnc_tensor_nd(b->info.dim) <= CCV_NNC_MAX_DIM + 2);
27
  // Assuming this is float 32.
28
16
  int adim[CCV_NNC_MAX_DIM_ALLOC];
29
16
  int rdim[CCV_NNC_MAX_DIM_ALLOC];
30
16
  ccv_nnc_tensor_view_get_dim(a, adim);
31
16
  ccv_nnc_tensor_view_get_dim(saved_mean, rdim);
32
16
  assert(ccv_nnc_tensor_view_check_dim(saved_inv_std, rdim));
33
16
  assert(ccv_nnc_tensor_view_check_dim(b, adim));
34
16
  assert(CCV_NNC_MAX_DIM == 2); // Need to change this logic for CCV_NNC_MAX_DIM == other number.
35
16
  int astride[CCV_NNC_MAX_DIM_ALLOC];
36
16
  int bstride[CCV_NNC_MAX_DIM_ALLOC];
37
16
  int scale_stride[CCV_NNC_MAX_DIM_ALLOC];
38
16
  int bias_stride[CCV_NNC_MAX_DIM_ALLOC];
39
16
  ccv_nnc_tensor_view_get_stride(a, astride);
40
16
  if (scale)
41
9
    ccv_nnc_tensor_view_get_stride(scale, scale_stride);
42
16
  if (bias)
43
9
    ccv_nnc_tensor_view_get_stride(bias, bias_stride);
44
16
  ccv_nnc_tensor_view_get_stride(b, bstride);
45
  // The epsilon is used a little bit differently from batch norm, it is outside of the sqrt in this case.
46
16
  const float epsilon = cmd.info.lnorm.epsilon;
47
16
  int saved_mean_stride[CCV_NNC_MAX_DIM_ALLOC];
48
16
  int saved_inv_std_stride[CCV_NNC_MAX_DIM_ALLOC];
49
16
  ccv_nnc_tensor_view_get_stride(saved_mean, saved_mean_stride);
50
16
  ccv_nnc_tensor_view_get_stride(saved_inv_std, saved_inv_std_stride);
51
16
  int x;
52
16
  int n = 1;
53
80
  for (x = 0; x < CCV_NNC_MAX_DIM + 2; 
x++64
)
54
64
    n *= adim[x];
55
80
  for (x = 0; x < CCV_NNC_MAX_DIM + 2; 
x++64
)
56
64
    n /= rdim[x];
57
16
  const float inv_n = 1. / n;
58
16
  int i[CCV_NNC_MAX_DIM + 2];
59
16
  float* const ap = a->data.f32;
60
16
  float* const meanp = saved_mean->data.f32;
61
16
  ccv_nnc_tensor_zero(saved_mean);
62
56
  for (i[0] = 0; i[0] < adim[0]; 
i[0]++40
)
63
40
  {
64
40
    float* const ap0 = ap + i[0] * astride[0];
65
40
    float* const meanp0 = meanp + (rdim[0] < adim[0] ? 
i[0] * rdim[0] / adim[0]0
: i[0]) * saved_mean_stride[0];
66
456
    for (i[1] = 0; i[1] < adim[1]; 
i[1]++416
)
67
416
    {
68
416
      float* ap1 = ap0 + i[1] * astride[1];
69
416
      float* const meanp1 = meanp0 + (rdim[1] < adim[1] ? 
i[1] * rdim[1] / adim[1]320
:
i[1]96
) * saved_mean_stride[1];
70
1.82k
      for (i[2] = 0; i[2] < adim[2]; 
i[2]++1.40k
)
71
1.40k
      {
72
1.40k
        float* const meanp2 = meanp1 + (rdim[2] < adim[2] ? 
i[2] * rdim[2] / adim[2]768
:
i[2]640
) * saved_mean_stride[2];
73
1.40k
        if (rdim[3] < adim[3])
74
5.63k
          
for (x = 0; 512
x < adim[3];
x++5.12k
)
75
5.12k
            meanp2[x * rdim[3] / adim[3]] += ap1[x];
76
896
        else
77
7.29k
          
for (x = 0; 896
x < adim[3];
x++6.40k
)
78
6.40k
            meanp2[x] += ap1[x];
79
1.40k
        ap1 += astride[2];
80
1.40k
      }
81
416
    }
82
40
  }
83
56
  for (i[0] = 0; i[0] < rdim[0]; 
i[0]++40
)
84
40
  {
85
40
    float* const meanp0 = meanp + i[0] * saved_mean_stride[0];
86
216
    for (i[1] = 0; i[1] < rdim[1]; 
i[1]++176
)
87
176
    {
88
176
      float* const meanp1 = meanp0 + i[1] * saved_mean_stride[1];
89
688
      for (i[2] = 0; i[2] < rdim[2]; 
i[2]++512
)
90
512
      {
91
512
        float* const meanp2 = meanp1 + i[2] * saved_mean_stride[2];
92
2.40k
        for (x = 0; x < rdim[3]; 
x++1.88k
)
93
1.88k
          meanp2[x] = meanp2[x] * inv_n;
94
512
      }
95
176
    }
96
40
  }
97
16
  ccv_nnc_tensor_zero(saved_inv_std);
98
16
  float* const varp = saved_inv_std->data.f32;
99
56
  for (i[0] = 0; i[0] < adim[0]; 
i[0]++40
)
100
40
  {
101
40
    float* const ap0 = ap + i[0] * astride[0];
102
40
    float* const meanp0 = meanp + (rdim[0] < adim[0] ? 
i[0] * rdim[0] / adim[0]0
: i[0]) * saved_mean_stride[0];
103
40
    float* const varp0 = varp + (rdim[0] < adim[0] ? 
i[0] * rdim[0] / adim[0]0
: i[0]) * saved_inv_std_stride[0];
104
456
    for (i[1] = 0; i[1] < adim[1]; 
i[1]++416
)
105
416
    {
106
416
      float* ap1 = ap0 + i[1] * astride[1];
107
416
      float* const meanp1 = meanp0 + (rdim[1] < adim[1] ? 
i[1] * rdim[1] / adim[1]320
:
i[1]96
) * saved_mean_stride[1];
108
416
      float* const varp1 = varp0 + (rdim[1] < adim[1] ? 
i[1] * rdim[1] / adim[1]320
:
i[1]96
) * saved_inv_std_stride[1];
109
1.82k
      for (i[2] = 0; i[2] < adim[2]; 
i[2]++1.40k
)
110
1.40k
      {
111
1.40k
        float* const meanp2 = meanp1 + (rdim[2] < adim[2] ? 
i[2] * rdim[2] / adim[2]768
:
i[2]640
) * saved_mean_stride[2];
112
1.40k
        float* const varp2 = varp1 + (rdim[2] < adim[2] ? 
i[2] * rdim[2] / adim[2]768
:
i[2]640
) * saved_inv_std_stride[2];
113
1.40k
        if (rdim[3] < adim[3])
114
5.63k
          
for (x = 0; 512
x < adim[3];
x++5.12k
)
115
5.12k
          {
116
5.12k
            float w = ap1[x] - meanp2[x * rdim[3] / adim[3]];
117
5.12k
            varp2[x * rdim[3] / adim[3]] += w * w;
118
5.12k
          }
119
896
        else
120
7.29k
          
for (x = 0; 896
x < adim[3];
x++6.40k
)
121
6.40k
          {
122
6.40k
            float w = ap1[x] - meanp2[x];
123
6.40k
            varp2[x] += w * w;
124
6.40k
          }
125
1.40k
        ap1 += astride[2];
126
1.40k
      }
127
416
    }
128
40
  }
129
56
  for (i[0] = 0; i[0] < rdim[0]; 
i[0]++40
)
130
40
  {
131
40
    float* const varp0 = varp + i[0] * saved_inv_std_stride[0];
132
216
    for (i[1] = 0; i[1] < rdim[1]; 
i[1]++176
)
133
176
    {
134
176
      float* const varp1 = varp0 + i[1] * saved_inv_std_stride[1];
135
688
      for (i[2] = 0; i[2] < rdim[2]; 
i[2]++512
)
136
512
      {
137
512
        float* const varp2 = varp1 + i[2] * saved_inv_std_stride[2];
138
2.40k
        for (x = 0; x < rdim[3]; 
x++1.88k
)
139
1.88k
          varp2[x] = 1. / sqrtf(varp2[x] * inv_n + epsilon);
140
512
      }
141
176
    }
142
40
  }
143
16
  if (cmd.info.gnorm.elementwise_affine)
144
9
  {
145
9
    float* const scalep = scale->data.f32;
146
9
    float* const biasp = bias->data.f32;
147
9
    int sdim[CCV_NNC_MAX_DIM_ALLOC];
148
9
    ccv_nnc_tensor_view_get_dim(scale, sdim);
149
9
    int bias_dim[CCV_NNC_MAX_DIM_ALLOC];
150
9
    ccv_nnc_tensor_view_get_dim(bias, bias_dim);
151
    // Do the straight-forward one, y = (x - mean) * inv_std * scale + bias, we cannot allocate extra memory to help.
152
    // There is no need for precompute since scale / bias is per element.
153
9
    float* const bp = b->data.f32;
154
37
    for (i[0] = 0; i[0] < adim[0]; 
i[0]++28
)
155
28
    {
156
28
      float* const ap0 = ap + i[0] * astride[0];
157
28
      float* const bp0 = bp + i[0] * bstride[0];
158
28
      float* const meanp0 = meanp + (rdim[0] < adim[0] ? 
i[0] * rdim[0] / adim[0]0
: i[0]) * saved_mean_stride[0];
159
28
      float* const varp0 = varp + (rdim[0] < adim[0] ? 
i[0] * rdim[0] / adim[0]0
: i[0]) * saved_inv_std_stride[0];
160
28
      float* const scalep0 = scalep + (sdim[0] < adim[0] ? 
i[0] * sdim[0] / adim[0]26
:
i[0]2
) * scale_stride[0];
161
28
      float* const biasp0 = biasp + (bias_dim[0] < adim[0] ? 
i[0] * bias_dim[0] / adim[0]26
:
i[0]2
) * bias_stride[0];
162
268
      for (i[1] = 0; i[1] < adim[1]; 
i[1]++240
)
163
240
      {
164
240
        float* ap1 = ap0 + i[1] * astride[1];
165
240
        float* bp1 = bp0 + i[1] * bstride[1];
166
240
        float* const meanp1 = meanp0 + (rdim[1] < adim[1] ? 
i[1] * rdim[1] / adim[1]160
:
i[1]80
) * saved_mean_stride[1];
167
240
        float* const varp1 = varp0 + (rdim[1] < adim[1] ? 
i[1] * rdim[1] / adim[1]160
:
i[1]80
) * saved_inv_std_stride[1];
168
240
        float* const scalep1 = scalep0 + (sdim[1] < adim[1] ? 
i[1] * sdim[1] / adim[1]16
:
i[1]224
) * scale_stride[1];
169
240
        float* const biasp1 = biasp0 + (bias_dim[1] < adim[1] ? 
i[1] * bias_dim[1] / adim[1]16
:
i[1]224
) * bias_stride[1];
170
1.07k
        for (i[2] = 0; i[2] < adim[2]; 
i[2]++832
)
171
832
        {
172
832
          float* const meanp2 = meanp1 + (rdim[2] < adim[2] ? 
i[2] * rdim[2] / adim[2]384
:
i[2]448
) * saved_mean_stride[2];
173
832
          float* const varp2 = varp1 + (rdim[2] < adim[2] ? 
i[2] * rdim[2] / adim[2]384
:
i[2]448
) * saved_inv_std_stride[2];
174
832
          float* const scalep2 = scalep1 + (sdim[2] < adim[2] ? 
i[2] * sdim[2] / adim[2]128
:
i[2]704
) * scale_stride[2];
175
832
          float* const biasp2 = biasp1 + (bias_dim[2] < adim[2] ? 
i[2] * bias_dim[2] / adim[2]128
:
i[2]704
) * bias_stride[2];
176
832
          if (rdim[3] < adim[3])
177
4.22k
            
for (x = 0; 384
x < adim[3];
x++3.84k
)
178
3.84k
              bp1[x] = (ap1[x] - meanp2[x * rdim[3] / adim[3]]) * varp2[x * rdim[3] / adim[3]] * scalep2[x * sdim[3] / adim[3]] + biasp2[x * bias_dim[3] / adim[3]];
179
448
          else
180
3.64k
            
for (x = 0; 448
x < adim[3];
x++3.20k
)
181
3.20k
              bp1[x] = (ap1[x] - meanp2[x]) * varp2[x] * scalep2[x * sdim[3] / adim[3]] + biasp2[x * bias_dim[3] / adim[3]];
182
832
          ap1 += astride[2];
183
832
          bp1 += bstride[2];
184
832
        }
185
240
      }
186
28
    }
187
9
  } else {
188
    // Do the straight-forward one, y = (x - mean) * inv_std, we cannot allocate extra memory to help.
189
7
    float* const bp = b->data.f32;
190
19
    for (i[0] = 0; i[0] < adim[0]; 
i[0]++12
)
191
12
    {
192
12
      float* const ap0 = ap + i[0] * astride[0];
193
12
      float* const bp0 = bp + i[0] * bstride[0];
194
12
      float* const meanp0 = meanp + (rdim[0] < adim[0] ? 
i[0] * rdim[0] / adim[0]0
: i[0]) * saved_mean_stride[0];
195
12
      float* const varp0 = varp + (rdim[0] < adim[0] ? 
i[0] * rdim[0] / adim[0]0
: i[0]) * saved_inv_std_stride[0];
196
188
      for (i[1] = 0; i[1] < adim[1]; 
i[1]++176
)
197
176
      {
198
176
        float* ap1 = ap0 + i[1] * astride[1];
199
176
        float* bp1 = bp0 + i[1] * bstride[1];
200
176
        float* const meanp1 = meanp0 + (rdim[1] < adim[1] ? 
i[1] * rdim[1] / adim[1]160
:
i[1]16
) * saved_mean_stride[1];
201
176
        float* const varp1 = varp0 + (rdim[1] < adim[1] ? 
i[1] * rdim[1] / adim[1]160
:
i[1]16
) * saved_inv_std_stride[1];
202
752
        for (i[2] = 0; i[2] < adim[2]; 
i[2]++576
)
203
576
        {
204
576
          float* const meanp2 = meanp1 + (rdim[2] < adim[2] ? 
i[2] * rdim[2] / adim[2]384
:
i[2]192
) * saved_mean_stride[2];
205
576
          float* const varp2 = varp1 + (rdim[2] < adim[2] ? 
i[2] * rdim[2] / adim[2]384
:
i[2]192
) * saved_inv_std_stride[2];
206
576
          if (rdim[3] < adim[3])
207
1.40k
            
for (x = 0; 128
x < adim[3];
x++1.28k
)
208
1.28k
              bp1[x] = (ap1[x] - meanp2[x * rdim[3] / adim[3]]) * varp2[x * rdim[3] / adim[3]];
209
448
          else
210
3.64k
            
for (x = 0; 448
x < adim[3];
x++3.20k
)
211
3.20k
              bp1[x] = (ap1[x] - meanp2[x]) * varp2[x];
212
576
          ap1 += astride[2];
213
576
          bp1 += bstride[2];
214
576
        }
215
176
      }
216
12
    }
217
7
  }
218
16
  return CCV_NNC_EXEC_SUCCESS;
219
16
}
220
221
static int _ccv_nnc_group_norm_back(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint, const int flags, ccv_nnc_tensor_t* const* const inputs, const int input_size, ccv_nnc_tensor_t* const* const outputs, const int output_size, ccv_nnc_stream_context_t* const stream_context)
222
9
{
223
9
  assert(input_size == 9 || input_size == 7);
224
9
  assert(output_size >= 1);
225
9
  const int elementwise_affine = cmd.info.gnorm.elementwise_affine;
226
9
  ccv_nnc_tensor_view_t* const g = (ccv_nnc_tensor_view_t*)inputs[0];
227
9
  ccv_nnc_tensor_view_t* const a = (ccv_nnc_tensor_view_t*)inputs[3];
228
9
  ccv_nnc_tensor_view_t* const scale = elementwise_affine ? 
(ccv_nnc_tensor_view_t*)inputs[4]5
:
04
;
229
9
  ccv_nnc_tensor_view_t* const saved_mean = (ccv_nnc_tensor_view_t*)inputs[elementwise_affine ? 
75
:
54
];
230
9
  ccv_nnc_tensor_view_t* const saved_inv_std = (ccv_nnc_tensor_view_t*)inputs[elementwise_affine ? 
85
:
64
];
231
9
  ccv_nnc_tensor_view_t* const h = (ccv_nnc_tensor_view_t*)outputs[0];
232
9
  ccv_nnc_tensor_view_t* const dscale = output_size > 1 ? 
(ccv_nnc_tensor_view_t*)outputs[1]5
:
04
;
233
9
  ccv_nnc_tensor_view_t* const dbias = output_size > 2 ? 
(ccv_nnc_tensor_view_t*)outputs[2]5
:
04
;
234
9
  assert(ccv_nnc_tensor_nd(g->info.dim) <= CCV_NNC_MAX_DIM + 2);
235
9
  assert(ccv_nnc_tensor_nd(a->info.dim) <= CCV_NNC_MAX_DIM + 2);
236
9
  assert(ccv_nnc_tensor_nd(h->info.dim) <= CCV_NNC_MAX_DIM + 2);
237
  // Assuming this is float 32.
238
9
  int gdim[CCV_NNC_MAX_DIM_ALLOC];
239
9
  int rdim[CCV_NNC_MAX_DIM_ALLOC];
240
9
  ccv_nnc_tensor_view_get_dim(g, gdim);
241
9
  ccv_nnc_tensor_view_get_dim(saved_mean, rdim);
242
9
  assert(ccv_nnc_tensor_view_check_dim(saved_inv_std, rdim));
243
9
  int sdim[CCV_NNC_MAX_DIM_ALLOC];
244
9
  if (scale)
245
5
    ccv_nnc_tensor_view_get_dim(scale, sdim);
246
9
  if (dscale)
247
4
    { assert(ccv_nnc_tensor_view_check_dim(dscale, sdim)); }
248
9
  assert(ccv_nnc_tensor_view_check_dim(a, gdim));
249
9
  assert(ccv_nnc_tensor_view_check_dim(h, gdim));
250
9
  if (dbias)
251
4
    _ccv_nnc_reduce_sum_forw_cpu_ref(g, dbias);
252
9
  int astride[CCV_NNC_MAX_DIM_ALLOC];
253
9
  int gstride[CCV_NNC_MAX_DIM_ALLOC];
254
9
  int hstride[CCV_NNC_MAX_DIM_ALLOC];
255
9
  int scale_stride[CCV_NNC_MAX_DIM_ALLOC];
256
9
  int mean_stride[CCV_NNC_MAX_DIM_ALLOC];
257
9
  int inv_std_stride[CCV_NNC_MAX_DIM_ALLOC];
258
9
  int dscale_stride[CCV_NNC_MAX_DIM_ALLOC];
259
9
  ccv_nnc_tensor_view_get_stride(a, astride);
260
9
  ccv_nnc_tensor_view_get_stride(g, gstride);
261
9
  ccv_nnc_tensor_view_get_stride(h, hstride);
262
9
  if (scale)
263
5
    ccv_nnc_tensor_view_get_stride(scale, scale_stride);
264
9
  ccv_nnc_tensor_view_get_stride(saved_mean, mean_stride);
265
9
  ccv_nnc_tensor_view_get_stride(saved_inv_std, inv_std_stride);
266
9
  if (dscale)
267
4
    ccv_nnc_tensor_view_get_stride(dscale, dscale_stride);
268
  // Need to allocate two additional memory:
269
  // 1. normalized a;
270
  // 2. scale * inv_std / n;
271
9
  assert(!(flags & CCV_NNC_ZERO_MEMORY_ALLOC));
272
9
  int x;
273
9
  int n = 1;
274
45
  for (x = 0; x < CCV_NNC_MAX_DIM + 2; 
x++36
)
275
36
    n *= gdim[x];
276
45
  for (x = 0; x < CCV_NNC_MAX_DIM + 2; 
x++36
)
277
36
    n /= rdim[x];
278
9
  int gcount = 1, rcount = 1;
279
45
  for (x = 0; x < CCV_NNC_MAX_DIM + 2; 
x++36
)
280
36
    gcount *= gdim[x], rcount *= rdim[x];
281
9
  float* const ah = (float*)ccv_nnc_stream_context_get_workspace(stream_context, sizeof(float) * gcount * 2 + sizeof(float) * rcount * 2, CCV_TENSOR_CPU_MEMORY);
282
9
  float* const gss = ah + gcount; // g * scale * inv_std
283
9
  float* const gssr = gss + gcount; // gss reduced to inv_std dimension
284
9
  float* const ahgssr = gssr + rcount; // ah * gss then reduced to inv_std dimension.
285
9
  int i[CCV_NNC_MAX_DIM + 2];
286
9
  float* ahp = ah;
287
9
  const float* const meanp = saved_mean->data.f32;
288
9
  const float* const inv_stdp = saved_inv_std->data.f32;
289
9
  const float* const ap = a->data.f32;
290
31
  for (i[0] = 0; i[0] < gdim[0]; 
i[0]++22
)
291
22
  {
292
22
    const float* const ap0 = ap + i[0] * astride[0];
293
22
    const float* const meanp0 = meanp + (rdim[0] < gdim[0] ? 
i[0] * rdim[0] / gdim[0]0
: i[0]) * mean_stride[0];
294
22
    const float* const inv_stdp0 = inv_stdp + (rdim[0] < gdim[0] ? 
i[0] * rdim[0] / gdim[0]0
: i[0]) * inv_std_stride[0];
295
262
    for (i[1] = 0; i[1] < gdim[1]; 
i[1]++240
)
296
240
    {
297
240
      const float* ap1 = ap0 + i[1] * astride[1];
298
240
      const float* const meanp1 = meanp0 + (rdim[1] < gdim[1] ? 
i[1] * rdim[1] / gdim[1]192
:
i[1]48
) * mean_stride[1];
299
240
      const float* const inv_stdp1 = inv_stdp0 + (rdim[1] < gdim[1] ? 
i[1] * rdim[1] / gdim[1]192
:
i[1]48
) * inv_std_stride[1];
300
1.00k
      for (i[2] = 0; i[2] < gdim[2]; 
i[2]++768
)
301
768
      {
302
768
        const float* const meanp2 = meanp1 + (rdim[2] < gdim[2] ? 
i[2] * rdim[2] / gdim[2]384
:
i[2]384
) * mean_stride[2];
303
768
        const float* const inv_stdp2 = inv_stdp1 + (rdim[2] < gdim[2] ? 
i[2] * rdim[2] / gdim[2]384
:
i[2]384
) * inv_std_stride[2];
304
768
        if (rdim[3] < gdim[3])
305
2.81k
          
for (x = 0; 256
x < gdim[3];
x++2.56k
)
306
2.56k
            ahp[x] = (ap1[x] - meanp2[x * rdim[3] / gdim[3]]) * inv_stdp2[x * rdim[3] / gdim[3]];
307
512
        else
308
4.35k
          
for (x = 0; 512
x < gdim[3];
x++3.84k
)
309
3.84k
            ahp[x] = (ap1[x] - meanp2[x]) * inv_stdp2[x];
310
768
        ap1 += astride[2];
311
768
        ahp += gdim[3];
312
768
      }
313
240
    }
314
22
  }
315
9
  if (dscale)
316
4
  {
317
4
    ccv_nnc_tensor_zero(dscale);
318
4
    ahp = ah;
319
4
    float* gssp = gss;
320
4
    const float* const gp = g->data.f32;
321
4
    const float* const scalep = scale->data.f32;
322
4
    float* const dscalep = dscale->data.f32;
323
17
    for (i[0] = 0; i[0] < gdim[0]; 
i[0]++13
)
324
13
    {
325
13
      const float* const gp0 = gp + i[0] * gstride[0];
326
13
      const float* const inv_stdp0 = inv_stdp + (rdim[0] < gdim[0] ? 
i[0] * rdim[0] / gdim[0]0
: i[0]) * inv_std_stride[0];
327
13
      const float* const scalep0 = scalep + (sdim[0] < gdim[0] ? 
i[0] * sdim[0] / gdim[0]12
:
i[0]1
) * scale_stride[0];
328
13
      float* const dscalep0 = dscalep + (sdim[0] < gdim[0] ? 
i[0] * sdim[0] / gdim[0]12
:
i[0]1
) * dscale_stride[0];
329
117
      for (i[1] = 0; i[1] < gdim[1]; 
i[1]++104
)
330
104
      {
331
104
        const float* gp1 = gp0 + i[1] * gstride[1];
332
104
        const float* const inv_stdp1 = inv_stdp0 + (rdim[1] < gdim[1] ? 
i[1] * rdim[1] / gdim[1]64
:
i[1]40
) * inv_std_stride[1];
333
104
        const float* const scalep1 = scalep0 + (sdim[1] < gdim[1] ? 
i[1] * sdim[1] / gdim[1]8
:
i[1]96
) * scale_stride[1];
334
104
        float* const dscalep1 = dscalep0 + (sdim[1] < gdim[1] ? 
i[1] * sdim[1] / gdim[1]8
:
i[1]96
) * dscale_stride[1];
335
488
        for (i[2] = 0; i[2] < gdim[2]; 
i[2]++384
)
336
384
        {
337
384
          const float* const inv_stdp2 = inv_stdp1 + (rdim[2] < gdim[2] ? 
i[2] * rdim[2] / gdim[2]192
:
i[2]192
) * inv_std_stride[2];
338
384
          const float* const scalep2 = scalep1 + (sdim[2] < gdim[2] ? 
i[2] * sdim[2] / gdim[2]64
:
i[2]320
) * scale_stride[2];
339
384
          float* const dscalep2 = dscalep1 + (sdim[2] < gdim[2] ? 
i[2] * sdim[2] / gdim[2]64
:
i[2]320
) * dscale_stride[2];
340
384
          if (sdim[3] < gdim[3])
341
704
            
for (x = 0; 64
x < gdim[3];
x++640
)
342
640
            {
343
640
              gssp[x] = gp1[x] * scalep2[x * sdim[3] / gdim[3]] * inv_stdp2[x * rdim[3] / gdim[3]];
344
640
              dscalep2[x * sdim[3] / gdim[3]] += ahp[x] * gp1[x];
345
640
            }
346
320
          else
347
2.88k
            
for (x = 0; 320
x < gdim[3];
x++2.56k
)
348
2.56k
            {
349
2.56k
              gssp[x] = gp1[x] * scalep2[x] * inv_stdp2[x * rdim[3] / gdim[3]];
350
2.56k
              dscalep2[x] += ahp[x] * gp1[x];
351
2.56k
            }
352
384
          gp1 += gstride[2];
353
384
          ahp += gdim[3];
354
384
          gssp += gdim[3];
355
384
        }
356
104
      }
357
13
    }
358
5
  } else {
359
5
    ahp = ah;
360
5
    float* gssp = gss;
361
5
    const float* const gp = g->data.f32;
362
5
    if (elementwise_affine)
363
1
    {
364
1
      const float* const scalep = scale->data.f32;
365
3
      for (i[0] = 0; i[0] < gdim[0]; 
i[0]++2
)
366
2
      {
367
2
        const float* const gp0 = gp + i[0] * gstride[0];
368
2
        const float* const inv_stdp0 = inv_stdp + (rdim[0] < gdim[0] ? 
i[0] * rdim[0] / gdim[0]0
: i[0]) * inv_std_stride[0];
369
2
        const float* const scalep0 = scalep + (sdim[0] < gdim[0] ? i[0] * sdim[0] / gdim[0] : 
i[0]0
) * scale_stride[0];
370
34
        for (i[1] = 0; i[1] < gdim[1]; 
i[1]++32
)
371
32
        {
372
32
          const float* gp1 = gp0 + i[1] * gstride[1];
373
32
          const float* const inv_stdp1 = inv_stdp0 + (rdim[1] < gdim[1] ? i[1] * rdim[1] / gdim[1] : 
i[1]0
) * inv_std_stride[1];
374
32
          const float* const scalep1 = scalep0 + (sdim[1] < gdim[1] ? 
i[1] * sdim[1] / gdim[1]0
: i[1]) * scale_stride[1];
375
96
          for (i[2] = 0; i[2] < gdim[2]; 
i[2]++64
)
376
64
          {
377
64
            const float* const inv_stdp2 = inv_stdp1 + (rdim[2] < gdim[2] ? 
i[2] * rdim[2] / gdim[2]0
: i[2]) * inv_std_stride[2];
378
64
            const float* const scalep2 = scalep1 + (sdim[2] < gdim[2] ? 
i[2] * sdim[2] / gdim[2]0
: i[2]) * scale_stride[2];
379
64
            if (sdim[3] < gdim[3])
380
0
              for (x = 0; x < gdim[3]; x++)
381
0
                gssp[x] = gp1[x] * scalep2[x * sdim[3] / gdim[3]] * inv_stdp2[x * rdim[3] / gdim[3]];
382
64
            else
383
704
              
for (x = 0; 64
x < gdim[3];
x++640
)
384
640
                gssp[x] = gp1[x] * scalep2[x] * inv_stdp2[x * rdim[3] / gdim[3]];
385
64
            gp1 += gstride[2];
386
64
            ahp += gdim[3];
387
64
            gssp += gdim[3];
388
64
          }
389
32
        }
390
2
      }
391
4
    } else {
392
11
      for (i[0] = 0; i[0] < gdim[0]; 
i[0]++7
)
393
7
      {
394
7
        const float* const gp0 = gp + i[0] * gstride[0];
395
7
        const float* const inv_stdp0 = inv_stdp + (rdim[0] < gdim[0] ? 
i[0] * rdim[0] / gdim[0]0
: i[0]) * inv_std_stride[0];
396
111
        for (i[1] = 0; i[1] < gdim[1]; 
i[1]++104
)
397
104
        {
398
104
          const float* gp1 = gp0 + i[1] * gstride[1];
399
104
          const float* const inv_stdp1 = inv_stdp0 + (rdim[1] < gdim[1] ? 
i[1] * rdim[1] / gdim[1]96
:
i[1]8
) * inv_std_stride[1];
400
424
          for (i[2] = 0; i[2] < gdim[2]; 
i[2]++320
)
401
320
          {
402
320
            const float* const inv_stdp2 = inv_stdp1 + (rdim[2] < gdim[2] ? 
i[2] * rdim[2] / gdim[2]192
:
i[2]128
) * inv_std_stride[2];
403
2.88k
            for (x = 0; x < gdim[3]; 
x++2.56k
)
404
2.56k
              gssp[x] = gp1[x] * inv_stdp2[x * rdim[3] / gdim[3]];
405
320
            gp1 += gstride[2];
406
320
            ahp += gdim[3];
407
320
            gssp += gdim[3];
408
320
          }
409
104
        }
410
7
      }
411
4
    }
412
5
  }
413
9
  ccv_nnc_tensor_t gssrt = ccv_nnc_tensor(gssr, saved_mean->info, 0);
414
9
  ccv_nnc_tensor_zero(&gssrt);
415
9
  float* gssp = gss;
416
9
  float* const gssrp = gssr;
417
31
  for (i[0] = 0; i[0] < gdim[0]; 
i[0]++22
)
418
22
  {
419
22
    float* const gssrp0 = gssrp + (rdim[0] < gdim[0] ? 
i[0] * rdim[0] / gdim[0]0
: i[0]) * rdim[1] * rdim[2] * rdim[3];
420
262
    for (i[1] = 0; i[1] < gdim[1]; 
i[1]++240
)
421
240
    {
422
240
      float* const gssrp1 = gssrp0 + (rdim[1] < gdim[1] ? 
i[1] * rdim[1] / gdim[1]192
:
i[1]48
) * rdim[2] * rdim[3];
423
1.00k
      for (i[2] = 0; i[2] < gdim[2]; 
i[2]++768
)
424
768
      {
425
768
        float* const gssrp2 = gssrp1 + (rdim[2] < gdim[2] ? 
i[2] * rdim[2] / gdim[2]384
:
i[2]384
) * rdim[3];
426
768
        if (rdim[3] < gdim[3])
427
2.81k
          
for (x = 0; 256
x < gdim[3];
x++2.56k
)
428
2.56k
            gssrp2[x * rdim[3] / gdim[3]] += gssp[x];
429
512
        else
430
4.35k
          
for (x = 0; 512
x < gdim[3];
x++3.84k
)
431
3.84k
            gssrp2[x] += gssp[x];
432
768
        gssp += gdim[3];
433
768
      }
434
240
    }
435
22
  }
436
9
  ahp = ah;
437
9
  gssp = gss;
438
9
  ccv_nnc_tensor_t ahgssrt = ccv_nnc_tensor(ahgssr, saved_mean->info, 0);
439
9
  ccv_nnc_tensor_zero(&ahgssrt);
440
9
  float* const ahgssrp = ahgssr;
441
31
  for (i[0] = 0; i[0] < gdim[0]; 
i[0]++22
)
442
22
  {
443
22
    float* const ahgssrp0 = ahgssrp + (rdim[0] < gdim[0] ? 
i[0] * rdim[0] / gdim[0]0
: i[0]) * rdim[1] * rdim[2] * rdim[3];
444
262
    for (i[1] = 0; i[1] < gdim[1]; 
i[1]++240
)
445
240
    {
446
240
      float* const ahgssrp1 = ahgssrp0 + (rdim[1] < gdim[1] ? 
i[1] * rdim[1] / gdim[1]192
:
i[1]48
) * rdim[2] * rdim[3];
447
1.00k
      for (i[2] = 0; i[2] < gdim[2]; 
i[2]++768
)
448
768
      {
449
768
        float* const ahgssrp2 = ahgssrp1 + (rdim[2] < gdim[2] ? 
i[2] * rdim[2] / gdim[2]384
:
i[2]384
) * rdim[3];
450
768
        if (rdim[3] < gdim[3])
451
2.81k
          
for (x = 0; 256
x < gdim[3];
x++2.56k
)
452
2.56k
            ahgssrp2[x * rdim[3] / gdim[3]] += ahp[x] * gssp[x];
453
512
        else
454
4.35k
          
for (x = 0; 512
x < gdim[3];
x++3.84k
)
455
3.84k
            ahgssrp2[x] += ahp[x] * gssp[x];
456
768
        ahp += gdim[3];
457
768
        gssp += gdim[3];
458
768
      }
459
240
    }
460
22
  }
461
  // Now the part to compute dx (h).
462
9
  float* const hp = h->data.f32;
463
9
  ahp = ah;
464
9
  const float inv_n = 1. / n;
465
9
  gssp = gss;
466
31
  for (i[0] = 0; i[0] < gdim[0]; 
i[0]++22
)
467
22
  {
468
22
    float* const hp0 = hp + i[0] * hstride[0];
469
22
    const float* const gssrp0 = gssrp + (rdim[0] < gdim[0] ? 
i[0] * rdim[0] / gdim[0]0
: i[0]) * rdim[1] * rdim[2] * rdim[3];
470
22
    const float* const ahgssrp0 = ahgssrp + (rdim[0] < gdim[0] ? 
i[0] * rdim[0] / gdim[0]0
: i[0]) * rdim[1] * rdim[2] * rdim[3];
471
262
    for (i[1] = 0; i[1] < gdim[1]; 
i[1]++240
)
472
240
    {
473
240
      float* hp1 = hp0 + i[1] * hstride[1];
474
240
      const float* const gssrp1 = gssrp0 + (rdim[1] < gdim[1] ? 
i[1] * rdim[1] / gdim[1]192
:
i[1]48
) * rdim[2] * rdim[3];
475
240
      const float* const ahgssrp1 = ahgssrp0 + (rdim[1] < gdim[1] ? 
i[1] * rdim[1] / gdim[1]192
:
i[1]48
) * rdim[2] * rdim[3];
476
1.00k
      for (i[2] = 0; i[2] < gdim[2]; 
i[2]++768
)
477
768
      {
478
768
        const float* const gssrp2 = gssrp1 + (rdim[2] < gdim[2] ? 
i[2] * rdim[2] / gdim[2]384
:
i[2]384
) * rdim[3];
479
768
        const float* const ahgssrp2 = ahgssrp1 + (rdim[2] < gdim[2] ? 
i[2] * rdim[2] / gdim[2]384
:
i[2]384
) * rdim[3];
480
768
        if (rdim[3] < gdim[3])
481
2.81k
          
for (x = 0; 256
x < gdim[3];
x++2.56k
)
482
2.56k
            hp1[x] = gssp[x] - inv_n * (gssrp2[x * rdim[3] / gdim[3]] + ahp[x] * ahgssrp2[x * rdim[3] / gdim[3]]);
483
512
        else
484
4.35k
          
for (x = 0; 512
x < gdim[3];
x++3.84k
)
485
3.84k
            hp1[x] = gssp[x] - inv_n * (gssrp2[x] + ahp[x] * ahgssrp2[x]);
486
768
        hp1 += hstride[2];
487
768
        ahp += gdim[3];
488
768
        gssp += gdim[3];
489
768
      }
490
240
    }
491
22
  }
492
9
  return CCV_NNC_EXEC_SUCCESS;
493
9
}
494
495
REGISTER_COMMAND_BACKEND(CCV_NNC_GROUP_NORM_FORWARD, CCV_NNC_BACKEND_CPU_REF)(ccv_nnc_cmd_backend_registry_t* const registry)
496
1
{
497
1
  registry->tensor_formats = CCV_TENSOR_FORMAT_NHWC | CCV_TENSOR_FORMAT_NCHW | CCV_TENSOR_FORMAT_CHWN;
498
1
  registry->tensor_datatypes = CCV_32F;
499
1
  registry->tensor_memory = CCV_TENSOR_CPU_MEMORY;
500
1
  registry->algorithms = 1;
501
1
  registry->exec = _ccv_nnc_group_norm_forw;
502
1
}
503
504
REGISTER_COMMAND_BACKEND(CCV_NNC_GROUP_NORM_BACKWARD, CCV_NNC_BACKEND_CPU_REF)(ccv_nnc_cmd_backend_registry_t* const registry)
505
1
{
506
1
  registry->tensor_formats = CCV_TENSOR_FORMAT_NHWC | CCV_TENSOR_FORMAT_NCHW | CCV_TENSOR_FORMAT_CHWN;
507
1
  registry->tensor_datatypes = CCV_32F;
508
1
  registry->tensor_memory = CCV_TENSOR_CPU_MEMORY;
509
1
  registry->algorithms = 1;
510
1
  registry->exec = _ccv_nnc_group_norm_back;
511
1
}