/home/liu/actions-runner/_work/ccv/ccv/lib/nnc/cmd/scaled_dot_product_attention/ccv_nnc_scaled_dot_product_attention.c
Line | Count | Source |
1 | | #include "ccv.h" |
2 | | #include "nnc/ccv_nnc.h" |
3 | | #include "nnc/ccv_nnc_internal.h" |
4 | | |
5 | | static int _ccv_nnc_scaled_dot_product_attention_forw_bitmask(const ccv_nnc_cmd_param_t cmd, const int input_size, const int output_size, const uint64_t* const input_bitmasks, const int input_bitmask_size, const uint64_t* const output_bitmasks, const int output_bitmask_size) |
6 | 14 | { |
7 | | // 6 inputs (query, key, value, [attn_mask], [unify head weight], [unify head bias]) |
8 | | // 3 outputs (y, softmax_lse, [qkv]) |
9 | 14 | if (input_size == 6 && (input_bitmasks[0] & 55u) == 55u && (output_bitmasks[0] & 7u) == 7u6 ) |
10 | 4 | return 1; |
11 | 10 | if (input_size == 5 && (input_bitmasks[0] & 23u) == 23u0 && (output_bitmasks[0] & 7u) == 7u0 ) |
12 | 0 | return 1; |
13 | 10 | if ((input_bitmasks[0] & 55u) == 7u && (output_bitmasks[0] & 3u) == 3u8 ) |
14 | 6 | return 1; |
15 | 4 | return 0; |
16 | 10 | } |
17 | | |
18 | | |
19 | | static int _ccv_nnc_allow_query_inplace(const ccv_nnc_cmd_param_t cmd, const int input_idx, const int input_size, const int output_idx, const int output_size) |
20 | 49 | { |
21 | 49 | if (input_idx == 0 && output_idx == 012 ) |
22 | 6 | return 1; |
23 | 43 | return 0; |
24 | 49 | } |
25 | | |
26 | | static int _ccv_nnc_scaled_dot_product_attention_back_bitmask(const ccv_nnc_cmd_param_t cmd, const int input_size, const int output_size, const uint64_t* const input_bitmasks, const int input_bitmask_size, const uint64_t* const output_bitmasks, const int output_bitmask_size) |
27 | 1 | { |
28 | | // 1, 0, 0, 8, 16, 32, 64?, 128?, 256?, 512, 1024, 2048? |
29 | | // 1, 2, 4, 8, 16, 32 |
30 | | // Inputs (gradient, 0, 0, q, k, v, [attn_mask], [head weight], [bias], y, saved softmax_lse, qkv) |
31 | | // Output (dquery, dkey, dvalue, [attn mask], dweight, dbias) [cannot diff against attn_mask] |
32 | 1 | if ((input_bitmasks[0] & 4025u) == 4025u && (output_bitmasks[0] & 63u) == 55u0 ) |
33 | 0 | return 1; |
34 | 1 | if ((input_bitmasks[0] & 3769u) == 3769u && (output_bitmasks[0] & 31u) == 23u0 ) |
35 | 0 | return 1; |
36 | 1 | if ((input_bitmasks[0] & 1593u) == 1593u && (output_bitmasks[0] & 7u) == 7u0 ) |
37 | 0 | return 1; |
38 | 1 | return 0; |
39 | 1 | } |
40 | | |
41 | | static void _ccv_nnc_scaled_dot_product_attention_tensor_auto_forw(const ccv_nnc_cmd_param_t cmd, const ccv_nnc_tensor_param_t* const inputs, const int input_size, const ccv_nnc_hint_t hint, ccv_nnc_tensor_param_t* const outputs, const int output_size) |
42 | 19 | { |
43 | 19 | assert(input_size >= 3); |
44 | 19 | assert(output_size >= 1); |
45 | 19 | const int q_nd = ccv_nnc_tensor_nd(inputs[0].dim); |
46 | 19 | assert(q_nd == 3 || q_nd == 4); |
47 | 19 | const int k_nd = ccv_nnc_tensor_nd(inputs[1].dim); |
48 | 19 | assert(k_nd == 3 || k_nd == 4); |
49 | 19 | const int v_nd = ccv_nnc_tensor_nd(inputs[2].dim); |
50 | 19 | assert(v_nd == 3 || v_nd == 4); |
51 | 19 | assert(q_nd == k_nd && k_nd == v_nd); |
52 | 19 | if (input_size > 4) |
53 | 12 | { |
54 | 12 | assert(output_size >= 3); |
55 | 12 | outputs[0] = inputs[0]; |
56 | 12 | outputs[0].dim[1] = inputs[0].dim[1]; // sequence length matches query, embedding size matches value * num_head. |
57 | 12 | outputs[0].dim[2] = inputs[2].dim[v_nd - 1] * (q_nd == 4 ? inputs[0].dim[2] : 10 ); |
58 | 12 | outputs[0].dim[3] = 0; |
59 | | // This is saved softmax_lse, which would be in 32F if exists. |
60 | 12 | outputs[1] = inputs[0]; |
61 | 12 | outputs[1].dim[q_nd - 3] = inputs[0].dim[q_nd - 2]; |
62 | 12 | outputs[1].dim[q_nd - 2] = inputs[0].dim[q_nd - 3]; |
63 | 12 | outputs[1].dim[q_nd - 1] = 0; |
64 | 12 | outputs[1].datatype = CCV_32F; |
65 | 12 | outputs[2] = inputs[0]; |
66 | 12 | outputs[2].dim[q_nd - 1] = inputs[2].dim[v_nd - 1]; // sequence length matches query, embedding size matches value. |
67 | 12 | } else { |
68 | 7 | outputs[0] = inputs[0]; |
69 | 7 | outputs[0].dim[q_nd - 1] = inputs[2].dim[v_nd - 1]; // sequence length matches query, embedding size matches value. |
70 | 7 | if (output_size == 1) |
71 | 3 | return; |
72 | 7 | assert(output_size > 1)4 ; |
73 | | // This is saved softmax_lse, which would be in 32F if exists. |
74 | 4 | outputs[1] = inputs[0]; |
75 | 4 | outputs[1].dim[q_nd - 3] = inputs[0].dim[q_nd - 2]; |
76 | 4 | outputs[1].dim[q_nd - 2] = inputs[0].dim[q_nd - 3]; |
77 | 4 | outputs[1].dim[q_nd - 1] = 0; |
78 | 4 | outputs[1].datatype = CCV_32F; |
79 | 4 | } |
80 | 19 | } |
81 | | |
82 | | static void _ccv_nnc_scaled_dot_product_attention_tensor_auto_back(const ccv_nnc_cmd_param_t cmd, const ccv_nnc_tensor_param_t* const inputs, const int input_size, const ccv_nnc_hint_t hint, ccv_nnc_tensor_param_t* const outputs, const int output_size) |
83 | 1 | { |
84 | 1 | assert(input_size >= 6); |
85 | 1 | assert(output_size >= 3); |
86 | 1 | int i; |
87 | 4 | for (i = 0; i < output_size; i++3 ) |
88 | 3 | outputs[i] = inputs[3 + i]; |
89 | 1 | } |
90 | | |
91 | | REGISTER_COMMAND(CCV_NNC_SCALED_DOT_PRODUCT_ATTENTION_FORWARD)(ccv_nnc_cmd_registry_t* const registry) |
92 | | FIND_BACKEND(ccv_nnc_scaled_dot_product_attention_cpu_ref.c, mps/ccv_nnc_scaled_dot_product_attention_mps.m, gpu/ccv_nnc_scaled_dot_product_attention_flash_attn.cu) |
93 | 1 | { |
94 | 1 | registry->bitmask = _ccv_nnc_scaled_dot_product_attention_forw_bitmask; |
95 | 1 | registry->tensor_auto = _ccv_nnc_scaled_dot_product_attention_tensor_auto_forw; |
96 | 1 | registry->allow_inplace = _ccv_nnc_allow_query_inplace; |
97 | 1 | } |
98 | | |
99 | | REGISTER_COMMAND(CCV_NNC_SCALED_DOT_PRODUCT_ATTENTION_BACKWARD)(ccv_nnc_cmd_registry_t* const registry) |
100 | | FIND_BACKEND(ccv_nnc_scaled_dot_product_attention_cpu_ref.c, mps/ccv_nnc_scaled_dot_product_attention_mps.m, gpu/ccv_nnc_scaled_dot_product_attention_flash_attn.cu) |
101 | 1 | { |
102 | 1 | registry->bitmask = _ccv_nnc_scaled_dot_product_attention_back_bitmask; |
103 | 1 | registry->tensor_auto = _ccv_nnc_scaled_dot_product_attention_tensor_auto_back; |
104 | 1 | } |
105 | | |
106 | | //@REGISTER_EASY_COMMAND_MACRO(CCV_NNC_SCALED_DOT_PRODUCT_ATTENTION_FORWARD) |
107 | | #define CMD_SCALED_DOT_PRODUCT_ATTENTION_FORWARD(_scale, _is_causal) ccv_nnc_cmd(CCV_NNC_SCALED_DOT_PRODUCT_ATTENTION_FORWARD, 0, ((ccv_nnc_cmd_param_t){.size={.dim={1,1,1}},.scaled_dot_product_attention={.scale=_scale,.is_causal=_is_causal}}), 0) |
108 | | //@REGISTER_EASY_COMMAND_MACRO(CCV_NNC_SCALED_DOT_PRODUCT_ATTENTION_BACKWARD) |
109 | | #define CMD_SCALED_DOT_PRODUCT_ATTENTION_BACKWARD(_scale, _is_causal) ccv_nnc_cmd(CCV_NNC_SCALED_DOT_PRODUCT_ATTENTION_BACKWARD, 0, ((ccv_nnc_cmd_param_t){.size={.dim={1,1,1}},.scaled_dot_product_attention={.scale=_scale,.is_causal=_is_causal}}), 0) |