Coverage Report

Created: 2024-08-18 16:21

/home/liu/actions-runner/_work/ccv/ccv/lib/nnc/cmd/adam/ccv_nnc_adamw_cpu_ref.c
Line
Count
Source (jump to first uncovered line)
1
#include "ccv.h"
2
#include "ccv_internal.h"
3
#include "nnc/ccv_nnc.h"
4
#include "nnc/ccv_nnc_easy.h"
5
#include "nnc/ccv_nnc_internal.h"
6
#ifdef USE_OPENMP
7
#include <omp.h>
8
#endif
9
#ifdef USE_DISPATCH
10
#include <dispatch/dispatch.h>
11
#endif
12
13
// Shared methods.
14
#include "../_ccv_nnc_cpu_ref.h"
15
16
static int _ccv_nnc_adamw_forw(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint, const int flags, ccv_nnc_tensor_t* const* const inputs, const int input_size, ccv_nnc_tensor_t* const* const outputs, const int output_size, ccv_nnc_stream_context_t* const stream_context)
17
1.00k
{
18
1.00k
  assert(input_size >= 4);
19
1.00k
  assert(output_size >= 3);
20
1.00k
  ccv_nnc_tensor_view_t* const g = (ccv_nnc_tensor_view_t*)inputs[0];
21
1.00k
  ccv_nnc_tensor_view_t* const a = (ccv_nnc_tensor_view_t*)inputs[1];
22
1.00k
  ccv_nnc_tensor_view_t* const m = (ccv_nnc_tensor_view_t*)inputs[2];
23
1.00k
  ccv_nnc_tensor_view_t* const v = (ccv_nnc_tensor_view_t*)inputs[3];
24
1.00k
  ccv_nnc_tensor_view_t* const vm = input_size >= 5 ? 
(ccv_nnc_tensor_view_t*)inputs[4]3
:
01.00k
;
25
1.00k
  ccv_nnc_tensor_view_t* const b = (ccv_nnc_tensor_view_t*)outputs[0];
26
1.00k
  ccv_nnc_tensor_view_t* const n = (ccv_nnc_tensor_view_t*)outputs[1];
27
1.00k
  ccv_nnc_tensor_view_t* const u = (ccv_nnc_tensor_view_t*)outputs[2];
28
1.00k
  ccv_nnc_tensor_view_t* const um = output_size >= 4 ? 
(ccv_nnc_tensor_view_t*)outputs[3]3
:
01.00k
;
29
1.00k
  assert(ccv_nnc_tensor_nd(a->info.dim) <= CCV_NNC_MAX_DIM + 2);
30
1.00k
  assert(ccv_nnc_tensor_nd(b->info.dim) <= CCV_NNC_MAX_DIM + 2);
31
  // Assuming this is float 32.
32
1.00k
  int adim[CCV_NNC_MAX_DIM_ALLOC];
33
1.00k
  ccv_nnc_tensor_view_get_dim(a, adim);
34
1.00k
  assert(ccv_nnc_tensor_view_check_dim(g, adim));
35
1.00k
  assert(ccv_nnc_tensor_view_check_dim(m, adim));
36
1.00k
  assert(ccv_nnc_tensor_view_check_dim(v, adim));
37
1.00k
  if (vm)
38
3
    { assert(ccv_nnc_tensor_view_check_dim(vm, adim)); }
39
1.00k
  assert(ccv_nnc_tensor_view_check_dim(b, adim));
40
1.00k
  assert(ccv_nnc_tensor_view_check_dim(n, adim));
41
1.00k
  assert(ccv_nnc_tensor_view_check_dim(u, adim));
42
1.00k
  if (um)
43
3
    { assert(ccv_nnc_tensor_view_check_dim(um, adim)); }
44
1.00k
  assert(CCV_NNC_MAX_DIM == 2); // Need to change this logic for CCV_NNC_MAX_DIM == other number.
45
1.00k
  int gstride[CCV_NNC_MAX_DIM_ALLOC];
46
1.00k
  int astride[CCV_NNC_MAX_DIM_ALLOC];
47
1.00k
  int mstride[CCV_NNC_MAX_DIM_ALLOC];
48
1.00k
  int vstride[CCV_NNC_MAX_DIM_ALLOC];
49
1.00k
  int vmstride[CCV_NNC_MAX_DIM_ALLOC];
50
1.00k
  int bstride[CCV_NNC_MAX_DIM_ALLOC];
51
1.00k
  int nstride[CCV_NNC_MAX_DIM_ALLOC];
52
1.00k
  int ustride[CCV_NNC_MAX_DIM_ALLOC];
53
1.00k
  int umstride[CCV_NNC_MAX_DIM_ALLOC];
54
1.00k
  ccv_nnc_tensor_view_get_stride(g, gstride);
55
1.00k
  ccv_nnc_tensor_view_get_stride(a, astride);
56
1.00k
  ccv_nnc_tensor_view_get_stride(m, mstride);
57
1.00k
  ccv_nnc_tensor_view_get_stride(v, vstride);
58
1.00k
  if (vm)
59
3
    ccv_nnc_tensor_view_get_stride(vm, vmstride);
60
1.00k
  ccv_nnc_tensor_view_get_stride(b, bstride);
61
1.00k
  ccv_nnc_tensor_view_get_stride(n, nstride);
62
1.00k
  ccv_nnc_tensor_view_get_stride(u, ustride);
63
1.00k
  if (um)
64
3
    ccv_nnc_tensor_view_get_stride(um, umstride);
65
1.00k
  const int step = cmd.info.adam.step;
66
1.00k
  const float rate = cmd.info.adam.rate;
67
1.00k
  const float scale = cmd.info.adam.scale;
68
1.00k
  const float beta1 = cmd.info.adam.beta1;
69
1.00k
  const float beta2 = cmd.info.adam.beta2;
70
1.00k
  const float decay = cmd.info.adam.decay;
71
1.00k
  const float epsilon = cmd.info.adam.epsilon;
72
1.00k
  assert(step >= 1);
73
1.00k
  const float rate_inv_bias_correction1 = rate / (1 - powf(beta1, step));
74
1.00k
  const float inv_bias_correction2 = 1. / (1 - powf(beta2, step));
75
1.00k
  const float rate_decay = rate * decay;
76
1.00k
  int i[CCV_NNC_MAX_DIM + 1];
77
1.00k
  int x;
78
1.00k
  float* const gp = g->data.f32;
79
1.00k
  float* const ap = a->data.f32;
80
1.00k
  float* const mp = m->data.f32;
81
1.00k
  float* const vp = v->data.f32;
82
1.00k
  float* const bp = b->data.f32;
83
1.00k
  float* const np = n->data.f32;
84
1.00k
  float* const up = u->data.f32;
85
1.00k
  if (cmd.info.adam.amsgrad && 
vm3
&&
um3
)
86
3
  {
87
3
    float* const vmp = vm->data.f32;
88
3
    float* const ump = um->data.f32;
89
6
    for (i[0] = 0; i[0] < adim[0]; 
i[0]++3
)
90
3
    {
91
3
      float* const gp0 = gp + i[0] * gstride[0];
92
3
      float* const ap0 = ap + i[0] * astride[0];
93
3
      float* const mp0 = mp + i[0] * mstride[0];
94
3
      float* const vp0 = vp + i[0] * vstride[0];
95
3
      float* const vmp0 = vmp + i[0] * vmstride[0];
96
3
      float* const bp0 = bp + i[0] * bstride[0];
97
3
      float* const np0 = np + i[0] * nstride[0];
98
3
      float* const up0 = up + i[0] * ustride[0];
99
3
      float* const ump0 = ump + i[0] * umstride[0];
100
6
      for (i[1] = 0; i[1] < adim[1]; 
i[1]++3
)
101
3
      {
102
3
        float* gp1 = gp0 + i[1] * gstride[1];
103
3
        float* ap1 = ap0 + i[1] * astride[1];
104
3
        float* mp1 = mp0 + i[1] * mstride[1];
105
3
        float* vp1 = vp0 + i[1] * vstride[1];
106
3
        float* vmp1 = vmp0 + i[1] * vmstride[1];
107
3
        float* bp1 = bp0 + i[1] * bstride[1];
108
3
        float* np1 = np0 + i[1] * nstride[1];
109
3
        float* up1 = up0 + i[1] * ustride[1];
110
3
        float* ump1 = ump0 + i[1] * umstride[1];
111
6
        for (i[2] = 0; i[2] < adim[2]; 
i[2]++3
)
112
3
        {
113
33
          for (x = 0; x < adim[3]; 
x++30
)
114
30
          {
115
30
            const float grad = scale * gp1[x];
116
30
            const float mom = np1[x] = beta1 * mp1[x] + (1 - beta1) * grad;
117
30
            const float vel = up1[x] = beta2 * vp1[x] + (1 - beta2) * grad * grad;
118
30
            const float vel_hat = vel * inv_bias_correction2;
119
30
            const float vel_max_hat = ump1[x] = ccv_max(vmp1[x], vel_hat);
120
30
            bp1[x] = ap1[x] - rate_decay * ap1[x] - (mom * rate_inv_bias_correction1) / (sqrtf(vel_max_hat) + epsilon);
121
30
          }
122
3
          gp1 += gstride[2];
123
3
          ap1 += astride[2];
124
3
          mp1 += mstride[2];
125
3
          vp1 += vstride[2];
126
3
          vmp1 += vmstride[2];
127
3
          bp1 += bstride[2];
128
3
          np1 += nstride[2];
129
3
          up1 += ustride[2];
130
3
          ump1 += umstride[2];
131
3
        }
132
3
      }
133
3
    }
134
1.00k
  } else {
135
2.00k
    for (i[0] = 0; i[0] < adim[0]; 
i[0]++1.00k
)
136
1.00k
    {
137
1.00k
      float* const gp0 = gp + i[0] * gstride[0];
138
1.00k
      float* const ap0 = ap + i[0] * astride[0];
139
1.00k
      float* const mp0 = mp + i[0] * mstride[0];
140
1.00k
      float* const vp0 = vp + i[0] * vstride[0];
141
1.00k
      float* const bp0 = bp + i[0] * bstride[0];
142
1.00k
      float* const np0 = np + i[0] * nstride[0];
143
1.00k
      float* const up0 = up + i[0] * ustride[0];
144
2.00k
      for (i[1] = 0; i[1] < adim[1]; 
i[1]++1.00k
)
145
1.00k
      {
146
1.00k
        float* gp1 = gp0 + i[1] * gstride[1];
147
1.00k
        float* ap1 = ap0 + i[1] * astride[1];
148
1.00k
        float* mp1 = mp0 + i[1] * mstride[1];
149
1.00k
        float* vp1 = vp0 + i[1] * vstride[1];
150
1.00k
        float* bp1 = bp0 + i[1] * bstride[1];
151
1.00k
        float* np1 = np0 + i[1] * nstride[1];
152
1.00k
        float* up1 = up0 + i[1] * ustride[1];
153
3.00k
        for (i[2] = 0; i[2] < adim[2]; 
i[2]++2.00k
)
154
2.00k
        {
155
6.03k
          for (x = 0; x < adim[3]; 
x++4.03k
)
156
4.03k
          {
157
4.03k
            const float grad = scale * gp1[x];
158
4.03k
            const float mom = np1[x] = beta1 * mp1[x] + (1 - beta1) * grad;
159
4.03k
            const float vel = up1[x] = beta2 * vp1[x] + (1 - beta2) * grad * grad;
160
4.03k
            bp1[x] = ap1[x] - rate_decay * ap1[x] - (mom * rate_inv_bias_correction1) / (sqrtf(vel * inv_bias_correction2) + epsilon);
161
4.03k
          }
162
2.00k
          gp1 += gstride[2];
163
2.00k
          ap1 += astride[2];
164
2.00k
          mp1 += mstride[2];
165
2.00k
          vp1 += vstride[2];
166
2.00k
          bp1 += bstride[2];
167
2.00k
          np1 += nstride[2];
168
2.00k
          up1 += ustride[2];
169
2.00k
        }
170
1.00k
      }
171
1.00k
    }
172
1.00k
  }
173
1.00k
  return CCV_NNC_EXEC_SUCCESS;
174
1.00k
}
175
176
static int _ccv_nnc_adamw_back(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint, const int flags, ccv_nnc_tensor_t* const* const inputs, const int input_size, ccv_nnc_tensor_t* const* const outputs, const int output_size, ccv_nnc_stream_context_t* const stream_context)
177
0
{
178
0
  return CCV_NNC_EXEC_INVALID;
179
0
}
180
181
REGISTER_COMMAND_BACKEND(CCV_NNC_ADAMW_FORWARD, CCV_NNC_BACKEND_CPU_REF)(ccv_nnc_cmd_backend_registry_t* const registry)
182
1
{
183
1
  registry->tensor_formats = CCV_TENSOR_FORMAT_NHWC | CCV_TENSOR_FORMAT_NCHW | CCV_TENSOR_FORMAT_CHWN;
184
1
  registry->tensor_datatypes = CCV_32F;
185
1
  registry->tensor_memory = CCV_TENSOR_CPU_MEMORY;
186
1
  registry->algorithms = 1;
187
1
  registry->exec = _ccv_nnc_adamw_forw;
188
1
}
189
190
REGISTER_COMMAND_BACKEND(CCV_NNC_ADAMW_BACKWARD, CCV_NNC_BACKEND_CPU_REF)(ccv_nnc_cmd_backend_registry_t* const registry)
191
1
{
192
1
  registry->tensor_formats = CCV_TENSOR_FORMAT_NHWC | CCV_TENSOR_FORMAT_NCHW | CCV_TENSOR_FORMAT_CHWN;
193
1
  registry->tensor_datatypes = CCV_32F;
194
1
  registry->tensor_memory = CCV_TENSOR_CPU_MEMORY;
195
1
  registry->algorithms = 1;
196
1
  registry->exec = _ccv_nnc_adamw_back;
197
1
}