/home/liu/actions-runner/_work/ccv/ccv/lib/nnc/cmd/norm/ccv_nnc_rmsnorm_cpu_ref.c
Line | Count | Source |
1 | | #include "ccv.h" |
2 | | #include "ccv_internal.h" |
3 | | #include "nnc/ccv_nnc.h" |
4 | | #include "nnc/ccv_nnc_easy.h" |
5 | | #include "nnc/ccv_nnc_internal.h" |
6 | | #ifdef USE_OPENMP |
7 | | #include <omp.h> |
8 | | #endif |
9 | | #ifdef USE_DISPATCH |
10 | | #include <dispatch/dispatch.h> |
11 | | #endif |
12 | | |
13 | | // Shared methods. |
14 | | #include "../_ccv_nnc_cpu_ref.h" |
15 | | |
16 | | static int _ccv_nnc_rmsnorm_forw(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint, const int flags, ccv_nnc_tensor_t* const* const inputs, const int input_size, ccv_nnc_tensor_t* const* const outputs, const int output_size, ccv_nnc_stream_context_t* const stream_context) |
17 | 10 | { |
18 | 10 | assert(input_size == 2 || input_size == 1); |
19 | 10 | ccv_nnc_tensor_view_t* const a = (ccv_nnc_tensor_view_t*)inputs[0]; |
20 | 10 | ccv_nnc_tensor_view_t* const scale = input_size >= 2 ? (ccv_nnc_tensor_view_t*)inputs[1]5 : 05 ; |
21 | 10 | ccv_nnc_tensor_view_t* const b = (ccv_nnc_tensor_view_t*)outputs[0]; |
22 | 10 | ccv_nnc_tensor_view_t* const saved_inv_std = (ccv_nnc_tensor_view_t*)outputs[1]; |
23 | 10 | assert(ccv_nnc_tensor_nd(a->info.dim) <= CCV_NNC_MAX_DIM + 2); |
24 | 10 | assert(ccv_nnc_tensor_nd(b->info.dim) <= CCV_NNC_MAX_DIM + 2); |
25 | | // Assuming this is float 32. |
26 | 10 | int adim[CCV_NNC_MAX_DIM_ALLOC]; |
27 | 10 | int rdim[CCV_NNC_MAX_DIM_ALLOC]; |
28 | 10 | ccv_nnc_tensor_view_get_dim(a, adim); |
29 | 10 | ccv_nnc_tensor_view_get_dim(saved_inv_std, rdim); |
30 | 10 | assert(ccv_nnc_tensor_view_check_dim(b, adim)); |
31 | 10 | assert(CCV_NNC_MAX_DIM == 2); // Need to change this logic for CCV_NNC_MAX_DIM == other number. |
32 | 10 | int astride[CCV_NNC_MAX_DIM_ALLOC]; |
33 | 10 | int bstride[CCV_NNC_MAX_DIM_ALLOC]; |
34 | 10 | int scale_stride[CCV_NNC_MAX_DIM_ALLOC]; |
35 | 10 | ccv_nnc_tensor_view_get_stride(a, astride); |
36 | 10 | if (scale) |
37 | 5 | ccv_nnc_tensor_view_get_stride(scale, scale_stride); |
38 | 10 | ccv_nnc_tensor_view_get_stride(b, bstride); |
39 | | // The epsilon is used a little bit differently from batch norm, it is outside of the sqrt in this case. |
40 | 10 | const float epsilon = cmd.info.rmsnorm.epsilon; |
41 | 10 | int saved_inv_std_stride[CCV_NNC_MAX_DIM_ALLOC]; |
42 | 10 | ccv_nnc_tensor_view_get_stride(saved_inv_std, saved_inv_std_stride); |
43 | 10 | int x; |
44 | 10 | int n = 1; |
45 | 50 | for (x = 0; x < CCV_NNC_MAX_DIM + 2; x++40 ) |
46 | 40 | n *= adim[x]; |
47 | 50 | for (x = 0; x < CCV_NNC_MAX_DIM + 2; x++40 ) |
48 | 40 | n /= rdim[x]; |
49 | 10 | const float inv_n = 1. / n; |
50 | 10 | ccv_nnc_tensor_zero(saved_inv_std); |
51 | 10 | float* const ap = a->data.f32; |
52 | 10 | float* const varp = saved_inv_std->data.f32; |
53 | 10 | int i[CCV_NNC_MAX_DIM + 2]; |
54 | 54 | for (i[0] = 0; i[0] < adim[0]; i[0]++44 ) |
55 | 44 | { |
56 | 44 | float* const ap0 = ap + i[0] * astride[0]; |
57 | 44 | float* const varp0 = rdim[0] == 1 ? varp0 : varp + i[0] * saved_inv_std_stride[0]; |
58 | 196 | for (i[1] = 0; i[1] < adim[1]; i[1]++152 ) |
59 | 152 | { |
60 | 152 | float* ap1 = ap0 + i[1] * astride[1]; |
61 | 152 | float* const varp1 = rdim[1] == 1 ? varp0 : varp0 + i[1] * saved_inv_std_stride[1]0 ; |
62 | 712 | for (i[2] = 0; i[2] < adim[2]; i[2]++560 ) |
63 | 560 | { |
64 | 560 | float* const varp2 = rdim[2] == 1 ? varp1 : varp1 + i[2] * saved_inv_std_stride[2]0 ; |
65 | 560 | if (rdim[3] == 1) |
66 | 6.16k | for (x = 0; 560 x < adim[3]; x++5.60k ) |
67 | 5.60k | { |
68 | 5.60k | float w = ap1[x * astride[3]]; |
69 | 5.60k | varp2[0] += w * w; |
70 | 5.60k | } |
71 | 0 | else |
72 | 0 | for (x = 0; x < adim[3]; x++) |
73 | 0 | { |
74 | 0 | float w = ap1[x * astride[3]]; |
75 | 0 | varp2[x] += w * w; |
76 | 0 | } |
77 | 560 | ap1 += astride[2]; |
78 | 560 | } |
79 | 152 | } |
80 | 44 | } |
81 | 54 | for (i[0] = 0; i[0] < rdim[0]; i[0]++44 ) |
82 | 44 | { |
83 | 44 | float* const varp0 = varp + i[0] * saved_inv_std_stride[0]; |
84 | 88 | for (i[1] = 0; i[1] < rdim[1]; i[1]++44 ) |
85 | 44 | { |
86 | 44 | float* const varp1 = varp0 + i[1] * saved_inv_std_stride[1]; |
87 | 88 | for (i[2] = 0; i[2] < rdim[2]; i[2]++44 ) |
88 | 44 | { |
89 | 44 | float* const varp2 = varp1 + i[2] * saved_inv_std_stride[2]; |
90 | 88 | for (x = 0; x < rdim[3]; x++44 ) |
91 | 44 | varp2[x] = 1. / sqrtf(varp2[x] * inv_n + epsilon); |
92 | 44 | } |
93 | 44 | } |
94 | 44 | } |
95 | 10 | if (cmd.info.rmsnorm.elementwise_affine) |
96 | 5 | { |
97 | 5 | float* const scalep = scale->data.f32; |
98 | 5 | int sdim[CCV_NNC_MAX_DIM_ALLOC]; |
99 | 5 | ccv_nnc_tensor_view_get_dim(scale, sdim); |
100 | | // Do the straight-forward one, y = x * inv_std * scale + bias, we cannot allocate extra memory to help. |
101 | | // There is no need for precompute since scale / bias is per element. |
102 | 5 | float* const bp = b->data.f32; |
103 | 27 | for (i[0] = 0; i[0] < adim[0]; i[0]++22 ) |
104 | 22 | { |
105 | 22 | float* const ap0 = ap + i[0] * astride[0]; |
106 | 22 | float* const bp0 = bp + i[0] * bstride[0]; |
107 | 22 | float* const varp0 = rdim[0] == 1 ? varp0 : varp + i[0] * saved_inv_std_stride[0]; |
108 | 22 | float* const scalep0 = sdim[0] == 1 ? scalep : scalep + i[0] * scale_stride[0]0 ; |
109 | 98 | for (i[1] = 0; i[1] < adim[1]; i[1]++76 ) |
110 | 76 | { |
111 | 76 | float* ap1 = ap0 + i[1] * astride[1]; |
112 | 76 | float* bp1 = bp0 + i[1] * bstride[1]; |
113 | 76 | float* const varp1 = rdim[1] == 1 ? varp0 : varp0 + i[1] * saved_inv_std_stride[1]0 ; |
114 | 76 | float* const scalep1 = sdim[1] == 1 ? scalep00 : scalep0 + i[1] * scale_stride[1]; |
115 | 356 | for (i[2] = 0; i[2] < adim[2]; i[2]++280 ) |
116 | 280 | { |
117 | 280 | float* const varp2 = rdim[2] == 1 ? varp1 : varp1 + i[2] * saved_inv_std_stride[2]0 ; |
118 | 280 | float* const scalep2 = sdim[2] == 1 ? scalep10 : scalep1 + i[2] * scale_stride[2]; |
119 | 280 | if (rdim[3] == 1) |
120 | 3.08k | for (x = 0; 280 x < adim[3]; x++2.80k ) |
121 | 2.80k | bp1[x] = ap1[x * astride[3]] * varp2[0] * scalep2[sdim[3] == 1 ? 00 : x]; |
122 | 0 | else |
123 | 0 | for (x = 0; x < adim[3]; x++) |
124 | 0 | bp1[x] = ap1[x * astride[3]] * varp2[x] * scalep2[sdim[3] == 1 ? 0 : x]; |
125 | 280 | ap1 += astride[2]; |
126 | 280 | bp1 += bstride[2]; |
127 | 280 | } |
128 | 76 | } |
129 | 22 | } |
130 | 5 | } else { |
131 | | // Do the straight-forward one, y = x * inv_std, we cannot allocate extra memory to help. |
132 | 5 | float* const bp = b->data.f32; |
133 | 27 | for (i[0] = 0; i[0] < adim[0]; i[0]++22 ) |
134 | 22 | { |
135 | 22 | float* const ap0 = ap + i[0] * astride[0]; |
136 | 22 | float* const bp0 = bp + i[0] * bstride[0]; |
137 | 22 | float* const varp0 = rdim[0] == 1 ? varp0 : varp + i[0] * saved_inv_std_stride[0]; |
138 | 98 | for (i[1] = 0; i[1] < adim[1]; i[1]++76 ) |
139 | 76 | { |
140 | 76 | float* ap1 = ap0 + i[1] * astride[1]; |
141 | 76 | float* bp1 = bp0 + i[1] * bstride[1]; |
142 | 76 | float* const varp1 = rdim[1] == 1 ? varp0 : varp0 + i[1] * saved_inv_std_stride[1]0 ; |
143 | 356 | for (i[2] = 0; i[2] < adim[2]; i[2]++280 ) |
144 | 280 | { |
145 | 280 | float* const varp2 = rdim[2] == 1 ? varp1 : varp1 + i[2] * saved_inv_std_stride[2]0 ; |
146 | 280 | if (rdim[3] == 1) |
147 | 3.08k | for (x = 0; 280 x < adim[3]; x++2.80k ) |
148 | 2.80k | bp1[x] = ap1[x * astride[3]] * varp2[0]; |
149 | 0 | else |
150 | 0 | for (x = 0; x < adim[3]; x++) |
151 | 0 | bp1[x] = ap1[x * astride[3]] * varp2[x]; |
152 | 280 | ap1 += astride[2]; |
153 | 280 | bp1 += bstride[2]; |
154 | 280 | } |
155 | 76 | } |
156 | 22 | } |
157 | 5 | } |
158 | 10 | return CCV_NNC_EXEC_SUCCESS; |
159 | 10 | } |
160 | | |
161 | | static int _ccv_nnc_rmsnorm_back(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint, const int flags, ccv_nnc_tensor_t* const* const inputs, const int input_size, ccv_nnc_tensor_t* const* const outputs, const int output_size, ccv_nnc_stream_context_t* const stream_context) |
162 | 6 | { |
163 | 6 | assert(input_size == 6 || input_size == 5); |
164 | 6 | assert(output_size >= 1); |
165 | 6 | const int elementwise_affine = cmd.info.rmsnorm.elementwise_affine; |
166 | 6 | ccv_nnc_tensor_view_t* const g = (ccv_nnc_tensor_view_t*)inputs[0]; |
167 | 6 | ccv_nnc_tensor_view_t* const a = (ccv_nnc_tensor_view_t*)inputs[2]; |
168 | 6 | ccv_nnc_tensor_view_t* const scale = elementwise_affine ? (ccv_nnc_tensor_view_t*)inputs[3]3 : 03 ; |
169 | 6 | ccv_nnc_tensor_view_t* const saved_inv_std = (ccv_nnc_tensor_view_t*)inputs[elementwise_affine ? 53 : 43 ]; |
170 | 6 | ccv_nnc_tensor_view_t* const h = (ccv_nnc_tensor_view_t*)outputs[0]; |
171 | 6 | ccv_nnc_tensor_view_t* const dscale = output_size > 1 ? (ccv_nnc_tensor_view_t*)outputs[1]3 : 03 ; |
172 | 6 | assert(ccv_nnc_tensor_nd(g->info.dim) <= CCV_NNC_MAX_DIM + 2); |
173 | 6 | assert(ccv_nnc_tensor_nd(a->info.dim) <= CCV_NNC_MAX_DIM + 2); |
174 | 6 | assert(ccv_nnc_tensor_nd(h->info.dim) <= CCV_NNC_MAX_DIM + 2); |
175 | | // Assuming this is float 32. |
176 | 6 | int gdim[CCV_NNC_MAX_DIM_ALLOC]; |
177 | 6 | int rdim[CCV_NNC_MAX_DIM_ALLOC]; |
178 | 6 | ccv_nnc_tensor_view_get_dim(g, gdim); |
179 | 6 | ccv_nnc_tensor_view_get_dim(saved_inv_std, rdim); |
180 | 6 | int sdim[CCV_NNC_MAX_DIM_ALLOC]; |
181 | 6 | if (scale) |
182 | 3 | ccv_nnc_tensor_view_get_dim(scale, sdim); |
183 | 6 | if (dscale) |
184 | 2 | { assert(ccv_nnc_tensor_view_check_dim(dscale, sdim)); } |
185 | 6 | assert(ccv_nnc_tensor_view_check_dim(a, gdim)); |
186 | 6 | assert(ccv_nnc_tensor_view_check_dim(h, gdim)); |
187 | 6 | int astride[CCV_NNC_MAX_DIM_ALLOC]; |
188 | 6 | int gstride[CCV_NNC_MAX_DIM_ALLOC]; |
189 | 6 | int hstride[CCV_NNC_MAX_DIM_ALLOC]; |
190 | 6 | int scale_stride[CCV_NNC_MAX_DIM_ALLOC]; |
191 | 6 | int inv_std_stride[CCV_NNC_MAX_DIM_ALLOC]; |
192 | 6 | int dscale_stride[CCV_NNC_MAX_DIM_ALLOC]; |
193 | 6 | ccv_nnc_tensor_view_get_stride(a, astride); |
194 | 6 | ccv_nnc_tensor_view_get_stride(g, gstride); |
195 | 6 | ccv_nnc_tensor_view_get_stride(h, hstride); |
196 | 6 | if (scale) |
197 | 3 | ccv_nnc_tensor_view_get_stride(scale, scale_stride); |
198 | 6 | ccv_nnc_tensor_view_get_stride(saved_inv_std, inv_std_stride); |
199 | 6 | if (dscale) |
200 | 2 | ccv_nnc_tensor_view_get_stride(dscale, dscale_stride); |
201 | | // Need to allocate two additional memory: |
202 | | // 1. normalized a; |
203 | | // 2. scale * inv_std / n; |
204 | 6 | assert(!(flags & CCV_NNC_ZERO_MEMORY_ALLOC)); |
205 | 6 | int x; |
206 | 6 | int n = 1; |
207 | 30 | for (x = 0; x < CCV_NNC_MAX_DIM + 2; x++24 ) |
208 | 24 | n *= gdim[x]; |
209 | 30 | for (x = 0; x < CCV_NNC_MAX_DIM + 2; x++24 ) |
210 | 24 | n /= rdim[x]; |
211 | 6 | int gcount = 1, rcount = 1; |
212 | 30 | for (x = 0; x < CCV_NNC_MAX_DIM + 2; x++24 ) |
213 | 24 | gcount *= gdim[x], rcount *= rdim[x]; |
214 | 6 | float* const ah = (float*)ccv_nnc_stream_context_get_workspace(stream_context, sizeof(float) * gcount * 2 + sizeof(float) * rcount, CCV_TENSOR_CPU_MEMORY); |
215 | 6 | float* const gss = ah + gcount; // g * scale * inv_std |
216 | 6 | float* const ahgssr = gss + gcount; // ah * gss then reduced to inv_std dimension. |
217 | 6 | int i[CCV_NNC_MAX_DIM + 2]; |
218 | 6 | float* ahp = ah; |
219 | 6 | const float* const inv_stdp = saved_inv_std->data.f32; |
220 | 6 | const float* const ap = a->data.f32; |
221 | 30 | for (i[0] = 0; i[0] < gdim[0]; i[0]++24 ) |
222 | 24 | { |
223 | 24 | const float* const ap0 = ap + i[0] * astride[0]; |
224 | 24 | const float* const inv_stdp0 = rdim[0] == 1 ? inv_stdp0 : inv_stdp + i[0] * inv_std_stride[0]; |
225 | 104 | for (i[1] = 0; i[1] < gdim[1]; i[1]++80 ) |
226 | 80 | { |
227 | 80 | const float* ap1 = ap0 + i[1] * astride[1]; |
228 | 80 | const float* const inv_stdp1 = rdim[1] == 1 ? inv_stdp0 : inv_stdp0 + i[1] * inv_std_stride[1]0 ; |
229 | 368 | for (i[2] = 0; i[2] < gdim[2]; i[2]++288 ) |
230 | 288 | { |
231 | 288 | const float* const inv_stdp2 = rdim[2] == 1 ? inv_stdp1 : inv_stdp1 + i[2] * inv_std_stride[2]0 ; |
232 | 288 | if (rdim[3] == 1) |
233 | 3.16k | for (x = 0; 288 x < gdim[3]; x++2.88k ) |
234 | 2.88k | ahp[x] = ap1[x] * inv_stdp2[0]; |
235 | 0 | else |
236 | 0 | for (x = 0; x < gdim[3]; x++) |
237 | 0 | ahp[x] = ap1[x] * inv_stdp2[x]; |
238 | 288 | ap1 += astride[2]; |
239 | 288 | ahp += gdim[3]; |
240 | 288 | } |
241 | 80 | } |
242 | 24 | } |
243 | 6 | if (dscale) |
244 | 2 | { |
245 | 2 | ccv_nnc_tensor_zero(dscale); |
246 | 2 | ahp = ah; |
247 | 2 | float* gssp = gss; |
248 | 2 | const float* const gp = g->data.f32; |
249 | 2 | const float* const scalep = scale->data.f32; |
250 | 2 | float* const dscalep = dscale->data.f32; |
251 | 12 | for (i[0] = 0; i[0] < gdim[0]; i[0]++10 ) |
252 | 10 | { |
253 | 10 | const float* const gp0 = gp + i[0] * gstride[0]; |
254 | 10 | const float* const inv_stdp0 = rdim[0] == 1 ? inv_stdp0 : inv_stdp + i[0] * inv_std_stride[0]; |
255 | 10 | const float* const scalep0 = sdim[0] == 1 ? scalep : scalep + i[0] * scale_stride[0]0 ; |
256 | 10 | float* const dscalep0 = sdim[0] == 1 ? dscalep : dscalep + i[0] * dscale_stride[0]0 ; |
257 | 46 | for (i[1] = 0; i[1] < gdim[1]; i[1]++36 ) |
258 | 36 | { |
259 | 36 | const float* gp1 = gp0 + i[1] * gstride[1]; |
260 | 36 | const float* const inv_stdp1 = rdim[1] == 1 ? inv_stdp0 : inv_stdp0 + i[1] * inv_std_stride[1]0 ; |
261 | 36 | const float* const scalep1 = sdim[1] == 1 ? scalep00 : scalep0 + i[1] * scale_stride[1]; |
262 | 36 | float* const dscalep1 = sdim[1] == 1 ? dscalep00 : dscalep0 + i[1] * dscale_stride[1]; |
263 | 172 | for (i[2] = 0; i[2] < gdim[2]; i[2]++136 ) |
264 | 136 | { |
265 | 136 | const float* const inv_stdp2 = rdim[2] == 1 ? inv_stdp1 : inv_stdp1 + i[2] * inv_std_stride[2]0 ; |
266 | 136 | const float* const scalep2 = sdim[2] == 1 ? scalep10 : scalep1 + i[2] * scale_stride[2]; |
267 | 136 | float* const dscalep2 = sdim[2] == 1 ? dscalep10 : dscalep1 + i[2] * dscale_stride[2]; |
268 | 136 | if (sdim[3] == 1) |
269 | 0 | for (x = 0; x < gdim[3]; x++) |
270 | 0 | { |
271 | 0 | gssp[x] = gp1[x] * scalep2[0] * inv_stdp2[rdim[3] == 1 ? 0 : x]; |
272 | 0 | dscalep2[0] += ahp[x] * gp1[x]; |
273 | 0 | } |
274 | 136 | else |
275 | 1.49k | for (x = 0; 136 x < gdim[3]; x++1.36k ) |
276 | 1.36k | { |
277 | 1.36k | gssp[x] = gp1[x] * scalep2[x] * inv_stdp2[rdim[3] == 1 ? 0 : x0 ]; |
278 | 1.36k | dscalep2[x] += ahp[x] * gp1[x]; |
279 | 1.36k | } |
280 | 136 | gp1 += gstride[2]; |
281 | 136 | ahp += gdim[3]; |
282 | 136 | gssp += gdim[3]; |
283 | 136 | } |
284 | 36 | } |
285 | 10 | } |
286 | 4 | } else { |
287 | 4 | float* gssp = gss; |
288 | 4 | const float* const gp = g->data.f32; |
289 | 4 | if (elementwise_affine) |
290 | 1 | { |
291 | 1 | const float* const scalep = scale->data.f32; |
292 | 3 | for (i[0] = 0; i[0] < gdim[0]; i[0]++2 ) |
293 | 2 | { |
294 | 2 | const float* const gp0 = gp + i[0] * gstride[0]; |
295 | 2 | const float* const inv_stdp0 = rdim[0] == 1 ? inv_stdp0 : inv_stdp + i[0] * inv_std_stride[0]; |
296 | 2 | const float* const scalep0 = sdim[0] == 1 ? scalep : scalep + i[0] * scale_stride[0]0 ; |
297 | 6 | for (i[1] = 0; i[1] < gdim[1]; i[1]++4 ) |
298 | 4 | { |
299 | 4 | const float* gp1 = gp0 + i[1] * gstride[1]; |
300 | 4 | const float* const inv_stdp1 = rdim[1] == 1 ? inv_stdp0 : inv_stdp0 + i[1] * inv_std_stride[1]0 ; |
301 | 4 | const float* const scalep1 = sdim[1] == 1 ? scalep00 : scalep0 + i[1] * scale_stride[1]; |
302 | 12 | for (i[2] = 0; i[2] < gdim[2]; i[2]++8 ) |
303 | 8 | { |
304 | 8 | const float* const inv_stdp2 = rdim[2] == 1 ? inv_stdp1 : inv_stdp1 + i[2] * inv_std_stride[2]0 ; |
305 | 8 | const float* const scalep2 = sdim[2] == 1 ? scalep10 : scalep1 + i[2] * scale_stride[2]; |
306 | 8 | if (sdim[3] == 1) |
307 | 0 | for (x = 0; x < gdim[3]; x++) |
308 | 0 | gssp[x] = gp1[x] * scalep2[0] * inv_stdp2[rdim[3] == 1 ? 0 : x]; |
309 | 8 | else |
310 | 88 | for (x = 0; 8 x < gdim[3]; x++80 ) |
311 | 80 | gssp[x] = gp1[x] * scalep2[x] * inv_stdp2[rdim[3] == 1 ? 0 : x0 ]; |
312 | 8 | gp1 += gstride[2]; |
313 | 8 | gssp += gdim[3]; |
314 | 8 | } |
315 | 4 | } |
316 | 2 | } |
317 | 3 | } else { |
318 | 15 | for (i[0] = 0; i[0] < gdim[0]; i[0]++12 ) |
319 | 12 | { |
320 | 12 | const float* const gp0 = gp + i[0] * gstride[0]; |
321 | 12 | const float* const inv_stdp0 = rdim[0] == 1 ? inv_stdp0 : inv_stdp + i[0] * inv_std_stride[0]; |
322 | 52 | for (i[1] = 0; i[1] < gdim[1]; i[1]++40 ) |
323 | 40 | { |
324 | 40 | const float* gp1 = gp0 + i[1] * gstride[1]; |
325 | 40 | const float* const inv_stdp1 = rdim[1] == 1 ? inv_stdp0 : inv_stdp0 + i[1] * inv_std_stride[1]0 ; |
326 | 184 | for (i[2] = 0; i[2] < gdim[2]; i[2]++144 ) |
327 | 144 | { |
328 | 144 | const float* const inv_stdp2 = rdim[2] == 1 ? inv_stdp1 : inv_stdp1 + i[2] * inv_std_stride[2]0 ; |
329 | 144 | if (rdim[3] == 1) |
330 | 1.58k | for (x = 0; 144 x < gdim[3]; x++1.44k ) |
331 | 1.44k | gssp[x] = gp1[x] * inv_stdp2[0]; |
332 | 0 | else |
333 | 0 | for (x = 0; x < gdim[3]; x++) |
334 | 0 | gssp[x] = gp1[x] * inv_stdp2[x]; |
335 | 144 | gp1 += gstride[2]; |
336 | 144 | gssp += gdim[3]; |
337 | 144 | } |
338 | 40 | } |
339 | 12 | } |
340 | 3 | } |
341 | 4 | } |
342 | 6 | ahp = ah; |
343 | 6 | float* gssp = gss; |
344 | 6 | ccv_nnc_tensor_t ahgssrt = ccv_nnc_tensor(ahgssr, saved_inv_std->info, 0); |
345 | 6 | ccv_nnc_tensor_zero(&ahgssrt); |
346 | 6 | float* const ahgssrp = ahgssr; |
347 | 30 | for (i[0] = 0; i[0] < gdim[0]; i[0]++24 ) |
348 | 24 | { |
349 | 24 | float* const ahgssrp0 = rdim[0] == 1 ? ahgssrp0 : ahgssrp + i[0] * rdim[1] * rdim[2] * rdim[3]; |
350 | 104 | for (i[1] = 0; i[1] < gdim[1]; i[1]++80 ) |
351 | 80 | { |
352 | 80 | float* const ahgssrp1 = rdim[1] == 1 ? ahgssrp0 : ahgssrp0 + i[1] * rdim[2] * rdim[3]0 ; |
353 | 368 | for (i[2] = 0; i[2] < gdim[2]; i[2]++288 ) |
354 | 288 | { |
355 | 288 | float* const ahgssrp2 = rdim[2] == 1 ? ahgssrp1 : ahgssrp1 + i[2] * rdim[3]0 ; |
356 | 288 | if (rdim[3] == 1) |
357 | 3.16k | for (x = 0; 288 x < gdim[3]; x++2.88k ) |
358 | 2.88k | ahgssrp2[0] += ahp[x] * gssp[x]; |
359 | 0 | else |
360 | 0 | for (x = 0; x < gdim[3]; x++) |
361 | 0 | ahgssrp2[x] += ahp[x] * gssp[x]; |
362 | 288 | ahp += gdim[3]; |
363 | 288 | gssp += gdim[3]; |
364 | 288 | } |
365 | 80 | } |
366 | 24 | } |
367 | | // Now the part to compute dx (h). |
368 | 6 | float* const hp = h->data.f32; |
369 | 6 | ahp = ah; |
370 | 6 | const float inv_n = 1. / n; |
371 | 6 | gssp = gss; |
372 | 30 | for (i[0] = 0; i[0] < gdim[0]; i[0]++24 ) |
373 | 24 | { |
374 | 24 | float* const hp0 = hp + i[0] * hstride[0]; |
375 | 24 | const float* const ahgssrp0 = rdim[0] == 1 ? ahgssrp0 : ahgssrp + i[0] * rdim[1] * rdim[2] * rdim[3]; |
376 | 104 | for (i[1] = 0; i[1] < gdim[1]; i[1]++80 ) |
377 | 80 | { |
378 | 80 | float* hp1 = hp0 + i[1] * hstride[1]; |
379 | 80 | const float* const ahgssrp1 = rdim[1] == 1 ? ahgssrp0 : ahgssrp0 + i[1] * rdim[2] * rdim[3]0 ; |
380 | 368 | for (i[2] = 0; i[2] < gdim[2]; i[2]++288 ) |
381 | 288 | { |
382 | 288 | const float* const ahgssrp2 = rdim[2] == 1 ? ahgssrp1 : ahgssrp1 + i[2] * rdim[3]0 ; |
383 | 288 | if (rdim[3] == 1) |
384 | 3.16k | for (x = 0; 288 x < gdim[3]; x++2.88k ) |
385 | 2.88k | hp1[x] = gssp[x] - inv_n * ahp[x] * ahgssrp2[0]; |
386 | 0 | else |
387 | 0 | for (x = 0; x < gdim[3]; x++) |
388 | 0 | hp1[x] = gssp[x] - inv_n * ahp[x] * ahgssrp2[x]; |
389 | 288 | hp1 += hstride[2]; |
390 | 288 | ahp += gdim[3]; |
391 | 288 | gssp += gdim[3]; |
392 | 288 | } |
393 | 80 | } |
394 | 24 | } |
395 | 6 | return CCV_NNC_EXEC_SUCCESS; |
396 | 6 | } |
397 | | |
398 | | REGISTER_COMMAND_BACKEND(CCV_NNC_RMSNORM_FORWARD, CCV_NNC_BACKEND_CPU_REF)(ccv_nnc_cmd_backend_registry_t* const registry) |
399 | 1 | { |
400 | 1 | registry->tensor_formats = CCV_TENSOR_FORMAT_NHWC | CCV_TENSOR_FORMAT_NCHW | CCV_TENSOR_FORMAT_CHWN; |
401 | 1 | registry->tensor_datatypes = CCV_32F; |
402 | 1 | registry->tensor_memory = CCV_TENSOR_CPU_MEMORY; |
403 | 1 | registry->algorithms = 1; |
404 | 1 | registry->exec = _ccv_nnc_rmsnorm_forw; |
405 | 1 | } |
406 | | |
407 | | REGISTER_COMMAND_BACKEND(CCV_NNC_RMSNORM_BACKWARD, CCV_NNC_BACKEND_CPU_REF)(ccv_nnc_cmd_backend_registry_t* const registry) |
408 | 1 | { |
409 | 1 | registry->tensor_formats = CCV_TENSOR_FORMAT_NHWC | CCV_TENSOR_FORMAT_NCHW | CCV_TENSOR_FORMAT_CHWN; |
410 | 1 | registry->tensor_datatypes = CCV_32F; |
411 | 1 | registry->tensor_memory = CCV_TENSOR_CPU_MEMORY; |
412 | 1 | registry->algorithms = 1; |
413 | 1 | registry->exec = _ccv_nnc_rmsnorm_back; |
414 | 1 | } |