Coverage Report

Created: 2024-12-16 17:02

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/home/liu/actions-runner/_work/ccv/ccv/lib/nnc/cmd/scaled_dot_product_attention/ccv_nnc_scaled_dot_product_attention.c
Line
Count
Source
1
#include "ccv.h"
2
#include "nnc/ccv_nnc.h"
3
#include "nnc/ccv_nnc_internal.h"
4
5
static int _ccv_nnc_scaled_dot_product_attention_forw_bitmask(const ccv_nnc_cmd_param_t cmd, const int input_size, const int output_size, const uint64_t* const input_bitmasks, const int input_bitmask_size, const uint64_t* const output_bitmasks, const int output_bitmask_size)
6
14
{
7
  // 6 inputs (query, key, value, [attn_mask], [unify head weight], [unify head bias])
8
  // 3 outputs (y, softmax_lse, [qkv])
9
14
  if (input_size == 6 && (input_bitmasks[0] & 55u) == 55u && 
(output_bitmasks[0] & 7u) == 7u6
)
10
4
    return 1;
11
10
  if (input_size == 5 && 
(input_bitmasks[0] & 23u) == 23u0
&&
(output_bitmasks[0] & 7u) == 7u0
)
12
0
    return 1;
13
10
  if ((input_bitmasks[0] & 55u) == 7u && 
(output_bitmasks[0] & 3u) == 3u8
)
14
6
    return 1;
15
4
  return 0;
16
10
}
17
18
19
static int _ccv_nnc_allow_query_inplace(const ccv_nnc_cmd_param_t cmd, const int input_idx, const int input_size, const int output_idx, const int output_size)
20
49
{
21
49
  if (input_idx == 0 && 
output_idx == 012
)
22
6
    return 1;
23
43
  return 0;
24
49
}
25
26
static int _ccv_nnc_scaled_dot_product_attention_back_bitmask(const ccv_nnc_cmd_param_t cmd, const int input_size, const int output_size, const uint64_t* const input_bitmasks, const int input_bitmask_size, const uint64_t* const output_bitmasks, const int output_bitmask_size)
27
1
{
28
  // 1, 0, 0, 8, 16, 32, 64?, 128?, 256?, 512, 1024, 2048?
29
  // 1, 2, 4, 8, 16, 32
30
  // Inputs (gradient, 0, 0, q, k, v, [attn_mask], [head weight], [bias], y, saved softmax_lse, qkv)
31
  // Output (dquery, dkey, dvalue, [attn mask], dweight, dbias) [cannot diff against attn_mask]
32
1
  if ((input_bitmasks[0] & 4025u) == 4025u && 
(output_bitmasks[0] & 63u) == 55u0
)
33
0
    return 1;
34
1
  if ((input_bitmasks[0] & 3769u) == 3769u && 
(output_bitmasks[0] & 31u) == 23u0
)
35
0
    return 1;
36
1
  if ((input_bitmasks[0] & 1593u) == 1593u && 
(output_bitmasks[0] & 7u) == 7u0
)
37
0
    return 1;
38
1
  return 0;
39
1
}
40
41
static void _ccv_nnc_scaled_dot_product_attention_tensor_auto_forw(const ccv_nnc_cmd_param_t cmd, const ccv_nnc_tensor_param_t* const inputs, const int input_size, const ccv_nnc_hint_t hint, ccv_nnc_tensor_param_t* const outputs, const int output_size)
42
19
{
43
19
  assert(input_size >= 3);
44
19
  assert(output_size >= 1);
45
19
  const int q_nd = ccv_nnc_tensor_nd(inputs[0].dim);
46
19
  assert(q_nd == 3 || q_nd == 4);
47
19
  const int k_nd = ccv_nnc_tensor_nd(inputs[1].dim);
48
19
  assert(k_nd == 3 || k_nd == 4);
49
19
  const int v_nd = ccv_nnc_tensor_nd(inputs[2].dim);
50
19
  assert(v_nd == 3 || v_nd == 4);
51
19
  assert(q_nd == k_nd && k_nd == v_nd);
52
19
  if (input_size > 4)
53
12
  {
54
12
    assert(output_size >= 3);
55
12
    outputs[0] = inputs[0];
56
12
    outputs[0].dim[1] = inputs[0].dim[1]; // sequence length matches query, embedding size matches value * num_head.
57
12
    outputs[0].dim[2] = inputs[2].dim[v_nd - 1] * (q_nd == 4 ? inputs[0].dim[2] : 
10
);
58
12
    outputs[0].dim[3] = 0;
59
    // This is saved softmax_lse, which would be in 32F if exists.
60
12
    outputs[1] = inputs[0];
61
12
    outputs[1].dim[q_nd - 3] = inputs[0].dim[q_nd - 2];
62
12
    outputs[1].dim[q_nd - 2] = inputs[0].dim[q_nd - 3];
63
12
    outputs[1].dim[q_nd - 1] = 0;
64
12
    outputs[1].datatype = CCV_32F;
65
12
    outputs[2] = inputs[0];
66
12
    outputs[2].dim[q_nd - 1] = inputs[2].dim[v_nd - 1]; // sequence length matches query, embedding size matches value.
67
12
  } else {
68
7
    outputs[0] = inputs[0];
69
7
    outputs[0].dim[q_nd - 1] = inputs[2].dim[v_nd - 1]; // sequence length matches query, embedding size matches value.
70
7
    if (output_size == 1)
71
3
      return;
72
7
    assert
(output_size > 1)4
;
73
    // This is saved softmax_lse, which would be in 32F if exists.
74
4
    outputs[1] = inputs[0];
75
4
    outputs[1].dim[q_nd - 3] = inputs[0].dim[q_nd - 2];
76
4
    outputs[1].dim[q_nd - 2] = inputs[0].dim[q_nd - 3];
77
4
    outputs[1].dim[q_nd - 1] = 0;
78
4
    outputs[1].datatype = CCV_32F;
79
4
  }
80
19
}
81
82
static void _ccv_nnc_scaled_dot_product_attention_tensor_auto_back(const ccv_nnc_cmd_param_t cmd, const ccv_nnc_tensor_param_t* const inputs, const int input_size, const ccv_nnc_hint_t hint, ccv_nnc_tensor_param_t* const outputs, const int output_size)
83
1
{
84
1
  assert(input_size >= 6);
85
1
  assert(output_size >= 3);
86
1
  int i;
87
4
  for (i = 0; i < output_size; 
i++3
)
88
3
    outputs[i] = inputs[3 + i];
89
1
}
90
91
REGISTER_COMMAND(CCV_NNC_SCALED_DOT_PRODUCT_ATTENTION_FORWARD)(ccv_nnc_cmd_registry_t* const registry)
92
  FIND_BACKEND(ccv_nnc_scaled_dot_product_attention_cpu_ref.c, mps/ccv_nnc_scaled_dot_product_attention_mps.m, gpu/ccv_nnc_scaled_dot_product_attention_flash_attn.cu)
