/home/liu/actions-runner/_work/ccv/ccv/lib/nnc/cmd/norm/ccv_nnc_group_norm_cpu_ref.c
Line | Count | Source |
1 | | #include "ccv.h" |
2 | | #include "ccv_internal.h" |
3 | | #include "nnc/ccv_nnc.h" |
4 | | #include "nnc/ccv_nnc_easy.h" |
5 | | #include "nnc/ccv_nnc_internal.h" |
6 | | #ifdef USE_OPENMP |
7 | | #include <omp.h> |
8 | | #endif |
9 | | #ifdef USE_DISPATCH |
10 | | #include <dispatch/dispatch.h> |
11 | | #endif |
12 | | |
13 | | // Shared methods. |
14 | | #include "../_ccv_nnc_cpu_ref.h" |
15 | | |
16 | | static int _ccv_nnc_group_norm_forw(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint, const int flags, ccv_nnc_tensor_t* const* const inputs, const int input_size, ccv_nnc_tensor_t* const* const outputs, const int output_size, ccv_nnc_stream_context_t* const stream_context) |
17 | 16 | { |
18 | 16 | assert(input_size == 3 || input_size == 1); |
19 | 16 | ccv_nnc_tensor_view_t* const a = (ccv_nnc_tensor_view_t*)inputs[0]; |
20 | 16 | ccv_nnc_tensor_view_t* const scale = input_size >= 2 ? (ccv_nnc_tensor_view_t*)inputs[1]9 : 07 ; |
21 | 16 | ccv_nnc_tensor_view_t* const bias = input_size >= 3 ? (ccv_nnc_tensor_view_t*)inputs[2]9 : 07 ; |
22 | 16 | ccv_nnc_tensor_view_t* const b = (ccv_nnc_tensor_view_t*)outputs[0]; |
23 | 16 | ccv_nnc_tensor_view_t* const saved_mean = (ccv_nnc_tensor_view_t*)outputs[1]; |
24 | 16 | ccv_nnc_tensor_view_t* const saved_inv_std = (ccv_nnc_tensor_view_t*)outputs[2]; |
25 | 16 | assert(ccv_nnc_tensor_nd(a->info.dim) <= CCV_NNC_MAX_DIM + 2); |
26 | 16 | assert(ccv_nnc_tensor_nd(b->info.dim) <= CCV_NNC_MAX_DIM + 2); |
27 | | // Assuming this is float 32. |
28 | 16 | int adim[CCV_NNC_MAX_DIM_ALLOC]; |
29 | 16 | int rdim[CCV_NNC_MAX_DIM_ALLOC]; |
30 | 16 | ccv_nnc_tensor_view_get_dim(a, adim); |
31 | 16 | ccv_nnc_tensor_view_get_dim(saved_mean, rdim); |
32 | 16 | assert(ccv_nnc_tensor_view_check_dim(saved_inv_std, rdim)); |
33 | 16 | assert(ccv_nnc_tensor_view_check_dim(b, adim)); |
34 | 16 | assert(CCV_NNC_MAX_DIM == 2); // Need to change this logic for CCV_NNC_MAX_DIM == other number. |
35 | 16 | int astride[CCV_NNC_MAX_DIM_ALLOC]; |
36 | 16 | int bstride[CCV_NNC_MAX_DIM_ALLOC]; |
37 | 16 | int scale_stride[CCV_NNC_MAX_DIM_ALLOC]; |
38 | 16 | int bias_stride[CCV_NNC_MAX_DIM_ALLOC]; |
39 | 16 | ccv_nnc_tensor_view_get_stride(a, astride); |
40 | 16 | if (scale) |
41 | 9 | ccv_nnc_tensor_view_get_stride(scale, scale_stride); |
42 | 16 | if (bias) |
43 | 9 | ccv_nnc_tensor_view_get_stride(bias, bias_stride); |
44 | 16 | ccv_nnc_tensor_view_get_stride(b, bstride); |
45 | | // The epsilon is used a little bit differently from batch norm, it is outside of the sqrt in this case. |
46 | 16 | const float epsilon = cmd.info.lnorm.epsilon; |
47 | 16 | int saved_mean_stride[CCV_NNC_MAX_DIM_ALLOC]; |
48 | 16 | int saved_inv_std_stride[CCV_NNC_MAX_DIM_ALLOC]; |
49 | 16 | ccv_nnc_tensor_view_get_stride(saved_mean, saved_mean_stride); |
50 | 16 | ccv_nnc_tensor_view_get_stride(saved_inv_std, saved_inv_std_stride); |
51 | 16 | int x; |
52 | 16 | int n = 1; |
53 | 80 | for (x = 0; x < CCV_NNC_MAX_DIM + 2; x++64 ) |
54 | 64 | n *= adim[x]; |
55 | 80 | for (x = 0; x < CCV_NNC_MAX_DIM + 2; x++64 ) |
56 | 64 | n /= rdim[x]; |
57 | 16 | const float inv_n = 1. / n; |
58 | 16 | int i[CCV_NNC_MAX_DIM + 2]; |
59 | 16 | float* const ap = a->data.f32; |
60 | 16 | float* const meanp = saved_mean->data.f32; |
61 | 16 | ccv_nnc_tensor_zero(saved_mean); |
62 | 56 | for (i[0] = 0; i[0] < adim[0]; i[0]++40 ) |
63 | 40 | { |
64 | 40 | float* const ap0 = ap + i[0] * astride[0]; |
65 | 40 | float* const meanp0 = meanp + (rdim[0] < adim[0] ? i[0] * rdim[0] / adim[0]0 : i[0]) * saved_mean_stride[0]; |
66 | 456 | for (i[1] = 0; i[1] < adim[1]; i[1]++416 ) |
67 | 416 | { |
68 | 416 | float* ap1 = ap0 + i[1] * astride[1]; |
69 | 416 | float* const meanp1 = meanp0 + (rdim[1] < adim[1] ? i[1] * rdim[1] / adim[1]320 : i[1]96 ) * saved_mean_stride[1]; |
70 | 1.82k | for (i[2] = 0; i[2] < adim[2]; i[2]++1.40k ) |
71 | 1.40k | { |
72 | 1.40k | float* const meanp2 = meanp1 + (rdim[2] < adim[2] ? i[2] * rdim[2] / adim[2]768 : i[2]640 ) * saved_mean_stride[2]; |
73 | 1.40k | if (rdim[3] < adim[3]) |
74 | 5.63k | for (x = 0; 512 x < adim[3]; x++5.12k ) |
75 | 5.12k | meanp2[x * rdim[3] / adim[3]] += ap1[x]; |
76 | 896 | else |
77 | 7.29k | for (x = 0; 896 x < adim[3]; x++6.40k ) |
78 | 6.40k | meanp2[x] += ap1[x]; |
79 | 1.40k | ap1 += astride[2]; |
80 | 1.40k | } |
81 | 416 | } |
82 | 40 | } |
83 | 56 | for (i[0] = 0; i[0] < rdim[0]; i[0]++40 ) |
84 | 40 | { |
85 | 40 | float* const meanp0 = meanp + i[0] * saved_mean_stride[0]; |
86 | 216 | for (i[1] = 0; i[1] < rdim[1]; i[1]++176 ) |
87 | 176 | { |
88 | 176 | float* const meanp1 = meanp0 + i[1] * saved_mean_stride[1]; |
89 | 688 | for (i[2] = 0; i[2] < rdim[2]; i[2]++512 ) |
90 | 512 | { |
91 | 512 | float* const meanp2 = meanp1 + i[2] * saved_mean_stride[2]; |
92 | 2.40k | for (x = 0; x < rdim[3]; x++1.88k ) |
93 | 1.88k | meanp2[x] = meanp2[x] * inv_n; |
94 | 512 | } |
95 | 176 | } |
96 | 40 | } |
97 | 16 | ccv_nnc_tensor_zero(saved_inv_std); |
98 | 16 | float* const varp = saved_inv_std->data.f32; |
99 | 56 | for (i[0] = 0; i[0] < adim[0]; i[0]++40 ) |
100 | 40 | { |
101 | 40 | float* const ap0 = ap + i[0] * astride[0]; |
102 | 40 | float* const meanp0 = meanp + (rdim[0] < adim[0] ? i[0] * rdim[0] / adim[0]0 : i[0]) * saved_mean_stride[0]; |
103 | 40 | float* const varp0 = varp + (rdim[0] < adim[0] ? i[0] * rdim[0] / adim[0]0 : i[0]) * saved_inv_std_stride[0]; |
104 | 456 | for (i[1] = 0; i[1] < adim[1]; i[1]++416 ) |
105 | 416 | { |
106 | 416 | float* ap1 = ap0 + i[1] * astride[1]; |
107 | 416 | float* const meanp1 = meanp0 + (rdim[1] < adim[1] ? i[1] * rdim[1] / adim[1]320 : i[1]96 ) * saved_mean_stride[1]; |
108 | 416 | float* const varp1 = varp0 + (rdim[1] < adim[1] ? i[1] * rdim[1] / adim[1]320 : i[1]96 ) * saved_inv_std_stride[1]; |
109 | 1.82k | for (i[2] = 0; i[2] < adim[2]; i[2]++1.40k ) |
110 | 1.40k | { |
111 | 1.40k | float* const meanp2 = meanp1 + (rdim[2] < adim[2] ? i[2] * rdim[2] / adim[2]768 : i[2]640 ) * saved_mean_stride[2]; |
112 | 1.40k | float* const varp2 = varp1 + (rdim[2] < adim[2] ? i[2] * rdim[2] / adim[2]768 : i[2]640 ) * saved_inv_std_stride[2]; |
113 | 1.40k | if (rdim[3] < adim[3]) |
114 | 5.63k | for (x = 0; 512 x < adim[3]; x++5.12k ) |
115 | 5.12k | { |
116 | 5.12k | float w = ap1[x] - meanp2[x * rdim[3] / adim[3]]; |
117 | 5.12k | varp2[x * rdim[3] / adim[3]] += w * w; |
118 | 5.12k | } |
119 | 896 | else |
120 | 7.29k | for (x = 0; 896 x < adim[3]; x++6.40k ) |
121 | 6.40k | { |
122 | 6.40k | float w = ap1[x] - meanp2[x]; |
123 | 6.40k | varp2[x] += w * w; |
124 | 6.40k | } |
125 | 1.40k | ap1 += astride[2]; |
126 | 1.40k | } |
127 | 416 | } |
128 | 40 | } |
129 | 56 | for (i[0] = 0; i[0] < rdim[0]; i[0]++40 ) |
130 | 40 | { |
131 | 40 | float* const varp0 = varp + i[0] * saved_inv_std_stride[0]; |
132 | 216 | for (i[1] = 0; i[1] < rdim[1]; i[1]++176 ) |
133 | 176 | { |
134 | 176 | float* const varp1 = varp0 + i[1] * saved_inv_std_stride[1]; |
135 | 688 | for (i[2] = 0; i[2] < rdim[2]; i[2]++512 ) |
136 | 512 | { |
137 | 512 | float* const varp2 = varp1 + i[2] * saved_inv_std_stride[2]; |
138 | 2.40k | for (x = 0; x < rdim[3]; x++1.88k ) |
139 | 1.88k | varp2[x] = 1. / sqrtf(varp2[x] * inv_n + epsilon); |
140 | 512 | } |
141 | 176 | } |
142 | 40 | } |
143 | 16 | if (cmd.info.gnorm.elementwise_affine) |
144 | 9 | { |
145 | 9 | float* const scalep = scale->data.f32; |
146 | 9 | float* const biasp = bias->data.f32; |
147 | 9 | int sdim[CCV_NNC_MAX_DIM_ALLOC]; |
148 | 9 | ccv_nnc_tensor_view_get_dim(scale, sdim); |
149 | 9 | int bias_dim[CCV_NNC_MAX_DIM_ALLOC]; |
150 | 9 | ccv_nnc_tensor_view_get_dim(bias, bias_dim); |
151 | | // Do the straight-forward one, y = (x - mean) * inv_std * scale + bias, we cannot allocate extra memory to help. |
152 | | // There is no need for precompute since scale / bias is per element. |
153 | 9 | float* const bp = b->data.f32; |
154 | 37 | for (i[0] = 0; i[0] < adim[0]; i[0]++28 ) |
155 | 28 | { |
156 | 28 | float* const ap0 = ap + i[0] * astride[0]; |
157 | 28 | float* const bp0 = bp + i[0] * bstride[0]; |
158 | 28 | float* const meanp0 = meanp + (rdim[0] < adim[0] ? i[0] * rdim[0] / adim[0]0 : i[0]) * saved_mean_stride[0]; |
159 | 28 | float* const varp0 = varp + (rdim[0] < adim[0] ? i[0] * rdim[0] / adim[0]0 : i[0]) * saved_inv_std_stride[0]; |
160 | 28 | float* const scalep0 = scalep + (sdim[0] < adim[0] ? i[0] * sdim[0] / adim[0]26 : i[0]2 ) * scale_stride[0]; |
161 | 28 | float* const biasp0 = biasp + (bias_dim[0] < adim[0] ? i[0] * bias_dim[0] / adim[0]26 : i[0]2 ) * bias_stride[0]; |
162 | 268 | for (i[1] = 0; i[1] < adim[1]; i[1]++240 ) |
163 | 240 | { |
164 | 240 | float* ap1 = ap0 + i[1] * astride[1]; |
165 | 240 | float* bp1 = bp0 + i[1] * bstride[1]; |
166 | 240 | float* const meanp1 = meanp0 + (rdim[1] < adim[1] ? i[1] * rdim[1] / adim[1]160 : i[1]80 ) * saved_mean_stride[1]; |
167 | 240 | float* const varp1 = varp0 + (rdim[1] < adim[1] ? i[1] * rdim[1] / adim[1]160 : i[1]80 ) * saved_inv_std_stride[1]; |
168 | 240 | float* const scalep1 = scalep0 + (sdim[1] < adim[1] ? i[1] * sdim[1] / adim[1]16 : i[1]224 ) * scale_stride[1]; |
169 | 240 | float* const biasp1 = biasp0 + (bias_dim[1] < adim[1] ? i[1] * bias_dim[1] / adim[1]16 : i[1]224 ) * bias_stride[1]; |
170 | 1.07k | for (i[2] = 0; i[2] < adim[2]; i[2]++832 ) |
171 | 832 | { |
172 | 832 | float* const meanp2 = meanp1 + (rdim[2] < adim[2] ? i[2] * rdim[2] / adim[2]384 : i[2]448 ) * saved_mean_stride[2]; |
173 | 832 | float* const varp2 = varp1 + (rdim[2] < adim[2] ? i[2] * rdim[2] / adim[2]384 : i[2]448 ) * saved_inv_std_stride[2]; |
174 | 832 | float* const scalep2 = scalep1 + (sdim[2] < adim[2] ? i[2] * sdim[2] / adim[2]128 : i[2]704 ) * scale_stride[2]; |
175 | 832 | float* const biasp2 = biasp1 + (bias_dim[2] < adim[2] ? i[2] * bias_dim[2] / adim[2]128 : i[2]704 ) * bias_stride[2]; |
176 | 832 | if (rdim[3] < adim[3]) |
177 | 4.22k | for (x = 0; 384 x < adim[3]; x++3.84k ) |
178 | 3.84k | bp1[x] = (ap1[x] - meanp2[x * rdim[3] / adim[3]]) * varp2[x * rdim[3] / adim[3]] * scalep2[x * sdim[3] / adim[3]] + biasp2[x * bias_dim[3] / adim[3]]; |
179 | 448 | else |
180 | 3.64k | for (x = 0; 448 x < adim[3]; x++3.20k ) |
181 | 3.20k | bp1[x] = (ap1[x] - meanp2[x]) * varp2[x] * scalep2[x * sdim[3] / adim[3]] + biasp2[x * bias_dim[3] / adim[3]]; |
182 | 832 | ap1 += astride[2]; |
183 | 832 | bp1 += bstride[2]; |
184 | 832 | } |
185 | 240 | } |
186 | 28 | } |
187 | 9 | } else { |
188 | | // Do the straight-forward one, y = (x - mean) * inv_std, we cannot allocate extra memory to help. |
189 | 7 | float* const bp = b->data.f32; |
190 | 19 | for (i[0] = 0; i[0] < adim[0]; i[0]++12 ) |
191 | 12 | { |
192 | 12 | float* const ap0 = ap + i[0] * astride[0]; |
193 | 12 | float* const bp0 = bp + i[0] * bstride[0]; |
194 | 12 | float* const meanp0 = meanp + (rdim[0] < adim[0] ? i[0] * rdim[0] / adim[0]0 : i[0]) * saved_mean_stride[0]; |
195 | 12 | float* const varp0 = varp + (rdim[0] < adim[0] ? i[0] * rdim[0] / adim[0]0 : i[0]) * saved_inv_std_stride[0]; |
196 | 188 | for (i[1] = 0; i[1] < adim[1]; i[1]++176 ) |
197 | 176 | { |
198 | 176 | float* ap1 = ap0 + i[1] * astride[1]; |
199 | 176 | float* bp1 = bp0 + i[1] * bstride[1]; |
200 | 176 | float* const meanp1 = meanp0 + (rdim[1] < adim[1] ? i[1] * rdim[1] / adim[1]160 : i[1]16 ) * saved_mean_stride[1]; |
201 | 176 | float* const varp1 = varp0 + (rdim[1] < adim[1] ? i[1] * rdim[1] / adim[1]160 : i[1]16 ) * saved_inv_std_stride[1]; |
202 | 752 | for (i[2] = 0; i[2] < adim[2]; i[2]++576 ) |
203 | 576 | { |
204 | 576 | float* const meanp2 = meanp1 + (rdim[2] < adim[2] ? i[2] * rdim[2] / adim[2]384 : i[2]192 ) * saved_mean_stride[2]; |
205 | 576 | float* const varp2 = varp1 + (rdim[2] < adim[2] ? i[2] * rdim[2] / adim[2]384 : i[2]192 ) * saved_inv_std_stride[2]; |
206 | 576 | if (rdim[3] < adim[3]) |
207 | 1.40k | for (x = 0; 128 x < adim[3]; x++1.28k ) |
208 | 1.28k | bp1[x] = (ap1[x] - meanp2[x * rdim[3] / adim[3]]) * varp2[x * rdim[3] / adim[3]]; |
209 | 448 | else |
210 | 3.64k | for (x = 0; 448 x < adim[3]; x++3.20k ) |
211 | 3.20k | bp1[x] = (ap1[x] - meanp2[x]) * varp2[x]; |
212 | 576 | ap1 += astride[2]; |
213 | 576 | bp1 += bstride[2]; |
214 | 576 | } |
215 | 176 | } |
216 | 12 | } |
217 | 7 | } |
218 | 16 | return CCV_NNC_EXEC_SUCCESS; |
219 | 16 | } |
220 | | |
221 | | static int _ccv_nnc_group_norm_back(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint, const int flags, ccv_nnc_tensor_t* const* const inputs, const int input_size, ccv_nnc_tensor_t* const* const outputs, const int output_size, ccv_nnc_stream_context_t* const stream_context) |
222 | 9 | { |
223 | 9 | assert(input_size == 9 || input_size == 7); |
224 | 9 | assert(output_size >= 1); |
225 | 9 | const int elementwise_affine = cmd.info.gnorm.elementwise_affine; |
226 | 9 | ccv_nnc_tensor_view_t* const g = (ccv_nnc_tensor_view_t*)inputs[0]; |
227 | 9 | ccv_nnc_tensor_view_t* const a = (ccv_nnc_tensor_view_t*)inputs[3]; |
228 | 9 | ccv_nnc_tensor_view_t* const scale = elementwise_affine ? (ccv_nnc_tensor_view_t*)inputs[4]5 : 04 ; |
229 | 9 | ccv_nnc_tensor_view_t* const saved_mean = (ccv_nnc_tensor_view_t*)inputs[elementwise_affine ? 75 : 54 ]; |
230 | 9 | ccv_nnc_tensor_view_t* const saved_inv_std = (ccv_nnc_tensor_view_t*)inputs[elementwise_affine ? 85 : 64 ]; |
231 | 9 | ccv_nnc_tensor_view_t* const h = (ccv_nnc_tensor_view_t*)outputs[0]; |
232 | 9 | ccv_nnc_tensor_view_t* const dscale = output_size > 1 ? (ccv_nnc_tensor_view_t*)outputs[1]5 : 04 ; |
233 | 9 | ccv_nnc_tensor_view_t* const dbias = output_size > 2 ? (ccv_nnc_tensor_view_t*)outputs[2]5 : 04 ; |
234 | 9 | assert(ccv_nnc_tensor_nd(g->info.dim) <= CCV_NNC_MAX_DIM + 2); |
235 | 9 | assert(ccv_nnc_tensor_nd(a->info.dim) <= CCV_NNC_MAX_DIM + 2); |
236 | 9 | assert(ccv_nnc_tensor_nd(h->info.dim) <= CCV_NNC_MAX_DIM + 2); |
237 | | // Assuming this is float 32. |
238 | 9 | int gdim[CCV_NNC_MAX_DIM_ALLOC]; |
239 | 9 | int rdim[CCV_NNC_MAX_DIM_ALLOC]; |
240 | 9 | ccv_nnc_tensor_view_get_dim(g, gdim); |
241 | 9 | ccv_nnc_tensor_view_get_dim(saved_mean, rdim); |
242 | 9 | assert(ccv_nnc_tensor_view_check_dim(saved_inv_std, rdim)); |
243 | 9 | int sdim[CCV_NNC_MAX_DIM_ALLOC]; |
244 | 9 | if (scale) |
245 | 5 | ccv_nnc_tensor_view_get_dim(scale, sdim); |
246 | 9 | if (dscale) |
247 | 4 | { assert(ccv_nnc_tensor_view_check_dim(dscale, sdim)); } |
248 | 9 | assert(ccv_nnc_tensor_view_check_dim(a, gdim)); |
249 | 9 | assert(ccv_nnc_tensor_view_check_dim(h, gdim)); |
250 | 9 | if (dbias) |
251 | 4 | _ccv_nnc_reduce_sum_forw_cpu_ref(g, dbias); |
252 | 9 | int astride[CCV_NNC_MAX_DIM_ALLOC]; |
253 | 9 | int gstride[CCV_NNC_MAX_DIM_ALLOC]; |
254 | 9 | int hstride[CCV_NNC_MAX_DIM_ALLOC]; |
255 | 9 | int scale_stride[CCV_NNC_MAX_DIM_ALLOC]; |
256 | 9 | int mean_stride[CCV_NNC_MAX_DIM_ALLOC]; |
257 | 9 | int inv_std_stride[CCV_NNC_MAX_DIM_ALLOC]; |
258 | 9 | int dscale_stride[CCV_NNC_MAX_DIM_ALLOC]; |
259 | 9 | ccv_nnc_tensor_view_get_stride(a, astride); |
260 | 9 | ccv_nnc_tensor_view_get_stride(g, gstride); |
261 | 9 | ccv_nnc_tensor_view_get_stride(h, hstride); |
262 | 9 | if (scale) |
263 | 5 | ccv_nnc_tensor_view_get_stride(scale, scale_stride); |
264 | 9 | ccv_nnc_tensor_view_get_stride(saved_mean, mean_stride); |
265 | 9 | ccv_nnc_tensor_view_get_stride(saved_inv_std, inv_std_stride); |
266 | 9 | if (dscale) |
267 | 4 | ccv_nnc_tensor_view_get_stride(dscale, dscale_stride); |
268 | | // Need to allocate two additional memory: |
269 | | // 1. normalized a; |
270 | | // 2. scale * inv_std / n; |
271 | 9 | assert(!(flags & CCV_NNC_ZERO_MEMORY_ALLOC)); |
272 | 9 | int x; |
273 | 9 | int n = 1; |
274 | 45 | for (x = 0; x < CCV_NNC_MAX_DIM + 2; x++36 ) |
275 | 36 | n *= gdim[x]; |
276 | 45 | for (x = 0; x < CCV_NNC_MAX_DIM + 2; x++36 ) |
277 | 36 | n /= rdim[x]; |
278 | 9 | int gcount = 1, rcount = 1; |
279 | 45 | for (x = 0; x < CCV_NNC_MAX_DIM + 2; x++36 ) |
280 | 36 | gcount *= gdim[x], rcount *= rdim[x]; |
281 | 9 | float* const ah = (float*)ccv_nnc_stream_context_get_workspace(stream_context, sizeof(float) * gcount * 2 + sizeof(float) * rcount * 2, CCV_TENSOR_CPU_MEMORY); |
282 | 9 | float* const gss = ah + gcount; // g * scale * inv_std |
283 | 9 | float* const gssr = gss + gcount; // gss reduced to inv_std dimension |
284 | 9 | float* const ahgssr = gssr + rcount; // ah * gss then reduced to inv_std dimension. |
285 | 9 | int i[CCV_NNC_MAX_DIM + 2]; |
286 | 9 | float* ahp = ah; |
287 | 9 | const float* const meanp = saved_mean->data.f32; |
288 | 9 | const float* const inv_stdp = saved_inv_std->data.f32; |
289 | 9 | const float* const ap = a->data.f32; |
290 | 31 | for (i[0] = 0; i[0] < gdim[0]; i[0]++22 ) |
291 | 22 | { |
292 | 22 | const float* const ap0 = ap + i[0] * astride[0]; |
293 | 22 | const float* const meanp0 = meanp + (rdim[0] < gdim[0] ? i[0] * rdim[0] / gdim[0]0 : i[0]) * mean_stride[0]; |
294 | 22 | const float* const inv_stdp0 = inv_stdp + (rdim[0] < gdim[0] ? i[0] * rdim[0] / gdim[0]0 : i[0]) * inv_std_stride[0]; |
295 | 262 | for (i[1] = 0; i[1] < gdim[1]; i[1]++240 ) |
296 | 240 | { |
297 | 240 | const float* ap1 = ap0 + i[1] * astride[1]; |
298 | 240 | const float* const meanp1 = meanp0 + (rdim[1] < gdim[1] ? i[1] * rdim[1] / gdim[1]192 : i[1]48 ) * mean_stride[1]; |
299 | 240 | const float* const inv_stdp1 = inv_stdp0 + (rdim[1] < gdim[1] ? i[1] * rdim[1] / gdim[1]192 : i[1]48 ) * inv_std_stride[1]; |
300 | 1.00k | for (i[2] = 0; i[2] < gdim[2]; i[2]++768 ) |
301 | 768 | { |
302 | 768 | const float* const meanp2 = meanp1 + (rdim[2] < gdim[2] ? i[2] * rdim[2] / gdim[2]384 : i[2]384 ) * mean_stride[2]; |
303 | 768 | const float* const inv_stdp2 = inv_stdp1 + (rdim[2] < gdim[2] ? i[2] * rdim[2] / gdim[2]384 : i[2]384 ) * inv_std_stride[2]; |
304 | 768 | if (rdim[3] < gdim[3]) |
305 | 2.81k | for (x = 0; 256 x < gdim[3]; x++2.56k ) |
306 | 2.56k | ahp[x] = (ap1[x] - meanp2[x * rdim[3] / gdim[3]]) * inv_stdp2[x * rdim[3] / gdim[3]]; |
307 | 512 | else |
308 | 4.35k | for (x = 0; 512 x < gdim[3]; x++3.84k ) |
309 | 3.84k | ahp[x] = (ap1[x] - meanp2[x]) * inv_stdp2[x]; |
310 | 768 | ap1 += astride[2]; |
311 | 768 | ahp += gdim[3]; |
312 | 768 | } |
313 | 240 | } |
314 | 22 | } |
315 | 9 | if (dscale) |
316 | 4 | { |
317 | 4 | ccv_nnc_tensor_zero(dscale); |
318 | 4 | ahp = ah; |
319 | 4 | float* gssp = gss; |
320 | 4 | const float* const gp = g->data.f32; |
321 | 4 | const float* const scalep = scale->data.f32; |
322 | 4 | float* const dscalep = dscale->data.f32; |
323 | 17 | for (i[0] = 0; i[0] < gdim[0]; i[0]++13 ) |
324 | 13 | { |
325 | 13 | const float* const gp0 = gp + i[0] * gstride[0]; |
326 | 13 | const float* const inv_stdp0 = inv_stdp + (rdim[0] < gdim[0] ? i[0] * rdim[0] / gdim[0]0 : i[0]) * inv_std_stride[0]; |
327 | 13 | const float* const scalep0 = scalep + (sdim[0] < gdim[0] ? i[0] * sdim[0] / gdim[0]12 : i[0]1 ) * scale_stride[0]; |
328 | 13 | float* const dscalep0 = dscalep + (sdim[0] < gdim[0] ? i[0] * sdim[0] / gdim[0]12 : i[0]1 ) * dscale_stride[0]; |
329 | 117 | for (i[1] = 0; i[1] < gdim[1]; i[1]++104 ) |
330 | 104 | { |
331 | 104 | const float* gp1 = gp0 + i[1] * gstride[1]; |
332 | 104 | const float* const inv_stdp1 = inv_stdp0 + (rdim[1] < gdim[1] ? i[1] * rdim[1] / gdim[1]64 : i[1]40 ) * inv_std_stride[1]; |
333 | 104 | const float* const scalep1 = scalep0 + (sdim[1] < gdim[1] ? i[1] * sdim[1] / gdim[1]8 : i[1]96 ) * scale_stride[1]; |
334 | 104 | float* const dscalep1 = dscalep0 + (sdim[1] < gdim[1] ? i[1] * sdim[1] / gdim[1]8 : i[1]96 ) * dscale_stride[1]; |
335 | 488 | for (i[2] = 0; i[2] < gdim[2]; i[2]++384 ) |
336 | 384 | { |
337 | 384 | const float* const inv_stdp2 = inv_stdp1 + (rdim[2] < gdim[2] ? i[2] * rdim[2] / gdim[2]192 : i[2]192 ) * inv_std_stride[2]; |
338 | 384 | const float* const scalep2 = scalep1 + (sdim[2] < gdim[2] ? i[2] * sdim[2] / gdim[2]64 : i[2]320 ) * scale_stride[2]; |
339 | 384 | float* const dscalep2 = dscalep1 + (sdim[2] < gdim[2] ? i[2] * sdim[2] / gdim[2]64 : i[2]320 ) * dscale_stride[2]; |
340 | 384 | if (sdim[3] < gdim[3]) |
341 | 704 | for (x = 0; 64 x < gdim[3]; x++640 ) |
342 | 640 | { |
343 | 640 | gssp[x] = gp1[x] * scalep2[x * sdim[3] / gdim[3]] * inv_stdp2[x * rdim[3] / gdim[3]]; |
344 | 640 | dscalep2[x * sdim[3] / gdim[3]] += ahp[x] * gp1[x]; |
345 | 640 | } |
346 | 320 | else |
347 | 2.88k | for (x = 0; 320 x < gdim[3]; x++2.56k ) |
348 | 2.56k | { |
349 | 2.56k | gssp[x] = gp1[x] * scalep2[x] * inv_stdp2[x * rdim[3] / gdim[3]]; |
350 | 2.56k | dscalep2[x] += ahp[x] * gp1[x]; |
351 | 2.56k | } |
352 | 384 | gp1 += gstride[2]; |
353 | 384 | ahp += gdim[3]; |
354 | 384 | gssp += gdim[3]; |
355 | 384 | } |
356 | 104 | } |
357 | 13 | } |
358 | 5 | } else { |
359 | 5 | ahp = ah; |
360 | 5 | float* gssp = gss; |
361 | 5 | const float* const gp = g->data.f32; |
362 | 5 | if (elementwise_affine) |
363 | 1 | { |
364 | 1 | const float* const scalep = scale->data.f32; |
365 | 3 | for (i[0] = 0; i[0] < gdim[0]; i[0]++2 ) |
366 | 2 | { |
367 | 2 | const float* const gp0 = gp + i[0] * gstride[0]; |
368 | 2 | const float* const inv_stdp0 = inv_stdp + (rdim[0] < gdim[0] ? i[0] * rdim[0] / gdim[0]0 : i[0]) * inv_std_stride[0]; |
369 | 2 | const float* const scalep0 = scalep + (sdim[0] < gdim[0] ? i[0] * sdim[0] / gdim[0] : i[0]0 ) * scale_stride[0]; |
370 | 34 | for (i[1] = 0; i[1] < gdim[1]; i[1]++32 ) |
371 | 32 | { |
372 | 32 | const float* gp1 = gp0 + i[1] * gstride[1]; |
373 | 32 | const float* const inv_stdp1 = inv_stdp0 + (rdim[1] < gdim[1] ? i[1] * rdim[1] / gdim[1] : i[1]0 ) * inv_std_stride[1]; |
374 | 32 | const float* const scalep1 = scalep0 + (sdim[1] < gdim[1] ? i[1] * sdim[1] / gdim[1]0 : i[1]) * scale_stride[1]; |
375 | 96 | for (i[2] = 0; i[2] < gdim[2]; i[2]++64 ) |
376 | 64 | { |
377 | 64 | const float* const inv_stdp2 = inv_stdp1 + (rdim[2] < gdim[2] ? i[2] * rdim[2] / gdim[2]0 : i[2]) * inv_std_stride[2]; |
378 | 64 | const float* const scalep2 = scalep1 + (sdim[2] < gdim[2] ? i[2] * sdim[2] / gdim[2]0 : i[2]) * scale_stride[2]; |
379 | 64 | if (sdim[3] < gdim[3]) |
380 | 0 | for (x = 0; x < gdim[3]; x++) |
381 | 0 | gssp[x] = gp1[x] * scalep2[x * sdim[3] / gdim[3]] * inv_stdp2[x * rdim[3] / gdim[3]]; |
382 | 64 | else |
383 | 704 | for (x = 0; 64 x < gdim[3]; x++640 ) |
384 | 640 | gssp[x] = gp1[x] * scalep2[x] * inv_stdp2[x * rdim[3] / gdim[3]]; |
385 | 64 | gp1 += gstride[2]; |
386 | 64 | ahp += gdim[3]; |
387 | 64 | gssp += gdim[3]; |
388 | 64 | } |
389 | 32 | } |
390 | 2 | } |
391 | 4 | } else { |
392 | 11 | for (i[0] = 0; i[0] < gdim[0]; i[0]++7 ) |
393 | 7 | { |
394 | 7 | const float* const gp0 = gp + i[0] * gstride[0]; |
395 | 7 | const float* const inv_stdp0 = inv_stdp + (rdim[0] < gdim[0] ? i[0] * rdim[0] / gdim[0]0 : i[0]) * inv_std_stride[0]; |
396 | 111 | for (i[1] = 0; i[1] < gdim[1]; i[1]++104 ) |
397 | 104 | { |
398 | 104 | const float* gp1 = gp0 + i[1] * gstride[1]; |
399 | 104 | const float* const inv_stdp1 = inv_stdp0 + (rdim[1] < gdim[1] ? i[1] * rdim[1] / gdim[1]96 : i[1]8 ) * inv_std_stride[1]; |
400 | 424 | for (i[2] = 0; i[2] < gdim[2]; i[2]++320 ) |
401 | 320 | { |
402 | 320 | const float* const inv_stdp2 = inv_stdp1 + (rdim[2] < gdim[2] ? i[2] * rdim[2] / gdim[2]192 : i[2]128 ) * inv_std_stride[2]; |
403 | 2.88k | for (x = 0; x < gdim[3]; x++2.56k ) |
404 | 2.56k | gssp[x] = gp1[x] * inv_stdp2[x * rdim[3] / gdim[3]]; |
405 | 320 | gp1 += gstride[2]; |
406 | 320 | ahp += gdim[3]; |
407 | 320 | gssp += gdim[3]; |
408 | 320 | } |
409 | 104 | } |
410 | 7 | } |
411 | 4 | } |
412 | 5 | } |
413 | 9 | ccv_nnc_tensor_t gssrt = ccv_nnc_tensor(gssr, saved_mean->info, 0); |
414 | 9 | ccv_nnc_tensor_zero(&gssrt); |
415 | 9 | float* gssp = gss; |
416 | 9 | float* const gssrp = gssr; |
417 | 31 | for (i[0] = 0; i[0] < gdim[0]; i[0]++22 ) |
418 | 22 | { |
419 | 22 | float* const gssrp0 = gssrp + (rdim[0] < gdim[0] ? i[0] * rdim[0] / gdim[0]0 : i[0]) * rdim[1] * rdim[2] * rdim[3]; |
420 | 262 | for (i[1] = 0; i[1] < gdim[1]; i[1]++240 ) |
421 | 240 | { |
422 | 240 | float* const gssrp1 = gssrp0 + (rdim[1] < gdim[1] ? i[1] * rdim[1] / gdim[1]192 : i[1]48 ) * rdim[2] * rdim[3]; |
423 | 1.00k | for (i[2] = 0; i[2] < gdim[2]; i[2]++768 ) |
424 | 768 | { |
425 | 768 | float* const gssrp2 = gssrp1 + (rdim[2] < gdim[2] ? i[2] * rdim[2] / gdim[2]384 : i[2]384 ) * rdim[3]; |
426 | 768 | if (rdim[3] < gdim[3]) |
427 | 2.81k | for (x = 0; 256 x < gdim[3]; x++2.56k ) |
428 | 2.56k | gssrp2[x * rdim[3] / gdim[3]] += gssp[x]; |
429 | 512 | else |
430 | 4.35k | for (x = 0; 512 x < gdim[3]; x++3.84k ) |
431 | 3.84k | gssrp2[x] += gssp[x]; |
432 | 768 | gssp += gdim[3]; |
433 | 768 | } |
434 | 240 | } |
435 | 22 | } |
436 | 9 | ahp = ah; |
437 | 9 | gssp = gss; |
438 | 9 | ccv_nnc_tensor_t ahgssrt = ccv_nnc_tensor(ahgssr, saved_mean->info, 0); |
439 | 9 | ccv_nnc_tensor_zero(&ahgssrt); |
440 | 9 | float* const ahgssrp = ahgssr; |
441 | 31 | for (i[0] = 0; i[0] < gdim[0]; i[0]++22 ) |
442 | 22 | { |
443 | 22 | float* const ahgssrp0 = ahgssrp + (rdim[0] < gdim[0] ? i[0] * rdim[0] / gdim[0]0 : i[0]) * rdim[1] * rdim[2] * rdim[3]; |
444 | 262 | for (i[1] = 0; i[1] < gdim[1]; i[1]++240 ) |
445 | 240 | { |
446 | 240 | float* const ahgssrp1 = ahgssrp0 + (rdim[1] < gdim[1] ? i[1] * rdim[1] / gdim[1]192 : i[1]48 ) * rdim[2] * rdim[3]; |
447 | 1.00k | for (i[2] = 0; i[2] < gdim[2]; i[2]++768 ) |
448 | 768 | { |
449 | 768 | float* const ahgssrp2 = ahgssrp1 + (rdim[2] < gdim[2] ? i[2] * rdim[2] / gdim[2]384 : i[2]384 ) * rdim[3]; |
450 | 768 | if (rdim[3] < gdim[3]) |
451 | 2.81k | for (x = 0; 256 x < gdim[3]; x++2.56k ) |
452 | 2.56k | ahgssrp2[x * rdim[3] / gdim[3]] += ahp[x] * gssp[x]; |
453 | 512 | else |
454 | 4.35k | for (x = 0; 512 x < gdim[3]; x++3.84k ) |
455 | 3.84k | ahgssrp2[x] += ahp[x] * gssp[x]; |
456 | 768 | ahp += gdim[3]; |
457 | 768 | gssp += gdim[3]; |
458 | 768 | } |
459 | 240 | } |
460 | 22 | } |
461 | | // Now the part to compute dx (h). |
462 | 9 | float* const hp = h->data.f32; |
463 | 9 | ahp = ah; |
464 | 9 | const float inv_n = 1. / n; |
465 | 9 | gssp = gss; |
466 | 31 | for (i[0] = 0; i[0] < gdim[0]; i[0]++22 ) |
467 | 22 | { |
468 | 22 | float* const hp0 = hp + i[0] * hstride[0]; |
469 | 22 | const float* const gssrp0 = gssrp + (rdim[0] < gdim[0] ? i[0] * rdim[0] / gdim[0]0 : i[0]) * rdim[1] * rdim[2] * rdim[3]; |
470 | 22 | const float* const ahgssrp0 = ahgssrp + (rdim[0] < gdim[0] ? i[0] * rdim[0] / gdim[0]0 : i[0]) * rdim[1] * rdim[2] * rdim[3]; |
471 | 262 | for (i[1] = 0; i[1] < gdim[1]; i[1]++240 ) |
472 | 240 | { |
473 | 240 | float* hp1 = hp0 + i[1] * hstride[1]; |
474 | 240 | const float* const gssrp1 = gssrp0 + (rdim[1] < gdim[1] ? i[1] * rdim[1] / gdim[1]192 : i[1]48 ) * rdim[2] * rdim[3]; |
475 | 240 | const float* const ahgssrp1 = ahgssrp0 + (rdim[1] < gdim[1] ? i[1] * rdim[1] / gdim[1]192 : i[1]48 ) * rdim[2] * rdim[3]; |
476 | 1.00k | for (i[2] = 0; i[2] < gdim[2]; i[2]++768 ) |
477 | 768 | { |
478 | 768 | const float* const gssrp2 = gssrp1 + (rdim[2] < gdim[2] ? i[2] * rdim[2] / gdim[2]384 : i[2]384 ) * rdim[3]; |
479 | 768 | const float* const ahgssrp2 = ahgssrp1 + (rdim[2] < gdim[2] ? i[2] * rdim[2] / gdim[2]384 : i[2]384 ) * rdim[3]; |
480 | 768 | if (rdim[3] < gdim[3]) |
481 | 2.81k | for (x = 0; 256 x < gdim[3]; x++2.56k ) |
482 | 2.56k | hp1[x] = gssp[x] - inv_n * (gssrp2[x * rdim[3] / gdim[3]] + ahp[x] * ahgssrp2[x * rdim[3] / gdim[3]]); |
483 | 512 | else |
484 | 4.35k | for (x = 0; 512 x < gdim[3]; x++3.84k ) |
485 | 3.84k | hp1[x] = gssp[x] - inv_n * (gssrp2[x] + ahp[x] * ahgssrp2[x]); |
486 | 768 | hp1 += hstride[2]; |
487 | 768 | ahp += gdim[3]; |
488 | 768 | gssp += gdim[3]; |
489 | 768 | } |
490 | 240 | } |
491 | 22 | } |
492 | 9 | return CCV_NNC_EXEC_SUCCESS; |
493 | 9 | } |
494 | | |
495 | | REGISTER_COMMAND_BACKEND(CCV_NNC_GROUP_NORM_FORWARD, CCV_NNC_BACKEND_CPU_REF)(ccv_nnc_cmd_backend_registry_t* const registry) |
496 | 1 | { |
497 | 1 | registry->tensor_formats = CCV_TENSOR_FORMAT_NHWC | CCV_TENSOR_FORMAT_NCHW | CCV_TENSOR_FORMAT_CHWN; |
498 | 1 | registry->tensor_datatypes = CCV_32F; |
499 | 1 | registry->tensor_memory = CCV_TENSOR_CPU_MEMORY; |
500 | 1 | registry->algorithms = 1; |
501 | 1 | registry->exec = _ccv_nnc_group_norm_forw; |
502 | 1 | } |
503 | | |
504 | | REGISTER_COMMAND_BACKEND(CCV_NNC_GROUP_NORM_BACKWARD, CCV_NNC_BACKEND_CPU_REF)(ccv_nnc_cmd_backend_registry_t* const registry) |
505 | 1 | { |
506 | 1 | registry->tensor_formats = CCV_TENSOR_FORMAT_NHWC | CCV_TENSOR_FORMAT_NCHW | CCV_TENSOR_FORMAT_CHWN; |
507 | 1 | registry->tensor_datatypes = CCV_32F; |
508 | 1 | registry->tensor_memory = CCV_TENSOR_CPU_MEMORY; |
509 | 1 | registry->algorithms = 1; |
510 | 1 | registry->exec = _ccv_nnc_group_norm_back; |
511 | 1 | } |