Coverage Report

Created: 2025-02-24 17:43

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/home/liu/actions-runner/_work/ccv/ccv/test/unit/nnc/attention.tests.c
Line
Count
Source
1
#include "case.h"
2
#include "ccv_case.h"
3
#include "ccv_nnc_case.h"
4
#include <ccv.h>
5
#include <nnc/ccv_nnc.h>
6
#include <nnc/ccv_nnc_easy.h>
7
#include "3rdparty/dsfmt/dSFMT.h"
8
9
TEST_SETUP()
10
{
11
  ccv_nnc_init();
12
}
13
14
TEST_CASE("implement scaled dot product attention with fine-grained symbolic graph")
15
1
{
16
1
  ccv_nnc_symbolic_graph_t* const symbolic_graph = ccv_nnc_symbolic_graph_new();
17
1
  ccv_nnc_tensor_symbol_t q = ccv_nnc_tensor_symbol_new(symbolic_graph, CPU_TENSOR_NHWC(32F, 32, 128, 8, 64), "q");
18
1
  ccv_nnc_tensor_symbol_t tq = ccv_nnc_tensor_symbol_new(symbolic_graph, CPU_TENSOR_NHWC(32F, 32, 8, 128, 64), "q");
19
1
  ccv_nnc_tensor_symbol_t k = ccv_nnc_tensor_symbol_new(symbolic_graph, CPU_TENSOR_NHWC(32F, 32, 128, 8, 64), "k");
20
1
  ccv_nnc_tensor_symbol_t tk = ccv_nnc_tensor_symbol_new(symbolic_graph, CPU_TENSOR_NHWC(32F, 32, 8, 128, 64), "k");
21
1
  ccv_nnc_tensor_symbol_t v = ccv_nnc_tensor_symbol_new(symbolic_graph, CPU_TENSOR_NHWC(32F, 32, 128, 8, 96), "v");
22
1
  ccv_nnc_tensor_symbol_t tv = ccv_nnc_tensor_symbol_new(symbolic_graph, CPU_TENSOR_NHWC(32F, 32, 8, 128, 96), "v");
23
1
  ccv_nnc_tensor_symbol_t qk = ccv_nnc_tensor_symbol_new(symbolic_graph, CPU_TENSOR_NHWC(32F, 32, 8, 128, 128), "qk");
24
1
  ccv_nnc_tensor_symbol_t sq = ccv_nnc_tensor_symbol_new(symbolic_graph, CPU_TENSOR_NHWC(32F, 32, 8, 128, 64), "sq");
25
1
  ccv_nnc_graph_exec_symbol_new(symbolic_graph, CMD_TRANSPOSE_FORWARD(1, 2), TENSOR_SYMBOL_LIST(q), TENSOR_SYMBOL_LIST(tq), "transpose q");
26
1
  ccv_nnc_graph_exec_symbol_new(symbolic_graph, CMD_TRANSPOSE_FORWARD(1, 2), TENSOR_SYMBOL_LIST(k), TENSOR_SYMBOL_LIST(tk), "transpose k");
27
1
  ccv_nnc_graph_exec_symbol_new(symbolic_graph, CMD_TRANSPOSE_FORWARD(1, 2), TENSOR_SYMBOL_LIST(v), TENSOR_SYMBOL_LIST(tv), "transpose v");
28
1
  ccv_nnc_graph_exec_symbol_new(symbolic_graph, CMD_SCALAR_MUL_FORWARD(1.0 / 8), TENSOR_SYMBOL_LIST(tq), TENSOR_SYMBOL_LIST(sq), "scaled_q");
29
1
  ccv_nnc_graph_exec_symbol_new(symbolic_graph, CMD_GEMM_FORWARD(NO_TRANSPOSE, TRANSPOSE(2, 3)), TENSOR_SYMBOL_LIST(sq, tk), TENSOR_SYMBOL_LIST(qk), "q @ k");
30
1
  ccv_nnc_tensor_symbol_t qks = ccv_nnc_tensor_symbol_alias_new(symbolic_graph, qk, DIM_ALLOC(), DIM_ALLOC(128, 1), CPU_TENSOR_NHWC(32F, 32 * 8 * 128, 128), "qks");
31
1
  ccv_nnc_tensor_symbol_t s = ccv_nnc_tensor_symbol_new(symbolic_graph, CPU_TENSOR_NHWC(32F, 32 * 8 * 128, 128), "s");
32
1
  ccv_nnc_graph_exec_symbol_new(symbolic_graph, CMD_SOFTMAX_FORWARD(), TENSOR_SYMBOL_LIST(qks), TENSOR_SYMBOL_LIST(s), "softmax");
33
1
  ccv_nnc_tensor_symbol_t sa = ccv_nnc_tensor_symbol_alias_new(symbolic_graph, s, DIM_ALLOC(), DIM_ALLOC(8 * 128 * 128, 128 * 128, 128, 1), CPU_TENSOR_NHWC(32F, 32, 8, 128, 128), "sa");
34
1
  ccv_nnc_tensor_symbol_t r = ccv_nnc_tensor_symbol_new(symbolic_graph, CPU_TENSOR_NHWC(32F, 32, 8, 128, 96), "f");
35
1
  ccv_nnc_tensor_symbol_t tr = ccv_nnc_tensor_symbol_new(symbolic_graph, CPU_TENSOR_NHWC(32F, 32, 128, 8, 96), "f");
36
1
  ccv_nnc_graph_exec_symbol_new(symbolic_graph, CMD_GEMM_FORWARD(NO_TRANSPOSE, NO_TRANSPOSE), TENSOR_SYMBOL_LIST(sa, tv), TENSOR_SYMBOL_LIST(r), "final");
37
1
  ccv_nnc_graph_exec_symbol_new(symbolic_graph, CMD_TRANSPOSE_FORWARD(1, 2), TENSOR_SYMBOL_LIST(r), TENSOR_SYMBOL_LIST(tr), "final");
38
1
  ccv_nnc_graph_exec_symbol_autogen(symbolic_graph, 0, 0, CCV_NNC_AUTOGEN_ALL_EXECS | CCV_NNC_AUTOGEN_SOURCES_AND_DESTINATIONS);
39
1
  SYMBOLIC_GRAPH_GEN(symbolic_graph, CCV_NNC_LONG_DOT_GRAPH);
40
1
  ccv_nnc_graph_t* graph = 0;
41
1
  ccv_nnc_tensor_arena_t* tensor_arena = 0;
42
1
  ccv_nnc_graph_exec_arena_t* graph_exec_arena = 0;
43
1
  ccv_nnc_symbolic_graph_compile(symbolic_graph, ccv_nnc_default_compile_params, 0, 0, 0, 0, SYMBOLIC_GRAPH_SOURCES(symbolic_graph), SYMBOLIC_GRAPH_DESTINATIONS(symbolic_graph), &graph, &tensor_arena, &graph_exec_arena);
44
1
  GRAPH_GEN(graph, CCV_NNC_LONG_DOT_GRAPH);
45
1
  ccv_nnc_tensor_t* const q_tensor = ccv_nnc_tensor_from_symbol(tensor_arena, q);
46
1
  ccv_nnc_tensor_t* const k_tensor = ccv_nnc_tensor_from_symbol(tensor_arena, k);
47
1
  ccv_nnc_tensor_t* const v_tensor = ccv_nnc_tensor_from_symbol(tensor_arena, v);
48
1
  dsfmt_t dsfmt;
49
1
  int i;
50
1
  dsfmt_init_gen_rand(&dsfmt, 1);
51
2.09M
  for (i = 0; i < 32 * 8 * 128 * 64; 
i++2.09M
)
52
2.09M
    q_tensor->data.f32[i] = dsfmt_genrand_open_close(&dsfmt);
53
2.09M
  for (i = 0; i < 32 * 8 * 128 * 64; 
i++2.09M
)
54
2.09M
    k_tensor->data.f32[i] = dsfmt_genrand_open_close(&dsfmt);
55
3.14M
  for (i = 0; i < 32 * 8 * 128 * 96; 
i++3.14M
)
56
3.14M
    v_tensor->data.f32[i] = dsfmt_genrand_open_close(&dsfmt);
57
1
  ccv_nnc_symbolic_graph_t* const sdp_symbolic_graph = ccv_nnc_symbolic_graph_new();
58
1
  ccv_nnc_tensor_symbol_t bq = ccv_nnc_tensor_symbol_new(sdp_symbolic_graph, CPU_TENSOR_NHWC(32F, 32, 128, 8, 64), "q");
59
1
  ccv_nnc_tensor_symbol_t bk = ccv_nnc_tensor_symbol_new(sdp_symbolic_graph, CPU_TENSOR_NHWC(32F, 32, 128, 8, 64), "k");
60
1
  ccv_nnc_tensor_symbol_t bv = ccv_nnc_tensor_symbol_new(sdp_symbolic_graph, CPU_TENSOR_NHWC(32F, 32, 128, 8, 96), "v");
61
1
  ccv_nnc_tensor_symbol_t br = ccv_nnc_tensor_symbol_new(sdp_symbolic_graph, CPU_TENSOR_NHWC(32F, 32, 128, 8, 96), "r");
62
1
  ccv_nnc_graph_exec_symbol_new(sdp_symbolic_graph, CMD_SCALED_DOT_PRODUCT_ATTENTION_FORWARD(1.0 / 8, 0), TENSOR_SYMBOL_LIST(bq, bk, bv), TENSOR_SYMBOL_LIST(br), "scaled_dot_product_attention");
63
1
  ccv_nnc_graph_exec_symbol_autogen(sdp_symbolic_graph, 0, 0, CCV_NNC_AUTOGEN_ALL_EXECS | CCV_NNC_AUTOGEN_SOURCES_AND_DESTINATIONS);
64
1
  ccv_nnc_graph_t* sdp_graph = 0;
65
1
  ccv_nnc_tensor_arena_t* sdp_tensor_arena = 0;
66
1
  ccv_nnc_graph_exec_arena_t* sdp_graph_exec_arena = 0;
67
1
  ccv_nnc_symbolic_graph_compile(sdp_symbolic_graph, ccv_nnc_default_compile_params, 0, 0, 0, 0, SYMBOLIC_GRAPH_SOURCES(sdp_symbolic_graph), SYMBOLIC_GRAPH_DESTINATIONS(sdp_symbolic_graph), &sdp_graph, &sdp_tensor_arena, &sdp_graph_exec_arena);
68
1
  ccv_nnc_tensor_t* const bq_tensor = ccv_nnc_tensor_from_symbol(sdp_tensor_arena, bq);
69
1
  ccv_nnc_tensor_t* const bk_tensor = ccv_nnc_tensor_from_symbol(sdp_tensor_arena, bk);
70
1
  ccv_nnc_tensor_t* const bv_tensor = ccv_nnc_tensor_from_symbol(sdp_tensor_arena, bv);
71
1
  memcpy(bq_tensor->data.f32, q_tensor->data.f32, sizeof(float) * 32 * 8 * 128 * 64);
72
1
  memcpy(bk_tensor->data.f32, k_tensor->data.f32, sizeof(float) * 32 * 8 * 128 * 64);
73
1
  memcpy(bv_tensor->data.f32, v_tensor->data.f32, sizeof(float) * 32 * 8 * 128 * 96);
74
1
  ccv_nnc_graph_run(graph, 0, TRAVERSE_FULL, 0, 0);
75
1
  ccv_nnc_graph_run(sdp_graph, 0, TRAVERSE_FULL, 0, 0);
76
1
  ccv_nnc_tensor_t* const r_tensor = ccv_nnc_tensor_from_symbol(tensor_arena, tr);
77
1
  ccv_nnc_tensor_t* const br_tensor = ccv_nnc_tensor_from_symbol(sdp_tensor_arena, br);
78
1
  REQUIRE_TENSOR_EQ(r_tensor, br_tensor, "graph computed result should match scaled dot product attention op result");
79
1
  ccv_nnc_symbolic_graph_free(symbolic_graph);
80
1
  ccv_nnc_tensor_arena_free(tensor_arena);
81
1
  ccv_nnc_graph_exec_arena_free(graph_exec_arena);
82
1
  ccv_nnc_graph_free(graph);
83
1
  ccv_nnc_symbolic_graph_free(sdp_symbolic_graph);
84
1
  ccv_nnc_tensor_arena_free(sdp_tensor_arena);
85
1
  ccv_nnc_graph_exec_arena_free(sdp_graph_exec_arena);
86
1
  ccv_nnc_graph_free(sdp_graph);
87
1
}
88
89
TEST_CASE("implement scaled dot product attention + unify head output with fine-grained symbolic graph")
90
1
{
91
1
  ccv_nnc_symbolic_graph_t* const symbolic_graph = ccv_nnc_symbolic_graph_new();
92
1
  ccv_nnc_tensor_symbol_t q = ccv_nnc_tensor_symbol_new(symbolic_graph, CPU_TENSOR_NHWC(32F, 32, 128, 8, 64), "q");
93
1
  ccv_nnc_tensor_symbol_t k = ccv_nnc_tensor_symbol_new(symbolic_graph, CPU_TENSOR_NHWC(32F, 32, 128, 8, 64), "k");
94
1
  ccv_nnc_tensor_symbol_t v = ccv_nnc_tensor_symbol_new(symbolic_graph, CPU_TENSOR_NHWC(32F, 32, 128, 8, 96), "v");
95
1
  ccv_nnc_tensor_symbol_t tq = ccv_nnc_tensor_symbol_new(symbolic_graph, CPU_TENSOR_NHWC(32F, 32, 8, 128, 64), "q");
96
1
  ccv_nnc_tensor_symbol_t tk = ccv_nnc_tensor_symbol_new(symbolic_graph, CPU_TENSOR_NHWC(32F, 32, 8, 128, 64), "k");
97
1
  ccv_nnc_tensor_symbol_t tv = ccv_nnc_tensor_symbol_new(symbolic_graph, CPU_TENSOR_NHWC(32F, 32, 8, 128, 96), "v");
98
1
  ccv_nnc_tensor_symbol_t qk = ccv_nnc_tensor_symbol_new(symbolic_graph, CPU_TENSOR_NHWC(32F, 32, 8, 128, 128), "qk");
99
1
  ccv_nnc_tensor_symbol_t sq = ccv_nnc_tensor_symbol_new(symbolic_graph, CPU_TENSOR_NHWC(32F, 32, 8, 128, 64), "sq");
100
1
  ccv_nnc_graph_exec_symbol_new(symbolic_graph, CMD_TRANSPOSE_FORWARD(1, 2), TENSOR_SYMBOL_LIST(q), TENSOR_SYMBOL_LIST(tq), "transpose q");
101
1
  ccv_nnc_graph_exec_symbol_new(symbolic_graph, CMD_TRANSPOSE_FORWARD(1, 2), TENSOR_SYMBOL_LIST(k), TENSOR_SYMBOL_LIST(tk), "transpose k");
102
1
  ccv_nnc_graph_exec_symbol_new(symbolic_graph, CMD_TRANSPOSE_FORWARD(1, 2), TENSOR_SYMBOL_LIST(v), TENSOR_SYMBOL_LIST(tv), "transpose v");
103
1
  ccv_nnc_graph_exec_symbol_new(symbolic_graph, CMD_SCALAR_MUL_FORWARD(1.0 / 8), TENSOR_SYMBOL_LIST(tq), TENSOR_SYMBOL_LIST(sq), "scaled_q");
104
1
  ccv_nnc_graph_exec_symbol_new(symbolic_graph, CMD_GEMM_FORWARD(NO_TRANSPOSE, TRANSPOSE(2, 3)), TENSOR_SYMBOL_LIST(sq, tk), TENSOR_SYMBOL_LIST(qk), "q @ k");
105
1
  ccv_nnc_tensor_symbol_t qks = ccv_nnc_tensor_symbol_alias_new(symbolic_graph, qk, DIM_ALLOC(), DIM_ALLOC(128, 1), CPU_TENSOR_NHWC(32F, 32 * 8 * 128, 128), "qks");
106
1
  ccv_nnc_tensor_symbol_t s = ccv_nnc_tensor_symbol_new(symbolic_graph, CPU_TENSOR_NHWC(32F, 32 * 8 * 128, 128), "s");
107
1
  ccv_nnc_graph_exec_symbol_new(symbolic_graph, CMD_SOFTMAX_FORWARD(), TENSOR_SYMBOL_LIST(qks), TENSOR_SYMBOL_LIST(s), "softmax");
108
1
  ccv_nnc_tensor_symbol_t sa = ccv_nnc_tensor_symbol_alias_new(symbolic_graph, s, DIM_ALLOC(), DIM_ALLOC(8 * 128 * 128, 128 * 128, 128, 1), CPU_TENSOR_NHWC(32F, 32, 8, 128, 128), "sa");
109
1
  ccv_nnc_tensor_symbol_t c = ccv_nnc_tensor_symbol_new(symbolic_graph, CPU_TENSOR_NHWC(32F, 32, 8, 128, 96), "c");
110
1
  ccv_nnc_graph_exec_symbol_new(symbolic_graph, CMD_GEMM_FORWARD(NO_TRANSPOSE, NO_TRANSPOSE), TENSOR_SYMBOL_LIST(sa, tv), TENSOR_SYMBOL_LIST(c), "c");
111
1
  ccv_nnc_tensor_symbol_t ct = ccv_nnc_tensor_symbol_new(symbolic_graph, CPU_TENSOR_NHWC(32F, 32, 128, 8, 96), "ct");
112
1
  ccv_nnc_graph_exec_symbol_new(symbolic_graph, CMD_TRANSPOSE_FORWARD(1, 2), TENSOR_SYMBOL_LIST(c), TENSOR_SYMBOL_LIST(ct), "ct");
113
1
  ccv_nnc_tensor_symbol_t cta = ccv_nnc_tensor_symbol_alias_new(symbolic_graph, ct, DIM_ALLOC(), DIM_ALLOC(128 * 768, 768, 1), CPU_TENSOR_NHWC(32F, 32, 128, 768), "ct");
114
1
  ccv_nnc_tensor_symbol_t w = ccv_nnc_tensor_symbol_new(symbolic_graph, CPU_TENSOR_NHWC(32F, 768, 768), "w");
115
1
  ccv_nnc_tensor_symbol_t bias = ccv_nnc_tensor_symbol_new(symbolic_graph, CPU_TENSOR_NHWC(32F, 768), "bias");
116
1
  ccv_nnc_tensor_symbol_t r = ccv_nnc_tensor_symbol_new(symbolic_graph, CPU_TENSOR_NHWC(32F, 32, 128, 8 * 96), "r");
117
1
  ccv_nnc_graph_exec_symbol_new(symbolic_graph, CMD_GEMM_FORWARD(NO_TRANSPOSE, TRANSPOSE(0, 1)), TENSOR_SYMBOL_LIST(cta, w, bias), TENSOR_SYMBOL_LIST(r), "final");
118
1
  ccv_nnc_graph_exec_symbol_autogen(symbolic_graph, 0, 0, CCV_NNC_AUTOGEN_ALL_EXECS | CCV_NNC_AUTOGEN_SOURCES_AND_DESTINATIONS);
119
1
  SYMBOLIC_GRAPH_GEN(symbolic_graph, CCV_NNC_LONG_DOT_GRAPH);
120
1
  ccv_nnc_graph_t* graph = 0;
121
1
  ccv_nnc_tensor_arena_t* tensor_arena = 0;
122
1
  ccv_nnc_graph_exec_arena_t* graph_exec_arena = 0;
123
1
  ccv_nnc_symbolic_graph_compile(symbolic_graph, ccv_nnc_default_compile_params, 0, 0, 0, 0, SYMBOLIC_GRAPH_SOURCES(symbolic_graph), SYMBOLIC_GRAPH_DESTINATIONS(symbolic_graph), &graph, &tensor_arena, &graph_exec_arena);
124
1
  GRAPH_GEN(graph, CCV_NNC_LONG_DOT_GRAPH);
125
1
  ccv_nnc_tensor_t* const q_tensor = ccv_nnc_tensor_from_symbol(tensor_arena, q);
126
1
  ccv_nnc_tensor_t* const k_tensor = ccv_nnc_tensor_from_symbol(tensor_arena, k);
127
1
  ccv_nnc_tensor_t* const v_tensor = ccv_nnc_tensor_from_symbol(tensor_arena, v);
128
1
  ccv_nnc_tensor_t* const w_tensor = ccv_nnc_tensor_from_symbol(tensor_arena, w);
129
1
  ccv_nnc_tensor_t* const bias_tensor = ccv_nnc_tensor_from_symbol(tensor_arena, bias);
130
1
  dsfmt_t dsfmt;
131
1
  int i;
132
1
  dsfmt_init_gen_rand(&dsfmt, 1);
133
2.09M
  for (i = 0; i < 32 * 8 * 128 * 64; 
i++2.09M
)
134
2.09M
    q_tensor->data.f32[i] = dsfmt_genrand_open_close(&dsfmt);
135
2.09M
  for (i = 0; i < 32 * 8 * 128 * 64; 
i++2.09M
)
136
2.09M
    k_tensor->data.f32[i] = dsfmt_genrand_open_close(&dsfmt);
137
3.14M
  for (i = 0; i < 32 * 8 * 128 * 96; 
i++3.14M
)
138
3.14M
    v_tensor->data.f32[i] = dsfmt_genrand_open_close(&dsfmt);
139
589k
  for (i = 0; i < 768 * 768; 
i++589k
)
140
589k
    w_tensor->data.f32[i] = dsfmt_genrand_open_close(&dsfmt);
141
769
  for (i = 0; i < 768; 
i++768
)
142
768
    bias_tensor->data.f32[i] = dsfmt_genrand_open_close(&dsfmt);
143
1
  ccv_nnc_symbolic_graph_t* const sdp_symbolic_graph = ccv_nnc_symbolic_graph_new();
144
1
  ccv_nnc_tensor_symbol_t bq = ccv_nnc_tensor_symbol_new(sdp_symbolic_graph, CPU_TENSOR_NHWC(32F, 32, 128, 8, 64), "q");
145
1
  ccv_nnc_tensor_symbol_t bk = ccv_nnc_tensor_symbol_new(sdp_symbolic_graph, CPU_TENSOR_NHWC(32F, 32, 128, 8, 64), "k");
146
1
  ccv_nnc_tensor_symbol_t bv = ccv_nnc_tensor_symbol_new(sdp_symbolic_graph, CPU_TENSOR_NHWC(32F, 32, 128, 8, 96), "v");
147
1
  ccv_nnc_tensor_symbol_t bw = ccv_nnc_tensor_symbol_new(sdp_symbolic_graph, CPU_TENSOR_NHWC(32F, 768, 768), "w");
148
1
  ccv_nnc_tensor_symbol_t bbias = ccv_nnc_tensor_symbol_new(sdp_symbolic_graph, CPU_TENSOR_NHWC(32F, 768), "bias");
149
1
  ccv_nnc_tensor_symbol_t bc = ccv_nnc_tensor_symbol_new(sdp_symbolic_graph, CPU_TENSOR_NHWC(32F, 32, 128, 8, 96), "c");
150
1
  ccv_nnc_tensor_symbol_t br = ccv_nnc_tensor_symbol_new(sdp_symbolic_graph, CPU_TENSOR_NHWC(32F, 32, 128, 768), "r");
151
1
  ccv_nnc_graph_exec_symbol_new(sdp_symbolic_graph, CMD_SCALED_DOT_PRODUCT_ATTENTION_FORWARD(1.0 / 8, 0), TENSOR_SYMBOL_LIST(bq, bk, bv, NO_TENSOR_SYMBOL, bw, bbias), TENSOR_SYMBOL_LIST(br, NO_TENSOR_SYMBOL, bc), "scaled_dot_product_attention");
152
1
  ccv_nnc_graph_exec_symbol_autogen(sdp_symbolic_graph, 0, 0, CCV_NNC_AUTOGEN_ALL_EXECS | CCV_NNC_AUTOGEN_SOURCES_AND_DESTINATIONS);
153
1
  ccv_nnc_graph_t* sdp_graph = 0;
154
1
  ccv_nnc_tensor_arena_t* sdp_tensor_arena = 0;
155
1
  ccv_nnc_graph_exec_arena_t* sdp_graph_exec_arena = 0;
156
1
  ccv_nnc_symbolic_graph_compile(sdp_symbolic_graph, ccv_nnc_default_compile_params, 0, 0, 0, 0, SYMBOLIC_GRAPH_SOURCES(sdp_symbolic_graph), SYMBOLIC_GRAPH_DESTINATIONS(sdp_symbolic_graph), &sdp_graph, &sdp_tensor_arena, &sdp_graph_exec_arena);
157
1
  ccv_nnc_tensor_t* const bq_tensor = ccv_nnc_tensor_from_symbol(sdp_tensor_arena, bq);
158
1
  ccv_nnc_tensor_t* const bk_tensor = ccv_nnc_tensor_from_symbol(sdp_tensor_arena, bk);
159
1
  ccv_nnc_tensor_t* const bv_tensor = ccv_nnc_tensor_from_symbol(sdp_tensor_arena, bv);
160
1
  ccv_nnc_tensor_t* const bw_tensor = ccv_nnc_tensor_from_symbol(sdp_tensor_arena, bw);
161
1
  ccv_nnc_tensor_t* const bbias_tensor = ccv_nnc_tensor_from_symbol(sdp_tensor_arena, bbias);
162
1
  memcpy(bq_tensor->data.f32, q_tensor->data.f32, sizeof(float) * 32 * 8 * 128 * 64);
163
1
  memcpy(bk_tensor->data.f32, k_tensor->data.f32, sizeof(float) * 32 * 8 * 128 * 64);
164
1
  memcpy(bv_tensor->data.f32, v_tensor->data.f32, sizeof(float) * 32 * 8 * 128 * 96);
165
1
  memcpy(bw_tensor->data.f32, w_tensor->data.f32, sizeof(float) * 768 * 768);
166
1
  memcpy(bbias_tensor->data.f32, bias_tensor->data.f32, sizeof(float) * 768);
167
1
  ccv_nnc_graph_run(graph, 0, TRAVERSE_FULL, 0, 0);
168
1
  ccv_nnc_graph_run(sdp_graph, 0, TRAVERSE_FULL, 0, 0);
169
1
  ccv_nnc_tensor_t* const r_tensor = ccv_nnc_tensor_from_symbol(tensor_arena, r);
170
1
  ccv_nnc_tensor_t* const br_tensor = ccv_nnc_tensor_from_symbol(sdp_tensor_arena, br);
171
1
  REQUIRE_TENSOR_EQ(r_tensor, br_tensor, "graph computed result should match scaled dot product attention op result");
172
1
  ccv_nnc_symbolic_graph_free(symbolic_graph);
173
1
  ccv_nnc_tensor_arena_free(tensor_arena);
174
1
  ccv_nnc_graph_exec_arena_free(graph_exec_arena);
175
1
  ccv_nnc_graph_free(graph);
176
1
  ccv_nnc_symbolic_graph_free(sdp_symbolic_graph);
177
1
  ccv_nnc_tensor_arena_free(sdp_tensor_arena);
178
1
  ccv_nnc_graph_exec_arena_free(sdp_graph_exec_arena);
179
1
  ccv_nnc_graph_free(sdp_graph);
180
1
}
181
182
TEST_CASE("run scaled dot product attention with cnnp model")
183
1
{
184
1
  ccv_nnc_symbolic_graph_t* const sdp_symbolic_graph = ccv_nnc_symbolic_graph_new();
185
1
  ccv_nnc_tensor_symbol_t bq = ccv_nnc_tensor_symbol_new(sdp_symbolic_graph, CPU_TENSOR_NHWC(32F, 32, 128, 8, 64), "q");
186
1
  ccv_nnc_tensor_symbol_t bk = ccv_nnc_tensor_symbol_new(sdp_symbolic_graph, CPU_TENSOR_NHWC(32F, 32, 128, 8, 64), "k");
187
1
  ccv_nnc_tensor_symbol_t bv = ccv_nnc_tensor_symbol_new(sdp_symbolic_graph, CPU_TENSOR_NHWC(32F, 32, 128, 8, 96), "v");
188
1
  ccv_nnc_tensor_symbol_t br = ccv_nnc_tensor_symbol_new(sdp_symbolic_graph, CPU_TENSOR_NHWC(32F, 32, 128, 8, 96), "r");
189
1
  ccv_nnc_graph_exec_symbol_new(sdp_symbolic_graph, CMD_SCALED_DOT_PRODUCT_ATTENTION_FORWARD(1.0 / 8, 0), TENSOR_SYMBOL_LIST(bq, bk, bv), TENSOR_SYMBOL_LIST(br), "scaled_dot_product_attention");
190
1
  ccv_nnc_graph_exec_symbol_autogen(sdp_symbolic_graph, 0, 0, CCV_NNC_AUTOGEN_ALL_EXECS | CCV_NNC_AUTOGEN_SOURCES_AND_DESTINATIONS);
191
1
  ccv_nnc_graph_t* sdp_graph = 0;
192
1
  ccv_nnc_tensor_arena_t* sdp_tensor_arena = 0;
193
1
  ccv_nnc_graph_exec_arena_t* sdp_graph_exec_arena = 0;
194
1
  ccv_nnc_symbolic_graph_compile(sdp_symbolic_graph, ccv_nnc_default_compile_params, 0, 0, 0, 0, SYMBOLIC_GRAPH_SOURCES(sdp_symbolic_graph), SYMBOLIC_GRAPH_DESTINATIONS(sdp_symbolic_graph), &sdp_graph, &sdp_tensor_arena, &sdp_graph_exec_arena);
195
1
  ccv_nnc_tensor_t* const bq_tensor = ccv_nnc_tensor_from_symbol(sdp_tensor_arena, bq);
196
1
  ccv_nnc_tensor_t* const bk_tensor = ccv_nnc_tensor_from_symbol(sdp_tensor_arena, bk);
197
1
  ccv_nnc_tensor_t* const bv_tensor = ccv_nnc_tensor_from_symbol(sdp_tensor_arena, bv);
198
1
  int i;
199
1
  dsfmt_t dsfmt;
200
1
  dsfmt_init_gen_rand(&dsfmt, 1);
201
2.09M
  for (i = 0; i < 32 * 8 * 128 * 64; 
i++2.09M
)
202
2.09M
    bq_tensor->data.f32[i] = dsfmt_genrand_open_close(&dsfmt);
203
2.09M
  for (i = 0; i < 32 * 8 * 128 * 64; 
i++2.09M
)
204
2.09M
    bk_tensor->data.f32[i] = dsfmt_genrand_open_close(&dsfmt);
205
3.14M
  for (i = 0; i < 32 * 8 * 128 * 96; 
i++3.14M
)
206
3.14M
    bv_tensor->data.f32[i] = dsfmt_genrand_open_close(&dsfmt);
207
1
  ccv_nnc_tensor_t* const q_tensor = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, 32, 128, 8, 64), 0);