93
1
{
94
1
  registry->bitmask = _ccv_nnc_scaled_dot_product_attention_forw_bitmask;
95
1
  registry->tensor_auto = _ccv_nnc_scaled_dot_product_attention_tensor_auto_forw;
96
1
  registry->allow_inplace = _ccv_nnc_allow_query_inplace;
97
1
}
98
99
REGISTER_COMMAND(CCV_NNC_SCALED_DOT_PRODUCT_ATTENTION_BACKWARD)(ccv_nnc_cmd_registry_t* const registry)
100
  FIND_BACKEND(ccv_nnc_scaled_dot_product_attention_cpu_ref.c, mps/ccv_nnc_scaled_dot_product_attention_mps.m, gpu/ccv_nnc_scaled_dot_product_attention_flash_attn.cu)
101
1
{
102
1
  registry->bitmask = _ccv_nnc_scaled_dot_product_attention_back_bitmask;
103
1
  registry->tensor_auto = _ccv_nnc_scaled_dot_product_attention_tensor_auto_back;
104
1
}
105
106
//@REGISTER_EASY_COMMAND_MACRO(CCV_NNC_SCALED_DOT_PRODUCT_ATTENTION_FORWARD)
107
#define CMD_SCALED_DOT_PRODUCT_ATTENTION_FORWARD(_scale, _is_causal) ccv_nnc_cmd(CCV_NNC_SCALED_DOT_PRODUCT_ATTENTION_FORWARD, 0, ((ccv_nnc_cmd_param_t){.size={.dim={1,1,1}},.scaled_dot_product_attention={.scale=_scale,.is_causal=_is_causal}}), 0)
108
//@REGISTER_EASY_COMMAND_MACRO(CCV_NNC_SCALED_DOT_PRODUCT_ATTENTION_BACKWARD)
109
#define CMD_SCALED_DOT_PRODUCT_ATTENTION_BACKWARD(_scale, _is_causal) ccv_nnc_cmd(CCV_NNC_SCALED_DOT_PRODUCT_ATTENTION_BACKWARD, 0, ((ccv_nnc_cmd_param_t){.size={.dim={1,1,1}},.scaled_dot_product_attention={.scale=_scale,.is_causal=_is_causal}}), 0)