/home/liu/actions-runner/_work/ccv/ccv/lib/nnc/cmd/norm/ccv_nnc_batch_norm_cpu_ref.c
Line | Count | Source (jump to first uncovered line) |
1 | | #include "ccv.h" |
2 | | #include "ccv_internal.h" |
3 | | #include "nnc/ccv_nnc.h" |
4 | | #include "nnc/ccv_nnc_easy.h" |
5 | | #include "nnc/ccv_nnc_internal.h" |
6 | | #ifdef USE_OPENMP |
7 | | #include <omp.h> |
8 | | #endif |
9 | | #ifdef USE_DISPATCH |
10 | | #include <dispatch/dispatch.h> |
11 | | #endif |
12 | | |
13 | | // Shared methods. |
14 | | #include "../_ccv_nnc_cpu_ref.h" |
15 | | |
16 | | static int _ccv_nnc_batch_norm_forw(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint, const int flags, ccv_nnc_tensor_t* const* const inputs, const int input_size, ccv_nnc_tensor_t* const* const outputs, const int output_size, ccv_nnc_stream_context_t* const stream_context) |
17 | 24 | { |
18 | 24 | assert(input_size == 5); |
19 | 24 | ccv_nnc_tensor_view_t* const a = (ccv_nnc_tensor_view_t*)inputs[0]; |
20 | 24 | ccv_nnc_tensor_view_t* const scale = (ccv_nnc_tensor_view_t*)inputs[1]; |
21 | 24 | ccv_nnc_tensor_view_t* const bias = (ccv_nnc_tensor_view_t*)inputs[2]; |
22 | 24 | ccv_nnc_tensor_view_t* const mean = (ccv_nnc_tensor_view_t*)inputs[3]; |
23 | 24 | ccv_nnc_tensor_view_t* const var = (ccv_nnc_tensor_view_t*)inputs[4]; |
24 | 24 | ccv_nnc_tensor_view_t* const b = (ccv_nnc_tensor_view_t*)outputs[0]; |
25 | 24 | assert(ccv_nnc_tensor_nd(a->info.dim) <= CCV_NNC_MAX_DIM + 2); |
26 | 24 | assert(ccv_nnc_tensor_nd(b->info.dim) <= CCV_NNC_MAX_DIM + 2); |
27 | | // Assuming this is float 32. |
28 | 24 | int adim[CCV_NNC_MAX_DIM_ALLOC]; |
29 | 24 | int rdim[CCV_NNC_MAX_DIM_ALLOC]; |
30 | 24 | ccv_nnc_tensor_view_get_dim(a, adim); |
31 | 24 | ccv_nnc_tensor_view_get_dim(scale, rdim); |
32 | 24 | assert(ccv_nnc_tensor_view_check_dim(bias, rdim)); |
33 | 24 | assert(ccv_nnc_tensor_view_check_dim(mean, rdim)); |
34 | 24 | assert(ccv_nnc_tensor_view_check_dim(var, rdim)); |
35 | 24 | assert(ccv_nnc_tensor_view_check_dim(b, adim)); |
36 | 24 | assert(CCV_NNC_MAX_DIM == 2); // Need to change this logic for CCV_NNC_MAX_DIM == other number. |
37 | 24 | int astride[CCV_NNC_MAX_DIM_ALLOC]; |
38 | 24 | int bstride[CCV_NNC_MAX_DIM_ALLOC]; |
39 | 24 | int scale_stride[CCV_NNC_MAX_DIM_ALLOC]; |
40 | 24 | int bias_stride[CCV_NNC_MAX_DIM_ALLOC]; |
41 | 24 | ccv_nnc_tensor_view_get_stride(a, astride); |
42 | 24 | ccv_nnc_tensor_view_get_stride(scale, scale_stride); |
43 | 24 | ccv_nnc_tensor_view_get_stride(bias, bias_stride); |
44 | 24 | ccv_nnc_tensor_view_get_stride(b, bstride); |
45 | 24 | const float epsilon = cmd.info.bnorm.epsilon; |
46 | 24 | if (!cmd.info.bnorm.is_test) |
47 | 24 | { |
48 | 24 | assert(output_size == 5); |
49 | | // Both are inplace. |
50 | 24 | assert(inputs[3]->data.f32 == outputs[1]->data.f32); |
51 | 24 | assert(inputs[4]->data.f32 == outputs[2]->data.f32); |
52 | 24 | ccv_nnc_tensor_view_t* const saved_mean = (ccv_nnc_tensor_view_t*)outputs[3]; |
53 | 24 | ccv_nnc_tensor_view_t* const saved_inv_std = (ccv_nnc_tensor_view_t*)outputs[4]; |
54 | 24 | assert(ccv_nnc_tensor_view_check_dim(saved_mean, rdim)); |
55 | 24 | assert(ccv_nnc_tensor_view_check_dim(saved_inv_std, rdim)); |
56 | 24 | int saved_mean_stride[CCV_NNC_MAX_DIM_ALLOC]; |
57 | 24 | int saved_inv_std_stride[CCV_NNC_MAX_DIM_ALLOC]; |
58 | 24 | ccv_nnc_tensor_view_get_stride(saved_mean, saved_mean_stride); |
59 | 24 | ccv_nnc_tensor_view_get_stride(saved_inv_std, saved_inv_std_stride); |
60 | 24 | int i[CCV_NNC_MAX_DIM + 2]; |
61 | 24 | int x; |
62 | 24 | int batch_size = 1; |
63 | 120 | for (x = 0; x < CCV_NNC_MAX_DIM + 2; x++96 ) |
64 | 96 | batch_size *= adim[x]; |
65 | 120 | for (x = 0; x < CCV_NNC_MAX_DIM + 2; x++96 ) |
66 | 96 | batch_size /= rdim[x]; |
67 | 24 | const float inv_batch_size = 1. / batch_size; |
68 | 24 | _ccv_nnc_reduce_sum_forw_cpu_ref(a, saved_mean); |
69 | 24 | _ccv_nnc_mul_forw_cpu_ref(inv_batch_size, saved_mean, 0, saved_mean); |
70 | | // Copy this into running mean / var. |
71 | 24 | _ccv_nnc_add_forw_cpu_ref(cmd.info.bnorm.momentum, 1. - cmd.info.bnorm.momentum, mean, saved_mean, mean); |
72 | 24 | ccv_nnc_tensor_zero(saved_inv_std); |
73 | 24 | float* const ap = a->data.f32; |
74 | 24 | float* const meanp = saved_mean->data.f32; |
75 | 24 | float* const varp = saved_inv_std->data.f32; |
76 | 174 | for (i[0] = 0; i[0] < adim[0]; i[0]++150 ) |
77 | 150 | { |
78 | 150 | float* const ap0 = ap + i[0] * astride[0]; |
79 | 150 | float* const meanp0 = rdim[0] == 1 ? meanp : meanp + i[0] * saved_mean_stride[0]0 ; |
80 | 150 | float* const varp0 = rdim[0] == 1 ? varp : varp + i[0] * saved_inv_std_stride[0]0 ; |
81 | 722 | for (i[1] = 0; i[1] < adim[1]; i[1]++572 ) |
82 | 572 | { |
83 | 572 | float* ap1 = ap0 + i[1] * astride[1]; |
84 | 572 | float* const meanp1 = rdim[1] == 1 ? meanp0 : meanp0 + i[1] * saved_mean_stride[1]0 ; |
85 | 572 | float* const varp1 = rdim[1] == 1 ? varp0 : varp0 + i[1] * saved_inv_std_stride[1]0 ; |
86 | 2.80k | for (i[2] = 0; i[2] < adim[2]; i[2]++2.23k ) |
87 | 2.23k | { |
88 | 2.23k | float* const meanp2 = rdim[2] == 1 ? meanp1 : meanp1 + i[2] * saved_mean_stride[2]0 ; |
89 | 2.23k | float* const varp2 = rdim[2] == 1 ? varp1 : varp1 + i[2] * saved_inv_std_stride[2]0 ; |
90 | 2.23k | if (rdim[3] == 1) |
91 | 0 | for (x = 0; x < adim[3]; x++) |
92 | 0 | { |
93 | 0 | float w = ap1[x] - meanp2[0]; |
94 | 0 | varp2[0] += w * w; |
95 | 0 | } |
96 | 2.23k | else |
97 | 24.5k | for (x = 0; 2.23k x < adim[3]; x++22.3k ) |
98 | 22.3k | { |
99 | 22.3k | float w = ap1[x] - meanp2[x]; |
100 | 22.3k | varp2[x] += w * w; |
101 | 22.3k | } |
102 | 2.23k | ap1 += astride[2]; |
103 | 2.23k | } |
104 | 572 | } |
105 | 150 | } |
106 | 24 | _ccv_nnc_mul_forw_cpu_ref(inv_batch_size, saved_inv_std, 0, saved_inv_std); |
107 | 24 | _ccv_nnc_add_forw_cpu_ref(cmd.info.bnorm.momentum, 1. - cmd.info.bnorm.momentum, var, saved_inv_std, var); |
108 | 48 | for (i[0] = 0; i[0] < rdim[0]; i[0]++24 ) |
109 | 24 | { |
110 | 24 | float* const varp0 = varp + i[0] * saved_inv_std_stride[0]; |
111 | 48 | for (i[1] = 0; i[1] < rdim[1]; i[1]++24 ) |
112 | 24 | { |
113 | 24 | float* const varp1 = varp0 + i[1] * saved_inv_std_stride[1]; |
114 | 48 | for (i[2] = 0; i[2] < rdim[2]; i[2]++24 ) |
115 | 24 | { |
116 | 24 | float* const varp2 = varp1 + i[2] * saved_inv_std_stride[2]; |
117 | 264 | for (x = 0; x < rdim[3]; x++240 ) |
118 | 240 | varp2[x] = 1. / sqrtf(varp2[x] + epsilon); |
119 | 24 | } |
120 | 24 | } |
121 | 24 | } |
122 | 24 | float* const scalep = scale->data.f32; |
123 | 24 | float* const biasp = bias->data.f32; |
124 | | // Now, after mean and inv_std computed, go and stretch a. |
125 | 24 | if (flags & CCV_NNC_ZERO_MEMORY_ALLOC) |
126 | 0 | { |
127 | | // Do the straight-forward one, y = (x - mean) * inv_std * scale + bias, we cannot allocate extra memory to help. |
128 | 0 | float* const bp = b->data.f32; |
129 | 0 | for (i[0] = 0; i[0] < adim[0]; i[0]++) |
130 | 0 | { |
131 | 0 | float* const ap0 = ap + i[0] * astride[0]; |
132 | 0 | float* const bp0 = bp + i[0] * bstride[0]; |
133 | 0 | float* const meanp0 = rdim[0] == 1 ? meanp : meanp + i[0] * saved_mean_stride[0]; |
134 | 0 | float* const varp0 = rdim[0] == 1 ? varp : varp + i[0] * saved_inv_std_stride[0]; |
135 | 0 | float* const scalep0 = rdim[0] == 1 ? scalep : scalep + i[0] * scale_stride[0]; |
136 | 0 | float* const biasp0 = rdim[0] == 1 ? biasp : biasp + i[0] * bias_stride[0]; |
137 | 0 | for (i[1] = 0; i[1] < adim[1]; i[1]++) |
138 | 0 | { |
139 | 0 | float* ap1 = ap0 + i[1] * astride[1]; |
140 | 0 | float* bp1 = bp0 + i[1] * bstride[1]; |
141 | 0 | float* const meanp1 = rdim[1] == 1 ? meanp0 : meanp0 + i[1] * saved_mean_stride[1]; |
142 | 0 | float* const varp1 = rdim[1] == 1 ? varp0 : varp0 + i[1] * saved_inv_std_stride[1]; |
143 | 0 | float* const scalep1 = rdim[1] == 1 ? scalep0 : scalep0 + i[1] * scale_stride[1]; |
144 | 0 | float* const biasp1 = rdim[1] == 1 ? biasp0 : biasp0 + i[1] * bias_stride[1]; |
145 | 0 | for (i[2] = 0; i[2] < adim[2]; i[2]++) |
146 | 0 | { |
147 | 0 | float* const meanp2 = rdim[2] == 1 ? meanp1 : meanp1 + i[2] * saved_mean_stride[2]; |
148 | 0 | float* const varp2 = rdim[2] == 1 ? varp1 : varp1 + i[2] * saved_inv_std_stride[2]; |
149 | 0 | float* const scalep2 = rdim[2] == 1 ? scalep1 : scalep1 + i[2] * scale_stride[2]; |
150 | 0 | float* const biasp2 = rdim[2] == 1 ? biasp1 : biasp1 + i[2] * bias_stride[2]; |
151 | 0 | if (rdim[3] == 1) |
152 | 0 | for (x = 0; x < adim[3]; x++) |
153 | 0 | bp1[x] = (ap1[x] - meanp2[0]) * varp2[0] * scalep2[0] + biasp2[0]; |
154 | 0 | else |
155 | 0 | for (x = 0; x < adim[3]; x++) |
156 | 0 | bp1[x] = (ap1[x] - meanp2[x]) * varp2[x] * scalep2[x] + biasp2[x]; |
157 | 0 | ap1 += astride[2]; |
158 | 0 | bp1 += bstride[2]; |
159 | 0 | } |
160 | 0 | } |
161 | 0 | } |
162 | 24 | } else { |
163 | | // If we allocate extra memory, we can convert y = (x - mean) * inv_std * scale + bias |
164 | | // to y = x * inv_std * scale + (bias - mean * inv_std * scale) |
165 | | // we can pre-compute nscale = inv_std * scale, nbias = bias - mean * inv_std * scale |
166 | 24 | int count = 1; |
167 | 120 | for (x = 0; x < CCV_NNC_MAX_DIM + 2; x++96 ) |
168 | 96 | count *= rdim[x]; |
169 | 24 | float* const nscalep = (float*)ccv_nnc_stream_context_get_workspace(stream_context, sizeof(float) * count * 2, CCV_TENSOR_CPU_MEMORY); |
170 | 24 | float* const nbiasp = nscalep + count; |
171 | 48 | for (i[0] = 0; i[0] < rdim[0]; i[0]++24 ) |
172 | 24 | { |
173 | 24 | float* const meanp0 = meanp + i[0] * saved_mean_stride[0]; |
174 | 24 | float* const varp0 = varp + i[0] * saved_inv_std_stride[0]; |
175 | 24 | float* const scalep0 = scalep + i[0] * scale_stride[0]; |
176 | 24 | float* const biasp0 = biasp + i[0] * bias_stride[0]; |
177 | 24 | float* const nscalep0 = nscalep + i[0] * rdim[1] * rdim[2] * rdim[3]; |
178 | 24 | float* const nbiasp0 = nbiasp + i[0] * rdim[1] * rdim[2] * rdim[3]; |
179 | 48 | for (i[1] = 0; i[1] < rdim[1]; i[1]++24 ) |
180 | 24 | { |
181 | 24 | float* const meanp1 = meanp0 + i[1] * saved_mean_stride[1]; |
182 | 24 | float* const varp1 = varp0 + i[1] * saved_inv_std_stride[1]; |
183 | 24 | float* const scalep1 = scalep0 + i[1] * scale_stride[1]; |
184 | 24 | float* const biasp1 = biasp0 + i[1] * bias_stride[1]; |
185 | 24 | float* const nscalep1 = nscalep0 + i[1] * rdim[2] * rdim[3]; |
186 | 24 | float* const nbiasp1 = nbiasp0 + i[1] * rdim[2] * rdim[3]; |
187 | 48 | for (i[2] = 0; i[2] < rdim[2]; i[2]++24 ) |
188 | 24 | { |
189 | 24 | float* const meanp2 = meanp1 + i[2] * saved_mean_stride[2]; |
190 | 24 | float* const varp2 = varp1 + i[2] * saved_inv_std_stride[2]; |
191 | 24 | float* const scalep2 = scalep1 + i[2] * scale_stride[2]; |
192 | 24 | float* const biasp2 = biasp1 + i[2] * bias_stride[2]; |
193 | 24 | float* const nscalep2 = nscalep1 + i[2] * rdim[3]; |
194 | 24 | float* const nbiasp2 = nbiasp1 + i[2] * rdim[3]; |
195 | 264 | for (x = 0; x < rdim[3]; x++240 ) |
196 | 240 | { |
197 | 240 | const float w = varp2[x] * scalep2[x]; |
198 | 240 | nscalep2[x] = w; |
199 | 240 | nbiasp2[x] = biasp2[x] - meanp2[x] * w; |
200 | 240 | } |
201 | 24 | } |
202 | 24 | } |
203 | 24 | } |
204 | 24 | float* const bp = b->data.f32; |
205 | 174 | for (i[0] = 0; i[0] < adim[0]; i[0]++150 ) |
206 | 150 | { |
207 | 150 | float* const ap0 = ap + i[0] * astride[0]; |
208 | 150 | float* const bp0 = bp + i[0] * bstride[0]; |
209 | 150 | float* const nscalep0 = rdim[0] == 1 ? nscalep : nscalep + i[0] * rdim[1] * rdim[2] * rdim[3]0 ; |
210 | 150 | float* const nbiasp0 = rdim[0] == 1 ? nbiasp : nbiasp + i[0] * rdim[1] * rdim[2] * rdim[3]0 ; |
211 | 722 | for (i[1] = 0; i[1] < adim[1]; i[1]++572 ) |
212 | 572 | { |
213 | 572 | float* ap1 = ap0 + i[1] * astride[1]; |
214 | 572 | float* bp1 = bp0 + i[1] * bstride[1]; |
215 | 572 | float* const nscalep1 = rdim[1] == 1 ? nscalep0 : nscalep0 + i[1] * rdim[2] * rdim[3]0 ; |
216 | 572 | float* const nbiasp1 = rdim[1] == 1 ? nbiasp0 : nbiasp0 + i[1] * rdim[2] * rdim[3]0 ; |
217 | 2.80k | for (i[2] = 0; i[2] < adim[2]; i[2]++2.23k ) |
218 | 2.23k | { |
219 | 2.23k | float* const nscalep2 = rdim[2] == 1 ? nscalep1 : nscalep1 + i[2] * rdim[3]0 ; |
220 | 2.23k | float* const nbiasp2 = rdim[2] == 1 ? nbiasp1 : nbiasp1 + i[2] * rdim[3]0 ; |
221 | 2.23k | if (rdim[3] == 1) |
222 | 0 | for (x = 0; x < adim[3]; x++) |
223 | 0 | bp1[x] = ap1[x] * nscalep2[0] + nbiasp2[0]; |
224 | 2.23k | else |
225 | 24.5k | for (x = 0; 2.23k x < adim[3]; x++22.3k ) |
226 | 22.3k | bp1[x] = ap1[x] * nscalep2[x] + nbiasp2[x]; |
227 | 2.23k | ap1 += astride[2]; |
228 | 2.23k | bp1 += bstride[2]; |
229 | 2.23k | } |
230 | 572 | } |
231 | 150 | } |
232 | 24 | } |
233 | 24 | } else { |
234 | 0 | assert(output_size >= 1); |
235 | 0 | int mean_stride[CCV_NNC_MAX_DIM_ALLOC]; |
236 | 0 | int var_stride[CCV_NNC_MAX_DIM_ALLOC]; |
237 | 0 | ccv_nnc_tensor_view_get_stride(mean, mean_stride); |
238 | 0 | ccv_nnc_tensor_view_get_stride(var, var_stride); |
239 | 0 | int i[CCV_NNC_MAX_DIM + 2]; |
240 | 0 | int x; |
241 | 0 | assert(!(flags & CCV_NNC_ZERO_MEMORY_ALLOC)); |
242 | 0 | int count = 1; |
243 | 0 | for (x = 0; x < CCV_NNC_MAX_DIM + 2; x++) |
244 | 0 | count *= rdim[x]; |
245 | 0 | float* const meanp = mean->data.f32; |
246 | 0 | float* const varp = var->data.f32; |
247 | 0 | float* const scalep = scale->data.f32; |
248 | 0 | float* const biasp = bias->data.f32; |
249 | 0 | float* const nscalep = (float*)ccv_nnc_stream_context_get_workspace(stream_context, sizeof(float) * count * 2, CCV_TENSOR_CPU_MEMORY); |
250 | 0 | float* const nbiasp = nscalep + count; |
251 | 0 | for (i[0] = 0; i[0] < rdim[0]; i[0]++) |
252 | 0 | { |
253 | 0 | float* const meanp0 = meanp + i[0] * mean_stride[0]; |
254 | 0 | float* const varp0 = varp + i[0] * var_stride[0]; |
255 | 0 | float* const scalep0 = scalep + i[0] * scale_stride[0]; |
256 | 0 | float* const biasp0 = biasp + i[0] * bias_stride[0]; |
257 | 0 | float* const nscalep0 = nscalep + i[0] * rdim[1] * rdim[2] * rdim[3]; |
258 | 0 | float* const nbiasp0 = nbiasp + i[0] * rdim[1] * rdim[2] * rdim[3]; |
259 | 0 | for (i[1] = 0; i[1] < rdim[1]; i[1]++) |
260 | 0 | { |
261 | 0 | float* const meanp1 = meanp0 + i[1] * mean_stride[1]; |
262 | 0 | float* const varp1 = varp0 + i[1] * var_stride[1]; |
263 | 0 | float* const scalep1 = scalep0 + i[1] * scale_stride[1]; |
264 | 0 | float* const biasp1 = biasp0 + i[1] * bias_stride[1]; |
265 | 0 | float* const nscalep1 = nscalep0 + i[1] * rdim[2] * rdim[3]; |
266 | 0 | float* const nbiasp1 = nbiasp0 + i[1] * rdim[2] * rdim[3]; |
267 | 0 | for (i[2] = 0; i[2] < rdim[2]; i[2]++) |
268 | 0 | { |
269 | 0 | float* const meanp2 = meanp1 + i[2] * mean_stride[2]; |
270 | 0 | float* const varp2 = varp1 + i[2] * var_stride[2]; |
271 | 0 | float* const scalep2 = scalep1 + i[2] * scale_stride[2]; |
272 | 0 | float* const biasp2 = biasp1 + i[2] * bias_stride[2]; |
273 | 0 | float* const nscalep2 = nscalep1 + i[2] * rdim[3]; |
274 | 0 | float* const nbiasp2 = nbiasp1 + i[2] * rdim[3]; |
275 | 0 | for (x = 0; x < rdim[3]; x++) |
276 | 0 | { |
277 | 0 | const float w = scalep2[x] / (sqrtf(varp2[x]) + epsilon); |
278 | 0 | nscalep2[x] = w; |
279 | 0 | nbiasp2[x] = biasp2[x] - meanp2[x] * w; |
280 | 0 | } |
281 | 0 | } |
282 | 0 | } |
283 | 0 | } |
284 | 0 | float* const ap = a->data.f32; |
285 | 0 | float* const bp = b->data.f32; |
286 | 0 | for (i[0] = 0; i[0] < adim[0]; i[0]++) |
287 | 0 | { |
288 | 0 | float* const ap0 = ap + i[0] * astride[0]; |
289 | 0 | float* const bp0 = bp + i[0] * bstride[0]; |
290 | 0 | float* const nscalep0 = rdim[0] == 1 ? nscalep : nscalep + i[0] * rdim[1] * rdim[2] * rdim[3]; |
291 | 0 | float* const nbiasp0 = rdim[0] == 1 ? nbiasp : nbiasp + i[0] * rdim[1] * rdim[2] * rdim[3]; |
292 | 0 | for (i[1] = 0; i[1] < adim[1]; i[1]++) |
293 | 0 | { |
294 | 0 | float* ap1 = ap0 + i[1] * astride[1]; |
295 | 0 | float* bp1 = bp0 + i[1] * bstride[1]; |
296 | 0 | float* const nscalep1 = rdim[1] == 1 ? nscalep0 : nscalep0 + i[1] * rdim[2] * rdim[3]; |
297 | 0 | float* const nbiasp1 = rdim[1] == 1 ? nbiasp0 : nbiasp0 + i[1] * rdim[2] * rdim[3]; |
298 | 0 | for (i[2] = 0; i[2] < adim[2]; i[2]++) |
299 | 0 | { |
300 | 0 | float* const nscalep2 = rdim[2] == 1 ? nscalep1 : nscalep1 + i[2] * rdim[3]; |
301 | 0 | float* const nbiasp2 = rdim[2] == 1 ? nbiasp1 : nbiasp1 + i[2] * rdim[3]; |
302 | 0 | if (rdim[3] == 1) |
303 | 0 | for (x = 0; x < adim[3]; x++) |
304 | 0 | bp1[x] = ap1[x] * nscalep2[0] + nbiasp2[0]; |
305 | 0 | else |
306 | 0 | for (x = 0; x < adim[3]; x++) |
307 | 0 | bp1[x] = ap1[x] * nscalep2[x] + nbiasp2[x]; |
308 | 0 | ap1 += astride[2]; |
309 | 0 | bp1 += bstride[2]; |
310 | 0 | } |
311 | 0 | } |
312 | 0 | } |
313 | 0 | } |
314 | 24 | return CCV_NNC_EXEC_SUCCESS; |
315 | 24 | } |
316 | | |
317 | | static int _ccv_nnc_batch_norm_back(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint, const int flags, ccv_nnc_tensor_t* const* const inputs, const int input_size, ccv_nnc_tensor_t* const* const outputs, const int output_size, ccv_nnc_stream_context_t* const stream_context) |
318 | 3 | { |
319 | 3 | assert(input_size == 15); |
320 | 3 | assert(output_size >= 3); |
321 | 3 | ccv_nnc_tensor_view_t* const g = (ccv_nnc_tensor_view_t*)inputs[0]; |
322 | 3 | ccv_nnc_tensor_view_t* const a = (ccv_nnc_tensor_view_t*)inputs[5]; |
323 | 3 | ccv_nnc_tensor_view_t* const scale = (ccv_nnc_tensor_view_t*)inputs[6]; |
324 | 3 | ccv_nnc_tensor_view_t* const saved_mean = (ccv_nnc_tensor_view_t*)inputs[13]; |
325 | 3 | ccv_nnc_tensor_view_t* const saved_inv_std = (ccv_nnc_tensor_view_t*)inputs[14]; |
326 | 3 | ccv_nnc_tensor_view_t* const h = (ccv_nnc_tensor_view_t*)outputs[0]; |
327 | 3 | ccv_nnc_tensor_view_t* const dscale = (ccv_nnc_tensor_view_t*)outputs[1]; |
328 | 3 | ccv_nnc_tensor_view_t* const dbias = (ccv_nnc_tensor_view_t*)outputs[2]; |
329 | 3 | assert(ccv_nnc_tensor_nd(g->info.dim) <= CCV_NNC_MAX_DIM + 2); |
330 | 3 | assert(ccv_nnc_tensor_nd(a->info.dim) <= CCV_NNC_MAX_DIM + 2); |
331 | 3 | assert(ccv_nnc_tensor_nd(h->info.dim) <= CCV_NNC_MAX_DIM + 2); |
332 | | // Assuming this is float 32. |
333 | 3 | int gdim[CCV_NNC_MAX_DIM_ALLOC]; |
334 | 3 | int rdim[CCV_NNC_MAX_DIM_ALLOC]; |
335 | 3 | ccv_nnc_tensor_view_get_dim(g, gdim); |
336 | 3 | ccv_nnc_tensor_view_get_dim(scale, rdim); |
337 | 3 | assert(ccv_nnc_tensor_view_check_dim(saved_mean, rdim)); |
338 | 3 | assert(ccv_nnc_tensor_view_check_dim(saved_inv_std, rdim)); |
339 | 3 | assert(ccv_nnc_tensor_view_check_dim(dscale, rdim)); |
340 | 3 | assert(ccv_nnc_tensor_view_check_dim(dbias, rdim)); |
341 | 3 | assert(ccv_nnc_tensor_view_check_dim(a, gdim)); |
342 | 3 | assert(ccv_nnc_tensor_view_check_dim(h, gdim)); |
343 | 3 | _ccv_nnc_reduce_sum_forw_cpu_ref(g, dbias); |
344 | 3 | int astride[CCV_NNC_MAX_DIM_ALLOC]; |
345 | 3 | int gstride[CCV_NNC_MAX_DIM_ALLOC]; |
346 | 3 | int hstride[CCV_NNC_MAX_DIM_ALLOC]; |
347 | 3 | int mean_stride[CCV_NNC_MAX_DIM_ALLOC]; |
348 | 3 | int inv_std_stride[CCV_NNC_MAX_DIM_ALLOC]; |
349 | 3 | int dscale_stride[CCV_NNC_MAX_DIM_ALLOC]; |
350 | 3 | int dbias_stride[CCV_NNC_MAX_DIM_ALLOC]; |
351 | 3 | ccv_nnc_tensor_view_get_stride(a, astride); |
352 | 3 | ccv_nnc_tensor_view_get_stride(g, gstride); |
353 | 3 | ccv_nnc_tensor_view_get_stride(h, hstride); |
354 | 3 | ccv_nnc_tensor_view_get_stride(saved_mean, mean_stride); |
355 | 3 | ccv_nnc_tensor_view_get_stride(saved_inv_std, inv_std_stride); |
356 | 3 | ccv_nnc_tensor_view_get_stride(dscale, dscale_stride); |
357 | 3 | ccv_nnc_tensor_view_get_stride(dbias, dbias_stride); |
358 | | // Need to allocate two additional memory: |
359 | | // 1. normalized a; |
360 | | // 2. scale * inv_std / batch_size; |
361 | 3 | assert(!(flags & CCV_NNC_ZERO_MEMORY_ALLOC)); |
362 | 3 | int x; |
363 | 3 | int batch_size = 1; |
364 | 15 | for (x = 0; x < CCV_NNC_MAX_DIM + 2; x++12 ) |
365 | 12 | batch_size *= gdim[x]; |
366 | 15 | for (x = 0; x < CCV_NNC_MAX_DIM + 2; x++12 ) |
367 | 12 | batch_size /= rdim[x]; |
368 | 3 | int gcount = 1, rcount = 1; |
369 | 15 | for (x = 0; x < CCV_NNC_MAX_DIM + 2; x++12 ) |
370 | 12 | gcount *= gdim[x], rcount *= rdim[x]; |
371 | 3 | float* const ah = (float*)ccv_nnc_stream_context_get_workspace(stream_context, sizeof(float) * gcount + sizeof(float) * rcount, CCV_TENSOR_CPU_MEMORY); |
372 | 3 | float* const sisb = ah + gcount; |
373 | 3 | ccv_nnc_tensor_t sisbt = ccv_nnc_tensor(sisb, scale->info, 0); |
374 | 3 | _ccv_nnc_mul_forw_cpu_ref(1. / batch_size, scale, saved_inv_std, (ccv_nnc_tensor_view_t*)&sisbt); |
375 | 3 | int i[CCV_NNC_MAX_DIM + 2]; |
376 | 3 | float* const ap = a->data.f32; |
377 | 3 | float* ahp = ah; |
378 | 3 | float* const meanp = saved_mean->data.f32; |
379 | 3 | float* const inv_stdp = saved_inv_std->data.f32; |
380 | 9 | for (i[0] = 0; i[0] < gdim[0]; i[0]++6 ) |
381 | 6 | { |
382 | 6 | float* const ap0 = ap + i[0] * astride[0]; |
383 | 6 | float* const meanp0 = rdim[0] == 1 ? meanp : meanp + i[0] * mean_stride[0]0 ; |
384 | 6 | float* const inv_stdp0 = rdim[0] == 1 ? inv_stdp : inv_stdp + i[0] * inv_std_stride[0]0 ; |
385 | 18 | for (i[1] = 0; i[1] < gdim[1]; i[1]++12 ) |
386 | 12 | { |
387 | 12 | float* ap1 = ap0 + i[1] * astride[1]; |
388 | 12 | float* const meanp1 = rdim[1] == 1 ? meanp0 : meanp0 + i[1] * mean_stride[1]0 ; |
389 | 12 | float* const inv_stdp1 = rdim[1] == 1 ? inv_stdp0 : inv_stdp0 + i[1] * inv_std_stride[1]0 ; |
390 | 36 | for (i[2] = 0; i[2] < gdim[2]; i[2]++24 ) |
391 | 24 | { |
392 | 24 | float* const meanp2 = rdim[2] == 1 ? meanp1 : meanp1 + i[2] * mean_stride[2]0 ; |
393 | 24 | float* const inv_stdp2 = rdim[2] == 1 ? inv_stdp1 : inv_stdp1 + i[2] * inv_std_stride[2]0 ; |
394 | 24 | if (rdim[3] == 1) |
395 | 0 | for (x = 0; x < gdim[3]; x++) |
396 | 0 | ahp[x] = (ap1[x] - meanp2[0]) * inv_stdp2[0]; |
397 | 24 | else |
398 | 264 | for (x = 0; 24 x < gdim[3]; x++240 ) |
399 | 240 | ahp[x] = (ap1[x] - meanp2[x]) * inv_stdp2[x]; |
400 | 24 | ap1 += astride[2]; |
401 | 24 | ahp += gdim[3]; |
402 | 24 | } |
403 | 12 | } |
404 | 6 | } |
405 | 3 | ccv_nnc_tensor_zero(dscale); |
406 | 3 | ahp = ah; |
407 | 3 | float* const gp = g->data.f32; |
408 | 3 | float* const dscalep = dscale->data.f32; |
409 | 9 | for (i[0] = 0; i[0] < gdim[0]; i[0]++6 ) |
410 | 6 | { |
411 | 6 | float* const gp0 = gp + i[0] * gstride[0]; |
412 | 6 | float* const dscalep0 = rdim[0] == 1 ? dscalep : dscalep + i[0] * dscale_stride[0]0 ; |
413 | 18 | for (i[1] = 0; i[1] < gdim[1]; i[1]++12 ) |
414 | 12 | { |
415 | 12 | float* gp1 = gp0 + i[1] * gstride[1]; |
416 | 12 | float* const dscalep1 = rdim[1] == 1 ? dscalep0 : dscalep0 + i[1] * dscale_stride[1]0 ; |
417 | 36 | for (i[2] = 0; i[2] < gdim[2]; i[2]++24 ) |
418 | 24 | { |
419 | 24 | float* const dscalep2 = rdim[2] == 1 ? dscalep1 : dscalep1 + i[2] * dscale_stride[2]0 ; |
420 | 24 | if (rdim[3] == 1) |
421 | 0 | for (x = 0; x < gdim[3]; x++) |
422 | 0 | dscalep2[0] += ahp[x] * gp1[x]; |
423 | 24 | else |
424 | 264 | for (x = 0; 24 x < gdim[3]; x++240 ) |
425 | 240 | dscalep2[x] += ahp[x] * gp1[x]; |
426 | 24 | gp1 += gstride[2]; |
427 | 24 | ahp += gdim[3]; |
428 | 24 | } |
429 | 12 | } |
430 | 6 | } |
431 | | // Now the part to compute dx (h). |
432 | 3 | float* const hp = h->data.f32; |
433 | 3 | ahp = ah; |
434 | 3 | float* const sisbp = sisb; |
435 | 3 | float* const dbiasp = dbias->data.f32; |
436 | 9 | for (i[0] = 0; i[0] < gdim[0]; i[0]++6 ) |
437 | 6 | { |
438 | 6 | float* const gp0 = gp + i[0] * gstride[0]; |
439 | 6 | float* const hp0 = hp + i[0] * hstride[0]; |
440 | 6 | float* const sisbp0 = rdim[0] == 1 ? sisbp : sisbp + i[0] * rdim[1] * rdim[2] * rdim[3]0 ; |
441 | 6 | float* const dscalep0 = rdim[0] == 1 ? dscalep : dscalep + i[0] * dscale_stride[0]0 ; |
442 | 6 | float* const dbiasp0 = rdim[0] == 1 ? dbiasp : dbiasp + i[0] * dbias_stride[0]0 ; |
443 | 18 | for (i[1] = 0; i[1] < gdim[1]; i[1]++12 ) |
444 | 12 | { |
445 | 12 | float* gp1 = gp0 + i[1] * gstride[1]; |
446 | 12 | float* hp1 = hp0 + i[1] * hstride[1]; |
447 | 12 | float* const sisbp1 = rdim[1] == 1 ? sisbp0 : sisbp0 + i[1] * rdim[2] * rdim[3]0 ; |
448 | 12 | float* const dscalep1 = rdim[1] == 1 ? dscalep0 : dscalep0 + i[1] * dscale_stride[1]0 ; |
449 | 12 | float* const dbiasp1 = rdim[1] == 1 ? dbiasp0 : dbiasp0 + i[1] * dbias_stride[1]0 ; |
450 | 36 | for (i[2] = 0; i[2] < gdim[2]; i[2]++24 ) |
451 | 24 | { |
452 | 24 | float* const sisbp2 = rdim[2] == 1 ? sisbp1 : sisbp1 + i[2] * rdim[3]0 ; |
453 | 24 | float* const dscalep2 = rdim[2] == 1 ? dscalep1 : dscalep1 + i[2] * dscale_stride[2]0 ; |
454 | 24 | float* const dbiasp2 = rdim[2] == 1 ? dbiasp1 : dbiasp1 + i[2] * dbias_stride[2]0 ; |
455 | 24 | if (rdim[3] == 1) |
456 | 0 | for (x = 0; x < gdim[3]; x++) |
457 | 0 | hp1[x] = sisbp2[0] * (batch_size * gp1[x] - dbiasp2[0] - ahp[x] * dscalep2[0]); |
458 | 24 | else |
459 | 264 | for (x = 0; 24 x < gdim[3]; x++240 ) |
460 | 240 | hp1[x] = sisbp2[x] * (batch_size * gp1[x] - dbiasp2[x] - ahp[x] * dscalep2[x]); |
461 | 24 | gp1 += gstride[2]; |
462 | 24 | hp1 += hstride[2]; |
463 | 24 | ahp += gdim[3]; |
464 | 24 | } |
465 | 12 | } |
466 | 6 | } |
467 | 3 | return CCV_NNC_EXEC_SUCCESS; |
468 | 3 | } |
469 | | |
470 | | REGISTER_COMMAND_BACKEND(CCV_NNC_BATCH_NORM_FORWARD, CCV_NNC_BACKEND_CPU_REF)(ccv_nnc_cmd_backend_registry_t* const registry) |
471 | 1 | { |
472 | 1 | registry->tensor_formats = CCV_TENSOR_FORMAT_NHWC | CCV_TENSOR_FORMAT_NCHW | CCV_TENSOR_FORMAT_CHWN; |
473 | 1 | registry->tensor_datatypes = CCV_32F; |
474 | 1 | registry->tensor_memory = CCV_TENSOR_CPU_MEMORY; |
475 | 1 | registry->algorithms = 1; |
476 | 1 | registry->exec = _ccv_nnc_batch_norm_forw; |
477 | 1 | } |
478 | | |
479 | | REGISTER_COMMAND_BACKEND(CCV_NNC_BATCH_NORM_BACKWARD, CCV_NNC_BACKEND_CPU_REF)(ccv_nnc_cmd_backend_registry_t* const registry) |
480 | 1 | { |
481 | 1 | registry->tensor_formats = CCV_TENSOR_FORMAT_NHWC | CCV_TENSOR_FORMAT_NCHW | CCV_TENSOR_FORMAT_CHWN; |
482 | 1 | registry->tensor_datatypes = CCV_32F; |
483 | 1 | registry->tensor_memory = CCV_TENSOR_CPU_MEMORY; |
484 | 1 | registry->algorithms = 1; |
485 | 1 | registry->exec = _ccv_nnc_batch_norm_back; |
486 | 1 | } |