208
1
  ccv_nnc_tensor_t* const k_tensor = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, 32, 128, 8, 64), 0);
209
1
  ccv_nnc_tensor_t* const v_tensor = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, 32, 128, 8, 96), 0);
210
1
  memcpy(q_tensor->data.f32, bq_tensor->data.f32, sizeof(float) * 32 * 8 * 128 * 64);
211
1
  memcpy(k_tensor->data.f32, bk_tensor->data.f32, sizeof(float) * 32 * 8 * 128 * 64);
212
1
  memcpy(v_tensor->data.f32, bv_tensor->data.f32, sizeof(float) * 32 * 8 * 128 * 96);
213
1
  ccv_nnc_graph_run(sdp_graph, 0, TRAVERSE_FULL, 0, 0);
214
1
  ccv_nnc_tensor_t* const br_tensor = ccv_nnc_tensor_from_symbol(sdp_tensor_arena, br);
215
1
  ccv_nnc_tensor_t* const r_tensor = ccv_nnc_tensor_new(0, br_tensor->info, 0);
216
1
  ccv_cnnp_model_t* scaled_dot_product_attention = ccv_cnnp_scaled_dot_product_attention(1.0 / 8, 0, 0, 0, 0, 0, 0, "scaled_dot_product_attention");
217
1
  ccv_nnc_tensor_param_t qkv[3];
