/home/liu/actions-runner/_work/ccv/ccv/lib/nnc/cmd/ew/ccv_nnc_ew.c
Line | Count | Source |
1 | | #include "ccv.h" |
2 | | #include "nnc/ccv_nnc.h" |
3 | | #include "nnc/ccv_nnc_internal.h" |
4 | | |
5 | | static int _ccv_nnc_arbitary_inplace(const ccv_nnc_cmd_param_t cmd, const int input_idx, const int input_size, const int output_idx, const int output_size) |
6 | 11.5k | { |
7 | 11.5k | return 1; |
8 | 11.5k | } |
9 | | |
10 | | static int _ccv_nnc_ewsum_forw_bitmask(const ccv_nnc_cmd_param_t cmd, const int input_size, const int output_size, const uint64_t* const input_bitmasks, const int input_bitmask_size, const uint64_t* const output_bitmasks, const int output_bitmask_size) |
11 | 118 | { |
12 | 118 | if (output_size == 1 && output_bitmasks[0] == 1) |
13 | 118 | { |
14 | 118 | int i, j, flag = 0; |
15 | 118 | int input_bitcount = 0; |
16 | 236 | for (i = 0; i < input_bitmask_size; i++118 ) |
17 | 118 | { |
18 | 358 | for (j = 0; j < 64; j++240 ) |
19 | 358 | if (input_bitmasks[i] & (uint64_t)1 << j) |
20 | 240 | { |
21 | 240 | if (flag) |
22 | 0 | return 0; |
23 | 240 | } else |
24 | 118 | break; |
25 | 118 | input_bitcount += j; |
26 | | // Trailing zero even if it is not the end of input_bitmask_size, mark flag, |
27 | | // if we encounter additional 1, return invalid. |
28 | 118 | if (j < 64) |
29 | 118 | flag = 1; |
30 | | // Always like 1111100000, no 1110010101 |
31 | 7.43k | for (; j < 64; j++7.31k ) |
32 | 7.31k | if (input_bitmasks[i] & (uint64_t)1 << j) |
33 | 0 | return 0; |
34 | 118 | } |
35 | 118 | return input_size == input_bitcount; |
36 | 118 | } |
37 | 0 | return 0; |
38 | 118 | } |
39 | | |
40 | | static int _ccv_nnc_ewsum_back_bitmask(const ccv_nnc_cmd_param_t cmd, const int input_size, const int output_size, const uint64_t* const input_bitmasks, const int input_bitmask_size, const uint64_t* const output_bitmasks, const int output_bitmask_size) |
41 | 303 | { |
42 | 303 | if (input_size >= 1 && (input_bitmasks[0] & 1u) == 1u) |
43 | 235 | { |
44 | 235 | int i, j, flag = 0; |
45 | 235 | int output_bitcount = 0; |
46 | 458 | for (i = 0; i < output_bitmask_size; i++223 ) |
47 | 235 | { |
48 | 664 | for (j = 0; j < 64; j++429 ) |
49 | 664 | if (output_bitmasks[i] & (uint64_t)1 << j) |
50 | 429 | { |
51 | 429 | if (flag) |
52 | 0 | return 0; |
53 | 429 | } else |
54 | 235 | break; |
55 | 235 | output_bitcount += j; |
56 | | // Trailing zero even if it is not the end of input_bitmask_size, mark flag, |
57 | | // if we encounter additional 1, return invalid. |
58 | 235 | if (j < 64) |
59 | 235 | flag = 1; |
60 | | // Always like 1111100000, no 1110010101 |
61 | 14.0k | for (; j < 64; j++13.8k ) |
62 | 13.8k | if (output_bitmasks[i] & (uint64_t)1 << j) |
63 | 12 | return 0; |
64 | 235 | } |
65 | 223 | return output_size == output_bitcount; |
66 | 235 | } |
67 | 68 | return 0; |
68 | 303 | } |
69 | | |
70 | | REGISTER_COMMAND(CCV_NNC_EWSUM_FORWARD)(ccv_nnc_cmd_registry_t* const registry) |
71 | | FIND_BACKEND(ccv_nnc_ew_cpu_ref.c, gpu/ccv_nnc_ew_gpu_cudnn.cu, mps/ccv_nnc_ew_mps.m) |
72 | 1 | { |
73 | 1 | registry->bitmask = _ccv_nnc_ewsum_forw_bitmask; |
74 | 1 | registry->tensor_auto = ccv_nnc_hint_tensor_auto_forward_from_inputs; |
75 | 1 | registry->allow_inplace = _ccv_nnc_arbitary_inplace; |
76 | 1 | } |
77 | | |
78 | | REGISTER_COMMAND(CCV_NNC_EWSUM_BACKWARD)(ccv_nnc_cmd_registry_t* const registry) |
79 | | FIND_BACKEND(ccv_nnc_ew_cpu_ref.c, gpu/ccv_nnc_ew_gpu_cudnn.cu, mps/ccv_nnc_ew_mps.m) |
80 | 1 | { |
81 | 1 | registry->flags = CCV_NNC_CMD_ATTR_PASSTHROUGH | CCV_NNC_CMD_ATTR_NULL_IS_ONES; |
82 | 1 | registry->bitmask = _ccv_nnc_ewsum_back_bitmask; |
83 | 1 | registry->tensor_auto = ccv_nnc_hint_tensor_auto_backward_from_gradient; |
84 | 1 | registry->allow_inplace = _ccv_nnc_arbitary_inplace; |
85 | 1 | } |
86 | | |
87 | | //@REGISTER_EASY_COMMAND_MACRO(CCV_NNC_EWSUM_FORWARD) |
88 | | #define CMD_EWSUM_FORWARD() ccv_nnc_cmd(CCV_NNC_EWSUM_FORWARD, 0, ccv_nnc_cmd_auto, 0) |
89 | | //@REGISTER_EASY_COMMAND_MACRO(CCV_NNC_EWSUM_BACKWARD) |
90 | | #define CMD_EWSUM_BACKWARD() ccv_nnc_cmd(CCV_NNC_EWSUM_BACKWARD, 0, ccv_nnc_cmd_auto, 0) |
91 | | |
92 | | static int _ccv_nnc_ewprod_forw_bitmask(const ccv_nnc_cmd_param_t cmd, const int input_size, const int output_size, const uint64_t* const input_bitmasks, const int input_bitmask_size, const uint64_t* const output_bitmasks, const int output_bitmask_size) |
93 | 47 | { |
94 | 47 | if (output_size == 1 && output_bitmasks[0] == 1) |
95 | 47 | { |
96 | 47 | int i, j, flag = 0; |
97 | 47 | int input_bitcount = 0; |
98 | 94 | for (i = 0; i < input_bitmask_size; i++47 ) |
99 | 47 | { |
100 | 141 | for (j = 0; j < 64; j++94 ) |
101 | 141 | if (input_bitmasks[i] & (uint64_t)1 << j) |
102 | 94 | { |
103 | 94 | if (flag) |
104 | 0 | return 0; |
105 | 94 | } else |
106 | 47 | break; |
107 | 47 | input_bitcount += j; |
108 | | // Trailing zero even if it is not the end of input_bitmask_size, mark flag, |
109 | | // if we encounter additional 1, return invalid. |
110 | 47 | if (j < 64) |
111 | 47 | flag = 1; |
112 | | // Always like 1111100000, no 1110010101 |
113 | 2.96k | for (; j < 64; j++2.91k ) |
114 | 2.91k | if (input_bitmasks[i] & (uint64_t)1 << j) |
115 | 0 | return 0; |
116 | 47 | } |
117 | 47 | return input_size == input_bitcount; |
118 | 47 | } |
119 | 0 | return 0; |
120 | 47 | } |
121 | | |
122 | | static int _ccv_nnc_ewprod_back_bitmask(const ccv_nnc_cmd_param_t cmd, const int input_size, const int output_size, const uint64_t* const input_bitmasks, const int input_bitmask_size, const uint64_t* const output_bitmasks, const int output_bitmask_size) |
123 | 11.2k | { |
124 | 11.2k | int i, j; |
125 | 11.2k | int input_flag = 0; |
126 | 11.2k | int input_bitcount = 0; |
127 | 15.8k | for (i = 0; i < input_bitmask_size; i++4.52k ) |
128 | 11.2k | { |
129 | 33.9k | for (j = 0; j < 64; j++22.6k ) |
130 | 33.9k | if (input_bitmasks[i] & (uint64_t)1 << j) |
131 | 22.6k | { |
132 | 22.6k | if (input_flag) |
133 | 0 | return 0; |
134 | 22.6k | } else |
135 | 11.2k | break; |
136 | 11.2k | input_bitcount += j; |
137 | 11.2k | if (j < 64) |
138 | 11.2k | input_flag = 1; |
139 | | // Always like 1111100000, no 1110010101 |
140 | 291k | for (; j < 64; j++280k ) |
141 | 287k | if (input_bitmasks[i] & (uint64_t)1 << j) |
142 | 6.76k | return 0; |
143 | 11.2k | } |
144 | 4.52k | int output_flag = 0; |
145 | 4.52k | int output_bitcount = 0; |
146 | 9.05k | for (i = 0; i < output_bitmask_size; i++4.52k ) |
147 | 4.52k | { |
148 | 13.5k | for (j = 0; j < 64; j++9.03k ) |
149 | 13.5k | if ((output_bitmasks[i] & (uint64_t)1 << j)) |
150 | 9.03k | { |
151 | 9.03k | if (output_flag) |
152 | 0 | return 0; |
153 | 9.03k | } else |
154 | 4.52k | break; |
155 | 4.52k | output_bitcount += j; |
156 | 4.52k | if (j < 64) |
157 | 4.52k | output_flag = 1; |
158 | 285k | for (; j < 64; j++280k ) |
159 | 280k | if (output_bitmasks[i] & (uint64_t)1 << j) |
160 | 2 | return 0; |
161 | 4.52k | } |
162 | 4.52k | if (output_bitcount != output_size) |
163 | 10 | return 0; |
164 | 4.51k | return output_bitcount + 2 /* Gradient + Original output */ == input_bitcount; |
165 | 4.52k | } |
166 | | |
167 | | REGISTER_COMMAND(CCV_NNC_EWPROD_FORWARD)(ccv_nnc_cmd_registry_t* const registry) |
168 | | FIND_BACKEND(ccv_nnc_ew_cpu_ref.c) |
169 | 1 | { |
170 | 1 | registry->bitmask = _ccv_nnc_ewprod_forw_bitmask; |
171 | 1 | registry->tensor_auto = ccv_nnc_hint_tensor_auto_forward_from_inputs; |
172 | 1 | registry->allow_inplace = _ccv_nnc_arbitary_inplace; |
173 | 1 | } |
174 | | |
175 | | REGISTER_COMMAND(CCV_NNC_EWPROD_BACKWARD)(ccv_nnc_cmd_registry_t* const registry) |
176 | | FIND_BACKEND(ccv_nnc_ew_cpu_ref.c) |
177 | 1 | { |
178 | 1 | registry->flags = CCV_NNC_CMD_ATTR_NULL_IS_ONES; |
179 | 1 | registry->bitmask = _ccv_nnc_ewprod_back_bitmask; |
180 | 1 | registry->tensor_auto = ccv_nnc_hint_tensor_auto_backward_from_gradient; |
181 | 1 | } |
182 | | |
183 | | //@REGISTER_EASY_COMMAND_MACRO(CCV_NNC_EWPROD_FORWARD) |
184 | | #define CMD_EWPROD_FORWARD() ccv_nnc_cmd(CCV_NNC_EWPROD_FORWARD, 0, ccv_nnc_cmd_auto, 0) |
185 | | //@REGISTER_EASY_COMMAND_MACRO(CCV_NNC_EWPROD_BACKWARD) |
186 | | #define CMD_EWPROD_BACKWARD() ccv_nnc_cmd(CCV_NNC_EWPROD_BACKWARD, 0, ccv_nnc_cmd_auto, 0) |
187 | | |
188 | | static int _ccv_nnc_ewdiv_forw_bitmask(const ccv_nnc_cmd_param_t cmd, const int input_size, const int output_size, const uint64_t* const input_bitmasks, const int input_bitmask_size, const uint64_t* const output_bitmasks, const int output_bitmask_size) |
189 | 4 | { |
190 | 4 | if ((input_bitmasks[0] & 3u) == ((1u << 0) | (1u << 1)) && output_bitmasks[0] == 1u2 ) |
191 | 2 | return 1; |
192 | | // Nominator can be null (meaning 1). |
193 | 2 | if ((input_bitmasks[0] & 3u) == ((0u << 0) | (1u << 1)) && output_bitmasks[0] == 1u) |
194 | 2 | return 1; |
195 | 0 | return 0; |
196 | 2 | } |
197 | | |
198 | | static int _ccv_nnc_ewdiv_back_bitmask(const ccv_nnc_cmd_param_t cmd, const int input_size, const int output_size, const uint64_t* const input_bitmasks, const int input_bitmask_size, const uint64_t* const output_bitmasks, const int output_bitmask_size) |
199 | 68 | { |
200 | 68 | if ((input_bitmasks[0] & (15u & ~((uint64_t)1u << 1))) == ((1u << 0) | (0u << 1) | (1u << 2) | (1u << 3)) && output_bitmasks[0] == ((1u << 0) | (1u << 1))20 ) |
201 | 5 | return 1; |
202 | | // We don't need to know the original output. |
203 | 63 | if ((input_bitmasks[0] & (15u & ~((uint64_t)1u << 1))) == ((1u << 0) | (0u << 1) | (1u << 2) | (0u << 3)) && output_bitmasks[0] == ((1u << 0) | (0u << 1))17 ) |
204 | 0 | return 1; |
205 | 63 | if ((input_bitmasks[0] & (15u & ~((uint64_t)1u << 1))) == ((1u << 0) | (0u << 1) | (1u << 2) | (1u << 3)) && output_bitmasks[0] == ((0u << 0) | (1u << 1))15 ) |
206 | 15 | return 1; |
207 | 48 | return 0; |
208 | 63 | } |
209 | | |
210 | | static void _ccv_nnc_ewdiv_tensor_auto_forw(const ccv_nnc_cmd_param_t cmd, const ccv_nnc_tensor_param_t* const inputs, const int input_size, const ccv_nnc_hint_t hint, ccv_nnc_tensor_param_t* const outputs, const int output_size) |
211 | 46 | { |
212 | 46 | assert(output_size >= 1); |
213 | 46 | assert(input_size >= 2); |
214 | 46 | int i; |
215 | 92 | for (i = 0; i < output_size; i++46 ) |
216 | 46 | outputs[i] = inputs[1]; |
217 | 46 | } |
218 | | |
219 | | REGISTER_COMMAND(CCV_NNC_EWDIV_FORWARD)(ccv_nnc_cmd_registry_t* const registry) |
220 | | FIND_BACKEND(ccv_nnc_ew_cpu_ref.c, gpu/ccv_nnc_ew_gpu_ref.cu, mps/ccv_nnc_ew_mps.m) |
221 | 1 | { |
222 | 1 | registry->flags = CCV_NNC_CMD_ATTR_NULL_IS_ONES; |
223 | 1 | registry->bitmask = _ccv_nnc_ewdiv_forw_bitmask; |
224 | 1 | registry->tensor_auto = _ccv_nnc_ewdiv_tensor_auto_forw; |
225 | 1 | registry->allow_inplace = _ccv_nnc_arbitary_inplace; |
226 | 1 | } |
227 | | |
228 | | REGISTER_COMMAND(CCV_NNC_EWDIV_BACKWARD)(ccv_nnc_cmd_registry_t* const registry) |
229 | | FIND_BACKEND(ccv_nnc_ew_cpu_ref.c, gpu/ccv_nnc_ew_gpu_ref.cu, mps/ccv_nnc_ew_mps.m) |
230 | 1 | { |
231 | 1 | registry->flags = CCV_NNC_CMD_ATTR_NULL_IS_ONES; |
232 | 1 | registry->bitmask = _ccv_nnc_ewdiv_back_bitmask; |
233 | 1 | registry->tensor_auto = ccv_nnc_hint_tensor_auto_backward_from_gradient; |
234 | 1 | } |
235 | | |
236 | | //@REGISTER_EASY_COMMAND_MACRO(CCV_NNC_EWDIV_FORWARD) |
237 | | #define CMD_EWDIV_FORWARD() ccv_nnc_cmd(CCV_NNC_EWDIV_FORWARD, 0, ccv_nnc_cmd_auto, 0) |
238 | | //@REGISTER_EASY_COMMAND_MACRO(CCV_NNC_EWDIV_BACKWARD) |
239 | | #define CMD_EWDIV_BACKWARD() ccv_nnc_cmd(CCV_NNC_EWDIV_BACKWARD, 0, ccv_nnc_cmd_auto, 0) |
240 | | |
241 | | static int _ccv_nnc_ewexp_forw_bitmask(const ccv_nnc_cmd_param_t cmd, const int input_size, const int output_size, const uint64_t* const input_bitmasks, const int input_bitmask_size, const uint64_t* const output_bitmasks, const int output_bitmask_size) |
242 | 0 | { |
243 | 0 | if ((input_bitmasks[0] & 1u) == 1u && output_bitmasks[0] == 1u) |
244 | 0 | return 1; |
245 | 0 | return 0; |
246 | 0 | } |
247 | | |
248 | | static int _ccv_nnc_ewexp_back_bitmask(const ccv_nnc_cmd_param_t cmd, const int input_size, const int output_size, const uint64_t* const input_bitmasks, const int input_bitmask_size, const uint64_t* const output_bitmasks, const int output_bitmask_size) |
249 | 48 | { |
250 | | // We don't care about the original input. |
251 | 48 | if ((input_bitmasks[0] & (7u & ~((uint64_t)1u << 1))) == ((1u << 0) | (0u << 1) | (1u << 2)) && output_bitmasks[0] == 1u16 ) |
252 | 16 | return 1; |
253 | 32 | return 0; |
254 | 48 | } |
255 | | |
256 | | REGISTER_COMMAND(CCV_NNC_EWEXP_FORWARD)(ccv_nnc_cmd_registry_t* const registry) |
257 | | FIND_BACKEND(ccv_nnc_ew_cpu_ref.c, gpu/ccv_nnc_ew_gpu_ref.cu, mps/ccv_nnc_ew_mps.m) |
258 | 1 | { |
259 | 1 | registry->bitmask = _ccv_nnc_ewexp_forw_bitmask; |
260 | 1 | registry->tensor_auto = ccv_nnc_hint_tensor_auto_forward_from_inputs; |
261 | 1 | registry->allow_inplace = _ccv_nnc_arbitary_inplace; |
262 | 1 | } |
263 | | |
264 | | REGISTER_COMMAND(CCV_NNC_EWEXP_BACKWARD)(ccv_nnc_cmd_registry_t* const registry) |
265 | | FIND_BACKEND(ccv_nnc_ew_cpu_ref.c, gpu/ccv_nnc_ew_gpu_ref.cu, mps/ccv_nnc_ew_mps.m) |
266 | 1 | { |
267 | 1 | registry->flags = CCV_NNC_CMD_ATTR_NULL_IS_ONES; |
268 | 1 | registry->bitmask = _ccv_nnc_ewexp_back_bitmask; |
269 | 1 | registry->tensor_auto = ccv_nnc_hint_tensor_auto_backward_from_gradient; |
270 | 1 | registry->allow_inplace = _ccv_nnc_arbitary_inplace; |
271 | 1 | } |
272 | | |
273 | | //@REGISTER_EASY_COMMAND_MACRO(CCV_NNC_EWEXP_FORWARD) |
274 | | #define CMD_EWEXP_FORWARD() ccv_nnc_cmd(CCV_NNC_EWEXP_FORWARD, 0, ccv_nnc_cmd_auto, 0) |
275 | | //@REGISTER_EASY_COMMAND_MACRO(CCV_NNC_EWEXP_BACKWARD) |
276 | | #define CMD_EWEXP_BACKWARD() ccv_nnc_cmd(CCV_NNC_EWEXP_BACKWARD, 0, ccv_nnc_cmd_auto, 0) |
277 | | |
278 | | static int _ccv_nnc_ewlog_forw_bitmask(const ccv_nnc_cmd_param_t cmd, const int input_size, const int output_size, const uint64_t* const input_bitmasks, const int input_bitmask_size, const uint64_t* const output_bitmasks, const int output_bitmask_size) |
279 | 0 | { |
280 | 0 | if ((input_bitmasks[0] & 1u) == 1u && output_bitmasks[0] == 1u) |
281 | 0 | return 1; |
282 | 0 | return 0; |
283 | 0 | } |
284 | | |
285 | | static int _ccv_nnc_ewlog_back_bitmask(const ccv_nnc_cmd_param_t cmd, const int input_size, const int output_size, const uint64_t* const input_bitmasks, const int input_bitmask_size, const uint64_t* const output_bitmasks, const int output_bitmask_size) |
286 | 1.31k | { |
287 | | // We don't care about the original output. |
288 | 1.31k | if ((input_bitmasks[0] & 3u) == 3u && output_bitmasks[0] == 1u438 ) |
289 | 438 | return 1; |
290 | 872 | return 0; |
291 | 1.31k | } |
292 | | |
293 | | REGISTER_COMMAND(CCV_NNC_EWLOG_FORWARD)(ccv_nnc_cmd_registry_t* const registry) |
294 | | FIND_BACKEND(ccv_nnc_ew_cpu_ref.c, gpu/ccv_nnc_ew_gpu_ref.cu, mps/ccv_nnc_ew_mps.m) |
295 | 1 | { |
296 | 1 | registry->bitmask = _ccv_nnc_ewlog_forw_bitmask; |
297 | 1 | registry->tensor_auto = ccv_nnc_hint_tensor_auto_forward_from_inputs; |
298 | 1 | registry->allow_inplace = _ccv_nnc_arbitary_inplace; |
299 | 1 | } |
300 | | |
301 | | REGISTER_COMMAND(CCV_NNC_EWLOG_BACKWARD)(ccv_nnc_cmd_registry_t* const registry) |
302 | | FIND_BACKEND(ccv_nnc_ew_cpu_ref.c, gpu/ccv_nnc_ew_gpu_ref.cu, mps/ccv_nnc_ew_mps.m) |
303 | 1 | { |
304 | 1 | registry->flags = CCV_NNC_CMD_ATTR_NULL_IS_ONES; |
305 | 1 | registry->bitmask = _ccv_nnc_ewlog_back_bitmask; |
306 | 1 | registry->tensor_auto = ccv_nnc_hint_tensor_auto_backward_from_gradient; |
307 | 1 | registry->allow_inplace = _ccv_nnc_arbitary_inplace; |
308 | 1 | } |
309 | | |
310 | | //@REGISTER_EASY_COMMAND_MACRO(CCV_NNC_EWLOG_FORWARD) |
311 | | #define CMD_EWLOG_FORWARD() ccv_nnc_cmd(CCV_NNC_EWLOG_FORWARD, 0, ccv_nnc_cmd_auto, 0) |
312 | | //@REGISTER_EASY_COMMAND_MACRO(CCV_NNC_EWLOG_BACKWARD) |
313 | | #define CMD_EWLOG_BACKWARD() ccv_nnc_cmd(CCV_NNC_EWLOG_BACKWARD, 0, ccv_nnc_cmd_auto, 0) |
314 | | |
315 | | static int _ccv_nnc_ewsqrt_forw_bitmask(const ccv_nnc_cmd_param_t cmd, const int input_size, const int output_size, const uint64_t* const input_bitmasks, const int input_bitmask_size, const uint64_t* const output_bitmasks, const int output_bitmask_size) |
316 | 0 | { |
317 | 0 | if ((input_bitmasks[0] & 1u) == 1u && output_bitmasks[0] == 1u) |
318 | 0 | return 1; |
319 | 0 | return 0; |
320 | 0 | } |
321 | | |
322 | | static int _ccv_nnc_ewsqrt_back_bitmask(const ccv_nnc_cmd_param_t cmd, const int input_size, const int output_size, const uint64_t* const input_bitmasks, const int input_bitmask_size, const uint64_t* const output_bitmasks, const int output_bitmask_size) |
323 | 6 | { |
324 | | // We don't care about the original input. |
325 | 6 | if ((input_bitmasks[0] & (7u & ~((uint64_t)1u << 1))) == ((1u << 0) | (0u << 1) | (1u << 2)) && output_bitmasks[0] == 1u2 ) |
326 | 2 | return 1; |
327 | 4 | return 0; |
328 | 6 | } |
329 | | |
330 | | REGISTER_COMMAND(CCV_NNC_EWSQRT_FORWARD)(ccv_nnc_cmd_registry_t* const registry) |
331 | | FIND_BACKEND(ccv_nnc_ew_cpu_ref.c, gpu/ccv_nnc_ew_gpu_ref.cu, mps/ccv_nnc_ew_mps.m) |
332 | 1 | { |
333 | 1 | registry->bitmask = _ccv_nnc_ewsqrt_forw_bitmask; |
334 | 1 | registry->tensor_auto = ccv_nnc_hint_tensor_auto_forward_from_inputs; |
335 | 1 | registry->allow_inplace = _ccv_nnc_arbitary_inplace; |
336 | 1 | } |
337 | | |
338 | | REGISTER_COMMAND(CCV_NNC_EWSQRT_BACKWARD)(ccv_nnc_cmd_registry_t* const registry) |
339 | | FIND_BACKEND(ccv_nnc_ew_cpu_ref.c, gpu/ccv_nnc_ew_gpu_ref.cu, mps/ccv_nnc_ew_mps.m) |
340 | 1 | { |
341 | 1 | registry->flags = CCV_NNC_CMD_ATTR_NULL_IS_ONES; |
342 | 1 | registry->bitmask = _ccv_nnc_ewsqrt_back_bitmask; |
343 | 1 | registry->tensor_auto = ccv_nnc_hint_tensor_auto_backward_from_gradient; |
344 | 1 | registry->allow_inplace = _ccv_nnc_arbitary_inplace; |
345 | 1 | } |
346 | | |
347 | | //@REGISTER_EASY_COMMAND_MACRO(CCV_NNC_EWSQRT_FORWARD) |
348 | | #define CMD_EWSQRT_FORWARD() ccv_nnc_cmd(CCV_NNC_EWSQRT_FORWARD, 0, ccv_nnc_cmd_auto, 0) |
349 | | //@REGISTER_EASY_COMMAND_MACRO(CCV_NNC_EWSQRT_BACKWARD) |
350 | | #define CMD_EWSQRT_BACKWARD() ccv_nnc_cmd(CCV_NNC_EWSQRT_BACKWARD, 0, ccv_nnc_cmd_auto, 0) |
351 | | |
352 | | static int _ccv_nnc_clamp_forw_bitmask(const ccv_nnc_cmd_param_t cmd, const int input_size, const int output_size, const uint64_t* const input_bitmasks, const int input_bitmask_size, const uint64_t* const output_bitmasks, const int output_bitmask_size) |
353 | 0 | { |
354 | 0 | if ((input_bitmasks[0] & 1u) == 1u && output_bitmasks[0] == 1u) |
355 | 0 | return 1; |
356 | 0 | return 0; |
357 | 0 | } |
358 | | |
359 | | static int _ccv_nnc_clamp_back_bitmask(const ccv_nnc_cmd_param_t cmd, const int input_size, const int output_size, const uint64_t* const input_bitmasks, const int input_bitmask_size, const uint64_t* const output_bitmasks, const int output_bitmask_size) |
360 | 0 | { |
361 | | // We don't care about the original input. |
362 | 0 | if ((input_bitmasks[0] & (7u & ~((uint64_t)1u << 1))) == ((1u << 0) | (0u << 1) | (1u << 2)) && output_bitmasks[0] == 1u) |
363 | 0 | return 1; |
364 | 0 | return 0; |
365 | 0 | } |
366 | | |
367 | | REGISTER_COMMAND(CCV_NNC_CLAMP_FORWARD)(ccv_nnc_cmd_registry_t* const registry) |
368 | | FIND_BACKEND(ccv_nnc_ew_cpu_ref.c, gpu/ccv_nnc_ew_gpu_ref.cu, mps/ccv_nnc_ew_mps.m) |
369 | 1 | { |
370 | 1 | registry->bitmask = _ccv_nnc_clamp_forw_bitmask; |
371 | 1 | registry->tensor_auto = ccv_nnc_hint_tensor_auto_forward_from_inputs; |
372 | 1 | registry->allow_inplace = _ccv_nnc_arbitary_inplace; |
373 | 1 | } |
374 | | |
375 | | REGISTER_COMMAND(CCV_NNC_CLAMP_BACKWARD)(ccv_nnc_cmd_registry_t* const registry) |
376 | | FIND_BACKEND(ccv_nnc_ew_cpu_ref.c, gpu/ccv_nnc_ew_gpu_ref.cu, mps/ccv_nnc_ew_mps.m) |
377 | 1 | { |
378 | 1 | registry->flags = CCV_NNC_CMD_ATTR_NULL_IS_ONES; |
379 | 1 | registry->bitmask = _ccv_nnc_clamp_back_bitmask; |
380 | 1 | registry->tensor_auto = ccv_nnc_hint_tensor_auto_backward_from_gradient; |
381 | 1 | registry->allow_inplace = _ccv_nnc_arbitary_inplace; |
382 | 1 | } |
383 | | |
384 | | //@REGISTER_EASY_COMMAND_MACRO(CCV_NNC_CLAMP_FORWARD) |
385 | | #define CMD_CLAMP_FORWARD(_min, _max) ccv_nnc_cmd(CCV_NNC_CLAMP_FORWARD, 0, (ccv_nnc_cmd_param_t){.size={.dim={1,1,1}},.clamp={.min=_min,.max=_max}}, 0) |
386 | | //@REGISTER_EASY_COMMAND_MACRO(CCV_NNC_CLAMP_BACKWARD) |
387 | | #define CMD_CLAMP_BACKWARD(_min, _max) ccv_nnc_cmd(CCV_NNC_CLAMP_BACKWARD, 0, (ccv_nnc_cmd_param_t){.size={.dim={1,1,1}},.clamp={.min=_min,.max=_max}}, 0) |