/home/liu/actions-runner/_work/ccv/ccv/test/unit/nnc/attention.tests.c
Line | Count | Source |
1 | | #include "case.h" |
2 | | #include "ccv_case.h" |
3 | | #include "ccv_nnc_case.h" |
4 | | #include <ccv.h> |
5 | | #include <nnc/ccv_nnc.h> |
6 | | #include <nnc/ccv_nnc_easy.h> |
7 | | #include "3rdparty/dsfmt/dSFMT.h" |
8 | | |
9 | | TEST_SETUP() |
10 | | { |
11 | | ccv_nnc_init(); |
12 | | } |
13 | | |
14 | | TEST_CASE("implement scaled dot product attention with fine-grained symbolic graph") |
15 | 1 | { |
16 | 1 | ccv_nnc_symbolic_graph_t* const symbolic_graph = ccv_nnc_symbolic_graph_new(); |
17 | 1 | ccv_nnc_tensor_symbol_t q = ccv_nnc_tensor_symbol_new(symbolic_graph, CPU_TENSOR_NHWC(32F, 32, 128, 8, 64), "q"); |
18 | 1 | ccv_nnc_tensor_symbol_t tq = ccv_nnc_tensor_symbol_new(symbolic_graph, CPU_TENSOR_NHWC(32F, 32, 8, 128, 64), "q"); |
19 | 1 | ccv_nnc_tensor_symbol_t k = ccv_nnc_tensor_symbol_new(symbolic_graph, CPU_TENSOR_NHWC(32F, 32, 128, 8, 64), "k"); |
20 | 1 | ccv_nnc_tensor_symbol_t tk = ccv_nnc_tensor_symbol_new(symbolic_graph, CPU_TENSOR_NHWC(32F, 32, 8, 128, 64), "k"); |
21 | 1 | ccv_nnc_tensor_symbol_t v = ccv_nnc_tensor_symbol_new(symbolic_graph, CPU_TENSOR_NHWC(32F, 32, 128, 8, 96), "v"); |
22 | 1 | ccv_nnc_tensor_symbol_t tv = ccv_nnc_tensor_symbol_new(symbolic_graph, CPU_TENSOR_NHWC(32F, 32, 8, 128, 96), "v"); |
23 | 1 | ccv_nnc_tensor_symbol_t qk = ccv_nnc_tensor_symbol_new(symbolic_graph, CPU_TENSOR_NHWC(32F, 32, 8, 128, 128), "qk"); |
24 | 1 | ccv_nnc_tensor_symbol_t sq = ccv_nnc_tensor_symbol_new(symbolic_graph, CPU_TENSOR_NHWC(32F, 32, 8, 128, 64), "sq"); |
25 | 1 | ccv_nnc_graph_exec_symbol_new(symbolic_graph, CMD_TRANSPOSE_FORWARD(1, 2), TENSOR_SYMBOL_LIST(q), TENSOR_SYMBOL_LIST(tq), "transpose q"); |
26 | 1 | ccv_nnc_graph_exec_symbol_new(symbolic_graph, CMD_TRANSPOSE_FORWARD(1, 2), TENSOR_SYMBOL_LIST(k), TENSOR_SYMBOL_LIST(tk), "transpose k"); |
27 | 1 | ccv_nnc_graph_exec_symbol_new(symbolic_graph, CMD_TRANSPOSE_FORWARD(1, 2), TENSOR_SYMBOL_LIST(v), TENSOR_SYMBOL_LIST(tv), "transpose v"); |
28 | 1 | ccv_nnc_graph_exec_symbol_new(symbolic_graph, CMD_SCALAR_MUL_FORWARD(1.0 / 8), TENSOR_SYMBOL_LIST(tq), TENSOR_SYMBOL_LIST(sq), "scaled_q"); |
29 | 1 | ccv_nnc_graph_exec_symbol_new(symbolic_graph, CMD_GEMM_FORWARD(NO_TRANSPOSE, TRANSPOSE(2, 3)), TENSOR_SYMBOL_LIST(sq, tk), TENSOR_SYMBOL_LIST(qk), "q @ k"); |
30 | 1 | ccv_nnc_tensor_symbol_t qks = ccv_nnc_tensor_symbol_alias_new(symbolic_graph, qk, DIM_ALLOC(), DIM_ALLOC(128, 1), CPU_TENSOR_NHWC(32F, 32 * 8 * 128, 128), "qks"); |
31 | 1 | ccv_nnc_tensor_symbol_t s = ccv_nnc_tensor_symbol_new(symbolic_graph, CPU_TENSOR_NHWC(32F, 32 * 8 * 128, 128), "s"); |
32 | 1 | ccv_nnc_graph_exec_symbol_new(symbolic_graph, CMD_SOFTMAX_FORWARD(), TENSOR_SYMBOL_LIST(qks), TENSOR_SYMBOL_LIST(s), "softmax"); |
33 | 1 | ccv_nnc_tensor_symbol_t sa = ccv_nnc_tensor_symbol_alias_new(symbolic_graph, s, DIM_ALLOC(), DIM_ALLOC(8 * 128 * 128, 128 * 128, 128, 1), CPU_TENSOR_NHWC(32F, 32, 8, 128, 128), "sa"); |
34 | 1 | ccv_nnc_tensor_symbol_t r = ccv_nnc_tensor_symbol_new(symbolic_graph, CPU_TENSOR_NHWC(32F, 32, 8, 128, 96), "f"); |
35 | 1 | ccv_nnc_tensor_symbol_t tr = ccv_nnc_tensor_symbol_new(symbolic_graph, CPU_TENSOR_NHWC(32F, 32, 128, 8, 96), "f"); |
36 | 1 | ccv_nnc_graph_exec_symbol_new(symbolic_graph, CMD_GEMM_FORWARD(NO_TRANSPOSE, NO_TRANSPOSE), TENSOR_SYMBOL_LIST(sa, tv), TENSOR_SYMBOL_LIST(r), "final"); |
37 | 1 | ccv_nnc_graph_exec_symbol_new(symbolic_graph, CMD_TRANSPOSE_FORWARD(1, 2), TENSOR_SYMBOL_LIST(r), TENSOR_SYMBOL_LIST(tr), "final"); |
38 | 1 | ccv_nnc_graph_exec_symbol_autogen(symbolic_graph, 0, 0, CCV_NNC_AUTOGEN_ALL_EXECS | CCV_NNC_AUTOGEN_SOURCES_AND_DESTINATIONS); |
39 | 1 | SYMBOLIC_GRAPH_GEN(symbolic_graph, CCV_NNC_LONG_DOT_GRAPH); |
40 | 1 | ccv_nnc_graph_t* graph = 0; |
41 | 1 | ccv_nnc_tensor_arena_t* tensor_arena = 0; |
42 | 1 | ccv_nnc_graph_exec_arena_t* graph_exec_arena = 0; |
43 | 1 | ccv_nnc_symbolic_graph_compile(symbolic_graph, ccv_nnc_default_compile_params, 0, 0, 0, 0, SYMBOLIC_GRAPH_SOURCES(symbolic_graph), SYMBOLIC_GRAPH_DESTINATIONS(symbolic_graph), &graph, &tensor_arena, &graph_exec_arena); |
44 | 1 | GRAPH_GEN(graph, CCV_NNC_LONG_DOT_GRAPH); |
45 | 1 | ccv_nnc_tensor_t* const q_tensor = ccv_nnc_tensor_from_symbol(tensor_arena, q); |
46 | 1 | ccv_nnc_tensor_t* const k_tensor = ccv_nnc_tensor_from_symbol(tensor_arena, k); |
47 | 1 | ccv_nnc_tensor_t* const v_tensor = ccv_nnc_tensor_from_symbol(tensor_arena, v); |
48 | 1 | dsfmt_t dsfmt; |
49 | 1 | int i; |
50 | 1 | dsfmt_init_gen_rand(&dsfmt, 1); |
51 | 2.09M | for (i = 0; i < 32 * 8 * 128 * 64; i++2.09M ) |
52 | 2.09M | q_tensor->data.f32[i] = dsfmt_genrand_open_close(&dsfmt); |
53 | 2.09M | for (i = 0; i < 32 * 8 * 128 * 64; i++2.09M ) |
54 | 2.09M | k_tensor->data.f32[i] = dsfmt_genrand_open_close(&dsfmt); |
55 | 3.14M | for (i = 0; i < 32 * 8 * 128 * 96; i++3.14M ) |
56 | 3.14M | v_tensor->data.f32[i] = dsfmt_genrand_open_close(&dsfmt); |
57 | 1 | ccv_nnc_symbolic_graph_t* const sdp_symbolic_graph = ccv_nnc_symbolic_graph_new(); |
58 | 1 | ccv_nnc_tensor_symbol_t bq = ccv_nnc_tensor_symbol_new(sdp_symbolic_graph, CPU_TENSOR_NHWC(32F, 32, 128, 8, 64), "q"); |
59 | 1 | ccv_nnc_tensor_symbol_t bk = ccv_nnc_tensor_symbol_new(sdp_symbolic_graph, CPU_TENSOR_NHWC(32F, 32, 128, 8, 64), "k"); |
60 | 1 | ccv_nnc_tensor_symbol_t bv = ccv_nnc_tensor_symbol_new(sdp_symbolic_graph, CPU_TENSOR_NHWC(32F, 32, 128, 8, 96), "v"); |
61 | 1 | ccv_nnc_tensor_symbol_t br = ccv_nnc_tensor_symbol_new(sdp_symbolic_graph, CPU_TENSOR_NHWC(32F, 32, 128, 8, 96), "r"); |
62 | 1 | ccv_nnc_graph_exec_symbol_new(sdp_symbolic_graph, CMD_SCALED_DOT_PRODUCT_ATTENTION_FORWARD(1.0 / 8, 0), TENSOR_SYMBOL_LIST(bq, bk, bv), TENSOR_SYMBOL_LIST(br), "scaled_dot_product_attention"); |
63 | 1 | ccv_nnc_graph_exec_symbol_autogen(sdp_symbolic_graph, 0, 0, CCV_NNC_AUTOGEN_ALL_EXECS | CCV_NNC_AUTOGEN_SOURCES_AND_DESTINATIONS); |
64 | 1 | ccv_nnc_graph_t* sdp_graph = 0; |
65 | 1 | ccv_nnc_tensor_arena_t* sdp_tensor_arena = 0; |
66 | 1 | ccv_nnc_graph_exec_arena_t* sdp_graph_exec_arena = 0; |
67 | 1 | ccv_nnc_symbolic_graph_compile(sdp_symbolic_graph, ccv_nnc_default_compile_params, 0, 0, 0, 0, SYMBOLIC_GRAPH_SOURCES(sdp_symbolic_graph), SYMBOLIC_GRAPH_DESTINATIONS(sdp_symbolic_graph), &sdp_graph, &sdp_tensor_arena, &sdp_graph_exec_arena); |
68 | 1 | ccv_nnc_tensor_t* const bq_tensor = ccv_nnc_tensor_from_symbol(sdp_tensor_arena, bq); |
69 | 1 | ccv_nnc_tensor_t* const bk_tensor = ccv_nnc_tensor_from_symbol(sdp_tensor_arena, bk); |
70 | 1 | ccv_nnc_tensor_t* const bv_tensor = ccv_nnc_tensor_from_symbol(sdp_tensor_arena, bv); |
71 | 1 | memcpy(bq_tensor->data.f32, q_tensor->data.f32, sizeof(float) * 32 * 8 * 128 * 64); |
72 | 1 | memcpy(bk_tensor->data.f32, k_tensor->data.f32, sizeof(float) * 32 * 8 * 128 * 64); |
73 | 1 | memcpy(bv_tensor->data.f32, v_tensor->data.f32, sizeof(float) * 32 * 8 * 128 * 96); |
74 | 1 | ccv_nnc_graph_run(graph, 0, TRAVERSE_FULL, 0, 0); |
75 | 1 | ccv_nnc_graph_run(sdp_graph, 0, TRAVERSE_FULL, 0, 0); |
76 | 1 | ccv_nnc_tensor_t* const r_tensor = ccv_nnc_tensor_from_symbol(tensor_arena, tr); |
77 | 1 | ccv_nnc_tensor_t* const br_tensor = ccv_nnc_tensor_from_symbol(sdp_tensor_arena, br); |
78 | 1 | REQUIRE_TENSOR_EQ(r_tensor, br_tensor, "graph computed result should match scaled dot product attention op result"); |
79 | 1 | ccv_nnc_symbolic_graph_free(symbolic_graph); |
80 | 1 | ccv_nnc_tensor_arena_free(tensor_arena); |
81 | 1 | ccv_nnc_graph_exec_arena_free(graph_exec_arena); |
82 | 1 | ccv_nnc_graph_free(graph); |
83 | 1 | ccv_nnc_symbolic_graph_free(sdp_symbolic_graph); |
84 | 1 | ccv_nnc_tensor_arena_free(sdp_tensor_arena); |
85 | 1 | ccv_nnc_graph_exec_arena_free(sdp_graph_exec_arena); |
86 | 1 | ccv_nnc_graph_free(sdp_graph); |
87 | 1 | } |
88 | | |
89 | | TEST_CASE("implement scaled dot product attention + unify head output with fine-grained symbolic graph") |
90 | 1 | { |
91 | 1 | ccv_nnc_symbolic_graph_t* const symbolic_graph = ccv_nnc_symbolic_graph_new(); |
92 | 1 | ccv_nnc_tensor_symbol_t q = ccv_nnc_tensor_symbol_new(symbolic_graph, CPU_TENSOR_NHWC(32F, 32, 128, 8, 64), "q"); |
93 | 1 | ccv_nnc_tensor_symbol_t k = ccv_nnc_tensor_symbol_new(symbolic_graph, CPU_TENSOR_NHWC(32F, 32, 128, 8, 64), "k"); |
94 | 1 | ccv_nnc_tensor_symbol_t v = ccv_nnc_tensor_symbol_new(symbolic_graph, CPU_TENSOR_NHWC(32F, 32, 128, 8, 96), "v"); |
95 | 1 | ccv_nnc_tensor_symbol_t tq = ccv_nnc_tensor_symbol_new(symbolic_graph, CPU_TENSOR_NHWC(32F, 32, 8, 128, 64), "q"); |
96 | 1 | ccv_nnc_tensor_symbol_t tk = ccv_nnc_tensor_symbol_new(symbolic_graph, CPU_TENSOR_NHWC(32F, 32, 8, 128, 64), "k"); |
97 | 1 | ccv_nnc_tensor_symbol_t tv = ccv_nnc_tensor_symbol_new(symbolic_graph, CPU_TENSOR_NHWC(32F, 32, 8, 128, 96), "v"); |
98 | 1 | ccv_nnc_tensor_symbol_t qk = ccv_nnc_tensor_symbol_new(symbolic_graph, CPU_TENSOR_NHWC(32F, 32, 8, 128, 128), "qk"); |
99 | 1 | ccv_nnc_tensor_symbol_t sq = ccv_nnc_tensor_symbol_new(symbolic_graph, CPU_TENSOR_NHWC(32F, 32, 8, 128, 64), "sq"); |
100 | 1 | ccv_nnc_graph_exec_symbol_new(symbolic_graph, CMD_TRANSPOSE_FORWARD(1, 2), TENSOR_SYMBOL_LIST(q), TENSOR_SYMBOL_LIST(tq), "transpose q"); |
101 | 1 | ccv_nnc_graph_exec_symbol_new(symbolic_graph, CMD_TRANSPOSE_FORWARD(1, 2), TENSOR_SYMBOL_LIST(k), TENSOR_SYMBOL_LIST(tk), "transpose k"); |
102 | 1 | ccv_nnc_graph_exec_symbol_new(symbolic_graph, CMD_TRANSPOSE_FORWARD(1, 2), TENSOR_SYMBOL_LIST(v), TENSOR_SYMBOL_LIST(tv), "transpose v"); |
103 | 1 | ccv_nnc_graph_exec_symbol_new(symbolic_graph, CMD_SCALAR_MUL_FORWARD(1.0 / 8), TENSOR_SYMBOL_LIST(tq), TENSOR_SYMBOL_LIST(sq), "scaled_q"); |
104 | 1 | ccv_nnc_graph_exec_symbol_new(symbolic_graph, CMD_GEMM_FORWARD(NO_TRANSPOSE, TRANSPOSE(2, 3)), TENSOR_SYMBOL_LIST(sq, tk), TENSOR_SYMBOL_LIST(qk), "q @ k"); |
105 | 1 | ccv_nnc_tensor_symbol_t qks = ccv_nnc_tensor_symbol_alias_new(symbolic_graph, qk, DIM_ALLOC(), DIM_ALLOC(128, 1), CPU_TENSOR_NHWC(32F, 32 * 8 * 128, 128), "qks"); |
106 | 1 | ccv_nnc_tensor_symbol_t s = ccv_nnc_tensor_symbol_new(symbolic_graph, CPU_TENSOR_NHWC(32F, 32 * 8 * 128, 128), "s"); |
107 | 1 | ccv_nnc_graph_exec_symbol_new(symbolic_graph, CMD_SOFTMAX_FORWARD(), TENSOR_SYMBOL_LIST(qks), TENSOR_SYMBOL_LIST(s), "softmax"); |
108 | 1 | ccv_nnc_tensor_symbol_t sa = ccv_nnc_tensor_symbol_alias_new(symbolic_graph, s, DIM_ALLOC(), DIM_ALLOC(8 * 128 * 128, 128 * 128, 128, 1), CPU_TENSOR_NHWC(32F, 32, 8, 128, 128), "sa"); |
109 | 1 | ccv_nnc_tensor_symbol_t c = ccv_nnc_tensor_symbol_new(symbolic_graph, CPU_TENSOR_NHWC(32F, 32, 8, 128, 96), "c"); |
110 | 1 | ccv_nnc_graph_exec_symbol_new(symbolic_graph, CMD_GEMM_FORWARD(NO_TRANSPOSE, NO_TRANSPOSE), TENSOR_SYMBOL_LIST(sa, tv), TENSOR_SYMBOL_LIST(c), "c"); |
111 | 1 | ccv_nnc_tensor_symbol_t ct = ccv_nnc_tensor_symbol_new(symbolic_graph, CPU_TENSOR_NHWC(32F, 32, 128, 8, 96), "ct"); |
112 | 1 | ccv_nnc_graph_exec_symbol_new(symbolic_graph, CMD_TRANSPOSE_FORWARD(1, 2), TENSOR_SYMBOL_LIST(c), TENSOR_SYMBOL_LIST(ct), "ct"); |
113 | 1 | ccv_nnc_tensor_symbol_t cta = ccv_nnc_tensor_symbol_alias_new(symbolic_graph, ct, DIM_ALLOC(), DIM_ALLOC(128 * 768, 768, 1), CPU_TENSOR_NHWC(32F, 32, 128, 768), "ct"); |
114 | 1 | ccv_nnc_tensor_symbol_t w = ccv_nnc_tensor_symbol_new(symbolic_graph, CPU_TENSOR_NHWC(32F, 768, 768), "w"); |
115 | 1 | ccv_nnc_tensor_symbol_t bias = ccv_nnc_tensor_symbol_new(symbolic_graph, CPU_TENSOR_NHWC(32F, 768), "bias"); |
116 | 1 | ccv_nnc_tensor_symbol_t r = ccv_nnc_tensor_symbol_new(symbolic_graph, CPU_TENSOR_NHWC(32F, 32, 128, 8 * 96), "r"); |
117 | 1 | ccv_nnc_graph_exec_symbol_new(symbolic_graph, CMD_GEMM_FORWARD(NO_TRANSPOSE, TRANSPOSE(0, 1)), TENSOR_SYMBOL_LIST(cta, w, bias), TENSOR_SYMBOL_LIST(r), "final"); |
118 | 1 | ccv_nnc_graph_exec_symbol_autogen(symbolic_graph, 0, 0, CCV_NNC_AUTOGEN_ALL_EXECS | CCV_NNC_AUTOGEN_SOURCES_AND_DESTINATIONS); |
119 | 1 | SYMBOLIC_GRAPH_GEN(symbolic_graph, CCV_NNC_LONG_DOT_GRAPH); |
120 | 1 | ccv_nnc_graph_t* graph = 0; |
121 | 1 | ccv_nnc_tensor_arena_t* tensor_arena = 0; |
122 | 1 | ccv_nnc_graph_exec_arena_t* graph_exec_arena = 0; |
123 | 1 | ccv_nnc_symbolic_graph_compile(symbolic_graph, ccv_nnc_default_compile_params, 0, 0, 0, 0, SYMBOLIC_GRAPH_SOURCES(symbolic_graph), SYMBOLIC_GRAPH_DESTINATIONS(symbolic_graph), &graph, &tensor_arena, &graph_exec_arena); |
124 | 1 | GRAPH_GEN(graph, CCV_NNC_LONG_DOT_GRAPH); |
125 | 1 | ccv_nnc_tensor_t* const q_tensor = ccv_nnc_tensor_from_symbol(tensor_arena, q); |
126 | 1 | ccv_nnc_tensor_t* const k_tensor = ccv_nnc_tensor_from_symbol(tensor_arena, k); |
127 | 1 | ccv_nnc_tensor_t* const v_tensor = ccv_nnc_tensor_from_symbol(tensor_arena, v); |
128 | 1 | ccv_nnc_tensor_t* const w_tensor = ccv_nnc_tensor_from_symbol(tensor_arena, w); |
129 | 1 | ccv_nnc_tensor_t* const bias_tensor = ccv_nnc_tensor_from_symbol(tensor_arena, bias); |
130 | 1 | dsfmt_t dsfmt; |
131 | 1 | int i; |
132 | 1 | dsfmt_init_gen_rand(&dsfmt, 1); |
133 | 2.09M | for (i = 0; i < 32 * 8 * 128 * 64; i++2.09M ) |
134 | 2.09M | q_tensor->data.f32[i] = dsfmt_genrand_open_close(&dsfmt); |
135 | 2.09M | for (i = 0; i < 32 * 8 * 128 * 64; i++2.09M ) |
136 | 2.09M | k_tensor->data.f32[i] = dsfmt_genrand_open_close(&dsfmt); |
137 | 3.14M | for (i = 0; i < 32 * 8 * 128 * 96; i++3.14M ) |
138 | 3.14M | v_tensor->data.f32[i] = dsfmt_genrand_open_close(&dsfmt); |
139 | 589k | for (i = 0; i < 768 * 768; i++589k ) |
140 | 589k | w_tensor->data.f32[i] = dsfmt_genrand_open_close(&dsfmt); |
141 | 769 | for (i = 0; i < 768; i++768 ) |
142 | 768 | bias_tensor->data.f32[i] = dsfmt_genrand_open_close(&dsfmt); |
143 | 1 | ccv_nnc_symbolic_graph_t* const sdp_symbolic_graph = ccv_nnc_symbolic_graph_new(); |
144 | 1 | ccv_nnc_tensor_symbol_t bq = ccv_nnc_tensor_symbol_new(sdp_symbolic_graph, CPU_TENSOR_NHWC(32F, 32, 128, 8, 64), "q"); |
145 | 1 | ccv_nnc_tensor_symbol_t bk = ccv_nnc_tensor_symbol_new(sdp_symbolic_graph, CPU_TENSOR_NHWC(32F, 32, 128, 8, 64), "k"); |
146 | 1 | ccv_nnc_tensor_symbol_t bv = ccv_nnc_tensor_symbol_new(sdp_symbolic_graph, CPU_TENSOR_NHWC(32F, 32, 128, 8, 96), "v"); |
147 | 1 | ccv_nnc_tensor_symbol_t bw = ccv_nnc_tensor_symbol_new(sdp_symbolic_graph, CPU_TENSOR_NHWC(32F, 768, 768), "w"); |
148 | 1 | ccv_nnc_tensor_symbol_t bbias = ccv_nnc_tensor_symbol_new(sdp_symbolic_graph, CPU_TENSOR_NHWC(32F, 768), "bias"); |
149 | 1 | ccv_nnc_tensor_symbol_t bc = ccv_nnc_tensor_symbol_new(sdp_symbolic_graph, CPU_TENSOR_NHWC(32F, 32, 128, 8, 96), "c"); |
150 | 1 | ccv_nnc_tensor_symbol_t br = ccv_nnc_tensor_symbol_new(sdp_symbolic_graph, CPU_TENSOR_NHWC(32F, 32, 128, 768), "r"); |
151 | 1 | ccv_nnc_graph_exec_symbol_new(sdp_symbolic_graph, CMD_SCALED_DOT_PRODUCT_ATTENTION_FORWARD(1.0 / 8, 0), TENSOR_SYMBOL_LIST(bq, bk, bv, NO_TENSOR_SYMBOL, bw, bbias), TENSOR_SYMBOL_LIST(br, NO_TENSOR_SYMBOL, bc), "scaled_dot_product_attention"); |
152 | 1 | ccv_nnc_graph_exec_symbol_autogen(sdp_symbolic_graph, 0, 0, CCV_NNC_AUTOGEN_ALL_EXECS | CCV_NNC_AUTOGEN_SOURCES_AND_DESTINATIONS); |
153 | 1 | ccv_nnc_graph_t* sdp_graph = 0; |
154 | 1 | ccv_nnc_tensor_arena_t* sdp_tensor_arena = 0; |
155 | 1 | ccv_nnc_graph_exec_arena_t* sdp_graph_exec_arena = 0; |
156 | 1 | ccv_nnc_symbolic_graph_compile(sdp_symbolic_graph, ccv_nnc_default_compile_params, 0, 0, 0, 0, SYMBOLIC_GRAPH_SOURCES(sdp_symbolic_graph), SYMBOLIC_GRAPH_DESTINATIONS(sdp_symbolic_graph), &sdp_graph, &sdp_tensor_arena, &sdp_graph_exec_arena); |
157 | 1 | ccv_nnc_tensor_t* const bq_tensor = ccv_nnc_tensor_from_symbol(sdp_tensor_arena, bq); |
158 | 1 | ccv_nnc_tensor_t* const bk_tensor = ccv_nnc_tensor_from_symbol(sdp_tensor_arena, bk); |
159 | 1 | ccv_nnc_tensor_t* const bv_tensor = ccv_nnc_tensor_from_symbol(sdp_tensor_arena, bv); |
160 | 1 | ccv_nnc_tensor_t* const bw_tensor = ccv_nnc_tensor_from_symbol(sdp_tensor_arena, bw); |
161 | 1 | ccv_nnc_tensor_t* const bbias_tensor = ccv_nnc_tensor_from_symbol(sdp_tensor_arena, bbias); |
162 | 1 | memcpy(bq_tensor->data.f32, q_tensor->data.f32, sizeof(float) * 32 * 8 * 128 * 64); |
163 | 1 | memcpy(bk_tensor->data.f32, k_tensor->data.f32, sizeof(float) * 32 * 8 * 128 * 64); |
164 | 1 | memcpy(bv_tensor->data.f32, v_tensor->data.f32, sizeof(float) * 32 * 8 * 128 * 96); |
165 | 1 | memcpy(bw_tensor->data.f32, w_tensor->data.f32, sizeof(float) * 768 * 768); |
166 | 1 | memcpy(bbias_tensor->data.f32, bias_tensor->data.f32, sizeof(float) * 768); |
167 | 1 | ccv_nnc_graph_run(graph, 0, TRAVERSE_FULL, 0, 0); |
168 | 1 | ccv_nnc_graph_run(sdp_graph, 0, TRAVERSE_FULL, 0, 0); |
169 | 1 | ccv_nnc_tensor_t* const r_tensor = ccv_nnc_tensor_from_symbol(tensor_arena, r); |
170 | 1 | ccv_nnc_tensor_t* const br_tensor = ccv_nnc_tensor_from_symbol(sdp_tensor_arena, br); |
171 | 1 | REQUIRE_TENSOR_EQ(r_tensor, br_tensor, "graph computed result should match scaled dot product attention op result"); |
172 | 1 | ccv_nnc_symbolic_graph_free(symbolic_graph); |
173 | 1 | ccv_nnc_tensor_arena_free(tensor_arena); |
174 | 1 | ccv_nnc_graph_exec_arena_free(graph_exec_arena); |
175 | 1 | ccv_nnc_graph_free(graph); |
176 | 1 | ccv_nnc_symbolic_graph_free(sdp_symbolic_graph); |
177 | 1 | ccv_nnc_tensor_arena_free(sdp_tensor_arena); |
178 | 1 | ccv_nnc_graph_exec_arena_free(sdp_graph_exec_arena); |
179 | 1 | ccv_nnc_graph_free(sdp_graph); |
180 | 1 | } |
181 | | |
182 | | TEST_CASE("run scaled dot product attention with cnnp model") |
183 | 1 | { |
184 | 1 | ccv_nnc_symbolic_graph_t* const sdp_symbolic_graph = ccv_nnc_symbolic_graph_new(); |
185 | 1 | ccv_nnc_tensor_symbol_t bq = ccv_nnc_tensor_symbol_new(sdp_symbolic_graph, CPU_TENSOR_NHWC(32F, 32, 128, 8, 64), "q"); |
186 | 1 | ccv_nnc_tensor_symbol_t bk = ccv_nnc_tensor_symbol_new(sdp_symbolic_graph, CPU_TENSOR_NHWC(32F, 32, 128, 8, 64), "k"); |
187 | 1 | ccv_nnc_tensor_symbol_t bv = ccv_nnc_tensor_symbol_new(sdp_symbolic_graph, CPU_TENSOR_NHWC(32F, 32, 128, 8, 96), "v"); |
188 | 1 | ccv_nnc_tensor_symbol_t br = ccv_nnc_tensor_symbol_new(sdp_symbolic_graph, CPU_TENSOR_NHWC(32F, 32, 128, 8, 96), "r"); |
189 | 1 | ccv_nnc_graph_exec_symbol_new(sdp_symbolic_graph, CMD_SCALED_DOT_PRODUCT_ATTENTION_FORWARD(1.0 / 8, 0), TENSOR_SYMBOL_LIST(bq, bk, bv), TENSOR_SYMBOL_LIST(br), "scaled_dot_product_attention"); |
190 | 1 | ccv_nnc_graph_exec_symbol_autogen(sdp_symbolic_graph, 0, 0, CCV_NNC_AUTOGEN_ALL_EXECS | CCV_NNC_AUTOGEN_SOURCES_AND_DESTINATIONS); |
191 | 1 | ccv_nnc_graph_t* sdp_graph = 0; |
192 | 1 | ccv_nnc_tensor_arena_t* sdp_tensor_arena = 0; |
193 | 1 | ccv_nnc_graph_exec_arena_t* sdp_graph_exec_arena = 0; |
194 | 1 | ccv_nnc_symbolic_graph_compile(sdp_symbolic_graph, ccv_nnc_default_compile_params, 0, 0, 0, 0, SYMBOLIC_GRAPH_SOURCES(sdp_symbolic_graph), SYMBOLIC_GRAPH_DESTINATIONS(sdp_symbolic_graph), &sdp_graph, &sdp_tensor_arena, &sdp_graph_exec_arena); |
195 | 1 | ccv_nnc_tensor_t* const bq_tensor = ccv_nnc_tensor_from_symbol(sdp_tensor_arena, bq); |
196 | 1 | ccv_nnc_tensor_t* const bk_tensor = ccv_nnc_tensor_from_symbol(sdp_tensor_arena, bk); |
197 | 1 | ccv_nnc_tensor_t* const bv_tensor = ccv_nnc_tensor_from_symbol(sdp_tensor_arena, bv); |
198 | 1 | int i; |
199 | 1 | dsfmt_t dsfmt; |
200 | 1 | dsfmt_init_gen_rand(&dsfmt, 1); |
201 | 2.09M | for (i = 0; i < 32 * 8 * 128 * 64; i++2.09M ) |
202 | 2.09M | bq_tensor->data.f32[i] = dsfmt_genrand_open_close(&dsfmt); |
203 | 2.09M | for (i = 0; i < 32 * 8 * 128 * 64; i++2.09M ) |
204 | 2.09M | bk_tensor->data.f32[i] = dsfmt_genrand_open_close(&dsfmt); |
205 | 3.14M | for (i = 0; i < 32 * 8 * 128 * 96; i++3.14M ) |
206 | 3.14M | bv_tensor->data.f32[i] = dsfmt_genrand_open_close(&dsfmt); |
207 | 1 | ccv_nnc_tensor_t* const q_tensor = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, 32, 128, 8, 64), 0); |
208 | 1 | ccv_nnc_tensor_t* const k_tensor = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, 32, 128, 8, 64), 0); |
209 | 1 | ccv_nnc_tensor_t* const v_tensor = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, 32, 128, 8, 96), 0); |
210 | 1 | memcpy(q_tensor->data.f32, bq_tensor->data.f32, sizeof(float) * 32 * 8 * 128 * 64); |
211 | 1 | memcpy(k_tensor->data.f32, bk_tensor->data.f32, sizeof(float) * 32 * 8 * 128 * 64); |
212 | 1 | memcpy(v_tensor->data.f32, bv_tensor->data.f32, sizeof(float) * 32 * 8 * 128 * 96); |
213 | 1 | ccv_nnc_graph_run(sdp_graph, 0, TRAVERSE_FULL, 0, 0); |
214 | 1 | ccv_nnc_tensor_t* const br_tensor = ccv_nnc_tensor_from_symbol(sdp_tensor_arena, br); |
215 | 1 | ccv_nnc_tensor_t* const r_tensor = ccv_nnc_tensor_new(0, br_tensor->info, 0); |
216 | 1 | ccv_cnnp_model_t* scaled_dot_product_attention = ccv_cnnp_scaled_dot_product_attention(1.0 / 8, 0, 0, 0, 0, 0, 0, "scaled_dot_product_attention"); |
217 | 1 | ccv_nnc_tensor_param_t qkv[3]; |
218 | 1 | qkv[0] = q_tensor->info; |
219 | 1 | qkv[1] = k_tensor->info; |
220 | 1 | qkv[2] = v_tensor->info; |
221 | 1 | ccv_cnnp_model_compile(scaled_dot_product_attention, qkv, 3, CMD_NOOP(), CMD_NOOP()); |
222 | 1 | ccv_cnnp_model_evaluate(scaled_dot_product_attention, (ccv_cnnp_evaluate_param_t){}, TENSOR_LIST(q_tensor, k_tensor, v_tensor), TENSOR_LIST(r_tensor), 0, 0); |
223 | 1 | CNNP_MODEL_GEN(scaled_dot_product_attention, CCV_NNC_LONG_DOT_GRAPH); |
224 | 1 | REQUIRE_TENSOR_EQ(r_tensor, br_tensor, "graph computed result should match scaled dot product attention op result"); |
225 | 1 | ccv_nnc_symbolic_graph_free(sdp_symbolic_graph); |
226 | 1 | ccv_nnc_tensor_arena_free(sdp_tensor_arena); |
227 | 1 | ccv_nnc_graph_exec_arena_free(sdp_graph_exec_arena); |
228 | 1 | ccv_nnc_graph_free(sdp_graph); |
229 | 1 | ccv_nnc_tensor_free(q_tensor); |
230 | 1 | ccv_nnc_tensor_free(k_tensor); |
231 | 1 | ccv_nnc_tensor_free(v_tensor); |
232 | 1 | ccv_nnc_tensor_free(r_tensor); |
233 | 1 | ccv_cnnp_model_free(scaled_dot_product_attention); |
234 | 1 | } |
235 | | |
236 | | TEST_CASE("run scaled dot product attention + unify head output with cnnp model") |
237 | 1 | { |
238 | 1 | ccv_nnc_symbolic_graph_t* const sdp_symbolic_graph = ccv_nnc_symbolic_graph_new(); |
239 | 1 | ccv_nnc_tensor_symbol_t bq = ccv_nnc_tensor_symbol_new(sdp_symbolic_graph, CPU_TENSOR_NHWC(32F, 32, 128, 8, 64), "q"); |
240 | 1 | ccv_nnc_tensor_symbol_t bk = ccv_nnc_tensor_symbol_new(sdp_symbolic_graph, CPU_TENSOR_NHWC(32F, 32, 128, 8, 64), "k"); |
241 | 1 | ccv_nnc_tensor_symbol_t bv = ccv_nnc_tensor_symbol_new(sdp_symbolic_graph, CPU_TENSOR_NHWC(32F, 32, 128, 8, 96), "v"); |
242 | 1 | ccv_nnc_tensor_symbol_t bw = ccv_nnc_tensor_symbol_new(sdp_symbolic_graph, CPU_TENSOR_NHWC(32F, 768, 768), "w"); |
243 | 1 | ccv_nnc_tensor_symbol_t bbias = ccv_nnc_tensor_symbol_new(sdp_symbolic_graph, CPU_TENSOR_NHWC(32F, 768), "bias"); |
244 | 1 | ccv_nnc_tensor_symbol_t bc = ccv_nnc_tensor_symbol_new(sdp_symbolic_graph, CPU_TENSOR_NHWC(32F, 32, 128, 8, 96), "c"); |
245 | 1 | ccv_nnc_tensor_symbol_t br = ccv_nnc_tensor_symbol_new(sdp_symbolic_graph, CPU_TENSOR_NHWC(32F, 32, 128, 768), "r"); |
246 | 1 | ccv_nnc_graph_exec_symbol_new(sdp_symbolic_graph, CMD_SCALED_DOT_PRODUCT_ATTENTION_FORWARD(1.0 / 8, 0), TENSOR_SYMBOL_LIST(bq, bk, bv, NO_TENSOR_SYMBOL, bw, bbias), TENSOR_SYMBOL_LIST(br, NO_TENSOR_SYMBOL, bc), "scaled_dot_product_attention"); |
247 | 1 | ccv_nnc_graph_exec_symbol_autogen(sdp_symbolic_graph, 0, 0, CCV_NNC_AUTOGEN_ALL_EXECS | CCV_NNC_AUTOGEN_SOURCES_AND_DESTINATIONS); |
248 | 1 | ccv_nnc_graph_t* sdp_graph = 0; |
249 | 1 | ccv_nnc_tensor_arena_t* sdp_tensor_arena = 0; |
250 | 1 | ccv_nnc_graph_exec_arena_t* sdp_graph_exec_arena = 0; |
251 | 1 | ccv_nnc_symbolic_graph_compile(sdp_symbolic_graph, ccv_nnc_default_compile_params, 0, 0, 0, 0, SYMBOLIC_GRAPH_SOURCES(sdp_symbolic_graph), SYMBOLIC_GRAPH_DESTINATIONS(sdp_symbolic_graph), &sdp_graph, &sdp_tensor_arena, &sdp_graph_exec_arena); |
252 | 1 | ccv_nnc_tensor_t* const bq_tensor = ccv_nnc_tensor_from_symbol(sdp_tensor_arena, bq); |
253 | 1 | ccv_nnc_tensor_t* const bk_tensor = ccv_nnc_tensor_from_symbol(sdp_tensor_arena, bk); |
254 | 1 | ccv_nnc_tensor_t* const bv_tensor = ccv_nnc_tensor_from_symbol(sdp_tensor_arena, bv); |
255 | 1 | ccv_nnc_tensor_t* const bw_tensor = ccv_nnc_tensor_from_symbol(sdp_tensor_arena, bw); |
256 | 1 | ccv_nnc_tensor_t* const bbias_tensor = ccv_nnc_tensor_from_symbol(sdp_tensor_arena, bbias); |
257 | 1 | int i; |
258 | 1 | dsfmt_t dsfmt; |
259 | 1 | dsfmt_init_gen_rand(&dsfmt, 1); |
260 | 2.09M | for (i = 0; i < 32 * 8 * 128 * 64; i++2.09M ) |
261 | 2.09M | bq_tensor->data.f32[i] = dsfmt_genrand_open_close(&dsfmt); |
262 | 2.09M | for (i = 0; i < 32 * 8 * 128 * 64; i++2.09M ) |
263 | 2.09M | bk_tensor->data.f32[i] = dsfmt_genrand_open_close(&dsfmt); |
264 | 3.14M | for (i = 0; i < 32 * 8 * 128 * 96; i++3.14M ) |
265 | 3.14M | bv_tensor->data.f32[i] = dsfmt_genrand_open_close(&dsfmt); |
266 | 589k | for (i = 0; i < 768 * 768; i++589k ) |
267 | 589k | bw_tensor->data.f32[i] = dsfmt_genrand_open_close(&dsfmt); |
268 | 769 | for (i = 0; i < 768; i++768 ) |
269 | 768 | bbias_tensor->data.f32[i] = dsfmt_genrand_open_close(&dsfmt); |
270 | 1 | ccv_nnc_tensor_t* const q_tensor = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, 32, 128, 8, 64), 0); |
271 | 1 | ccv_nnc_tensor_t* const k_tensor = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, 32, 128, 8, 64), 0); |
272 | 1 | ccv_nnc_tensor_t* const v_tensor = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, 32, 128, 8, 96), 0); |
273 | 1 | memcpy(q_tensor->data.f32, bq_tensor->data.f32, sizeof(float) * 32 * 8 * 128 * 64); |
274 | 1 | memcpy(k_tensor->data.f32, bk_tensor->data.f32, sizeof(float) * 32 * 8 * 128 * 64); |
275 | 1 | memcpy(v_tensor->data.f32, bv_tensor->data.f32, sizeof(float) * 32 * 8 * 128 * 96); |
276 | 1 | ccv_nnc_tensor_t* const br_tensor = ccv_nnc_tensor_from_symbol(sdp_tensor_arena, br); |
277 | 1 | ccv_nnc_tensor_t* const r_tensor = ccv_nnc_tensor_new(0, br_tensor->info, 0); |
278 | 1 | ccv_cnnp_model_t* scaled_dot_product_attention = ccv_cnnp_scaled_dot_product_attention(1.0 / 8, 0, 0, 0, 1, 0, 1, "scaled_dot_product_attention"); |
279 | 1 | ccv_nnc_tensor_param_t qkv[3]; |
280 | 1 | qkv[0] = q_tensor->info; |
281 | 1 | qkv[1] = k_tensor->info; |
282 | 1 | qkv[2] = v_tensor->info; |
283 | 1 | ccv_cnnp_model_compile(scaled_dot_product_attention, qkv, 3, CMD_NOOP(), CMD_NOOP()); |
284 | 1 | ccv_cnnp_model_set_parameter(scaled_dot_product_attention, ccv_cnnp_model_parameters(scaled_dot_product_attention, CCV_CNNP_PARAMETER_SELECT_WEIGHT, 0), bw_tensor); |
285 | 1 | ccv_cnnp_model_set_parameter(scaled_dot_product_attention, ccv_cnnp_model_parameters(scaled_dot_product_attention, CCV_CNNP_PARAMETER_SELECT_BIAS, 0), bbias_tensor); |
286 | 1 | ccv_cnnp_model_evaluate(scaled_dot_product_attention, (ccv_cnnp_evaluate_param_t){}, TENSOR_LIST(q_tensor, k_tensor, v_tensor), TENSOR_LIST(r_tensor), 0, 0); |
287 | 1 | CNNP_MODEL_GEN(scaled_dot_product_attention, CCV_NNC_LONG_DOT_GRAPH); |
288 | 1 | ccv_nnc_graph_run(sdp_graph, 0, TRAVERSE_FULL, 0, 0); |
289 | 1 | REQUIRE_TENSOR_EQ(r_tensor, br_tensor, "graph computed result should match scaled dot product attention op result"); |
290 | 1 | ccv_nnc_symbolic_graph_free(sdp_symbolic_graph); |
291 | 1 | ccv_nnc_tensor_arena_free(sdp_tensor_arena); |
292 | 1 | ccv_nnc_graph_exec_arena_free(sdp_graph_exec_arena); |
293 | 1 | ccv_nnc_graph_free(sdp_graph); |
294 | 1 | ccv_nnc_tensor_free(q_tensor); |
295 | 1 | ccv_nnc_tensor_free(k_tensor); |
296 | 1 | ccv_nnc_tensor_free(v_tensor); |
297 | 1 | ccv_nnc_tensor_free(r_tensor); |
298 | 1 | ccv_cnnp_model_free(scaled_dot_product_attention); |
299 | 1 | } |
300 | | |
301 | | TEST_CASE("run scaled dot product attention + attention mask with cnnp model") |
302 | 1 | { |
303 | 1 | ccv_nnc_symbolic_graph_t* const sdp_symbolic_graph = ccv_nnc_symbolic_graph_new(); |
304 | 1 | ccv_nnc_tensor_symbol_t bq = ccv_nnc_tensor_symbol_new(sdp_symbolic_graph, CPU_TENSOR_NHWC(32F, 32, 128, 8, 64), "q"); |
305 | 1 | ccv_nnc_tensor_symbol_t bk = ccv_nnc_tensor_symbol_new(sdp_symbolic_graph, CPU_TENSOR_NHWC(32F, 32, 128, 8, 64), "k"); |
306 | 1 | ccv_nnc_tensor_symbol_t bv = ccv_nnc_tensor_symbol_new(sdp_symbolic_graph, CPU_TENSOR_NHWC(32F, 32, 128, 8, 96), "v"); |
307 | 1 | ccv_nnc_tensor_symbol_t battn_mask = ccv_nnc_tensor_symbol_new(sdp_symbolic_graph, CPU_TENSOR_NHWC(32F, 1, 1, 128, 128), "attn_mask"); |
308 | 1 | ccv_nnc_tensor_symbol_t br = ccv_nnc_tensor_symbol_new(sdp_symbolic_graph, CPU_TENSOR_NHWC(32F, 32, 128, 8, 96), "r"); |
309 | 1 | ccv_nnc_graph_exec_symbol_new(sdp_symbolic_graph, CMD_SCALED_DOT_PRODUCT_ATTENTION_FORWARD(1.0 / 8, 0), TENSOR_SYMBOL_LIST(bq, bk, bv, battn_mask), TENSOR_SYMBOL_LIST(br), "scaled_dot_product_attention"); |
310 | 1 | ccv_nnc_graph_exec_symbol_autogen(sdp_symbolic_graph, 0, 0, CCV_NNC_AUTOGEN_ALL_EXECS | CCV_NNC_AUTOGEN_SOURCES_AND_DESTINATIONS); |
311 | 1 | ccv_nnc_graph_t* sdp_graph = 0; |
312 | 1 | ccv_nnc_tensor_arena_t* sdp_tensor_arena = 0; |
313 | 1 | ccv_nnc_graph_exec_arena_t* sdp_graph_exec_arena = 0; |
314 | 1 | ccv_nnc_symbolic_graph_compile(sdp_symbolic_graph, ccv_nnc_default_compile_params, 0, 0, 0, 0, SYMBOLIC_GRAPH_SOURCES(sdp_symbolic_graph), SYMBOLIC_GRAPH_DESTINATIONS(sdp_symbolic_graph), &sdp_graph, &sdp_tensor_arena, &sdp_graph_exec_arena); |
315 | 1 | ccv_nnc_tensor_t* const bq_tensor = ccv_nnc_tensor_from_symbol(sdp_tensor_arena, bq); |
316 | 1 | ccv_nnc_tensor_t* const bk_tensor = ccv_nnc_tensor_from_symbol(sdp_tensor_arena, bk); |
317 | 1 | ccv_nnc_tensor_t* const bv_tensor = ccv_nnc_tensor_from_symbol(sdp_tensor_arena, bv); |
318 | 1 | ccv_nnc_tensor_t* const battn_mask_tensor = ccv_nnc_tensor_from_symbol(sdp_tensor_arena, battn_mask); |
319 | 1 | int i, j; |
320 | 1 | dsfmt_t dsfmt; |
321 | 1 | dsfmt_init_gen_rand(&dsfmt, 1); |
322 | 2.09M | for (i = 0; i < 32 * 8 * 128 * 64; i++2.09M ) |
323 | 2.09M | bq_tensor->data.f32[i] = dsfmt_genrand_open_close(&dsfmt); |
324 | 2.09M | for (i = 0; i < 32 * 8 * 128 * 64; i++2.09M ) |
325 | 2.09M | bk_tensor->data.f32[i] = dsfmt_genrand_open_close(&dsfmt); |
326 | 3.14M | for (i = 0; i < 32 * 8 * 128 * 96; i++3.14M ) |
327 | 3.14M | bv_tensor->data.f32[i] = dsfmt_genrand_open_close(&dsfmt); |
328 | 16.3k | for (i = 0; i < 128 * 128; i++16.3k ) |
329 | 16.3k | battn_mask_tensor->data.f32[i] = 0; |
330 | 128 | for (i = 0; i < 127; i++127 ) |
331 | 8.25k | for (j = i + 1; 127 j < 128; j++8.12k ) |
332 | 8.12k | battn_mask_tensor->data.f32[i * 128 + j] = -FLT_MAX; |
333 | 1 | ccv_nnc_tensor_t* const q_tensor = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, 32, 128, 8, 64), 0); |
334 | 1 | ccv_nnc_tensor_t* const k_tensor = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, 32, 128, 8, 64), 0); |
335 | 1 | ccv_nnc_tensor_t* const v_tensor = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, 32, 128, 8, 96), 0); |
336 | 1 | ccv_nnc_tensor_t* const attn_mask_tensor = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, 1, 1, 128, 128), 0); |
337 | 1 | memcpy(q_tensor->data.f32, bq_tensor->data.f32, sizeof(float) * 32 * 8 * 128 * 64); |
338 | 1 | memcpy(k_tensor->data.f32, bk_tensor->data.f32, sizeof(float) * 32 * 8 * 128 * 64); |
339 | 1 | memcpy(v_tensor->data.f32, bv_tensor->data.f32, sizeof(float) * 32 * 8 * 128 * 96); |
340 | 1 | memcpy(attn_mask_tensor->data.f32, battn_mask_tensor->data.f32, sizeof(float) * 128 * 128); |
341 | 1 | ccv_nnc_graph_run(sdp_graph, 0, TRAVERSE_FULL, 0, 0); |
342 | 1 | ccv_nnc_tensor_t* const br_tensor = ccv_nnc_tensor_from_symbol(sdp_tensor_arena, br); |
343 | 1 | ccv_nnc_tensor_t* const r_tensor = ccv_nnc_tensor_new(0, br_tensor->info, 0); |
344 | 1 | ccv_cnnp_model_t* scaled_dot_product_attention = ccv_cnnp_scaled_dot_product_attention(1.0 / 8, 0, 1, 0, 0, 0, 0, "scaled_dot_product_attention"); |
345 | 1 | ccv_nnc_tensor_param_t qkv[4]; |
346 | 1 | qkv[0] = q_tensor->info; |
347 | 1 | qkv[1] = k_tensor->info; |
348 | 1 | qkv[2] = v_tensor->info; |
349 | 1 | qkv[3] = attn_mask_tensor->info; |
350 | 1 | ccv_cnnp_model_compile(scaled_dot_product_attention, qkv, 4, CMD_NOOP(), CMD_NOOP()); |
351 | 1 | ccv_cnnp_model_evaluate(scaled_dot_product_attention, (ccv_cnnp_evaluate_param_t){}, TENSOR_LIST(q_tensor, k_tensor, v_tensor, attn_mask_tensor), TENSOR_LIST(r_tensor), 0, 0); |
352 | 1 | CNNP_MODEL_GEN(scaled_dot_product_attention, CCV_NNC_LONG_DOT_GRAPH); |
353 | 1 | REQUIRE_TENSOR_EQ(r_tensor, br_tensor, "graph computed result should match scaled dot product attention op result"); |
354 | 1 | ccv_nnc_symbolic_graph_free(sdp_symbolic_graph); |
355 | 1 | ccv_nnc_tensor_arena_free(sdp_tensor_arena); |
356 | 1 | ccv_nnc_graph_exec_arena_free(sdp_graph_exec_arena); |
357 | 1 | ccv_nnc_graph_free(sdp_graph); |
358 | 1 | ccv_nnc_tensor_free(q_tensor); |
359 | 1 | ccv_nnc_tensor_free(k_tensor); |
360 | 1 | ccv_nnc_tensor_free(v_tensor); |
361 | 1 | ccv_nnc_tensor_free(attn_mask_tensor); |
362 | 1 | ccv_nnc_tensor_free(r_tensor); |
363 | 1 | ccv_cnnp_model_free(scaled_dot_product_attention); |
364 | 1 | } |
365 | | |
366 | | TEST_CASE("implement gradient of scaled dot product attention with fine-grained symbolic graph") |
367 | 1 | { |
368 | 1 | ccv_nnc_symbolic_graph_t* const symbolic_graph = ccv_nnc_symbolic_graph_new(); |
369 | 1 | ccv_nnc_tensor_symbol_t q = ccv_nnc_tensor_symbol_new(symbolic_graph, CPU_TENSOR_NHWC(32F, 32, 128, 8, 64), "q"); |
370 | 1 | ccv_nnc_tensor_symbol_t tq = ccv_nnc_tensor_symbol_new(symbolic_graph, CPU_TENSOR_NHWC(32F, 32, 8, 128, 64), "q"); |
371 | 1 | ccv_nnc_tensor_symbol_t k = ccv_nnc_tensor_symbol_new(symbolic_graph, CPU_TENSOR_NHWC(32F, 32, 128, 8, 64), "k"); |
372 | 1 | ccv_nnc_tensor_symbol_t tk = ccv_nnc_tensor_symbol_new(symbolic_graph, CPU_TENSOR_NHWC(32F, 32, 8, 128, 64), "k"); |
373 | 1 | ccv_nnc_tensor_symbol_t v = ccv_nnc_tensor_symbol_new(symbolic_graph, CPU_TENSOR_NHWC(32F, 32, 128, 8, 96), "v"); |
374 | 1 | ccv_nnc_tensor_symbol_t tv = ccv_nnc_tensor_symbol_new(symbolic_graph, CPU_TENSOR_NHWC(32F, 32, 8, 128, 96), "v"); |
375 | 1 | ccv_nnc_tensor_symbol_t qk = ccv_nnc_tensor_symbol_new(symbolic_graph, CPU_TENSOR_NHWC(32F, 32, 8, 128, 128), "qk"); |
376 | 1 | ccv_nnc_tensor_symbol_t sq = ccv_nnc_tensor_symbol_new(symbolic_graph, CPU_TENSOR_NHWC(32F, 32, 8, 128, 64), "sq"); |
377 | 1 | ccv_nnc_graph_exec_symbol_new(symbolic_graph, CMD_TRANSPOSE_FORWARD(1, 2), TENSOR_SYMBOL_LIST(q), TENSOR_SYMBOL_LIST(tq), "transpose q"); |
378 | 1 | ccv_nnc_graph_exec_symbol_new(symbolic_graph, CMD_TRANSPOSE_FORWARD(1, 2), TENSOR_SYMBOL_LIST(k), TENSOR_SYMBOL_LIST(tk), "transpose k"); |
379 | 1 | ccv_nnc_graph_exec_symbol_new(symbolic_graph, CMD_TRANSPOSE_FORWARD(1, 2), TENSOR_SYMBOL_LIST(v), TENSOR_SYMBOL_LIST(tv), "transpose v"); |
380 | 1 | ccv_nnc_graph_exec_symbol_new(symbolic_graph, CMD_SCALAR_MUL_FORWARD(1.0 / 8), TENSOR_SYMBOL_LIST(tq), TENSOR_SYMBOL_LIST(sq), "scaled_q"); |
381 | 1 | ccv_nnc_graph_exec_symbol_new(symbolic_graph, CMD_GEMM_FORWARD(NO_TRANSPOSE, TRANSPOSE(2, 3)), TENSOR_SYMBOL_LIST(sq, tk), TENSOR_SYMBOL_LIST(qk), "q @ k"); |
382 | 1 | ccv_nnc_tensor_symbol_t qks = ccv_nnc_tensor_symbol_alias_new(symbolic_graph, qk, DIM_ALLOC(), DIM_ALLOC(128, 1), CPU_TENSOR_NHWC(32F, 32 * 8 * 128, 128), "qks"); |
383 | 1 | ccv_nnc_tensor_symbol_t s = ccv_nnc_tensor_symbol_new(symbolic_graph, CPU_TENSOR_NHWC(32F, 32 * 8 * 128, 128), "s"); |
384 | 1 | ccv_nnc_graph_exec_symbol_new(symbolic_graph, CMD_SOFTMAX_FORWARD(), TENSOR_SYMBOL_LIST(qks), TENSOR_SYMBOL_LIST(s), "softmax"); |
385 | 1 | ccv_nnc_tensor_symbol_t sa = ccv_nnc_tensor_symbol_alias_new(symbolic_graph, s, DIM_ALLOC(), DIM_ALLOC(8 * 128 * 128, 128 * 128, 128, 1), CPU_TENSOR_NHWC(32F, 32, 8, 128, 128), "sa"); |
386 | 1 | ccv_nnc_tensor_symbol_t r = ccv_nnc_tensor_symbol_new(symbolic_graph, CPU_TENSOR_NHWC(32F, 32, 8, 128, 96), "f"); |
387 | 1 | ccv_nnc_tensor_symbol_t tr = ccv_nnc_tensor_symbol_new(symbolic_graph, CPU_TENSOR_NHWC(32F, 32, 128, 8, 96), "f"); |
388 | 1 | ccv_nnc_graph_exec_symbol_new(symbolic_graph, CMD_GEMM_FORWARD(NO_TRANSPOSE, NO_TRANSPOSE), TENSOR_SYMBOL_LIST(sa, tv), TENSOR_SYMBOL_LIST(r), "final"); |
389 | 1 | ccv_nnc_graph_exec_symbol_new(symbolic_graph, CMD_TRANSPOSE_FORWARD(1, 2), TENSOR_SYMBOL_LIST(r), TENSOR_SYMBOL_LIST(tr), "final"); |
390 | 1 | ccv_nnc_graph_exec_symbol_autogen(symbolic_graph, 0, 0, CCV_NNC_AUTOGEN_ALL_EXECS | CCV_NNC_AUTOGEN_SOURCES_AND_DESTINATIONS); |
391 | 1 | ccv_nnc_symbolic_graph_backward(symbolic_graph, TENSOR_SYMBOL_LIST(tr), TENSOR_SYMBOL_LIST(q, k, v), SYMBOLIC_GRAPH_SOURCES(symbolic_graph), SYMBOLIC_GRAPH_DESTINATIONS(symbolic_graph)); |
392 | 1 | ccv_nnc_graph_exec_symbol_autogen(symbolic_graph, 0, 0, CCV_NNC_AUTOGEN_ALL_EXECS | CCV_NNC_AUTOGEN_SOURCES_AND_DESTINATIONS); |
393 | 1 | SYMBOLIC_GRAPH_GEN(symbolic_graph, CCV_NNC_LONG_DOT_GRAPH); |
394 | 1 | ccv_nnc_graph_t* graph = 0; |
395 | 1 | ccv_nnc_tensor_arena_t* tensor_arena = 0; |
396 | 1 | ccv_nnc_graph_exec_arena_t* graph_exec_arena = 0; |
397 | 1 | ccv_nnc_symbolic_graph_compile(symbolic_graph, ccv_nnc_default_compile_params, 0, 0, 0, 0, SYMBOLIC_GRAPH_SOURCES(symbolic_graph), SYMBOLIC_GRAPH_DESTINATIONS(symbolic_graph), &graph, &tensor_arena, &graph_exec_arena); |
398 | 1 | GRAPH_GEN(graph, CCV_NNC_LONG_DOT_GRAPH); |
399 | 1 | ccv_nnc_tensor_t* const q_tensor = ccv_nnc_tensor_from_symbol(tensor_arena, q); |
400 | 1 | ccv_nnc_tensor_t* const k_tensor = ccv_nnc_tensor_from_symbol(tensor_arena, k); |
401 | 1 | ccv_nnc_tensor_t* const v_tensor = ccv_nnc_tensor_from_symbol(tensor_arena, v); |
402 | 1 | dsfmt_t dsfmt; |
403 | 1 | int i; |
404 | 1 | dsfmt_init_gen_rand(&dsfmt, 1); |
405 | 2.09M | for (i = 0; i < 32 * 8 * 128 * 64; i++2.09M ) |
406 | 2.09M | q_tensor->data.f32[i] = dsfmt_genrand_open_close(&dsfmt); |
407 | 2.09M | for (i = 0; i < 32 * 8 * 128 * 64; i++2.09M ) |
408 | 2.09M | k_tensor->data.f32[i] = dsfmt_genrand_open_close(&dsfmt); |
409 | 3.14M | for (i = 0; i < 32 * 8 * 128 * 96; i++3.14M ) |
410 | 3.14M | v_tensor->data.f32[i] = dsfmt_genrand_open_close(&dsfmt); |
411 | 1 | ccv_nnc_tensor_symbol_t dr = ccv_nnc_tensor_symbol_for_backward(symbolic_graph, tr); |
412 | 1 | ccv_nnc_tensor_t* const dr_tensor = ccv_nnc_tensor_from_symbol(tensor_arena, dr); |
413 | 3.14M | for (i = 0; i < 32 * 128 * 8 * 96; i++3.14M ) |
414 | 3.14M | dr_tensor->data.f32[i] = 1; |
415 | 1 | ccv_nnc_symbolic_graph_t* const sdp_symbolic_graph = ccv_nnc_symbolic_graph_new(); |
416 | 1 | ccv_nnc_tensor_symbol_t bq = ccv_nnc_tensor_symbol_new(sdp_symbolic_graph, CPU_TENSOR_NHWC(32F, 32, 128, 8, 64), "q"); |
417 | 1 | ccv_nnc_tensor_symbol_t bk = ccv_nnc_tensor_symbol_new(sdp_symbolic_graph, CPU_TENSOR_NHWC(32F, 32, 128, 8, 64), "k"); |
418 | 1 | ccv_nnc_tensor_symbol_t bv = ccv_nnc_tensor_symbol_new(sdp_symbolic_graph, CPU_TENSOR_NHWC(32F, 32, 128, 8, 96), "v"); |
419 | 1 | ccv_nnc_tensor_symbol_t br = ccv_nnc_tensor_symbol_new(sdp_symbolic_graph, CPU_TENSOR_NHWC(32F, 32, 128, 8, 96), "r"); |
420 | 1 | ccv_nnc_graph_exec_symbol_new(sdp_symbolic_graph, CMD_SCALED_DOT_PRODUCT_ATTENTION_FORWARD(1.0 / 8, 0), TENSOR_SYMBOL_LIST(bq, bk, bv), TENSOR_SYMBOL_LIST(br, NO_TENSOR_SYMBOL, NO_TENSOR_SYMBOL), "scaled_dot_product_attention"); |
421 | 1 | ccv_nnc_graph_exec_symbol_autogen(sdp_symbolic_graph, 0, 0, CCV_NNC_AUTOGEN_ALL_EXECS | CCV_NNC_AUTOGEN_SOURCES_AND_DESTINATIONS); |
422 | 1 | ccv_nnc_symbolic_graph_backward(sdp_symbolic_graph, TENSOR_SYMBOL_LIST(br), TENSOR_SYMBOL_LIST(bq, bk, bv), SYMBOLIC_GRAPH_SOURCES(sdp_symbolic_graph), SYMBOLIC_GRAPH_DESTINATIONS(sdp_symbolic_graph)); |
423 | 1 | ccv_nnc_graph_exec_symbol_autogen(sdp_symbolic_graph, 0, 0, CCV_NNC_AUTOGEN_ALL_EXECS | CCV_NNC_AUTOGEN_SOURCES_AND_DESTINATIONS); |
424 | 1 | SYMBOLIC_GRAPH_GEN(sdp_symbolic_graph, CCV_NNC_LONG_DOT_GRAPH); |
425 | 1 | ccv_nnc_graph_t* sdp_graph = 0; |
426 | 1 | ccv_nnc_tensor_arena_t* sdp_tensor_arena = 0; |
427 | 1 | ccv_nnc_graph_exec_arena_t* sdp_graph_exec_arena = 0; |
428 | 1 | ccv_nnc_symbolic_graph_compile(sdp_symbolic_graph, ccv_nnc_default_compile_params, 0, 0, 0, 0, SYMBOLIC_GRAPH_SOURCES(sdp_symbolic_graph), SYMBOLIC_GRAPH_DESTINATIONS(sdp_symbolic_graph), &sdp_graph, &sdp_tensor_arena, &sdp_graph_exec_arena); |
429 | 1 | GRAPH_GEN(sdp_graph, CCV_NNC_LONG_DOT_GRAPH); |
430 | 1 | ccv_nnc_tensor_t* const bq_tensor = ccv_nnc_tensor_from_symbol(sdp_tensor_arena, bq); |
431 | 1 | ccv_nnc_tensor_t* const bk_tensor = ccv_nnc_tensor_from_symbol(sdp_tensor_arena, bk); |
432 | 1 | ccv_nnc_tensor_t* const bv_tensor = ccv_nnc_tensor_from_symbol(sdp_tensor_arena, bv); |
433 | 1 | memcpy(bq_tensor->data.f32, q_tensor->data.f32, sizeof(float) * 32 * 8 * 128 * 64); |
434 | 1 | memcpy(bk_tensor->data.f32, k_tensor->data.f32, sizeof(float) * 32 * 8 * 128 * 64); |
435 | 1 | memcpy(bv_tensor->data.f32, v_tensor->data.f32, sizeof(float) * 32 * 8 * 128 * 96); |
436 | 1 | ccv_nnc_tensor_symbol_t dbr = ccv_nnc_tensor_symbol_for_backward(sdp_symbolic_graph, br); |
437 | 1 | ccv_nnc_tensor_t* const dbr_tensor = ccv_nnc_tensor_from_symbol(sdp_tensor_arena, dbr); |
438 | 3.14M | for (i = 0; i < 32 * 128 * 8 * 96; i++3.14M ) |
439 | 3.14M | dbr_tensor->data.f32[i] = 1; |
440 | | // CCV_CLI_SET_OUTPUT_LEVEL_AND_ABOVE(CCV_CLI_VERBOSE); |
441 | 1 | ccv_nnc_graph_run(graph, 0, TRAVERSE_FULL, 0, 0); |
442 | 1 | ccv_nnc_graph_run(sdp_graph, 0, TRAVERSE_FULL, 0, 0); |
443 | 1 | ccv_nnc_tensor_symbol_t dq = ccv_nnc_tensor_symbol_for_backward(symbolic_graph, q); |
444 | 1 | ccv_nnc_tensor_symbol_t dk = ccv_nnc_tensor_symbol_for_backward(symbolic_graph, k); |
445 | 1 | ccv_nnc_tensor_symbol_t dv = ccv_nnc_tensor_symbol_for_backward(symbolic_graph, v); |
446 | 1 | ccv_nnc_tensor_t* const dq_tensor = ccv_nnc_tensor_from_symbol(tensor_arena, dq); |
447 | 1 | ccv_nnc_tensor_t* const dk_tensor = ccv_nnc_tensor_from_symbol(tensor_arena, dk); |
448 | 1 | ccv_nnc_tensor_t* const dv_tensor = ccv_nnc_tensor_from_symbol(tensor_arena, dv); |
449 | 1 | ccv_nnc_tensor_symbol_t dbq = ccv_nnc_tensor_symbol_for_backward(sdp_symbolic_graph, bq); |
450 | 1 | ccv_nnc_tensor_symbol_t dbk = ccv_nnc_tensor_symbol_for_backward(sdp_symbolic_graph, bk); |
451 | 1 | ccv_nnc_tensor_symbol_t dbv = ccv_nnc_tensor_symbol_for_backward(sdp_symbolic_graph, bv); |
452 | 1 | ccv_nnc_tensor_t* const dbq_tensor = ccv_nnc_tensor_from_symbol(sdp_tensor_arena, dbq); |
453 | 1 | ccv_nnc_tensor_t* const dbk_tensor = ccv_nnc_tensor_from_symbol(sdp_tensor_arena, dbk); |
454 | 1 | ccv_nnc_tensor_t* const dbv_tensor = ccv_nnc_tensor_from_symbol(sdp_tensor_arena, dbv); |
455 | 1 | REQUIRE_ARRAY_EQ_WITH_TOLERANCE(float, dq_tensor->data.f32, dbq_tensor->data.f32, 32 * 128 * 8 * 64, 1e-5, "graph computed gradient should match scaled dot product attention op gradient"); |
456 | 1 | REQUIRE_ARRAY_EQ_WITH_TOLERANCE(float, dk_tensor->data.f32, dbk_tensor->data.f32, 32 * 128 * 8 * 64, 1e-5, "graph computed gradient should match scaled dot product attention op gradient"); |
457 | 1 | REQUIRE_ARRAY_EQ_WITH_TOLERANCE(float, dv_tensor->data.f32, dbv_tensor->data.f32, 32 * 128 * 8 * 96, 1e-5, "graph computed gradient should match scaled dot product attention op gradient"); |
458 | 1 | ccv_nnc_symbolic_graph_free(symbolic_graph); |
459 | 1 | ccv_nnc_tensor_arena_free(tensor_arena); |
460 | 1 | ccv_nnc_graph_exec_arena_free(graph_exec_arena); |
461 | 1 | ccv_nnc_graph_free(graph); |
462 | 1 | ccv_nnc_symbolic_graph_free(sdp_symbolic_graph); |
463 | 1 | ccv_nnc_tensor_arena_free(sdp_tensor_arena); |
464 | 1 | ccv_nnc_graph_exec_arena_free(sdp_graph_exec_arena); |
465 | 1 | ccv_nnc_graph_free(sdp_graph); |
466 | 1 | } |
467 | | |
468 | | #include "case_main.h" |