218
1
  qkv[0] = q_tensor->info;
219
1
  qkv[1] = k_tensor->info;
220
1
  qkv[2] = v_tensor->info;
221
1
  ccv_cnnp_model_compile(scaled_dot_product_attention, qkv, 3, CMD_NOOP(), CMD_NOOP());
222
1
  ccv_cnnp_model_evaluate(scaled_dot_product_attention, (ccv_cnnp_evaluate_param_t){}, TENSOR_LIST(q_tensor, k_tensor, v_tensor), TENSOR_LIST(r_tensor), 0, 0);
223
1
  CNNP_MODEL_GEN(scaled_dot_product_attention, CCV_NNC_LONG_DOT_GRAPH);
224
1
  REQUIRE_TENSOR_EQ(r_tensor, br_tensor, "graph computed result should match scaled dot product attention op result");
225
1
  ccv_nnc_symbolic_graph_free(sdp_symbolic_graph);
226
1
  ccv_nnc_tensor_arena_free(sdp_tensor_arena);
227
1
  ccv_nnc_graph_exec_arena_free(sdp_graph_exec_arena);
228
1
  ccv_nnc_graph_free(sdp_graph);
229
1
  ccv_nnc_tensor_free(q_tensor);
230
1
  ccv_nnc_tensor_free(k_tensor);
