Coverage Report

Created: 2026-04-14 20:48

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/home/liu/actions-runner/_work/ccv/ccv/lib/nnc/cmd/norm/ccv_nnc_rmsnorm_cpu_ref.c
Line
Count
Source
1
#include "ccv.h"
2
#include "ccv_internal.h"
3
#include "nnc/ccv_nnc.h"
4
#include "nnc/ccv_nnc_easy.h"
5
#include "nnc/ccv_nnc_internal.h"
6
#ifdef USE_OPENMP
7
#include <omp.h>
8
#endif
9
#ifdef USE_DISPATCH
10
#include <dispatch/dispatch.h>
11
#endif
12
13
// Shared methods.
14
#include "../_ccv_nnc_cpu_ref.h"
15
16
static int _ccv_nnc_rmsnorm_forw(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint, const int flags, ccv_nnc_tensor_t* const* const inputs, const int input_size, ccv_nnc_tensor_t* const* const outputs, const int output_size, ccv_nnc_stream_context_t* const stream_context)
17
10
{
18
10
  assert(input_size == 2 || input_size == 1);
19
10
  ccv_nnc_tensor_view_t* const a = (ccv_nnc_tensor_view_t*)inputs[0];
20
10
  ccv_nnc_tensor_view_t* const scale = input_size >= 2 ? 
(ccv_nnc_tensor_view_t*)inputs[1]5
:
05
;
21
10
  ccv_nnc_tensor_view_t* const b = (ccv_nnc_tensor_view_t*)outputs[0];
22
10
  ccv_nnc_tensor_view_t* const saved_inv_std = (ccv_nnc_tensor_view_t*)outputs[1];
23
10
  assert(ccv_nnc_tensor_nd(a->info.dim) <= CCV_NNC_MAX_DIM + 2);
24
10
  assert(ccv_nnc_tensor_nd(b->info.dim) <= CCV_NNC_MAX_DIM + 2);
25
  // Assuming this is float 32.
26
10
  int adim[CCV_NNC_MAX_DIM_ALLOC];
27
10
  int rdim[CCV_NNC_MAX_DIM_ALLOC];
28
10
  ccv_nnc_tensor_view_get_dim(a, adim);
29
10
  ccv_nnc_tensor_view_get_dim(saved_inv_std, rdim);
30
10
  assert(ccv_nnc_tensor_view_check_dim(b, adim));
31
10
  assert(CCV_NNC_MAX_DIM == 2); // Need to change this logic for CCV_NNC_MAX_DIM == other number.
32
10
  int astride[CCV_NNC_MAX_DIM_ALLOC];
33
10
  int bstride[CCV_NNC_MAX_DIM_ALLOC];
34
10
  int scale_stride[CCV_NNC_MAX_DIM_ALLOC];
35
10
  ccv_nnc_tensor_view_get_stride(a, astride);
36
10
  if (scale)
37
5
    ccv_nnc_tensor_view_get_stride(scale, scale_stride);
38
10
  ccv_nnc_tensor_view_get_stride(b, bstride);
39
  // The epsilon is used a little bit differently from batch norm, it is outside of the sqrt in this case.
40
10
  const float epsilon = cmd.info.rmsnorm.epsilon;
41
10
  int saved_inv_std_stride[CCV_NNC_MAX_DIM_ALLOC];
42
10
  ccv_nnc_tensor_view_get_stride(saved_inv_std, saved_inv_std_stride);
43
10
  int x;
44
10
  int n = 1;
45
50
  for (x = 0; x < CCV_NNC_MAX_DIM + 2; 
x++40
)
46
40
    n *= adim[x];
47
50
  for (x = 0; x < CCV_NNC_MAX_DIM + 2; 
x++40
)
48
40
    n /= rdim[x];
49
10
  const float inv_n = 1. / n;
50
10
  ccv_nnc_tensor_zero(saved_inv_std);
51
10
  float* const ap = a->data.f32;
52
10
  float* const varp = saved_inv_std->data.f32;
53
10
  int i[CCV_NNC_MAX_DIM + 2];
54
54
  for (i[0] = 0; i[0] < adim[0]; 
i[0]++44
)
55
44
  {
56
44
    float* const ap0 = ap + i[0] * astride[0];
57
44
    float* const varp0 = rdim[0] == 1 ? 
varp0
: varp + i[0] * saved_inv_std_stride[0];
58
196
    for (i[1] = 0; i[1] < adim[1]; 
i[1]++152
)
59
152
    {
60
152
      float* ap1 = ap0 + i[1] * astride[1];
61
152
      float* const varp1 = rdim[1] == 1 ? varp0 : 
varp0 + i[1] * saved_inv_std_stride[1]0
;
62
712
      for (i[2] = 0; i[2] < adim[2]; 
i[2]++560
)
63
560
      {
64
560
        float* const varp2 = rdim[2] == 1 ? varp1 : 
varp1 + i[2] * saved_inv_std_stride[2]0
;
65
560
        if (rdim[3] == 1)
66
6.16k
          
for (x = 0; 560
x < adim[3];
x++5.60k
)
67
5.60k
          {
68
5.60k
            float w = ap1[x * astride[3]];
69
5.60k
            varp2[0] += w * w;
70
5.60k
          }
71
0
        else
72
0
          for (x = 0; x < adim[3]; x++)
73
0
          {
74
0
            float w = ap1[x * astride[3]];
75
0
            varp2[x] += w * w;
76
0
          }
77
560
        ap1 += astride[2];
78
560
      }
79
152
    }
80
44
  }
81
54
  for (i[0] = 0; i[0] < rdim[0]; 
i[0]++44
)
82
44
  {
83
44
    float* const varp0 = varp + i[0] * saved_inv_std_stride[0];
84
88
    for (i[1] = 0; i[1] < rdim[1]; 
i[1]++44
)
85
44
    {
86
44
      float* const varp1 = varp0 + i[1] * saved_inv_std_stride[1];
87
88
      for (i[2] = 0; i[2] < rdim[2]; 
i[2]++44
)
88
44
      {
89
44
        float* const varp2 = varp1 + i[2] * saved_inv_std_stride[2];
90
88
        for (x = 0; x < rdim[3]; 
x++44
)
91
44
          varp2[x] = 1. / sqrtf(varp2[x] * inv_n + epsilon);
92
44
      }
93
44
    }
94
44
  }
95
10
  if (cmd.info.rmsnorm.elementwise_affine)
96
5
  {
97
5
    float* const scalep = scale->data.f32;
98
5
    int sdim[CCV_NNC_MAX_DIM_ALLOC];
99
5
    ccv_nnc_tensor_view_get_dim(scale, sdim);
100
    // Do the straight-forward one, y = x * inv_std * scale + bias, we cannot allocate extra memory to help.
101
    // There is no need for precompute since scale / bias is per element.
102
5
    float* const bp = b->data.f32;
103
27
    for (i[0] = 0; i[0] < adim[0]; 
i[0]++22
)
104
22
    {
105
22
      float* const ap0 = ap + i[0] * astride[0];
106
22
      float* const bp0 = bp + i[0] * bstride[0];
107
22
      float* const varp0 = rdim[0] == 1 ? 
varp0
: varp + i[0] * saved_inv_std_stride[0];
108
22
      float* const scalep0 = sdim[0] == 1 ? scalep : 
scalep + i[0] * scale_stride[0]0
;
109
98
      for (i[1] = 0; i[1] < adim[1]; 
i[1]++76
)
110
76
      {
111
76
        float* ap1 = ap0 + i[1] * astride[1];
112
76
        float* bp1 = bp0 + i[1] * bstride[1];
113
76
        float* const varp1 = rdim[1] == 1 ? varp0 : 
varp0 + i[1] * saved_inv_std_stride[1]0
;
114
76
        float* const scalep1 = sdim[1] == 1 ? 
scalep00
: scalep0 + i[1] * scale_stride[1];
115
356
        for (i[2] = 0; i[2] < adim[2]; 
i[2]++280
)
116
280
        {
117
280
          float* const varp2 = rdim[2] == 1 ? varp1 : 
varp1 + i[2] * saved_inv_std_stride[2]0
;
118
280
          float* const scalep2 = sdim[2] == 1 ? 
scalep10
: scalep1 + i[2] * scale_stride[2];
119
280
          if (rdim[3] == 1)
120
3.08k
            
for (x = 0; 280
x < adim[3];
x++2.80k
)
121
2.80k
              bp1[x] = ap1[x * astride[3]] * varp2[0] * scalep2[sdim[3] == 1 ? 
00
: x];
122
0
          else
123
0
            for (x = 0; x < adim[3]; x++)
124
0
              bp1[x] = ap1[x * astride[3]] * varp2[x] * scalep2[sdim[3] == 1 ? 0 : x];
125
280
          ap1 += astride[2];
126
280
          bp1 += bstride[2];
127
280
        }
128
76
      }
129
22
    }
130
5
  } else {
131
    // Do the straight-forward one, y = x * inv_std, we cannot allocate extra memory to help.
132
5
    float* const bp = b->data.f32;
133
27
    for (i[0] = 0; i[0] < adim[0]; 
i[0]++22
)
134
22
    {
135
22
      float* const ap0 = ap + i[0] * astride[0];
136
22
      float* const bp0 = bp + i[0] * bstride[0];
137
22
      float* const varp0 = rdim[0] == 1 ? 
varp0
: varp + i[0] * saved_inv_std_stride[0];
138
98
      for (i[1] = 0; i[1] < adim[1]; 
i[1]++76
)
139
76
      {
140
76
        float* ap1 = ap0 + i[1] * astride[1];
141
76
        float* bp1 = bp0 + i[1] * bstride[1];
142
76
        float* const varp1 = rdim[1] == 1 ? varp0 : 
varp0 + i[1] * saved_inv_std_stride[1]0
;
143
356
        for (i[2] = 0; i[2] < adim[2]; 
i[2]++280
)
144
280
        {
145
280
          float* const varp2 = rdim[2] == 1 ? varp1 : 
varp1 + i[2] * saved_inv_std_stride[2]0
;
146
280
          if (rdim[3] == 1)
147
3.08k
            
for (x = 0; 280
x < adim[3];
x++2.80k
)
148
2.80k
              bp1[x] = ap1[x * astride[3]] * varp2[0];
149
0
          else
150
0
            for (x = 0; x < adim[3]; x++)
151
0
              bp1[x] = ap1[x * astride[3]] * varp2[x];
152
280
          ap1 += astride[2];
153
280
          bp1 += bstride[2];
154
280
        }
155
76
      }
156
22
    }
157
5
  }
158
10
  return CCV_NNC_EXEC_SUCCESS;
159
10
}
160
161
static int _ccv_nnc_rmsnorm_back(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint, const int flags, ccv_nnc_tensor_t* const* const inputs, const int input_size, ccv_nnc_tensor_t* const* const outputs, const int output_size, ccv_nnc_stream_context_t* const stream_context)
162
6
{
163
6
  assert(input_size == 6 || input_size == 5);
164
6
  assert(output_size >= 1);
165
6
  const int elementwise_affine = cmd.info.rmsnorm.elementwise_affine;
166
6
  ccv_nnc_tensor_view_t* const g = (ccv_nnc_tensor_view_t*)inputs[0];
167
6
  ccv_nnc_tensor_view_t* const a = (ccv_nnc_tensor_view_t*)inputs[2];
168
6
  ccv_nnc_tensor_view_t* const scale = elementwise_affine ? 
(ccv_nnc_tensor_view_t*)inputs[3]3
:
03
;
169
6
  ccv_nnc_tensor_view_t* const saved_inv_std = (ccv_nnc_tensor_view_t*)inputs[elementwise_affine ? 
53
:
43
];
170
6
  ccv_nnc_tensor_view_t* const h = (ccv_nnc_tensor_view_t*)outputs[0];
171
6
  ccv_nnc_tensor_view_t* const dscale = output_size > 1 ? 
(ccv_nnc_tensor_view_t*)outputs[1]3
:
03
;
172
6
  assert(ccv_nnc_tensor_nd(g->info.dim) <= CCV_NNC_MAX_DIM + 2);
173
6
  assert(ccv_nnc_tensor_nd(a->info.dim) <= CCV_NNC_MAX_DIM + 2);
174
6
  assert(ccv_nnc_tensor_nd(h->info.dim) <= CCV_NNC_MAX_DIM + 2);
175
  // Assuming this is float 32.
176
6
  int gdim[CCV_NNC_MAX_DIM_ALLOC];
177
6
  int rdim[CCV_NNC_MAX_DIM_ALLOC];
178
6
  ccv_nnc_tensor_view_get_dim(g, gdim);
179
6
  ccv_nnc_tensor_view_get_dim(saved_inv_std, rdim);
180
6
  int sdim[CCV_NNC_MAX_DIM_ALLOC];
181
6
  if (scale)
182
3
    ccv_nnc_tensor_view_get_dim(scale, sdim);
183
6
  if (dscale)
184
2
    { assert(ccv_nnc_tensor_view_check_dim(dscale, sdim)); }
185
6
  assert(ccv_nnc_tensor_view_check_dim(a, gdim));
186
6
  assert(ccv_nnc_tensor_view_check_dim(h, gdim));
187
6
  int astride[CCV_NNC_MAX_DIM_ALLOC];
188
6
  int gstride[CCV_NNC_MAX_DIM_ALLOC];
189
6
  int hstride[CCV_NNC_MAX_DIM_ALLOC];
190
6
  int scale_stride[CCV_NNC_MAX_DIM_ALLOC];
191
6
  int inv_std_stride[CCV_NNC_MAX_DIM_ALLOC];
192
6
  int dscale_stride[CCV_NNC_MAX_DIM_ALLOC];
193
6
  ccv_nnc_tensor_view_get_stride(a, astride);
194
6
  ccv_nnc_tensor_view_get_stride(g, gstride);
195
6
  ccv_nnc_tensor_view_get_stride(h, hstride);
196
6
  if (scale)
197
3
    ccv_nnc_tensor_view_get_stride(scale, scale_stride);
198
6
  ccv_nnc_tensor_view_get_stride(saved_inv_std, inv_std_stride);
199
6
  if (dscale)
200
2
    ccv_nnc_tensor_view_get_stride(dscale, dscale_stride);
201
  // Need to allocate two additional memory:
202
  // 1. normalized a;
203
  // 2. scale * inv_std / n;
204
6
  assert(!(flags & CCV_NNC_ZERO_MEMORY_ALLOC));
205
6
  int x;
206
6
  int n = 1;
207
30
  for (x = 0; x < CCV_NNC_MAX_DIM + 2; 
x++24
)
208
24
    n *= gdim[x];
209
30
  for (x = 0; x < CCV_NNC_MAX_DIM + 2; 
x++24
)
210
24
    n /= rdim[x];
211
6
  int gcount = 1, rcount = 1;
212
30
  for (x = 0; x < CCV_NNC_MAX_DIM + 2; 
x++24
)
213
24
    gcount *= gdim[x], rcount *= rdim[x];
214
6
  float* const ah = (float*)ccv_nnc_stream_context_get_workspace(stream_context, sizeof(float) * gcount * 2 + sizeof(float) * rcount, CCV_TENSOR_CPU_MEMORY);
215
6
  float* const gss = ah + gcount; // g * scale * inv_std
216
6
  float* const ahgssr = gss + gcount; // ah * gss then reduced to inv_std dimension.
217
6
  int i[CCV_NNC_MAX_DIM + 2];
218
6
  float* ahp = ah;
219
6
  const float* const inv_stdp = saved_inv_std->data.f32;
220
6
  const float* const ap = a->data.f32;
221
30
  for (i[0] = 0; i[0] < gdim[0]; 
i[0]++24
)
222
24
  {
223
24
    const float* const ap0 = ap + i[0] * astride[0];
224
24
    const float* const inv_stdp0 = rdim[0] == 1 ? 
inv_stdp0
: inv_stdp + i[0] * inv_std_stride[0];
225
104
    for (i[1] = 0; i[1] < gdim[1]; 
i[1]++80
)
226
80
    {
227
80
      const float* ap1 = ap0 + i[1] * astride[1];
228
80
      const float* const inv_stdp1 = rdim[1] == 1 ? inv_stdp0 : 
inv_stdp0 + i[1] * inv_std_stride[1]0
;
229
368
      for (i[2] = 0; i[2] < gdim[2]; 
i[2]++288
)
230
288
      {
231
288
        const float* const inv_stdp2 = rdim[2] == 1 ? inv_stdp1 : 
inv_stdp1 + i[2] * inv_std_stride[2]0
;
232
288
        if (rdim[3] == 1)
233
3.16k
          
for (x = 0; 288
x < gdim[3];
x++2.88k
)
234
2.88k
            ahp[x] = ap1[x] * inv_stdp2[0];
235
0
        else
236
0
          for (x = 0; x < gdim[3]; x++)
237
0
            ahp[x] = ap1[x] * inv_stdp2[x];
238
288
        ap1 += astride[2];
239
288
        ahp += gdim[3];
240
288
      }
241
80
    }
242
24
  }
243
6
  if (dscale)
244
2
  {
245
2
    ccv_nnc_tensor_zero(dscale);
246
2
    ahp = ah;
247
2
    float* gssp = gss;
248
2
    const float* const gp = g->data.f32;
249
2
    const float* const scalep = scale->data.f32;
250
2
    float* const dscalep = dscale->data.f32;
251
12
    for (i[0] = 0; i[0] < gdim[0]; 
i[0]++10
)
252
10
    {
253
10
      const float* const gp0 = gp + i[0] * gstride[0];
254
10
      const float* const inv_stdp0 = rdim[0] == 1 ? 
inv_stdp0
: inv_stdp + i[0] * inv_std_stride[0];
255
10
      const float* const scalep0 = sdim[0] == 1 ? scalep : 
scalep + i[0] * scale_stride[0]0
;
256
10
      float* const dscalep0 = sdim[0] == 1 ? dscalep : 
dscalep + i[0] * dscale_stride[0]0
;
257
46
      for (i[1] = 0; i[1] < gdim[1]; 
i[1]++36
)
258
36
      {
259
36
        const float* gp1 = gp0 + i[1] * gstride[1];
260
36
        const float* const inv_stdp1 = rdim[1] == 1 ? inv_stdp0 : 
inv_stdp0 + i[1] * inv_std_stride[1]0
;
261
36
        const float* const scalep1 = sdim[1] == 1 ? 
scalep00
: scalep0 + i[1] * scale_stride[1];
262
36
        float* const dscalep1 = sdim[1] == 1 ? 
dscalep00
: dscalep0 + i[1] * dscale_stride[1];
263
172
        for (i[2] = 0; i[2] < gdim[2]; 
i[2]++136
)
264
136
        {
265
136
          const float* const inv_stdp2 = rdim[2] == 1 ? inv_stdp1 : 
inv_stdp1 + i[2] * inv_std_stride[2]0
;
266
136
          const float* const scalep2 = sdim[2] == 1 ? 
scalep10
: scalep1 + i[2] * scale_stride[2];
267
136
          float* const dscalep2 = sdim[2] == 1 ? 
dscalep10
: dscalep1 + i[2] * dscale_stride[2];
268
136
          if (sdim[3] == 1)
269
0
            for (x = 0; x < gdim[3]; x++)
270
0
            {
271
0
              gssp[x] = gp1[x] * scalep2[0] * inv_stdp2[rdim[3] == 1 ? 0 : x];
272
0
              dscalep2[0] += ahp[x] * gp1[x];
273
0
            }
274
136
          else
275
1.49k
            
for (x = 0; 136
x < gdim[3];
x++1.36k
)
276
1.36k
            {
277
1.36k
              gssp[x] = gp1[x] * scalep2[x] * inv_stdp2[rdim[3] == 1 ? 0 : 
x0
];
278
1.36k
              dscalep2[x] += ahp[x] * gp1[x];
279
1.36k
            }
280
136
          gp1 += gstride[2];
281
136
          ahp += gdim[3];
282
136
          gssp += gdim[3];
283
136
        }
284
36
      }
285
10
    }
286
4
  } else {
287
4
    float* gssp = gss;
288
4
    const float* const gp = g->data.f32;
289
4
    if (elementwise_affine)
290
1
    {
291
1
      const float* const scalep = scale->data.f32;
292
3
      for (i[0] = 0; i[0] < gdim[0]; 
i[0]++2
)
293
2
      {
294
2
        const float* const gp0 = gp + i[0] * gstride[0];
295
2
        const float* const inv_stdp0 = rdim[0] == 1 ? 
inv_stdp0
: inv_stdp + i[0] * inv_std_stride[0];
296
2
        const float* const scalep0 = sdim[0] == 1 ? scalep : 
scalep + i[0] * scale_stride[0]0
;
297
6
        for (i[1] = 0; i[1] < gdim[1]; 
i[1]++4
)
298
4
        {
299
4
          const float* gp1 = gp0 + i[1] * gstride[1];
300
4
          const float* const inv_stdp1 = rdim[1] == 1 ? inv_stdp0 : 
inv_stdp0 + i[1] * inv_std_stride[1]0
;
301
4
          const float* const scalep1 = sdim[1] == 1 ? 
scalep00
: scalep0 + i[1] * scale_stride[1];
302
12
          for (i[2] = 0; i[2] < gdim[2]; 
i[2]++8
)
303
8
          {
304
8
            const float* const inv_stdp2 = rdim[2] == 1 ? inv_stdp1 : 
inv_stdp1 + i[2] * inv_std_stride[2]0
;
305
8
            const float* const scalep2 = sdim[2] == 1 ? 
scalep10
: scalep1 + i[2] * scale_stride[2];
306
8
            if (sdim[3] == 1)
307
0
              for (x = 0; x < gdim[3]; x++)
308
0
                gssp[x] = gp1[x] * scalep2[0] * inv_stdp2[rdim[3] == 1 ? 0 : x];
309
8
            else
310
88
              
for (x = 0; 8
x < gdim[3];
x++80
)
311
80
                gssp[x] = gp1[x] * scalep2[x] * inv_stdp2[rdim[3] == 1 ? 0 : 
x0
];
312
8
            gp1 += gstride[2];
313
8
            gssp += gdim[3];
314
8
          }
315
4
        }
316
2
      }
317
3
    } else {
318
15
      for (i[0] = 0; i[0] < gdim[0]; 
i[0]++12
)
319
12
      {
320
12
        const float* const gp0 = gp + i[0] * gstride[0];
321
12
        const float* const inv_stdp0 = rdim[0] == 1 ? 
inv_stdp0
: inv_stdp + i[0] * inv_std_stride[0];
322
52
        for (i[1] = 0; i[1] < gdim[1]; 
i[1]++40
)
323
40
        {
324
40
          const float* gp1 = gp0 + i[1] * gstride[1];
325
40
          const float* const inv_stdp1 = rdim[1] == 1 ? inv_stdp0 : 
inv_stdp0 + i[1] * inv_std_stride[1]0
;
326
184
          for (i[2] = 0; i[2] < gdim[2]; 
i[2]++144
)
327
144
          {
328
144
            const float* const inv_stdp2 = rdim[2] == 1 ? inv_stdp1 : 
inv_stdp1 + i[2] * inv_std_stride[2]0
;
329
144
            if (rdim[3] == 1)
330
1.58k
              
for (x = 0; 144
x < gdim[3];
x++1.44k
)
331
1.44k
                gssp[x] = gp1[x] * inv_stdp2[0];
332
0
            else
333
0
              for (x = 0; x < gdim[3]; x++)
334
0
                gssp[x] = gp1[x] * inv_stdp2[x];
335
144
            gp1 += gstride[2];
336
144
            gssp += gdim[3];
337
144
          }
338
40
        }
339
12
      }
340
3
    }
341
4
  }
342
6
  ahp = ah;
343
6
  float* gssp = gss;
344
6
  ccv_nnc_tensor_t ahgssrt = ccv_nnc_tensor(ahgssr, saved_inv_std->info, 0);
345
6
  ccv_nnc_tensor_zero(&ahgssrt);
346
6
  float* const ahgssrp = ahgssr;
347
30
  for (i[0] = 0; i[0] < gdim[0]; 
i[0]++24
)
348
24
  {
349
24
    float* const ahgssrp0 = rdim[0] == 1 ? 
ahgssrp0
: ahgssrp + i[0] * rdim[1] * rdim[2] * rdim[3];
350
104
    for (i[1] = 0; i[1] < gdim[1]; 
i[1]++80
)
351
80
    {
352
80
      float* const ahgssrp1 = rdim[1] == 1 ? ahgssrp0 : 
ahgssrp0 + i[1] * rdim[2] * rdim[3]0
;
353
368
      for (i[2] = 0; i[2] < gdim[2]; 
i[2]++288
)
354
288
      {
355
288
        float* const ahgssrp2 = rdim[2] == 1 ? ahgssrp1 : 
ahgssrp1 + i[2] * rdim[3]0
;
356
288
        if (rdim[3] == 1)
357
3.16k
          
for (x = 0; 288
x < gdim[3];
x++2.88k
)
358
2.88k
            ahgssrp2[0] += ahp[x] * gssp[x];
359
0
        else
360
0
          for (x = 0; x < gdim[3]; x++)
361
0
            ahgssrp2[x] += ahp[x] * gssp[x];
362
288
        ahp += gdim[3];
363
288
        gssp += gdim[3];
364
288
      }
365
80
    }
366
24
  }
367
  // Now the part to compute dx (h).
368
6
  float* const hp = h->data.f32;
369
6
  ahp = ah;
370
6
  const float inv_n = 1. / n;
371
6
  gssp = gss;
372
30
  for (i[0] = 0; i[0] < gdim[0]; 
i[0]++24
)
373
24
  {
374
24
    float* const hp0 = hp + i[0] * hstride[0];
375
24
    const float* const ahgssrp0 = rdim[0] == 1 ? 
ahgssrp0
: ahgssrp + i[0] * rdim[1] * rdim[2] * rdim[3];
376
104
    for (i[1] = 0; i[1] < gdim[1]; 
i[1]++80
)
377
80
    {
378
80
      float* hp1 = hp0 + i[1] * hstride[1];
379
80
      const float* const ahgssrp1 = rdim[1] == 1 ? ahgssrp0 : 
ahgssrp0 + i[1] * rdim[2] * rdim[3]0
;
380
368
      for (i[2] = 0; i[2] < gdim[2]; 
i[2]++288
)
381
288
      {
382
288
        const float* const ahgssrp2 = rdim[2] == 1 ? ahgssrp1 : 
ahgssrp1 + i[2] * rdim[3]0
;
383
288
        if (rdim[3] == 1)
384
3.16k
          
for (x = 0; 288
x < gdim[3];
x++2.88k
)
385
2.88k
            hp1[x] = gssp[x] - inv_n * ahp[x] * ahgssrp2[0];
386
0
        else
387
0
          for (x = 0; x < gdim[3]; x++)
388
0
            hp1[x] = gssp[x] - inv_n * ahp[x] * ahgssrp2[x];
389
288
        hp1 += hstride[2];
390
288
        ahp += gdim[3];
391
288
        gssp += gdim[3];
392
288
      }
393
80
    }
394
24
  }
395
6
  return CCV_NNC_EXEC_SUCCESS;
396
6
}
397
398
REGISTER_COMMAND_BACKEND(CCV_NNC_RMSNORM_FORWARD, CCV_NNC_BACKEND_CPU_REF)(ccv_nnc_cmd_backend_registry_t* const registry)
399
1
{
400
1
  registry->tensor_formats = CCV_TENSOR_FORMAT_NHWC | CCV_TENSOR_FORMAT_NCHW | CCV_TENSOR_FORMAT_CHWN;
401
1
  registry->tensor_datatypes = CCV_32F;
402
1
  registry->tensor_memory = CCV_TENSOR_CPU_MEMORY;
403
1
  registry->algorithms = 1;
404
1
  registry->exec = _ccv_nnc_rmsnorm_forw;
405
1
}
406
407
REGISTER_COMMAND_BACKEND(CCV_NNC_RMSNORM_BACKWARD, CCV_NNC_BACKEND_CPU_REF)(ccv_nnc_cmd_backend_registry_t* const registry)
408
1
{
409
1
  registry->tensor_formats = CCV_TENSOR_FORMAT_NHWC | CCV_TENSOR_FORMAT_NCHW | CCV_TENSOR_FORMAT_CHWN;
410
1
  registry->tensor_datatypes = CCV_32F;
411
1
  registry->tensor_memory = CCV_TENSOR_CPU_MEMORY;
412
1
  registry->algorithms = 1;
413
1
  registry->exec = _ccv_nnc_rmsnorm_back;
414
1
}