/home/liu/actions-runner/_work/ccv/ccv/lib/nnc/cmd/adam/ccv_nnc_adamw_cpu_ref.c
Line | Count | Source (jump to first uncovered line) |
1 | | #include "ccv.h" |
2 | | #include "ccv_internal.h" |
3 | | #include "nnc/ccv_nnc.h" |
4 | | #include "nnc/ccv_nnc_easy.h" |
5 | | #include "nnc/ccv_nnc_internal.h" |
6 | | #ifdef USE_OPENMP |
7 | | #include <omp.h> |
8 | | #endif |
9 | | #ifdef USE_DISPATCH |
10 | | #include <dispatch/dispatch.h> |
11 | | #endif |
12 | | |
13 | | // Shared methods. |
14 | | #include "../_ccv_nnc_cpu_ref.h" |
15 | | |
16 | | static int _ccv_nnc_adamw_forw(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint, const int flags, ccv_nnc_tensor_t* const* const inputs, const int input_size, ccv_nnc_tensor_t* const* const outputs, const int output_size, ccv_nnc_stream_context_t* const stream_context) |
17 | 1.00k | { |
18 | 1.00k | assert(input_size >= 4); |
19 | 1.00k | assert(output_size >= 3); |
20 | 1.00k | ccv_nnc_tensor_view_t* const g = (ccv_nnc_tensor_view_t*)inputs[0]; |
21 | 1.00k | ccv_nnc_tensor_view_t* const a = (ccv_nnc_tensor_view_t*)inputs[1]; |
22 | 1.00k | ccv_nnc_tensor_view_t* const m = (ccv_nnc_tensor_view_t*)inputs[2]; |
23 | 1.00k | ccv_nnc_tensor_view_t* const v = (ccv_nnc_tensor_view_t*)inputs[3]; |
24 | 1.00k | ccv_nnc_tensor_view_t* const vm = input_size >= 5 ? (ccv_nnc_tensor_view_t*)inputs[4]3 : 01.00k ; |
25 | 1.00k | ccv_nnc_tensor_view_t* const b = (ccv_nnc_tensor_view_t*)outputs[0]; |
26 | 1.00k | ccv_nnc_tensor_view_t* const n = (ccv_nnc_tensor_view_t*)outputs[1]; |
27 | 1.00k | ccv_nnc_tensor_view_t* const u = (ccv_nnc_tensor_view_t*)outputs[2]; |
28 | 1.00k | ccv_nnc_tensor_view_t* const um = output_size >= 4 ? (ccv_nnc_tensor_view_t*)outputs[3]3 : 01.00k ; |
29 | 1.00k | assert(ccv_nnc_tensor_nd(a->info.dim) <= CCV_NNC_MAX_DIM + 2); |
30 | 1.00k | assert(ccv_nnc_tensor_nd(b->info.dim) <= CCV_NNC_MAX_DIM + 2); |
31 | | // Assuming this is float 32. |
32 | 1.00k | int adim[CCV_NNC_MAX_DIM_ALLOC]; |
33 | 1.00k | ccv_nnc_tensor_view_get_dim(a, adim); |
34 | 1.00k | assert(ccv_nnc_tensor_view_check_dim(g, adim)); |
35 | 1.00k | assert(ccv_nnc_tensor_view_check_dim(m, adim)); |
36 | 1.00k | assert(ccv_nnc_tensor_view_check_dim(v, adim)); |
37 | 1.00k | if (vm) |
38 | 3 | { assert(ccv_nnc_tensor_view_check_dim(vm, adim)); } |
39 | 1.00k | assert(ccv_nnc_tensor_view_check_dim(b, adim)); |
40 | 1.00k | assert(ccv_nnc_tensor_view_check_dim(n, adim)); |
41 | 1.00k | assert(ccv_nnc_tensor_view_check_dim(u, adim)); |
42 | 1.00k | if (um) |
43 | 3 | { assert(ccv_nnc_tensor_view_check_dim(um, adim)); } |
44 | 1.00k | assert(CCV_NNC_MAX_DIM == 2); // Need to change this logic for CCV_NNC_MAX_DIM == other number. |
45 | 1.00k | int gstride[CCV_NNC_MAX_DIM_ALLOC]; |
46 | 1.00k | int astride[CCV_NNC_MAX_DIM_ALLOC]; |
47 | 1.00k | int mstride[CCV_NNC_MAX_DIM_ALLOC]; |
48 | 1.00k | int vstride[CCV_NNC_MAX_DIM_ALLOC]; |
49 | 1.00k | int vmstride[CCV_NNC_MAX_DIM_ALLOC]; |
50 | 1.00k | int bstride[CCV_NNC_MAX_DIM_ALLOC]; |
51 | 1.00k | int nstride[CCV_NNC_MAX_DIM_ALLOC]; |
52 | 1.00k | int ustride[CCV_NNC_MAX_DIM_ALLOC]; |
53 | 1.00k | int umstride[CCV_NNC_MAX_DIM_ALLOC]; |
54 | 1.00k | ccv_nnc_tensor_view_get_stride(g, gstride); |
55 | 1.00k | ccv_nnc_tensor_view_get_stride(a, astride); |
56 | 1.00k | ccv_nnc_tensor_view_get_stride(m, mstride); |
57 | 1.00k | ccv_nnc_tensor_view_get_stride(v, vstride); |
58 | 1.00k | if (vm) |
59 | 3 | ccv_nnc_tensor_view_get_stride(vm, vmstride); |
60 | 1.00k | ccv_nnc_tensor_view_get_stride(b, bstride); |
61 | 1.00k | ccv_nnc_tensor_view_get_stride(n, nstride); |
62 | 1.00k | ccv_nnc_tensor_view_get_stride(u, ustride); |
63 | 1.00k | if (um) |
64 | 3 | ccv_nnc_tensor_view_get_stride(um, umstride); |
65 | 1.00k | const int step = cmd.info.adam.step; |
66 | 1.00k | const float rate = cmd.info.adam.rate; |
67 | 1.00k | const float scale = cmd.info.adam.scale; |
68 | 1.00k | const float beta1 = cmd.info.adam.beta1; |
69 | 1.00k | const float beta2 = cmd.info.adam.beta2; |
70 | 1.00k | const float decay = cmd.info.adam.decay; |
71 | 1.00k | const float epsilon = cmd.info.adam.epsilon; |
72 | 1.00k | assert(step >= 1); |
73 | 1.00k | const float rate_inv_bias_correction1 = rate / (1 - powf(beta1, step)); |
74 | 1.00k | const float inv_bias_correction2 = 1. / (1 - powf(beta2, step)); |
75 | 1.00k | const float rate_decay = rate * decay; |
76 | 1.00k | int i[CCV_NNC_MAX_DIM + 1]; |
77 | 1.00k | int x; |
78 | 1.00k | float* const gp = g->data.f32; |
79 | 1.00k | float* const ap = a->data.f32; |
80 | 1.00k | float* const mp = m->data.f32; |
81 | 1.00k | float* const vp = v->data.f32; |
82 | 1.00k | float* const bp = b->data.f32; |
83 | 1.00k | float* const np = n->data.f32; |
84 | 1.00k | float* const up = u->data.f32; |
85 | 1.00k | if (cmd.info.adam.amsgrad && vm3 && um3 ) |
86 | 3 | { |
87 | 3 | float* const vmp = vm->data.f32; |
88 | 3 | float* const ump = um->data.f32; |
89 | 6 | for (i[0] = 0; i[0] < adim[0]; i[0]++3 ) |
90 | 3 | { |
91 | 3 | float* const gp0 = gp + i[0] * gstride[0]; |
92 | 3 | float* const ap0 = ap + i[0] * astride[0]; |
93 | 3 | float* const mp0 = mp + i[0] * mstride[0]; |
94 | 3 | float* const vp0 = vp + i[0] * vstride[0]; |
95 | 3 | float* const vmp0 = vmp + i[0] * vmstride[0]; |
96 | 3 | float* const bp0 = bp + i[0] * bstride[0]; |
97 | 3 | float* const np0 = np + i[0] * nstride[0]; |
98 | 3 | float* const up0 = up + i[0] * ustride[0]; |
99 | 3 | float* const ump0 = ump + i[0] * umstride[0]; |
100 | 6 | for (i[1] = 0; i[1] < adim[1]; i[1]++3 ) |
101 | 3 | { |
102 | 3 | float* gp1 = gp0 + i[1] * gstride[1]; |
103 | 3 | float* ap1 = ap0 + i[1] * astride[1]; |
104 | 3 | float* mp1 = mp0 + i[1] * mstride[1]; |
105 | 3 | float* vp1 = vp0 + i[1] * vstride[1]; |
106 | 3 | float* vmp1 = vmp0 + i[1] * vmstride[1]; |
107 | 3 | float* bp1 = bp0 + i[1] * bstride[1]; |
108 | 3 | float* np1 = np0 + i[1] * nstride[1]; |
109 | 3 | float* up1 = up0 + i[1] * ustride[1]; |
110 | 3 | float* ump1 = ump0 + i[1] * umstride[1]; |
111 | 6 | for (i[2] = 0; i[2] < adim[2]; i[2]++3 ) |
112 | 3 | { |
113 | 33 | for (x = 0; x < adim[3]; x++30 ) |
114 | 30 | { |
115 | 30 | const float grad = scale * gp1[x]; |
116 | 30 | const float mom = np1[x] = beta1 * mp1[x] + (1 - beta1) * grad; |
117 | 30 | const float vel = up1[x] = beta2 * vp1[x] + (1 - beta2) * grad * grad; |
118 | 30 | const float vel_hat = vel * inv_bias_correction2; |
119 | 30 | const float vel_max_hat = ump1[x] = ccv_max(vmp1[x], vel_hat); |
120 | 30 | bp1[x] = ap1[x] - rate_decay * ap1[x] - (mom * rate_inv_bias_correction1) / (sqrtf(vel_max_hat) + epsilon); |
121 | 30 | } |
122 | 3 | gp1 += gstride[2]; |
123 | 3 | ap1 += astride[2]; |
124 | 3 | mp1 += mstride[2]; |
125 | 3 | vp1 += vstride[2]; |
126 | 3 | vmp1 += vmstride[2]; |
127 | 3 | bp1 += bstride[2]; |
128 | 3 | np1 += nstride[2]; |
129 | 3 | up1 += ustride[2]; |
130 | 3 | ump1 += umstride[2]; |
131 | 3 | } |
132 | 3 | } |
133 | 3 | } |
134 | 1.00k | } else { |
135 | 2.00k | for (i[0] = 0; i[0] < adim[0]; i[0]++1.00k ) |
136 | 1.00k | { |
137 | 1.00k | float* const gp0 = gp + i[0] * gstride[0]; |
138 | 1.00k | float* const ap0 = ap + i[0] * astride[0]; |
139 | 1.00k | float* const mp0 = mp + i[0] * mstride[0]; |
140 | 1.00k | float* const vp0 = vp + i[0] * vstride[0]; |
141 | 1.00k | float* const bp0 = bp + i[0] * bstride[0]; |
142 | 1.00k | float* const np0 = np + i[0] * nstride[0]; |
143 | 1.00k | float* const up0 = up + i[0] * ustride[0]; |
144 | 2.00k | for (i[1] = 0; i[1] < adim[1]; i[1]++1.00k ) |
145 | 1.00k | { |
146 | 1.00k | float* gp1 = gp0 + i[1] * gstride[1]; |
147 | 1.00k | float* ap1 = ap0 + i[1] * astride[1]; |
148 | 1.00k | float* mp1 = mp0 + i[1] * mstride[1]; |
149 | 1.00k | float* vp1 = vp0 + i[1] * vstride[1]; |
150 | 1.00k | float* bp1 = bp0 + i[1] * bstride[1]; |
151 | 1.00k | float* np1 = np0 + i[1] * nstride[1]; |
152 | 1.00k | float* up1 = up0 + i[1] * ustride[1]; |
153 | 3.00k | for (i[2] = 0; i[2] < adim[2]; i[2]++2.00k ) |
154 | 2.00k | { |
155 | 6.03k | for (x = 0; x < adim[3]; x++4.03k ) |
156 | 4.03k | { |
157 | 4.03k | const float grad = scale * gp1[x]; |
158 | 4.03k | const float mom = np1[x] = beta1 * mp1[x] + (1 - beta1) * grad; |
159 | 4.03k | const float vel = up1[x] = beta2 * vp1[x] + (1 - beta2) * grad * grad; |
160 | 4.03k | bp1[x] = ap1[x] - rate_decay * ap1[x] - (mom * rate_inv_bias_correction1) / (sqrtf(vel * inv_bias_correction2) + epsilon); |
161 | 4.03k | } |
162 | 2.00k | gp1 += gstride[2]; |
163 | 2.00k | ap1 += astride[2]; |
164 | 2.00k | mp1 += mstride[2]; |
165 | 2.00k | vp1 += vstride[2]; |
166 | 2.00k | bp1 += bstride[2]; |
167 | 2.00k | np1 += nstride[2]; |
168 | 2.00k | up1 += ustride[2]; |
169 | 2.00k | } |
170 | 1.00k | } |
171 | 1.00k | } |
172 | 1.00k | } |
173 | 1.00k | return CCV_NNC_EXEC_SUCCESS; |
174 | 1.00k | } |
175 | | |
176 | | static int _ccv_nnc_adamw_back(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint, const int flags, ccv_nnc_tensor_t* const* const inputs, const int input_size, ccv_nnc_tensor_t* const* const outputs, const int output_size, ccv_nnc_stream_context_t* const stream_context) |
177 | 0 | { |
178 | 0 | return CCV_NNC_EXEC_INVALID; |
179 | 0 | } |
180 | | |
181 | | REGISTER_COMMAND_BACKEND(CCV_NNC_ADAMW_FORWARD, CCV_NNC_BACKEND_CPU_REF)(ccv_nnc_cmd_backend_registry_t* const registry) |
182 | 1 | { |
183 | 1 | registry->tensor_formats = CCV_TENSOR_FORMAT_NHWC | CCV_TENSOR_FORMAT_NCHW | CCV_TENSOR_FORMAT_CHWN; |
184 | 1 | registry->tensor_datatypes = CCV_32F; |
185 | 1 | registry->tensor_memory = CCV_TENSOR_CPU_MEMORY; |
186 | 1 | registry->algorithms = 1; |
187 | 1 | registry->exec = _ccv_nnc_adamw_forw; |
188 | 1 | } |
189 | | |
190 | | REGISTER_COMMAND_BACKEND(CCV_NNC_ADAMW_BACKWARD, CCV_NNC_BACKEND_CPU_REF)(ccv_nnc_cmd_backend_registry_t* const registry) |
191 | 1 | { |
192 | 1 | registry->tensor_formats = CCV_TENSOR_FORMAT_NHWC | CCV_TENSOR_FORMAT_NCHW | CCV_TENSOR_FORMAT_CHWN; |
193 | 1 | registry->tensor_datatypes = CCV_32F; |
194 | 1 | registry->tensor_memory = CCV_TENSOR_CPU_MEMORY; |
195 | 1 | registry->algorithms = 1; |
196 | 1 | registry->exec = _ccv_nnc_adamw_back; |
197 | 1 | } |