231
1
  ccv_nnc_tensor_free(v_tensor);
232
1
  ccv_nnc_tensor_free(r_tensor);
233
1
  ccv_cnnp_model_free(scaled_dot_product_attention);
234
1
}
235
236
TEST_CASE("run scaled dot product attention + unify head output with cnnp model")
237
1
{
238
1
  ccv_nnc_symbolic_graph_t* const sdp_symbolic_graph = ccv_nnc_symbolic_graph_new();
239
1
  ccv_nnc_tensor_symbol_t bq = ccv_nnc_tensor_symbol_new(sdp_symbolic_graph, CPU_TENSOR_NHWC(32F, 32, 128, 8, 64), "q");
240
1
  ccv_nnc_tensor_symbol_t bk = ccv_nnc_tensor_symbol_new(sdp_symbolic_graph, CPU_TENSOR_NHWC(32F, 32, 128, 8, 64), "k");
241
1
  ccv_nnc_tensor_symbol_t bv = ccv_nnc_tensor_symbol_new(sdp_symbolic_graph, CPU_TENSOR_NHWC(32F, 32, 128, 8, 96), "v");
242
1
  ccv_nnc_tensor_symbol_t bw = ccv_nnc_tensor_symbol_new(sdp_symbolic_graph, CPU_TENSOR_NHWC(32F, 768, 768), "w");
243
1
  ccv_nnc_tensor_symbol_t bbias = ccv_nnc_tensor_symbol_new(sdp_symbolic_graph, CPU_TENSOR_NHWC(32F, 768), "bias");
244
1
  ccv_nnc_tensor_symbol_t bc = ccv_nnc_tensor_symbol_new(sdp_symbolic_graph, CPU_TENSOR_NHWC(32F, 32, 128, 8, 96), "c");
245
1
  ccv_nnc_tensor_symbol_t br = ccv_nnc_tensor_symbol_new(sdp_symbolic_graph, CPU_TENSOR_NHWC(32F, 32, 128, 768), "r");
246
1
  ccv_nnc_graph_exec_symbol_new(sdp_symbolic_graph, CMD_SCALED_DOT_PRODUCT_ATTENTION_FORWARD(1.0 / 8, 0), TENSOR_SYMBOL_LIST(bq, bk, bv, NO_TENSOR_SYMBOL, bw, bbias), TENSOR_SYMBOL_LIST(br, NO_TENSOR_SYMBOL, bc), "scaled_dot_product_attention");
247
1
  ccv_nnc_graph_exec_symbol_autogen(sdp_symbolic_graph, 0, 0, CCV_NNC_AUTOGEN_ALL_EXECS | CCV_NNC_AUTOGEN_SOURCES_AND_DESTINATIONS);
248
1
  ccv_nnc_graph_t* sdp_graph = 0;
249
1
  ccv_nnc_tensor_arena_t* sdp_tensor_arena = 0;
250
1
  ccv_nnc_graph_exec_arena_t* sdp_graph_exec_arena = 0;
251
1
  ccv_nnc_symbolic_graph_compile(sdp_symbolic_graph, ccv_nnc_default_compile_params, 0, 0, 0, 0, SYMBOLIC_GRAPH_SOURCES(sdp_symbolic_graph), SYMBOLIC_GRAPH_DESTINATIONS(sdp_symbolic_graph), &sdp_graph, &sdp_tensor_arena, &sdp_graph_exec_arena);
252
1
  ccv_nnc_tensor_t* const bq_tensor = ccv_nnc_tensor_from_symbol(sdp_tensor_arena, bq);
253
1
  ccv_nnc_tensor_t* const bk_tensor = ccv_nnc_tensor_from_symbol(sdp_tensor_arena, bk);
254
1
  ccv_nnc_tensor_t* const bv_tensor = ccv_nnc_tensor_from_symbol(sdp_tensor_arena, bv);
255
1
  ccv_nnc_tensor_t* const bw_tensor = ccv_nnc_tensor_from_symbol(sdp_tensor_arena, bw);
256
1
  ccv_nnc_tensor_t* const bbias_tensor = ccv_nnc_tensor_from_symbol(sdp_tensor_arena, bbias);
257
1
  int i;
258
1
  dsfmt_t dsfmt;
259
1
  dsfmt_init_gen_rand(&dsfmt, 1);
260
2.09M
  for (i = 0; i < 32 * 8 * 128 * 64; 
i++2.09M
)
261
2.09M
    bq_tensor->data.f32[i] = dsfmt_genrand_open_close(&dsfmt);
262
2.09M
  for (i = 0; i < 32 * 8 * 128 * 64; 
i++2.09M
)
263
2.09M
    bk_tensor->data.f32[i] = dsfmt_genrand_open_close(&dsfmt);
264
3.14M
  for (i = 0; i < 32 * 8 * 128 * 96; 
i++3.14M
)
265
3.14M
    bv_tensor->data.f32[i] = dsfmt_genrand_open_close(&dsfmt);
266
589k
  for (i = 0; i < 768 * 768; 
i++589k
)
267
589k
    bw_tensor->data.f32[i] = dsfmt_genrand_open_close(&dsfmt);
268
769
  for (i = 0; i < 768; 
i++768
)
269
768
    bbias_tensor->data.f32[i] = dsfmt_genrand_open_close(&dsfmt);
270
1
  ccv_nnc_tensor_t* const q_tensor = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, 32, 128, 8, 64), 0);
