Coverage Report

Created: 2025-02-24 17:43

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/home/liu/actions-runner/_work/ccv/ccv/lib/nnc/cmd/lamb/ccv_nnc_lamb_cpu_ref.c
Line
Count
Source
1
#include "ccv.h"
2
#include "ccv_internal.h"
3
#include "nnc/ccv_nnc.h"
4
#include "nnc/ccv_nnc_easy.h"
5
#include "nnc/ccv_nnc_internal.h"
6
#ifdef USE_OPENMP
7
#include <omp.h>
8
#endif
9
#ifdef USE_DISPATCH
10
#include <dispatch/dispatch.h>
11
#endif
12
13
// Shared methods.
14
#include "../_ccv_nnc_cpu_ref.h"
15
16
static int _ccv_nnc_lamb_forw(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint, const int flags, ccv_nnc_tensor_t* const* const inputs, const int input_size, ccv_nnc_tensor_t* const* const outputs, const int output_size, ccv_nnc_stream_context_t* const stream_context)
17
3
{
18
3
  assert(input_size == 4);
19
3
  assert(output_size == 3);
20
3
  ccv_nnc_tensor_view_t* const g = (ccv_nnc_tensor_view_t*)inputs[0];
21
3
  ccv_nnc_tensor_view_t* const a = (ccv_nnc_tensor_view_t*)inputs[1];
22
3
  ccv_nnc_tensor_view_t* const m = (ccv_nnc_tensor_view_t*)inputs[2];
23
3
  ccv_nnc_tensor_view_t* const v = (ccv_nnc_tensor_view_t*)inputs[3];
24
3
  ccv_nnc_tensor_view_t* const b = (ccv_nnc_tensor_view_t*)outputs[0];
25
3
  ccv_nnc_tensor_view_t* const n = (ccv_nnc_tensor_view_t*)outputs[1];
26
3
  ccv_nnc_tensor_view_t* const u = (ccv_nnc_tensor_view_t*)outputs[2];
27
3
  assert(ccv_nnc_tensor_nd(a->info.dim) <= CCV_NNC_MAX_DIM + 2);
28
3
  assert(ccv_nnc_tensor_nd(b->info.dim) <= CCV_NNC_MAX_DIM + 2);
29
  // Assuming this is float 32.
30
3
  int adim[CCV_NNC_MAX_DIM_ALLOC];
31
3
  ccv_nnc_tensor_view_get_dim(a, adim);
32
3
  assert(ccv_nnc_tensor_view_check_dim(g, adim));
33
3
  assert(ccv_nnc_tensor_view_check_dim(m, adim));
34
3
  assert(ccv_nnc_tensor_view_check_dim(v, adim));
35
3
  assert(ccv_nnc_tensor_view_check_dim(b, adim));
36
3
  assert(ccv_nnc_tensor_view_check_dim(n, adim));
37
3
  assert(ccv_nnc_tensor_view_check_dim(u, adim));
38
3
  assert(CCV_NNC_MAX_DIM == 2); // Need to change this logic for CCV_NNC_MAX_DIM == other number.
39
3
  int gstride[CCV_NNC_MAX_DIM_ALLOC];
40
3
  int astride[CCV_NNC_MAX_DIM_ALLOC];
41
3
  int mstride[CCV_NNC_MAX_DIM_ALLOC];
42
3
  int vstride[CCV_NNC_MAX_DIM_ALLOC];
43
3
  int bstride[CCV_NNC_MAX_DIM_ALLOC];
44
3
  int nstride[CCV_NNC_MAX_DIM_ALLOC];
45
3
  int ustride[CCV_NNC_MAX_DIM_ALLOC];
46
3
  ccv_nnc_tensor_view_get_stride(g, gstride);
47
3
  ccv_nnc_tensor_view_get_stride(a, astride);
48
3
  ccv_nnc_tensor_view_get_stride(m, mstride);
49
3
  ccv_nnc_tensor_view_get_stride(v, vstride);
50
3
  ccv_nnc_tensor_view_get_stride(b, bstride);
51
3
  ccv_nnc_tensor_view_get_stride(n, nstride);
52
3
  ccv_nnc_tensor_view_get_stride(u, ustride);
53
3
  const int step = cmd.info.lamb.step;
54
3
  const float rate = cmd.info.lamb.rate;
55
3
  const float scale = cmd.info.lamb.scale;
56
3
  const float beta1 = cmd.info.lamb.beta1;
57
3
  const float beta2 = cmd.info.lamb.beta2;
58
3
  const float decay = cmd.info.lamb.decay;
59
3
  const float epsilon = cmd.info.lamb.epsilon;
60
3
  assert(step >= 1);
61
3
  const float inv_bias_correction1 = 1. / (1 - powf(beta1, step));
62
3
  const float inv_bias_correction2 = 1. / (1 - powf(beta2, step));
63
3
  int i[CCV_NNC_MAX_DIM + 1];
64
3
  int x;
65
3
  float* const gp = g->data.f32;
66
3
  float* const ap = a->data.f32;
67
3
  float* const mp = m->data.f32;
68
3
  float* const vp = v->data.f32;
69
3
  float* const bp = b->data.f32;
70
3
  float* const np = n->data.f32;
71
3
  float* const up = u->data.f32;
72
3
  float* const update = (float*)ccv_nnc_stream_context_get_workspace(stream_context, sizeof(float) * adim[0] * adim[1] * adim[2] * adim[3], CCV_TENSOR_CPU_MEMORY);
73
3
  float* updatep = update;
74
3
  double update_norm = 0;
75
3
  double w_norm = 0;
76
6
  for (i[0] = 0; i[0] < adim[0]; 
i[0]++3
)
77
3
  {
78
3
    float* const gp0 = gp + i[0] * gstride[0];
79
3
    float* const ap0 = ap + i[0] * astride[0];
80
3
    float* const mp0 = mp + i[0] * mstride[0];
81
3
    float* const vp0 = vp + i[0] * vstride[0];
82
3
    float* const np0 = np + i[0] * nstride[0];
83
3
    float* const up0 = up + i[0] * ustride[0];
84
6
    for (i[1] = 0; i[1] < adim[1]; 
i[1]++3
)
85
3
    {
86
3
      float* gp1 = gp0 + i[1] * gstride[1];
87
3
      float* ap1 = ap0 + i[1] * astride[1];
88
3
      float* mp1 = mp0 + i[1] * mstride[1];
89
3
      float* vp1 = vp0 + i[1] * vstride[1];
90
3
      float* np1 = np0 + i[1] * nstride[1];
91
3
      float* up1 = up0 + i[1] * ustride[1];
92
6
      for (i[2] = 0; i[2] < adim[2]; 
i[2]++3
)
93
3
      {
94
33
        for (x = 0; x < adim[3]; 
x++30
)
95
30
        {
96
30
          const float grad = scale * gp1[x];
97
30
          const float w = ap1[x];
98
30
          const float mom = np1[x] = beta1 * mp1[x] + (1 - beta1) * grad;
99
30
          const float vel = up1[x] = beta2 * vp1[x] + (1 - beta2) * grad * grad;
100
30
          const float update = updatep[x] = (mom * inv_bias_correction1) / (sqrtf(vel * inv_bias_correction2) + epsilon) + w * decay;
101
30
          w_norm += w * w;
102
30
          update_norm += update * update;
103
30
        }
104
3
        gp1 += gstride[2];
105
3
        ap1 += astride[2];
106
3
        mp1 += mstride[2];
107
3
        vp1 += vstride[2];
108
3
        np1 += nstride[2];
109
3
        up1 += ustride[2];
110
3
        updatep += adim[3];
111
3
      }
112
3
    }
113
3
  }
114
3
  w_norm = sqrt(w_norm);
115
3
  update_norm = sqrt(update_norm);
116
3
  const float trust_ratio = w_norm > 0 && update_norm > 0 ? w_norm / update_norm : 
1.0
;
117
3
  const float rate_trust_ratio = rate * trust_ratio;
118
3
  updatep = update;
119
6
  for (i[0] = 0; i[0] < adim[0]; 
i[0]++3
)
120
3
  {
121
3
    float* const ap0 = ap + i[0] * astride[0];
122
3
    float* const bp0 = bp + i[0] * bstride[0];
123
6
    for (i[1] = 0; i[1] < adim[1]; 
i[1]++3
)
124
3
    {
125
3
      float* ap1 = ap0 + i[1] * astride[1];
126
3
      float* bp1 = bp0 + i[1] * bstride[1];
127
6
      for (i[2] = 0; i[2] < adim[2]; 
i[2]++3
)
128
3
      {
129
33
        for (x = 0; x < adim[3]; 
x++30
)
130
30
          bp1[x] = ap1[x] - rate_trust_ratio * updatep[x];
131
3
        ap1 += astride[2];
132
3
        bp1 += bstride[2];
133
3
        updatep += adim[3];
134
3
      }
135
3
    }
136
3
  }
137
3
  return CCV_NNC_EXEC_SUCCESS;
138
3
}
139
140
static int _ccv_nnc_lamb_back(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint, const int flags, ccv_nnc_tensor_t* const* const inputs, const int input_size, ccv_nnc_tensor_t* const* const outputs, const int output_size, ccv_nnc_stream_context_t* const stream_context)
141
0
{
142
0
  return CCV_NNC_EXEC_INVALID;
143
0
}
144
145
REGISTER_COMMAND_BACKEND(CCV_NNC_LAMB_FORWARD, CCV_NNC_BACKEND_CPU_REF)(ccv_nnc_cmd_backend_registry_t* const registry)
146
1
{
147
1
  registry->tensor_formats = CCV_TENSOR_FORMAT_NHWC | CCV_TENSOR_FORMAT_NCHW | CCV_TENSOR_FORMAT_CHWN;
148
1
  registry->tensor_datatypes = CCV_32F;
149
1
  registry->tensor_memory = CCV_TENSOR_CPU_MEMORY;
150
1
  registry->algorithms = 1;
151
1
  registry->exec = _ccv_nnc_lamb_forw;
152
1
}
153
154
REGISTER_COMMAND_BACKEND(CCV_NNC_LAMB_BACKWARD, CCV_NNC_BACKEND_CPU_REF)(ccv_nnc_cmd_backend_registry_t* const registry)
155
1
{
156
1
  registry->tensor_formats = CCV_TENSOR_FORMAT_NHWC | CCV_TENSOR_FORMAT_NCHW | CCV_TENSOR_FORMAT_CHWN;
157
1
  registry->tensor_datatypes = CCV_32F;
158
1
  registry->tensor_memory = CCV_TENSOR_CPU_MEMORY;
159
1
  registry->algorithms = 1;
160
1
  registry->exec = _ccv_nnc_lamb_back;
161
1
}