Coverage Report

Created: 2024-08-18 16:21

/home/liu/actions-runner/_work/ccv/ccv/lib/nnc/cmd/norm/ccv_nnc_batch_norm_cpu_ref.c
Line
Count
Source (jump to first uncovered line)
1
#include "ccv.h"
2
#include "ccv_internal.h"
3
#include "nnc/ccv_nnc.h"
4
#include "nnc/ccv_nnc_easy.h"
5
#include "nnc/ccv_nnc_internal.h"
6
#ifdef USE_OPENMP
7
#include <omp.h>
8
#endif
9
#ifdef USE_DISPATCH
10
#include <dispatch/dispatch.h>
11
#endif
12
13
// Shared methods.
14
#include "../_ccv_nnc_cpu_ref.h"
15
16
static int _ccv_nnc_batch_norm_forw(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint, const int flags, ccv_nnc_tensor_t* const* const inputs, const int input_size, ccv_nnc_tensor_t* const* const outputs, const int output_size, ccv_nnc_stream_context_t* const stream_context)
17
24
{
18
24
  assert(input_size == 5);
19
24
  ccv_nnc_tensor_view_t* const a = (ccv_nnc_tensor_view_t*)inputs[0];
20
24
  ccv_nnc_tensor_view_t* const scale = (ccv_nnc_tensor_view_t*)inputs[1];
21
24
  ccv_nnc_tensor_view_t* const bias = (ccv_nnc_tensor_view_t*)inputs[2];
22
24
  ccv_nnc_tensor_view_t* const mean = (ccv_nnc_tensor_view_t*)inputs[3];
23
24
  ccv_nnc_tensor_view_t* const var = (ccv_nnc_tensor_view_t*)inputs[4];
24
24
  ccv_nnc_tensor_view_t* const b = (ccv_nnc_tensor_view_t*)outputs[0];
25
24
  assert(ccv_nnc_tensor_nd(a->info.dim) <= CCV_NNC_MAX_DIM + 2);
26
24
  assert(ccv_nnc_tensor_nd(b->info.dim) <= CCV_NNC_MAX_DIM + 2);
27
  // Assuming this is float 32.
28
24
  int adim[CCV_NNC_MAX_DIM_ALLOC];
29
24
  int rdim[CCV_NNC_MAX_DIM_ALLOC];
30
24
  ccv_nnc_tensor_view_get_dim(a, adim);
31
24
  ccv_nnc_tensor_view_get_dim(scale, rdim);
32
24
  assert(ccv_nnc_tensor_view_check_dim(bias, rdim));
33
24
  assert(ccv_nnc_tensor_view_check_dim(mean, rdim));
34
24
  assert(ccv_nnc_tensor_view_check_dim(var, rdim));
35
24
  assert(ccv_nnc_tensor_view_check_dim(b, adim));
36
24
  assert(CCV_NNC_MAX_DIM == 2); // Need to change this logic for CCV_NNC_MAX_DIM == other number.
37
24
  int astride[CCV_NNC_MAX_DIM_ALLOC];
38
24
  int bstride[CCV_NNC_MAX_DIM_ALLOC];
39
24
  int scale_stride[CCV_NNC_MAX_DIM_ALLOC];
40
24
  int bias_stride[CCV_NNC_MAX_DIM_ALLOC];
41
24
  ccv_nnc_tensor_view_get_stride(a, astride);
42
24
  ccv_nnc_tensor_view_get_stride(scale, scale_stride);
43
24
  ccv_nnc_tensor_view_get_stride(bias, bias_stride);
44
24
  ccv_nnc_tensor_view_get_stride(b, bstride);
45
24
  const float epsilon = cmd.info.bnorm.epsilon;
46
24
  if (!cmd.info.bnorm.is_test)
47
24
  {
48
24
    assert(output_size == 5);
49
    // Both are inplace.
50
24
    assert(inputs[3]->data.f32 == outputs[1]->data.f32);
51
24
    assert(inputs[4]->data.f32 == outputs[2]->data.f32);
52
24
    ccv_nnc_tensor_view_t* const saved_mean = (ccv_nnc_tensor_view_t*)outputs[3];
53
24
    ccv_nnc_tensor_view_t* const saved_inv_std = (ccv_nnc_tensor_view_t*)outputs[4];
54
24
    assert(ccv_nnc_tensor_view_check_dim(saved_mean, rdim));
55
24
    assert(ccv_nnc_tensor_view_check_dim(saved_inv_std, rdim));
56
24
    int saved_mean_stride[CCV_NNC_MAX_DIM_ALLOC];
57
24
    int saved_inv_std_stride[CCV_NNC_MAX_DIM_ALLOC];
58
24
    ccv_nnc_tensor_view_get_stride(saved_mean, saved_mean_stride);
59
24
    ccv_nnc_tensor_view_get_stride(saved_inv_std, saved_inv_std_stride);
60
24
    int i[CCV_NNC_MAX_DIM + 2];
61
24
    int x;
62
24
    int batch_size = 1;
63
120
    for (x = 0; x < CCV_NNC_MAX_DIM + 2; 
x++96
)
64
96
      batch_size *= adim[x];
65
120
    for (x = 0; x < CCV_NNC_MAX_DIM + 2; 
x++96
)
66
96
      batch_size /= rdim[x];
67
24
    const float inv_batch_size = 1. / batch_size;
68
24
    _ccv_nnc_reduce_sum_forw_cpu_ref(a, saved_mean);
69
24
    _ccv_nnc_mul_forw_cpu_ref(inv_batch_size, saved_mean, 0, saved_mean);
70
    // Copy this into running mean / var.
71
24
    _ccv_nnc_add_forw_cpu_ref(cmd.info.bnorm.momentum, 1. - cmd.info.bnorm.momentum, mean, saved_mean, mean);
72
24
    ccv_nnc_tensor_zero(saved_inv_std);
73
24
    float* const ap = a->data.f32;
74
24
    float* const meanp = saved_mean->data.f32;
75
24
    float* const varp = saved_inv_std->data.f32;
76
174
    for (i[0] = 0; i[0] < adim[0]; 
i[0]++150
)
77
150
    {
78
150
      float* const ap0 = ap + i[0] * astride[0];
79
150
      float* const meanp0 = rdim[0] == 1 ? meanp : 
meanp + i[0] * saved_mean_stride[0]0
;
80
150
      float* const varp0 = rdim[0] == 1 ? varp : 
varp + i[0] * saved_inv_std_stride[0]0
;
81
722
      for (i[1] = 0; i[1] < adim[1]; 
i[1]++572
)
82
572
      {
83
572
        float* ap1 = ap0 + i[1] * astride[1];
84
572
        float* const meanp1 = rdim[1] == 1 ? meanp0 : 
meanp0 + i[1] * saved_mean_stride[1]0
;
85
572
        float* const varp1 = rdim[1] == 1 ? varp0 : 
varp0 + i[1] * saved_inv_std_stride[1]0
;
86
2.80k
        for (i[2] = 0; i[2] < adim[2]; 
i[2]++2.23k
)
87
2.23k
        {
88
2.23k
          float* const meanp2 = rdim[2] == 1 ? meanp1 : 
meanp1 + i[2] * saved_mean_stride[2]0
;
89
2.23k
          float* const varp2 = rdim[2] == 1 ? varp1 : 
varp1 + i[2] * saved_inv_std_stride[2]0
;
90
2.23k
          if (rdim[3] == 1)
91
0
            for (x = 0; x < adim[3]; x++)
92
0
            {
93
0
              float w = ap1[x] - meanp2[0];
94
0
              varp2[0] += w * w;
95
0
            }
96
2.23k
          else
97
24.5k
            
for (x = 0; 2.23k
x < adim[3];
x++22.3k
)
98
22.3k
            {
99
22.3k
              float w = ap1[x] - meanp2[x];
100
22.3k
              varp2[x] += w * w;
101
22.3k
            }
102
2.23k
          ap1 += astride[2];
103
2.23k
        }
104
572
      }
105
150
    }
106
24
    _ccv_nnc_mul_forw_cpu_ref(inv_batch_size, saved_inv_std, 0, saved_inv_std);
107
24
    _ccv_nnc_add_forw_cpu_ref(cmd.info.bnorm.momentum, 1. - cmd.info.bnorm.momentum, var, saved_inv_std, var);
108
48
    for (i[0] = 0; i[0] < rdim[0]; 
i[0]++24
)
109
24
    {
110
24
      float* const varp0 = varp + i[0] * saved_inv_std_stride[0];
111
48
      for (i[1] = 0; i[1] < rdim[1]; 
i[1]++24
)
112
24
      {
113
24
        float* const varp1 = varp0 + i[1] * saved_inv_std_stride[1];
114
48
        for (i[2] = 0; i[2] < rdim[2]; 
i[2]++24
)
115
24
        {
116
24
          float* const varp2 = varp1 + i[2] * saved_inv_std_stride[2];
117
264
          for (x = 0; x < rdim[3]; 
x++240
)
118
240
            varp2[x] = 1. / sqrtf(varp2[x] + epsilon);
119
24
        }
120
24
      }
121
24
    }
122
24
    float* const scalep = scale->data.f32;
123
24
    float* const biasp = bias->data.f32;
124
    // Now, after mean and inv_std computed, go and stretch a.
125
24
    if (flags & CCV_NNC_ZERO_MEMORY_ALLOC)
126
0
    {
127
      // Do the straight-forward one, y = (x - mean) * inv_std * scale + bias, we cannot allocate extra memory to help.
128
0
      float* const bp = b->data.f32;
129
0
      for (i[0] = 0; i[0] < adim[0]; i[0]++)
130
0
      {
131
0
        float* const ap0 = ap + i[0] * astride[0];
132
0
        float* const bp0 = bp + i[0] * bstride[0];
133
0
        float* const meanp0 = rdim[0] == 1 ? meanp : meanp + i[0] * saved_mean_stride[0];
134
0
        float* const varp0 = rdim[0] == 1 ? varp : varp + i[0] * saved_inv_std_stride[0];
135
0
        float* const scalep0 = rdim[0] == 1 ? scalep : scalep + i[0] * scale_stride[0];
136
0
        float* const biasp0 = rdim[0] == 1 ? biasp : biasp + i[0] * bias_stride[0];
137
0
        for (i[1] = 0; i[1] < adim[1]; i[1]++)
138
0
        {
139
0
          float* ap1 = ap0 + i[1] * astride[1];
140
0
          float* bp1 = bp0 + i[1] * bstride[1];
141
0
          float* const meanp1 = rdim[1] == 1 ? meanp0 : meanp0 + i[1] * saved_mean_stride[1];
142
0
          float* const varp1 = rdim[1] == 1 ? varp0 : varp0 + i[1] * saved_inv_std_stride[1];
143
0
          float* const scalep1 = rdim[1] == 1 ? scalep0 : scalep0 + i[1] * scale_stride[1];
144
0
          float* const biasp1 = rdim[1] == 1 ? biasp0 : biasp0 + i[1] * bias_stride[1];
145
0
          for (i[2] = 0; i[2] < adim[2]; i[2]++)
146
0
          {
147
0
            float* const meanp2 = rdim[2] == 1 ? meanp1 : meanp1 + i[2] * saved_mean_stride[2];
148
0
            float* const varp2 = rdim[2] == 1 ? varp1 : varp1 + i[2] * saved_inv_std_stride[2];
149
0
            float* const scalep2 = rdim[2] == 1 ? scalep1 : scalep1 + i[2] * scale_stride[2];
150
0
            float* const biasp2 = rdim[2] == 1 ? biasp1 : biasp1 + i[2] * bias_stride[2];
151
0
            if (rdim[3] == 1)
152
0
              for (x = 0; x < adim[3]; x++)
153
0
                bp1[x] = (ap1[x] - meanp2[0]) * varp2[0] * scalep2[0] + biasp2[0];
154
0
            else
155
0
              for (x = 0; x < adim[3]; x++)
156
0
                bp1[x] = (ap1[x] - meanp2[x]) * varp2[x] * scalep2[x] + biasp2[x];
157
0
            ap1 += astride[2];
158
0
            bp1 += bstride[2];
159
0
          }
160
0
        }
161
0
      }
162
24
    } else {
163
      // If we allocate extra memory, we can convert y = (x - mean) * inv_std * scale + bias
164
      // to y = x * inv_std * scale + (bias - mean * inv_std * scale)
165
      // we can pre-compute nscale = inv_std * scale, nbias = bias - mean * inv_std * scale
166
24
      int count = 1;
167
120
      for (x = 0; x < CCV_NNC_MAX_DIM + 2; 
x++96
)
168
96
        count *= rdim[x];
169
24
      float* const nscalep = (float*)ccv_nnc_stream_context_get_workspace(stream_context, sizeof(float) * count * 2, CCV_TENSOR_CPU_MEMORY);
170
24
      float* const nbiasp = nscalep + count;
171
48
      for (i[0] = 0; i[0] < rdim[0]; 
i[0]++24
)
172
24
      {
173
24
        float* const meanp0 = meanp + i[0] * saved_mean_stride[0];
174
24
        float* const varp0 = varp + i[0] * saved_inv_std_stride[0];
175
24
        float* const scalep0 = scalep + i[0] * scale_stride[0];
176
24
        float* const biasp0 = biasp + i[0] * bias_stride[0];
177
24
        float* const nscalep0 = nscalep + i[0] * rdim[1] * rdim[2] * rdim[3];
178
24
        float* const nbiasp0 = nbiasp + i[0] * rdim[1] * rdim[2] * rdim[3];
179
48
        for (i[1] = 0; i[1] < rdim[1]; 
i[1]++24
)
180
24
        {
181
24
          float* const meanp1 = meanp0 + i[1] * saved_mean_stride[1];
182
24
          float* const varp1 = varp0 + i[1] * saved_inv_std_stride[1];
183
24
          float* const scalep1 = scalep0 + i[1] * scale_stride[1];
184
24
          float* const biasp1 = biasp0 + i[1] * bias_stride[1];
185
24
          float* const nscalep1 = nscalep0 + i[1] * rdim[2] * rdim[3];
186
24
          float* const nbiasp1 = nbiasp0 + i[1] * rdim[2] * rdim[3];
187
48
          for (i[2] = 0; i[2] < rdim[2]; 
i[2]++24
)
188
24
          {
189
24
            float* const meanp2 = meanp1 + i[2] * saved_mean_stride[2];
190
24
            float* const varp2 = varp1 + i[2] * saved_inv_std_stride[2];
191
24
            float* const scalep2 = scalep1 + i[2] * scale_stride[2];
192
24
            float* const biasp2 = biasp1 + i[2] * bias_stride[2];
193
24
            float* const nscalep2 = nscalep1 + i[2] * rdim[3];
194
24
            float* const nbiasp2 = nbiasp1 + i[2] * rdim[3];
195
264
            for (x = 0; x < rdim[3]; 
x++240
)
196
240
            {
197
240
              const float w = varp2[x] * scalep2[x];
198
240
              nscalep2[x] = w;
199
240
              nbiasp2[x] = biasp2[x] - meanp2[x] * w;
200
240
            }
201
24
          }
202
24
        }
203
24
      }
204
24
      float* const bp = b->data.f32;
205
174
      for (i[0] = 0; i[0] < adim[0]; 
i[0]++150
)
206
150
      {
207
150
        float* const ap0 = ap + i[0] * astride[0];
208
150
        float* const bp0 = bp + i[0] * bstride[0];
209
150
        float* const nscalep0 = rdim[0] == 1 ? nscalep : 
nscalep + i[0] * rdim[1] * rdim[2] * rdim[3]0
;
210
150
        float* const nbiasp0 = rdim[0] == 1 ? nbiasp : 
nbiasp + i[0] * rdim[1] * rdim[2] * rdim[3]0
;
211
722
        for (i[1] = 0; i[1] < adim[1]; 
i[1]++572
)
212
572
        {
213
572
          float* ap1 = ap0 + i[1] * astride[1];
214
572
          float* bp1 = bp0 + i[1] * bstride[1];
215
572
          float* const nscalep1 = rdim[1] == 1 ? nscalep0 : 
nscalep0 + i[1] * rdim[2] * rdim[3]0
;
216
572
          float* const nbiasp1 = rdim[1] == 1 ? nbiasp0 : 
nbiasp0 + i[1] * rdim[2] * rdim[3]0
;
217
2.80k
          for (i[2] = 0; i[2] < adim[2]; 
i[2]++2.23k
)
218
2.23k
          {
219
2.23k
            float* const nscalep2 = rdim[2] == 1 ? nscalep1 : 
nscalep1 + i[2] * rdim[3]0
;
220
2.23k
            float* const nbiasp2 = rdim[2] == 1 ? nbiasp1 : 
nbiasp1 + i[2] * rdim[3]0
;
221
2.23k
            if (rdim[3] == 1)
222
0
              for (x = 0; x < adim[3]; x++)
223
0
                bp1[x] = ap1[x] * nscalep2[0] + nbiasp2[0];
224
2.23k
            else
225
24.5k
              
for (x = 0; 2.23k
x < adim[3];
x++22.3k
)
226
22.3k
                bp1[x] = ap1[x] * nscalep2[x] + nbiasp2[x];
227
2.23k
            ap1 += astride[2];
228
2.23k
            bp1 += bstride[2];
229
2.23k
          }
230
572
        }
231
150
      }
232
24
    }
233
24
  } else {
234
0
    assert(output_size >= 1);
235
0
    int mean_stride[CCV_NNC_MAX_DIM_ALLOC];
236
0
    int var_stride[CCV_NNC_MAX_DIM_ALLOC];
237
0
    ccv_nnc_tensor_view_get_stride(mean, mean_stride);
238
0
    ccv_nnc_tensor_view_get_stride(var, var_stride);
239
0
    int i[CCV_NNC_MAX_DIM + 2];
240
0
    int x;
241
0
    assert(!(flags & CCV_NNC_ZERO_MEMORY_ALLOC));
242
0
    int count = 1;
243
0
    for (x = 0; x < CCV_NNC_MAX_DIM + 2; x++)
244
0
      count *= rdim[x];
245
0
    float* const meanp = mean->data.f32;
246
0
    float* const varp = var->data.f32;
247
0
    float* const scalep = scale->data.f32;
248
0
    float* const biasp = bias->data.f32;
249
0
    float* const nscalep = (float*)ccv_nnc_stream_context_get_workspace(stream_context, sizeof(float) * count * 2, CCV_TENSOR_CPU_MEMORY);
250
0
    float* const nbiasp = nscalep + count;
251
0
    for (i[0] = 0; i[0] < rdim[0]; i[0]++)
252
0
    {
253
0
      float* const meanp0 = meanp + i[0] * mean_stride[0];
254
0
      float* const varp0 = varp + i[0] * var_stride[0];
255
0
      float* const scalep0 = scalep + i[0] * scale_stride[0];
256
0
      float* const biasp0 = biasp + i[0] * bias_stride[0];
257
0
      float* const nscalep0 = nscalep + i[0] * rdim[1] * rdim[2] * rdim[3];
258
0
      float* const nbiasp0 = nbiasp + i[0] * rdim[1] * rdim[2] * rdim[3];
259
0
      for (i[1] = 0; i[1] < rdim[1]; i[1]++)
260
0
      {
261
0
        float* const meanp1 = meanp0 + i[1] * mean_stride[1];
262
0
        float* const varp1 = varp0 + i[1] * var_stride[1];
263
0
        float* const scalep1 = scalep0 + i[1] * scale_stride[1];
264
0
        float* const biasp1 = biasp0 + i[1] * bias_stride[1];
265
0
        float* const nscalep1 = nscalep0 + i[1] * rdim[2] * rdim[3];
266
0
        float* const nbiasp1 = nbiasp0 + i[1] * rdim[2] * rdim[3];
267
0
        for (i[2] = 0; i[2] < rdim[2]; i[2]++)
268
0
        {
269
0
          float* const meanp2 = meanp1 + i[2] * mean_stride[2];
270
0
          float* const varp2 = varp1 + i[2] * var_stride[2];
271
0
          float* const scalep2 = scalep1 + i[2] * scale_stride[2];
272
0
          float* const biasp2 = biasp1 + i[2] * bias_stride[2];
273
0
          float* const nscalep2 = nscalep1 + i[2] * rdim[3];
274
0
          float* const nbiasp2 = nbiasp1 + i[2] * rdim[3];
275
0
          for (x = 0; x < rdim[3]; x++)
276
0
          {
277
0
            const float w = scalep2[x] / (sqrtf(varp2[x]) + epsilon);
278
0
            nscalep2[x] = w;
279
0
            nbiasp2[x] = biasp2[x] - meanp2[x] * w;
280
0
          }
281
0
        }
282
0
      }
283
0
    }
284
0
    float* const ap = a->data.f32;
285
0
    float* const bp = b->data.f32;
286
0
    for (i[0] = 0; i[0] < adim[0]; i[0]++)
287
0
    {
288
0
      float* const ap0 = ap + i[0] * astride[0];
289
0
      float* const bp0 = bp + i[0] * bstride[0];
290
0
      float* const nscalep0 = rdim[0] == 1 ? nscalep : nscalep + i[0] * rdim[1] * rdim[2] * rdim[3];
291
0
      float* const nbiasp0 = rdim[0] == 1 ? nbiasp : nbiasp + i[0] * rdim[1] * rdim[2] * rdim[3];
292
0
      for (i[1] = 0; i[1] < adim[1]; i[1]++)
293
0
      {
294
0
        float* ap1 = ap0 + i[1] * astride[1];
295
0
        float* bp1 = bp0 + i[1] * bstride[1];
296
0
        float* const nscalep1 = rdim[1] == 1 ? nscalep0 : nscalep0 + i[1] * rdim[2] * rdim[3];
297
0
        float* const nbiasp1 = rdim[1] == 1 ? nbiasp0 : nbiasp0 + i[1] * rdim[2] * rdim[3];
298
0
        for (i[2] = 0; i[2] < adim[2]; i[2]++)
299
0
        {
300
0
          float* const nscalep2 = rdim[2] == 1 ? nscalep1 : nscalep1 + i[2] * rdim[3];
301
0
          float* const nbiasp2 = rdim[2] == 1 ? nbiasp1 : nbiasp1 + i[2] * rdim[3];
302
0
          if (rdim[3] == 1)
303
0
            for (x = 0; x < adim[3]; x++)
304
0
              bp1[x] = ap1[x] * nscalep2[0] + nbiasp2[0];
305
0
          else
306
0
            for (x = 0; x < adim[3]; x++)
307
0
              bp1[x] = ap1[x] * nscalep2[x] + nbiasp2[x];
308
0
          ap1 += astride[2];
309
0
          bp1 += bstride[2];
310
0
        }
311
0
      }
312
0
    }
313
0
  }
314
24
  return CCV_NNC_EXEC_SUCCESS;
315
24
}
316
317
static int _ccv_nnc_batch_norm_back(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint, const int flags, ccv_nnc_tensor_t* const* const inputs, const int input_size, ccv_nnc_tensor_t* const* const outputs, const int output_size, ccv_nnc_stream_context_t* const stream_context)
318
3
{
319
3
  assert(input_size == 15);
320
3
  assert(output_size >= 3);
321
3
  ccv_nnc_tensor_view_t* const g = (ccv_nnc_tensor_view_t*)inputs[0];
322
3
  ccv_nnc_tensor_view_t* const a = (ccv_nnc_tensor_view_t*)inputs[5];
323
3
  ccv_nnc_tensor_view_t* const scale = (ccv_nnc_tensor_view_t*)inputs[6];
324
3
  ccv_nnc_tensor_view_t* const saved_mean = (ccv_nnc_tensor_view_t*)inputs[13];
325
3
  ccv_nnc_tensor_view_t* const saved_inv_std = (ccv_nnc_tensor_view_t*)inputs[14];
326
3
  ccv_nnc_tensor_view_t* const h = (ccv_nnc_tensor_view_t*)outputs[0];
327
3
  ccv_nnc_tensor_view_t* const dscale = (ccv_nnc_tensor_view_t*)outputs[1];
328
3
  ccv_nnc_tensor_view_t* const dbias = (ccv_nnc_tensor_view_t*)outputs[2];
329
3
  assert(ccv_nnc_tensor_nd(g->info.dim) <= CCV_NNC_MAX_DIM + 2);
330
3
  assert(ccv_nnc_tensor_nd(a->info.dim) <= CCV_NNC_MAX_DIM + 2);
331
3
  assert(ccv_nnc_tensor_nd(h->info.dim) <= CCV_NNC_MAX_DIM + 2);
332
  // Assuming this is float 32.
333
3
  int gdim[CCV_NNC_MAX_DIM_ALLOC];
334
3
  int rdim[CCV_NNC_MAX_DIM_ALLOC];
335
3
  ccv_nnc_tensor_view_get_dim(g, gdim);
336
3
  ccv_nnc_tensor_view_get_dim(scale, rdim);
337
3
  assert(ccv_nnc_tensor_view_check_dim(saved_mean, rdim));
338
3
  assert(ccv_nnc_tensor_view_check_dim(saved_inv_std, rdim));
339
3
  assert(ccv_nnc_tensor_view_check_dim(dscale, rdim));
340
3
  assert(ccv_nnc_tensor_view_check_dim(dbias, rdim));
341
3
  assert(ccv_nnc_tensor_view_check_dim(a, gdim));
342
3
  assert(ccv_nnc_tensor_view_check_dim(h, gdim));
343
3
  _ccv_nnc_reduce_sum_forw_cpu_ref(g, dbias);
344
3
  int astride[CCV_NNC_MAX_DIM_ALLOC];
345
3
  int gstride[CCV_NNC_MAX_DIM_ALLOC];
346
3
  int hstride[CCV_NNC_MAX_DIM_ALLOC];
347
3
  int mean_stride[CCV_NNC_MAX_DIM_ALLOC];
348
3
  int inv_std_stride[CCV_NNC_MAX_DIM_ALLOC];
349
3
  int dscale_stride[CCV_NNC_MAX_DIM_ALLOC];
350
3
  int dbias_stride[CCV_NNC_MAX_DIM_ALLOC];
351
3
  ccv_nnc_tensor_view_get_stride(a, astride);
352
3
  ccv_nnc_tensor_view_get_stride(g, gstride);
353
3
  ccv_nnc_tensor_view_get_stride(h, hstride);
354
3
  ccv_nnc_tensor_view_get_stride(saved_mean, mean_stride);
355
3
  ccv_nnc_tensor_view_get_stride(saved_inv_std, inv_std_stride);
356
3
  ccv_nnc_tensor_view_get_stride(dscale, dscale_stride);
357
3
  ccv_nnc_tensor_view_get_stride(dbias, dbias_stride);
358
  // Need to allocate two additional memory:
359
  // 1. normalized a;
360
  // 2. scale * inv_std / batch_size;
361
3
  assert(!(flags & CCV_NNC_ZERO_MEMORY_ALLOC));
362
3
  int x;
363
3
  int batch_size = 1;
364
15
  for (x = 0; x < CCV_NNC_MAX_DIM + 2; 
x++12
)
365
12
    batch_size *= gdim[x];
366
15
  for (x = 0; x < CCV_NNC_MAX_DIM + 2; 
x++12
)
367
12
    batch_size /= rdim[x];
368
3
  int gcount = 1, rcount = 1;
369
15
  for (x = 0; x < CCV_NNC_MAX_DIM + 2; 
x++12
)
370
12
    gcount *= gdim[x], rcount *= rdim[x];
371
3
  float* const ah = (float*)ccv_nnc_stream_context_get_workspace(stream_context, sizeof(float) * gcount + sizeof(float) * rcount, CCV_TENSOR_CPU_MEMORY);
372
3
  float* const sisb = ah + gcount;
373
3
  ccv_nnc_tensor_t sisbt = ccv_nnc_tensor(sisb, scale->info, 0);
374
3
  _ccv_nnc_mul_forw_cpu_ref(1. / batch_size, scale, saved_inv_std, (ccv_nnc_tensor_view_t*)&sisbt);
375
3
  int i[CCV_NNC_MAX_DIM + 2];
376
3
  float* const ap = a->data.f32;
377
3
  float* ahp = ah;
378
3
  float* const meanp = saved_mean->data.f32;
379
3
  float* const inv_stdp = saved_inv_std->data.f32;
380
9
  for (i[0] = 0; i[0] < gdim[0]; 
i[0]++6
)
381
6
  {
382
6
    float* const ap0 = ap + i[0] * astride[0];
383
6
    float* const meanp0 = rdim[0] == 1 ? meanp : 
meanp + i[0] * mean_stride[0]0
;
384
6
    float* const inv_stdp0 = rdim[0] == 1 ? inv_stdp : 
inv_stdp + i[0] * inv_std_stride[0]0
;
385
18
    for (i[1] = 0; i[1] < gdim[1]; 
i[1]++12
)
386
12
    {
387
12
      float* ap1 = ap0 + i[1] * astride[1];
388
12
      float* const meanp1 = rdim[1] == 1 ? meanp0 : 
meanp0 + i[1] * mean_stride[1]0
;
389
12
      float* const inv_stdp1 = rdim[1] == 1 ? inv_stdp0 : 
inv_stdp0 + i[1] * inv_std_stride[1]0
;
390
36
      for (i[2] = 0; i[2] < gdim[2]; 
i[2]++24
)
391
24
      {
392
24
        float* const meanp2 = rdim[2] == 1 ? meanp1 : 
meanp1 + i[2] * mean_stride[2]0
;
393
24
        float* const inv_stdp2 = rdim[2] == 1 ? inv_stdp1 : 
inv_stdp1 + i[2] * inv_std_stride[2]0
;
394
24
        if (rdim[3] == 1)
395
0
          for (x = 0; x < gdim[3]; x++)
396
0
            ahp[x] = (ap1[x] - meanp2[0]) * inv_stdp2[0];
397
24
        else
398
264
          
for (x = 0; 24
x < gdim[3];
x++240
)
399
240
            ahp[x] = (ap1[x] - meanp2[x]) * inv_stdp2[x];
400
24
        ap1 += astride[2];
401
24
        ahp += gdim[3];
402
24
      }
403
12
    }
404
6
  }
405
3
  ccv_nnc_tensor_zero(dscale);
406
3
  ahp = ah;
407
3
  float* const gp = g->data.f32;
408
3
  float* const dscalep = dscale->data.f32;
409
9
  for (i[0] = 0; i[0] < gdim[0]; 
i[0]++6
)
410
6
  {
411
6
    float* const gp0 = gp + i[0] * gstride[0];
412
6
    float* const dscalep0 = rdim[0] == 1 ? dscalep : 
dscalep + i[0] * dscale_stride[0]0
;
413
18
    for (i[1] = 0; i[1] < gdim[1]; 
i[1]++12
)
414
12
    {
415
12
      float* gp1 = gp0 + i[1] * gstride[1];
416
12
      float* const dscalep1 = rdim[1] == 1 ? dscalep0 : 
dscalep0 + i[1] * dscale_stride[1]0
;
417
36
      for (i[2] = 0; i[2] < gdim[2]; 
i[2]++24
)
418
24
      {
419
24
        float* const dscalep2 = rdim[2] == 1 ? dscalep1 : 
dscalep1 + i[2] * dscale_stride[2]0
;
420
24
        if (rdim[3] == 1)
421
0
          for (x = 0; x < gdim[3]; x++)
422
0
            dscalep2[0] += ahp[x] * gp1[x];
423
24
        else
424
264
          
for (x = 0; 24
x < gdim[3];
x++240
)
425
240
            dscalep2[x] += ahp[x] * gp1[x];
426
24
        gp1 += gstride[2];
427
24
        ahp += gdim[3];
428
24
      }
429
12
    }
430
6
  }
431
  // Now the part to compute dx (h).
432
3
  float* const hp = h->data.f32;
433
3
  ahp = ah;
434
3
  float* const sisbp = sisb;
435
3
  float* const dbiasp = dbias->data.f32;
436
9
  for (i[0] = 0; i[0] < gdim[0]; 
i[0]++6
)
437
6
  {
438
6
    float* const gp0 = gp + i[0] * gstride[0];
439
6
    float* const hp0 = hp + i[0] * hstride[0];
440
6
    float* const sisbp0 = rdim[0] == 1 ? sisbp : 
sisbp + i[0] * rdim[1] * rdim[2] * rdim[3]0
;
441
6
    float* const dscalep0 = rdim[0] == 1 ? dscalep : 
dscalep + i[0] * dscale_stride[0]0
;
442
6
    float* const dbiasp0 = rdim[0] == 1 ? dbiasp : 
dbiasp + i[0] * dbias_stride[0]0
;
443
18
    for (i[1] = 0; i[1] < gdim[1]; 
i[1]++12
)
444
12
    {
445
12
      float* gp1 = gp0 + i[1] * gstride[1];
446
12
      float* hp1 = hp0 + i[1] * hstride[1];
447
12
      float* const sisbp1 = rdim[1] == 1 ? sisbp0 : 
sisbp0 + i[1] * rdim[2] * rdim[3]0
;
448
12
      float* const dscalep1 = rdim[1] == 1 ? dscalep0 : 
dscalep0 + i[1] * dscale_stride[1]0
;
449
12
      float* const dbiasp1 = rdim[1] == 1 ? dbiasp0 : 
dbiasp0 + i[1] * dbias_stride[1]0
;
450
36
      for (i[2] = 0; i[2] < gdim[2]; 
i[2]++24
)
451
24
      {
452
24
        float* const sisbp2 = rdim[2] == 1 ? sisbp1 : 
sisbp1 + i[2] * rdim[3]0
;
453
24
        float* const dscalep2 = rdim[2] == 1 ? dscalep1 : 
dscalep1 + i[2] * dscale_stride[2]0
;
454
24
        float* const dbiasp2 = rdim[2] == 1 ? dbiasp1 : 
dbiasp1 + i[2] * dbias_stride[2]0
;
455
24
        if (rdim[3] == 1)
456
0
          for (x = 0; x < gdim[3]; x++)
457
0
            hp1[x] = sisbp2[0] * (batch_size * gp1[x] - dbiasp2[0] - ahp[x] * dscalep2[0]);
458
24
        else
459
264
          
for (x = 0; 24
x < gdim[3];
x++240
)
460
240
            hp1[x] = sisbp2[x] * (batch_size * gp1[x] - dbiasp2[x] - ahp[x] * dscalep2[x]);
461
24
        gp1 += gstride[2];
462
24
        hp1 += hstride[2];
463
24
        ahp += gdim[3];
464
24
      }
465
12
    }
466
6
  }
467
3
  return CCV_NNC_EXEC_SUCCESS;
468
3
}
469
470
REGISTER_COMMAND_BACKEND(CCV_NNC_BATCH_NORM_FORWARD, CCV_NNC_BACKEND_CPU_REF)(ccv_nnc_cmd_backend_registry_t* const registry)
471
1
{
472
1
  registry->tensor_formats = CCV_TENSOR_FORMAT_NHWC | CCV_TENSOR_FORMAT_NCHW | CCV_TENSOR_FORMAT_CHWN;
473
1
  registry->tensor_datatypes = CCV_32F;
474
1
  registry->tensor_memory = CCV_TENSOR_CPU_MEMORY;
475
1
  registry->algorithms = 1;
476
1
  registry->exec = _ccv_nnc_batch_norm_forw;
477
1
}
478
479
REGISTER_COMMAND_BACKEND(CCV_NNC_BATCH_NORM_BACKWARD, CCV_NNC_BACKEND_CPU_REF)(ccv_nnc_cmd_backend_registry_t* const registry)
480
1
{
481
1
  registry->tensor_formats = CCV_TENSOR_FORMAT_NHWC | CCV_TENSOR_FORMAT_NCHW | CCV_TENSOR_FORMAT_CHWN;
482
1
  registry->tensor_datatypes = CCV_32F;
483
1
  registry->tensor_memory = CCV_TENSOR_CPU_MEMORY;
484
1
  registry->algorithms = 1;
485
1
  registry->exec = _ccv_nnc_batch_norm_back;
486
1
}