271
1
  ccv_nnc_tensor_t* const k_tensor = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, 32, 128, 8, 64), 0);
272
1
  ccv_nnc_tensor_t* const v_tensor = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, 32, 128, 8, 96), 0);
273
1
  memcpy(q_tensor->data.f32, bq_tensor->data.f32, sizeof(float) * 32 * 8 * 128 * 64);
274
1
  memcpy(k_tensor->data.f32, bk_tensor->data.f32, sizeof(float) * 32 * 8 * 128 * 64);
275
1
  memcpy(v_tensor->data.f32, bv_tensor->data.f32, sizeof(float) * 32 * 8 * 128 * 96);
276
1
  ccv_nnc_tensor_t* const br_tensor = ccv_nnc_tensor_from_symbol(sdp_tensor_arena, br);
277
1
  ccv_nnc_tensor_t* const r_tensor = ccv_nnc_tensor_new(0, br_tensor->info, 0);
278
1
  ccv_cnnp_model_t* scaled_dot_product_attention = ccv_cnnp_scaled_dot_product_attention(1.0 / 8, 0, 0, 0, 1, 0, 1, "scaled_dot_product_attention");
279
1
  ccv_nnc_tensor_param_t qkv[3];
280
1
  qkv[0] = q_tensor->info;
281
1
  qkv[1] = k_tensor->info;
282
1
  qkv[2] = v_tensor->info;
283
1
  ccv_cnnp_model_compile(scaled_dot_product_attention, qkv, 3, CMD_NOOP(), CMD_NOOP());
284
1
  ccv_cnnp_model_set_parameter(scaled_dot_product_attention, ccv_cnnp_model_parameters(scaled_dot_product_attention, CCV_CNNP_PARAMETER_SELECT_WEIGHT, 0), bw_tensor);
285
1
  ccv_cnnp_model_set_parameter(scaled_dot_product_attention, ccv_cnnp_model_parameters(scaled_dot_product_attention, CCV_CNNP_PARAMETER_SELECT_BIAS, 0), bbias_tensor);
286
1
  ccv_cnnp_model_evaluate(scaled_dot_product_attention, (ccv_cnnp_evaluate_param_t){}, TENSOR_LIST(q_tensor, k_tensor, v_tensor), TENSOR_LIST(r_tensor), 0, 0);
287
1
  CNNP_MODEL_GEN(scaled_dot_product_attention, CCV_NNC_LONG_DOT_GRAPH);
288
1
  ccv_nnc_graph_run(sdp_graph, 0, TRAVERSE_FULL, 0, 0);
289
1
  REQUIRE_TENSOR_EQ(r_tensor, br_tensor, "graph computed result should match scaled dot product attention op result");
290
1
  ccv_nnc_symbolic_graph_free(sdp_symbolic_graph);
