/home/liu/actions-runner/_work/ccv/ccv/lib/nnc/cmd/adam/ccv_nnc_adam.c
Line | Count | Source |
1 | | #include "ccv.h" |
2 | | #include "nnc/ccv_nnc.h" |
3 | | #include "nnc/ccv_nnc_internal.h" |
4 | | |
5 | | static int _ccv_nnc_adam_forw_bitmask(const ccv_nnc_cmd_param_t cmd, const int input_size, const int output_size, const uint64_t* const input_bitmasks, const int input_bitmask_size, const uint64_t* const output_bitmasks, const int output_bitmask_size) |
6 | 24 | { |
7 | 24 | if (cmd.adam.amsgrad) |
8 | 0 | { |
9 | | // 5 inputs (gradient, x, momentum, velocity, v_max) |
10 | | // 4 outputs (y, new momentum, new velocity, new v_max) |
11 | 0 | if (input_bitmasks[0] == 31u && output_bitmasks[0] == 15u) |
12 | 0 | return 1; |
13 | 24 | } else { |
14 | | // 4 inputs (gradient, x, momentum, velocity) |
15 | | // 3 outputs (y, new momentum, new velocity) |
16 | 24 | if (input_bitmasks[0] == 15u && output_bitmasks[0] == 7u8 ) |
17 | 8 | return 1; |
18 | 24 | } |
19 | 16 | return 0; |
20 | 24 | } |
21 | | |
22 | | static int _ccv_nnc_adam_allow_inplace(const ccv_nnc_cmd_param_t cmd, const int input_idx, const int input_size, const int output_idx, const int output_size) |
23 | 2.08k | { |
24 | 2.08k | if (input_idx == output_idx + 1) |
25 | 454 | return 1; |
26 | 1.63k | return 0; |
27 | 2.08k | } |
28 | | |
29 | | static int _ccv_nnc_adam_back_bitmask(const ccv_nnc_cmd_param_t cmd, const int input_size, const int output_size, const uint64_t* const input_bitmasks, const int input_bitmask_size, const uint64_t* const output_bitmasks, const int output_bitmask_size) |
30 | 0 | { |
31 | | // Doesn't support. |
32 | 0 | return 0; |
33 | 0 | } |
34 | | |
35 | | static void _ccv_nnc_adam_tensor_auto_forw(const ccv_nnc_cmd_param_t cmd, const ccv_nnc_tensor_param_t* const inputs, const int input_size, const ccv_nnc_hint_t hint, ccv_nnc_tensor_param_t* const outputs, const int output_size) |
36 | 1.13k | { |
37 | 1.13k | int i; |
38 | 4.55k | for (i = 0; i < output_size; i++3.41k ) |
39 | 3.41k | outputs[i] = inputs[0]; |
40 | 1.13k | } |
41 | | |
42 | | static void _ccv_nnc_adam_tensor_auto_back(const ccv_nnc_cmd_param_t cmd, const ccv_nnc_tensor_param_t* const inputs, const int input_size, const ccv_nnc_hint_t hint, ccv_nnc_tensor_param_t* const outputs, const int output_size) |
43 | 0 | { |
44 | | // Doesn't support. |
45 | 0 | } |
46 | | |
47 | | REGISTER_COMMAND(CCV_NNC_ADAM_FORWARD)(ccv_nnc_cmd_registry_t* const registry) |
48 | | FIND_BACKEND(ccv_nnc_adam_cpu_ref.c, gpu/ccv_nnc_adam_gpu_ref.cu, mps/ccv_nnc_adam_mps.m) |
49 | 1 | { |
50 | 1 | registry->bitmask = _ccv_nnc_adam_forw_bitmask; |
51 | 1 | registry->tensor_auto = _ccv_nnc_adam_tensor_auto_forw; |
52 | 1 | registry->allow_inplace = _ccv_nnc_adam_allow_inplace; |
53 | 1 | } |
54 | | |
55 | | REGISTER_COMMAND(CCV_NNC_ADAM_BACKWARD)(ccv_nnc_cmd_registry_t* const registry) |
56 | | FIND_BACKEND(ccv_nnc_adam_cpu_ref.c, gpu/ccv_nnc_adam_gpu_ref.cu, mps/ccv_nnc_adam_mps.m) |
57 | 1 | { |
58 | 1 | registry->bitmask = _ccv_nnc_adam_back_bitmask; |
59 | 1 | registry->tensor_auto = _ccv_nnc_adam_tensor_auto_back; |
60 | 1 | } |
61 | | |
62 | | //@REGISTER_EASY_COMMAND_MACRO(CCV_NNC_ADAM_FORWARD) |
63 | | #define CMD_ADAM_FORWARD(_step, _rate, _beta1, _beta2, _decay, _epsilon, _amsgrad) ccv_nnc_cmd(CCV_NNC_ADAM_FORWARD, 0, ((ccv_nnc_cmd_param_t){.size={.dim={1,1,1}},.adam={.step=_step,.rate=_rate,.scale=1,.beta1=_beta1,.beta2=_beta2,.decay=_decay,.epsilon=_epsilon,.amsgrad=_amsgrad}}), 0) |
64 | | |
65 | | REGISTER_COMMAND(CCV_NNC_ADAMW_FORWARD)(ccv_nnc_cmd_registry_t* const registry) |
66 | | FIND_BACKEND(ccv_nnc_adamw_cpu_ref.c, gpu/ccv_nnc_adamw_gpu_ref.cu, mps/ccv_nnc_adamw_mps.m) |
67 | 1 | { |
68 | 1 | registry->bitmask = _ccv_nnc_adam_forw_bitmask; |
69 | 1 | registry->tensor_auto = _ccv_nnc_adam_tensor_auto_forw; |
70 | 1 | registry->allow_inplace = _ccv_nnc_adam_allow_inplace; |
71 | 1 | } |
72 | | |
73 | | REGISTER_COMMAND(CCV_NNC_ADAMW_BACKWARD)(ccv_nnc_cmd_registry_t* const registry) |
74 | | FIND_BACKEND(ccv_nnc_adamw_cpu_ref.c, gpu/ccv_nnc_adamw_gpu_ref.cu, mps/ccv_nnc_adamw_mps.m) |
75 | 1 | { |
76 | 1 | registry->bitmask = _ccv_nnc_adam_back_bitmask; |
77 | 1 | registry->tensor_auto = _ccv_nnc_adam_tensor_auto_back; |
78 | 1 | } |
79 | | |
80 | | //@REGISTER_EASY_COMMAND_MACRO(CCV_NNC_ADAMW_FORWARD) |
81 | | #define CMD_ADAMW_FORWARD(_step, _rate, _beta1, _beta2, _decay, _epsilon, _amsgrad) ccv_nnc_cmd(CCV_NNC_ADAMW_FORWARD, 0, ((ccv_nnc_cmd_param_t){.size={.dim={1,1,1}},.adam={.step=_step,.rate=_rate,.scale=1,.beta1=_beta1,.beta2=_beta2,.decay=_decay,.epsilon=_epsilon,.amsgrad=_amsgrad}}), 0) |