291
1
  ccv_nnc_tensor_arena_free(sdp_tensor_arena);
292
1
  ccv_nnc_graph_exec_arena_free(sdp_graph_exec_arena);
293
1
  ccv_nnc_graph_free(sdp_graph);
294
1
  ccv_nnc_tensor_free(q_tensor);
295
1
  ccv_nnc_tensor_free(k_tensor);
296
1
  ccv_nnc_tensor_free(v_tensor);
297
1
  ccv_nnc_tensor_free(r_tensor);
298
1
  ccv_cnnp_model_free(scaled_dot_product_attention);
299
1
}
300
301
TEST_CASE("run scaled dot product attention + attention mask with cnnp model")
302
1
{
303
1
  ccv_nnc_symbolic_graph_t* const sdp_symbolic_graph = ccv_nnc_symbolic_graph_new();
304
1
  ccv_nnc_tensor_symbol_t bq = ccv_nnc_tensor_symbol_new(sdp_symbolic_graph, CPU_TENSOR_NHWC(32F, 32, 128, 8, 64), "q");
305
1
  ccv_nnc_tensor_symbol_t bk = ccv_nnc_tensor_symbol_new(sdp_symbolic_graph, CPU_TENSOR_NHWC(32F, 32, 128, 8, 64), "k");
306
1
  ccv_nnc_tensor_symbol_t bv = ccv_nnc_tensor_symbol_new(sdp_symbolic_graph, CPU_TENSOR_NHWC(32F, 32, 128, 8, 96), "v");
307
1
  ccv_nnc_tensor_symbol_t battn_mask = ccv_nnc_tensor_symbol_new(sdp_symbolic_graph, CPU_TENSOR_NHWC(32F, 1, 1, 128, 128), "attn_mask");
308
1
  ccv_nnc_tensor_symbol_t br = ccv_nnc_tensor_symbol_new(sdp_symbolic_graph, CPU_TENSOR_NHWC(32F, 32, 128, 8, 96), "r");
309
1
  ccv_nnc_graph_exec_symbol_new(sdp_symbolic_graph, CMD_SCALED_DOT_PRODUCT_ATTENTION_FORWARD(1.0 / 8, 0), TENSOR_SYMBOL_LIST(bq, bk, bv, battn_mask), TENSOR_SYMBOL_LIST(br), "scaled_dot_product_attention");
310
1
  ccv_nnc_graph_exec_symbol_autogen(sdp_symbolic_graph, 0, 0, CCV_NNC_AUTOGEN_ALL_EXECS | CCV_NNC_AUTOGEN_SOURCES_AND_DESTINATIONS);
311
1
  ccv_nnc_graph_t* sdp_graph = 0;
312
1
  ccv_nnc_tensor_arena_t* sdp_tensor_arena = 0;
313
1
  ccv_nnc_graph_exec_arena_t* sdp_graph_exec_arena = 0;
314
1
  ccv_nnc_symbolic_graph_compile(sdp_symbolic_graph, ccv_nnc_default_compile_params, 0, 0, 0, 0, SYMBOLIC_GRAPH_SOURCES(sdp_symbolic_graph), SYMBOLIC_GRAPH_DESTINATIONS(sdp_symbolic_graph), &sdp_graph, &sdp_tensor_arena, &sdp_graph_exec_arena);
315
1
  ccv_nnc_tensor_t* const bq_tensor = ccv_nnc_tensor_from_symbol(sdp_tensor_arena, bq);
316
1
  ccv_nnc_tensor_t* const bk_tensor = ccv_nnc_tensor_from_symbol(sdp_tensor_arena, bk);
317
1
  ccv_nnc_tensor_t* const bv_tensor = ccv_nnc_tensor_from_symbol(sdp_tensor_arena, bv);
318
1
  ccv_nnc_tensor_t* const battn_mask_tensor = ccv_nnc_tensor_from_symbol(sdp_tensor_arena, battn_mask);
319
1
  int i, j;
320
1
  dsfmt_t dsfmt;
321
1
  dsfmt_init_gen_rand(&dsfmt, 1);
322
2.09M
  for (i = 0; i < 32 * 8 * 128 * 64; 
i++2.09M
)
323
2.09M
    bq_tensor->data.f32[i] = dsfmt_genrand_open_close(&dsfmt);
324
2.09M
  for (i = 0; i < 32 * 8 * 128 * 64; 
i++2.09M
)
325
2.09M
    bk_tensor->data.f32[i] = dsfmt_genrand_open_close(&dsfmt);
326
3.14M
  for (i = 0; i < 32 * 8 * 128 * 96; 
i++3.14M
)
327
3.14M
    bv_tensor->data.f32[i] = dsfmt_genrand_open_close(&dsfmt);
328
16.3k
  for (i = 0; i < 128 * 128; 
i++16.3k
)
329
16.3k
    battn_mask_tensor->data.f32[i] = 0;
330
128
  for (i = 0; i < 127; 
i++127
)
331
8.25k
    
for (j = i + 1; 127
j < 128;
j++8.12k
)
332
8.12k
      battn_mask_tensor->data.f32[i * 128 + j] = -FLT_MAX;
333
1
  ccv_nnc_tensor_t* const q_tensor = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, 32, 128, 8, 64), 0);
334
1
  ccv_nnc_tensor_t* const k_tensor = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, 32, 128, 8, 64), 0);
335
1
  ccv_nnc_tensor_t* const v_tensor = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, 32, 128, 8, 96), 0);
336
1
  ccv_nnc_tensor_t* const attn_mask_tensor = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, 1, 1, 128, 128), 0);
337
1
  memcpy(q_tensor->data.f32, bq_tensor->data.f32, sizeof(float) * 32 * 8 * 128 * 64);
338
1
  memcpy(k_tensor->data.f32, bk_tensor->data.f32, sizeof(float) * 32 * 8 * 128 * 64);
339
1
  memcpy(v_tensor->data.f32, bv_tensor->data.f32, sizeof(float) * 32 * 8 * 128 * 96);
340
1
  memcpy(attn_mask_tensor->data.f32, battn_mask_tensor->data.f32, sizeof(float) * 128 * 128);
341
1
  ccv_nnc_graph_run(sdp_graph, 0, TRAVERSE_FULL, 0, 0);
342
1
  ccv_nnc_tensor_t* const br_tensor = ccv_nnc_tensor_from_symbol(sdp_tensor_arena, br);
343
1
  ccv_nnc_tensor_t* const r_tensor = ccv_nnc_tensor_new(0, br_tensor->info, 0);
344
1
  ccv_cnnp_model_t* scaled_dot_product_attention = ccv_cnnp_scaled_dot_product_attention(1.0 / 8, 0, 1, 0, 0, 0, 0, "scaled_dot_product_attention");
345
1
  ccv_nnc_tensor_param_t qkv[4];
346
1
  qkv[0] = q_tensor->info;
347
1
  qkv[1] = k_tensor->info;
348
1
  qkv[2] = v_tensor->info;
349
1
  qkv[3] = attn_mask_tensor->info;
350
1
  ccv_cnnp_model_compile(scaled_dot_product_attention, qkv, 4, CMD_NOOP(), CMD_NOOP());
351
1
  ccv_cnnp_model_evaluate(scaled_dot_product_attention, (ccv_cnnp_evaluate_param_t){}, TENSOR_LIST(q_tensor, k_tensor, v_tensor, attn_mask_tensor), TENSOR_LIST(r_tensor), 0, 0);
352
1
  CNNP_MODEL_GEN(scaled_dot_product_attention, CCV_NNC_LONG_DOT_GRAPH);
353
1
  REQUIRE_TENSOR_EQ(r_tensor, br_tensor, "graph computed result should match scaled dot product attention op result");
354
1
  ccv_nnc_symbolic_graph_free(sdp_symbolic_graph);
355
1
  ccv_nnc_tensor_arena_free(sdp_tensor_arena);
356
1
  ccv_nnc_graph_exec_arena_free(sdp_graph_exec_arena);
357
1
  ccv_nnc_graph_free(sdp_graph);
358
1
  ccv_nnc_tensor_free(q_tensor);
359
1
  ccv_nnc_tensor_free(k_tensor);
360
1
  ccv_nnc_tensor_free(v_tensor);
361
1
  ccv_nnc_tensor_free(attn_mask_tensor);
362
1
  ccv_nnc_tensor_free(r_tensor);
363
1
  ccv_cnnp_model_free(scaled_dot_product_attention);
364
1
}
365
366
TEST_CASE("implement gradient of scaled dot product attention with fine-grained symbolic graph")
367
1
{
368
1
  ccv_nnc_symbolic_graph_t* const symbolic_graph = ccv_nnc_symbolic_graph_new();
369
1
  ccv_nnc_tensor_symbol_t q = ccv_nnc_tensor_symbol_new(symbolic_graph, CPU_TENSOR_NHWC(32F, 32, 128, 8, 64), "q");
370
1
  ccv_nnc_tensor_symbol_t tq = ccv_nnc_tensor_symbol_new(symbolic_graph, CPU_TENSOR_NHWC(32F, 32, 8, 128, 64), "q");
371
1
  ccv_nnc_tensor_symbol_t k = ccv_nnc_tensor_symbol_new(symbolic_graph, CPU_TENSOR_NHWC(32F, 32, 128, 8, 64), "k");
372
1
  ccv_nnc_tensor_symbol_t tk = ccv_nnc_tensor_symbol_new(symbolic_graph, CPU_TENSOR_NHWC(32F, 32, 8, 128, 64), "k");
373
1
  ccv_nnc_tensor_symbol_t v = ccv_nnc_tensor_symbol_new(symbolic_graph, CPU_TENSOR_NHWC(32F, 32, 128, 8, 96), "v");
374
1
  ccv_nnc_tensor_symbol_t tv = ccv_nnc_tensor_symbol_new(symbolic_graph, CPU_TENSOR_NHWC(32F, 32, 8, 128, 96), "v");
375
1
  ccv_nnc_tensor_symbol_t qk = ccv_nnc_tensor_symbol_new(symbolic_graph, CPU_TENSOR_NHWC(32F, 32, 8, 128, 128), "qk");
376
1
  ccv_nnc_tensor_symbol_t sq = ccv_nnc_tensor_symbol_new(symbolic_graph, CPU_TENSOR_NHWC(32F, 32, 8, 128, 64), "sq");
377
1
  ccv_nnc_graph_exec_symbol_new(symbolic_graph, CMD_TRANSPOSE_FORWARD(1, 2), TENSOR_SYMBOL_LIST(q), TENSOR_SYMBOL_LIST(tq), "transpose q");
378
1
  ccv_nnc_graph_exec_symbol_new(symbolic_graph, CMD_TRANSPOSE_FORWARD(1, 2), TENSOR_SYMBOL_LIST(k), TENSOR_SYMBOL_LIST(tk), "transpose k");
379
1
  ccv_nnc_graph_exec_symbol_new(symbolic_graph, CMD_TRANSPOSE_FORWARD(1, 2), TENSOR_SYMBOL_LIST(v), TENSOR_SYMBOL_LIST(tv), "transpose v");
380
1
  ccv_nnc_graph_exec_symbol_new(symbolic_graph, CMD_SCALAR_MUL_FORWARD(1.0 / 8), TENSOR_SYMBOL_LIST(tq), TENSOR_SYMBOL_LIST(sq), "scaled_q");
381
1
  ccv_nnc_graph_exec_symbol_new(symbolic_graph, CMD_GEMM_FORWARD(NO_TRANSPOSE, TRANSPOSE(2, 3)), TENSOR_SYMBOL_LIST(sq, tk), TENSOR_SYMBOL_LIST(qk), "q @ k");
382
1
  ccv_nnc_tensor_symbol_t qks = ccv_nnc_tensor_symbol_alias_new(symbolic_graph, qk, DIM_ALLOC(), DIM_ALLOC(128, 1), CPU_TENSOR_NHWC(32F, 32 * 8 * 128, 128), "qks");
383
1
  ccv_nnc_tensor_symbol_t s = ccv_nnc_tensor_symbol_new(symbolic_graph, CPU_TENSOR_NHWC(32F, 32 * 8 * 128, 128), "s");
384
1
  ccv_nnc_graph_exec_symbol_new(symbolic_graph, CMD_SOFTMAX_FORWARD(), TENSOR_SYMBOL_LIST(qks), TENSOR_SYMBOL_LIST(s), "softmax");
385
1
  ccv_nnc_tensor_symbol_t sa = ccv_nnc_tensor_symbol_alias_new(symbolic_graph, s, DIM_ALLOC(), DIM_ALLOC(8 * 128 * 128, 128 * 128, 128, 1), CPU_TENSOR_NHWC(32F, 32, 8, 128, 128), "sa");
386
1
  ccv_nnc_tensor_symbol_t r = ccv_nnc_tensor_symbol_new(symbolic_graph, CPU_TENSOR_NHWC(32F, 32, 8, 128, 96), "f");
387
1
  ccv_nnc_tensor_symbol_t tr = ccv_nnc_tensor_symbol_new(symbolic_graph, CPU_TENSOR_NHWC(32F, 32, 128, 8, 96), "f");
388
1
  ccv_nnc_graph_exec_symbol_new(symbolic_graph, CMD_GEMM_FORWARD(NO_TRANSPOSE, NO_TRANSPOSE), TENSOR_SYMBOL_LIST(sa, tv), TENSOR_SYMBOL_LIST(r), "final");
389
1
  ccv_nnc_graph_exec_symbol_new(symbolic_graph, CMD_TRANSPOSE_FORWARD(1, 2), TENSOR_SYMBOL_LIST(r), TENSOR_SYMBOL_LIST(tr), "final");
390
1
  ccv_nnc_graph_exec_symbol_autogen(symbolic_graph, 0, 0, CCV_NNC_AUTOGEN_ALL_EXECS | CCV_NNC_AUTOGEN_SOURCES_AND_DESTINATIONS);
391
1
  ccv_nnc_symbolic_graph_backward(symbolic_graph, TENSOR_SYMBOL_LIST(tr), TENSOR_SYMBOL_LIST(q, k, v), SYMBOLIC_GRAPH_SOURCES(symbolic_graph), SYMBOLIC_GRAPH_DESTINATIONS(symbolic_graph));
392
1
  ccv_nnc_graph_exec_symbol_autogen(symbolic_graph, 0, 0, CCV_NNC_AUTOGEN_ALL_EXECS | CCV_NNC_AUTOGEN_SOURCES_AND_DESTINATIONS);
393
1
  SYMBOLIC_GRAPH_GEN(symbolic_graph, CCV_NNC_LONG_DOT_GRAPH);
394
1
  ccv_nnc_graph_t* graph = 0;
395
1
  ccv_nnc_tensor_arena_t* tensor_arena = 0;
396
1
  ccv_nnc_graph_exec_arena_t* graph_exec_arena = 0;
397
1
  ccv_nnc_symbolic_graph_compile(symbolic_graph, ccv_nnc_default_compile_params, 0, 0, 0, 0, SYMBOLIC_GRAPH_SOURCES(symbolic_graph), SYMBOLIC_GRAPH_DESTINATIONS(symbolic_graph), &graph, &tensor_arena, &graph_exec_arena);
398
1
  GRAPH_GEN(graph, CCV_NNC_LONG_DOT_GRAPH);
399
1
  ccv_nnc_tensor_t* const q_tensor = ccv_nnc_tensor_from_symbol(tensor_arena, q);
400
1
  ccv_nnc_tensor_t* const k_tensor = ccv_nnc_tensor_from_symbol(tensor_arena, k);
401
1
  ccv_nnc_tensor_t* const v_tensor = ccv_nnc_tensor_from_symbol(tensor_arena, v);
402
1
  dsfmt_t dsfmt;
403
1
  int i;
404
1
  dsfmt_init_gen_rand(&dsfmt, 1);
405
2.09M
  for (i = 0; i < 32 * 8 * 128 * 64; 
i++2.09M
)
406
2.09M
    q_tensor->data.f32[i] = dsfmt_genrand_open_close(&dsfmt);
407
2.09M
  for (i = 0; i < 32 * 8 * 128 * 64; 
i++2.09M
)
408
2.09M
    k_tensor->data.f32[i] = dsfmt_genrand_open_close(&dsfmt);
409
3.14M
  for (i = 0; i < 32 * 8 * 128 * 96; 
i++3.14M
)
410
3.14M
    v_tensor->data.f32[i] = dsfmt_genrand_open_close(&dsfmt);
411
1
  ccv_nnc_tensor_symbol_t dr = ccv_nnc_tensor_symbol_for_backward(symbolic_graph, tr);
412
1
  ccv_nnc_tensor_t* const dr_tensor = ccv_nnc_tensor_from_symbol(tensor_arena, dr);
413
3.14M
  for (i = 0; i < 32 * 128 * 8 * 96; 
i++3.14M
)
414
3.14M
    dr_tensor->data.f32[i] = 1;
415
1
  ccv_nnc_symbolic_graph_t* const sdp_symbolic_graph = ccv_nnc_symbolic_graph_new();
416
1
  ccv_nnc_tensor_symbol_t bq = ccv_nnc_tensor_symbol_new(sdp_symbolic_graph, CPU_TENSOR_NHWC(32F, 32, 128, 8, 64), "q");
417
1
  ccv_nnc_tensor_symbol_t bk = ccv_nnc_tensor_symbol_new(sdp_symbolic_graph, CPU_TENSOR_NHWC(32F, 32, 128, 8, 64), "k");
418
1
  ccv_nnc_tensor_symbol_t bv = ccv_nnc_tensor_symbol_new(sdp_symbolic_graph, CPU_TENSOR_NHWC(32F, 32, 128, 8, 96), "v");
419
1
  ccv_nnc_tensor_symbol_t br = ccv_nnc_tensor_symbol_new(sdp_symbolic_graph, CPU_TENSOR_NHWC(32F, 32, 128, 8, 96), "r");
420
1
  ccv_nnc_graph_exec_symbol_new(sdp_symbolic_graph, CMD_SCALED_DOT_PRODUCT_ATTENTION_FORWARD(1.0 / 8, 0), TENSOR_SYMBOL_LIST(bq, bk, bv), TENSOR_SYMBOL_LIST(br, NO_TENSOR_SYMBOL, NO_TENSOR_SYMBOL), "scaled_dot_product_attention");
421
1
  ccv_nnc_graph_exec_symbol_autogen(sdp_symbolic_graph, 0, 0, CCV_NNC_AUTOGEN_ALL_EXECS | CCV_NNC_AUTOGEN_SOURCES_AND_DESTINATIONS);
422
1
  ccv_nnc_symbolic_graph_backward(sdp_symbolic_graph, TENSOR_SYMBOL_LIST(br), TENSOR_SYMBOL_LIST(bq, bk, bv), SYMBOLIC_GRAPH_SOURCES(sdp_symbolic_graph), SYMBOLIC_GRAPH_DESTINATIONS(sdp_symbolic_graph));
423
1
  ccv_nnc_graph_exec_symbol_autogen(sdp_symbolic_graph, 0, 0, CCV_NNC_AUTOGEN_ALL_EXECS | CCV_NNC_AUTOGEN_SOURCES_AND_DESTINATIONS);
424
1
  SYMBOLIC_GRAPH_GEN(sdp_symbolic_graph, CCV_NNC_LONG_DOT_GRAPH);
425
1
  ccv_nnc_graph_t* sdp_graph = 0;
426
1
  ccv_nnc_tensor_arena_t* sdp_tensor_arena = 0;
427
1
  ccv_nnc_graph_exec_arena_t* sdp_graph_exec_arena = 0;
428
1
  ccv_nnc_symbolic_graph_compile(sdp_symbolic_graph, ccv_nnc_default_compile_params, 0, 0, 0, 0, SYMBOLIC_GRAPH_SOURCES(sdp_symbolic_graph), SYMBOLIC_GRAPH_DESTINATIONS(sdp_symbolic_graph), &sdp_graph, &sdp_tensor_arena, &sdp_graph_exec_arena);
429
1
  GRAPH_GEN(sdp_graph, CCV_NNC_LONG_DOT_GRAPH);
430
1
  ccv_nnc_tensor_t* const bq_tensor = ccv_nnc_tensor_from_symbol(sdp_tensor_arena, bq);
431
1
  ccv_nnc_tensor_t* const bk_tensor = ccv_nnc_tensor_from_symbol(sdp_tensor_arena, bk);
432
1
  ccv_nnc_tensor_t* const bv_tensor = ccv_nnc_tensor_from_symbol(sdp_tensor_arena, bv);
433
1
  memcpy(bq_tensor->data.f32, q_tensor->data.f32, sizeof(float) * 32 * 8 * 128 * 64);
434
1
  memcpy(bk_tensor->data.f32, k_tensor->data.f32, sizeof(float) * 32 * 8 * 128 * 64);
435
1
  memcpy(bv_tensor->data.f32, v_tensor->data.f32, sizeof(float) * 32 * 8 * 128 * 96);
436
1
  ccv_nnc_tensor_symbol_t dbr = ccv_nnc_tensor_symbol_for_backward(sdp_symbolic_graph, br);
437
1
  ccv_nnc_tensor_t* const dbr_tensor = ccv_nnc_tensor_from_symbol(sdp_tensor_arena, dbr);
438
3.14M
  for (i = 0; i < 32 * 128 * 8 * 96; 
i++3.14M
)
439
3.14M
    dbr_tensor->data.f32[i] = 1;
440
  // CCV_CLI_SET_OUTPUT_LEVEL_AND_ABOVE(CCV_CLI_VERBOSE);
441
1
  ccv_nnc_graph_run(graph, 0, TRAVERSE_FULL, 0, 0);
442
1
  ccv_nnc_graph_run(sdp_graph, 0, TRAVERSE_FULL, 0, 0);
443
1
  ccv_nnc_tensor_symbol_t dq = ccv_nnc_tensor_symbol_for_backward(symbolic_graph, q);
444
1
  ccv_nnc_tensor_symbol_t dk = ccv_nnc_tensor_symbol_for_backward(symbolic_graph, k);
445
1
  ccv_nnc_tensor_symbol_t dv = ccv_nnc_tensor_symbol_for_backward(symbolic_graph, v);
446
1
  ccv_nnc_tensor_t* const dq_tensor = ccv_nnc_tensor_from_symbol(tensor_arena, dq);
447
1
  ccv_nnc_tensor_t* const dk_tensor = ccv_nnc_tensor_from_symbol(tensor_arena, dk);
448
1
  ccv_nnc_tensor_t* const dv_tensor = ccv_nnc_tensor_from_symbol(tensor_arena, dv);
449
1
  ccv_nnc_tensor_symbol_t dbq = ccv_nnc_tensor_symbol_for_backward(sdp_symbolic_graph, bq);
450
1
  ccv_nnc_tensor_symbol_t dbk = ccv_nnc_tensor_symbol_for_backward(sdp_symbolic_graph, bk);
451
1
  ccv_nnc_tensor_symbol_t dbv = ccv_nnc_tensor_symbol_for_backward(sdp_symbolic_graph, bv);
452
1
  ccv_nnc_tensor_t* const dbq_tensor = ccv_nnc_tensor_from_symbol(sdp_tensor_arena, dbq);
453
1
  ccv_nnc_tensor_t* const dbk_tensor = ccv_nnc_tensor_from_symbol(sdp_tensor_arena, dbk);
454
1
  ccv_nnc_tensor_t* const dbv_tensor = ccv_nnc_tensor_from_symbol(sdp_tensor_arena, dbv);
455
1
  REQUIRE_ARRAY_EQ_WITH_TOLERANCE(float, dq_tensor->data.f32, dbq_tensor->data.f32, 32 * 128 * 8 * 64, 1e-5, "graph computed gradient should match scaled dot product attention op gradient");
456
1
  REQUIRE_ARRAY_EQ_WITH_TOLERANCE(float, dk_tensor->data.f32, dbk_tensor->data.f32, 32 * 128 * 8 * 64, 1e-5, "graph computed gradient should match scaled dot product attention op gradient");
457
1
  REQUIRE_ARRAY_EQ_WITH_TOLERANCE(float, dv_tensor->data.f32, dbv_tensor->data.f32, 32 * 128 * 8 * 96, 1e-5, "graph computed gradient should match scaled dot product attention op gradient");
458
1
  ccv_nnc_symbolic_graph_free(symbolic_graph);
459
1
  ccv_nnc_tensor_arena_free(tensor_arena);
460
1
  ccv_nnc_graph_exec_arena_free(graph_exec_arena);
461
1
  ccv_nnc_graph_free(graph);
462
1
  ccv_nnc_symbolic_graph_free(sdp_symbolic_graph);
463
1
  ccv_nnc_tensor_arena_free(sdp_tensor_arena);
464
1
  ccv_nnc_graph_exec_arena_free(sdp_graph_exec_arena);
465
1
  ccv_nnc_graph_free(sdp_graph);
466
1
}
467
468
#include "case_main.h"