/home/liu/actions-runner/_work/ccv/ccv/lib/nnc/cmd/scaled_dot_product_attention/ccv_nnc_scaled_dot_product_attention_cpu_ref.c
Line | Count | Source |
1 | | #include "ccv.h" |
2 | | #include "ccv_internal.h" |
3 | | #include "nnc/ccv_nnc.h" |
4 | | #include "nnc/ccv_nnc_easy.h" |
5 | | #include "nnc/ccv_nnc_internal.h" |
6 | | #ifdef USE_OPENMP |
7 | | #include <omp.h> |
8 | | #endif |
9 | | #ifdef USE_DISPATCH |
10 | | #include <dispatch/dispatch.h> |
11 | | #endif |
12 | | |
13 | | // Shared methods. |
14 | | #include "../_ccv_nnc_cpu_ref.h" |
15 | | |
16 | | static int _ccv_nnc_scaled_dot_product_attention_forw(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint, const int flags, ccv_nnc_tensor_t* const* const inputs, const int input_size, ccv_nnc_tensor_t* const* const outputs, const int output_size, ccv_nnc_stream_context_t* const stream_context) |
17 | 9 | { |
18 | 9 | assert(input_size >= 3); |
19 | 9 | assert(output_size >= 1); |
20 | 9 | ccv_nnc_tensor_view_t* const q = (ccv_nnc_tensor_view_t*)inputs[0]; |
21 | 9 | ccv_nnc_tensor_view_t* const k = (ccv_nnc_tensor_view_t*)inputs[1]; |
22 | 9 | ccv_nnc_tensor_view_t* const v = (ccv_nnc_tensor_view_t*)inputs[2]; |
23 | 9 | ccv_nnc_tensor_view_t* const attn_mask = input_size > 3 ? (ccv_nnc_tensor_view_t*)inputs[3]6 : 03 ; |
24 | 9 | ccv_nnc_tensor_view_t* const w = input_size > 4 ? (ccv_nnc_tensor_view_t*)inputs[4]5 : 04 ; |
25 | 9 | ccv_nnc_tensor_view_t* const bias = input_size > 5 ? (ccv_nnc_tensor_view_t*)inputs[5]5 : 04 ; |
26 | 9 | if (bias) // bias always requires a weight matrix. |
27 | 3 | { assert(w); } |
28 | 9 | ccv_nnc_tensor_view_t* const c = (w) ? (ccv_nnc_tensor_view_t*)outputs[2]3 : (ccv_nnc_tensor_view_t*)outputs[0]6 ; |
29 | 9 | const int q_nd = ccv_nnc_tensor_nd(q->info.dim); |
30 | 9 | assert(q_nd == 3 || q_nd == 4); |
31 | 9 | const int k_nd = ccv_nnc_tensor_nd(k->info.dim); |
32 | 9 | assert(k_nd == 3 || k_nd == 4); |
33 | 9 | const int v_nd = ccv_nnc_tensor_nd(v->info.dim); |
34 | 9 | assert(v_nd == 3 || v_nd == 4); |
35 | 9 | const int c_nd = ccv_nnc_tensor_nd(c->info.dim); |
36 | 9 | assert(c_nd == 3 || c_nd == 4); |
37 | 9 | assert(q_nd == k_nd && k_nd == v_nd && v_nd == c_nd); |
38 | | // Assuming this is float 32. |
39 | 9 | int qdim[CCV_NNC_MAX_DIM_ALLOC]; |
40 | 9 | int kdim[CCV_NNC_MAX_DIM_ALLOC]; |
41 | 9 | int vdim[CCV_NNC_MAX_DIM_ALLOC]; |
42 | 9 | int cdim[CCV_NNC_MAX_DIM_ALLOC]; |
43 | 9 | int amdim[CCV_NNC_MAX_DIM_ALLOC]; |
44 | 9 | ccv_nnc_tensor_view_get_dim(q, qdim); |
45 | 9 | ccv_nnc_tensor_view_get_dim(k, kdim); |
46 | 9 | ccv_nnc_tensor_view_get_dim(v, vdim); |
47 | 9 | ccv_nnc_tensor_view_get_dim(c, cdim); |
48 | 9 | if (q_nd == 3) |
49 | 0 | { |
50 | 0 | qdim[0] = qdim[1], qdim[1] = qdim[2], qdim[2] = 1; |
51 | 0 | kdim[0] = kdim[1], kdim[1] = kdim[2], kdim[2] = 1; |
52 | 0 | vdim[0] = vdim[1], vdim[1] = vdim[2], vdim[2] = 1; |
53 | 0 | cdim[0] = cdim[1], cdim[1] = cdim[2], cdim[2] = 1; |
54 | 0 | } |
55 | 9 | assert(qdim[0] == kdim[0] && kdim[0] == vdim[0] && vdim[0] == cdim[0]); |
56 | 9 | assert(qdim[2] == cdim[2]); |
57 | 9 | assert(kdim[2] == vdim[2]); |
58 | 9 | assert(qdim[2] % kdim[2] == 0); |
59 | 9 | assert(qdim[2] >= kdim[2]); |
60 | 9 | assert(qdim[3] == kdim[3]); |
61 | 9 | assert(kdim[1] == vdim[1]); |
62 | 9 | assert(cdim[1] == qdim[1]); |
63 | 9 | assert(cdim[3] == vdim[3]); |
64 | 9 | assert(CCV_NNC_MAX_DIM == 2); // Need to change this logic for CCV_NNC_MAX_DIM == other number. |
65 | 9 | int qstride[CCV_NNC_MAX_DIM_ALLOC]; |
66 | 9 | int kstride[CCV_NNC_MAX_DIM_ALLOC]; |
67 | 9 | int vstride[CCV_NNC_MAX_DIM_ALLOC]; |
68 | 9 | int cstride[CCV_NNC_MAX_DIM_ALLOC]; |
69 | 9 | int amstride[CCV_NNC_MAX_DIM_ALLOC]; |
70 | 9 | ccv_nnc_tensor_view_get_stride(q, qstride); |
71 | 9 | ccv_nnc_tensor_view_get_stride(k, kstride); |
72 | 9 | ccv_nnc_tensor_view_get_stride(v, vstride); |
73 | 9 | ccv_nnc_tensor_view_get_stride(c, cstride); |
74 | 9 | if (q_nd == 3) |
75 | 0 | { |
76 | 0 | qstride[0] = qstride[1], qstride[1] = qstride[2], qstride[2] = qstride[3]; |
77 | 0 | kstride[0] = kstride[1], kstride[1] = kstride[2], kstride[2] = kstride[3]; |
78 | 0 | vstride[0] = vstride[1], vstride[1] = vstride[2], vstride[2] = vstride[3]; |
79 | 0 | cstride[0] = cstride[1], cstride[1] = cstride[2], cstride[2] = cstride[3]; |
80 | 0 | } |
81 | 9 | if (attn_mask) |
82 | 2 | { |
83 | 2 | ccv_nnc_tensor_view_get_dim(attn_mask, amdim); |
84 | 2 | ccv_nnc_tensor_view_get_stride(attn_mask, amstride); |
85 | 2 | assert(amdim[0] == qdim[0] || amdim[0] == 1); |
86 | 2 | assert(amdim[1] == qdim[2] || amdim[1] == 1); |
87 | 2 | assert(amdim[2] == qdim[1]); |
88 | 2 | assert(amdim[3] == kdim[1]); |
89 | 2 | } |
90 | 9 | int i[CCV_NNC_MAX_DIM + 2]; |
91 | 9 | float* qk = ccv_nnc_stream_context_get_workspace(stream_context, sizeof(float) * qdim[1] * kdim[1], CCV_TENSOR_CPU_MEMORY); |
92 | 9 | const float* const qp = q->data.f32; |
93 | 9 | const float* const kp = k->data.f32; |
94 | 9 | const float* const vp = v->data.f32; |
95 | 9 | const float* const amp = attn_mask ? attn_mask->data.f322 : 07 ; |
96 | 9 | float* const cp = c->data.f32; |
97 | 9 | const float scale = cmd.info.scaled_dot_product_attention.scale; |
98 | 9 | const int is_causal = cmd.info.scaled_dot_product_attention.is_causal; |
99 | 9 | const int h_h_k_ratio = qdim[2] / kdim[2]; |
100 | 9 | assert(kdim[2] == vdim[2]); |
101 | 9 | assert(qdim[2] >= kdim[2]); |
102 | 9 | assert(qdim[2] % kdim[2] == 0); |
103 | 297 | for (i[0] = 0; 9 i[0] < qdim[0]; i[0]++288 ) |
104 | 288 | { |
105 | 288 | const float* const qp0 = qp + i[0] * qstride[0]; |
106 | 288 | const float* const kp0 = kp + i[0] * kstride[0]; |
107 | 288 | const float* const vp0 = vp + i[0] * vstride[0]; |
108 | 288 | const float* const amp0 = amp && amdim[0] > 164 ? amp + i[0] * amstride[0]0 : amp; |
109 | 288 | float* const cp0 = cp + i[0] * cstride[0]; |
110 | 2.59k | for (i[1] = 0; i[1] < qdim[2]; i[1]++2.30k ) |
111 | 2.30k | { |
112 | 2.30k | const float* const qp1 = qp0 + i[1] * qstride[2]; |
113 | 2.30k | const float* const kp1 = kp0 + (i[1] / h_h_k_ratio) * kstride[2]; |
114 | 2.30k | const float* const vp1 = vp0 + (i[1] / h_h_k_ratio) * vstride[2]; |
115 | 2.30k | const float* const amp1 = amp && amdim[1] > 1512 ? amp0 + i[1] * amstride[1]0 : amp0; |
116 | 2.30k | float* const cp1 = cp0 + i[1] * cstride[2]; |
117 | | // Compute Q @ K^T |
118 | 294k | parallel_for2.30k (x, qdim[1]) { |
119 | 294k | int y, k; |
120 | 294k | const float* const qp2 = qp1 + x * qstride[1]; |
121 | 294k | float* const cp2 = cp1 + x * cstride[1]; |
122 | 294k | float* const qk0 = qk + x * kdim[1]; |
123 | 294k | const float* const amp2 = amp1 ? amp1 + x * amstride[2]65.5k : 0229k ; |
124 | 294k | if (attn_mask) |
125 | 65.5k | { |
126 | 8.45M | for (y = 0; y < kdim[1]; y++8.38M ) |
127 | 8.38M | { |
128 | 8.38M | const float* const kp2 = kp1 + y * kstride[1]; |
129 | 8.38M | float v = 0; |
130 | 545M | for (k = 0; k < qdim[3]; k++536M ) |
131 | 536M | v += qp2[k * qstride[3]] * kp2[k * kstride[3]]; |
132 | 8.38M | qk0[y] = scale * v + amp2[y * amstride[3]]; |
133 | 8.38M | } |
134 | 229k | } else { |
135 | 29.5M | for (y = 0; y < kdim[1]; y++29.3M ) |
136 | 29.3M | { |
137 | 29.3M | const float* const kp2 = kp1 + y * kstride[1]; |
138 | 29.3M | float v = 0; |
139 | 1.90G | for (k = 0; k < qdim[3]; k++1.87G ) |
140 | 1.87G | v += qp2[k * qstride[3]] * kp2[k * kstride[3]]; |
141 | 29.3M | qk0[y] = scale * v; |
142 | 29.3M | } |
143 | 229k | } |
144 | | // Compute softmax on qk. |
145 | 294k | if (is_causal) |
146 | 0 | { |
147 | 0 | const int x_end = ccv_max(x - qdim[1] + kdim[1] + 1, 0); |
148 | 0 | for (y = x_end; y < kdim[1]; y++) |
149 | 0 | qk0[y] = 0; |
150 | 0 | double maxval = qk0[0]; |
151 | 0 | for (y = 1; y < x_end; y++) |
152 | 0 | if (qk0[y] > maxval) |
153 | 0 | maxval = qk0[y]; |
154 | 0 | double sumval = 0; |
155 | 0 | for (y = 0; y < x_end; y++) |
156 | 0 | sumval += (qk0[y] = expf(qk0[y] - maxval)); |
157 | 0 | sumval = 1.0 / sumval; |
158 | 0 | for (y = 0; y < x_end; y++) |
159 | 0 | qk0[y] *= sumval; |
160 | 294k | } else { |
161 | 294k | double maxval = qk0[0]; |
162 | 37.7M | for (y = 1; y < kdim[1]; y++37.4M ) |
163 | 37.4M | if (qk0[y] > maxval) |
164 | 1.25M | maxval = qk0[y]; |
165 | 294k | double sumval = 0; |
166 | 38.0M | for (y = 0; y < kdim[1]; y++37.7M ) |
167 | 37.7M | sumval += (qk0[y] = expf(qk0[y] - maxval)); |
168 | 294k | sumval = 1.0 / sumval; |
169 | 38.0M | for (y = 0; y < kdim[1]; y++37.7M ) |
170 | 37.7M | qk0[y] *= sumval; |
171 | 294k | } |
172 | 28.6M | for (k = 0; k < vdim[3]; k++28.3M ) |
173 | 28.3M | cp2[k * cstride[3]] = 0; |
174 | 38.0M | for (y = 0; y < kdim[1]; y++37.7M ) |
175 | 37.7M | { |
176 | 37.7M | const float* const vp2 = vp1 + y * vstride[1]; |
177 | 37.7M | const float v = qk0[y]; |
178 | 3.66G | for (k = 0; k < vdim[3]; k++3.62G ) |
179 | 3.62G | cp2[k * cstride[3]] += v * vp2[k * vstride[3]]; |
180 | 37.7M | } |
181 | 294k | } parallel_endfor |
182 | 2.30k | } |
183 | 288 | } |
184 | 9 | if (w) |
185 | 3 | { |
186 | 3 | const int num_heads = cdim[2]; |
187 | 3 | ccv_nnc_tensor_view_t* const d = (ccv_nnc_tensor_view_t*)outputs[0]; |
188 | 3 | const int w_nd = ccv_nnc_tensor_nd(w->info.dim); |
189 | 3 | assert(w_nd == 2); |
190 | 3 | assert(CCV_IS_TENSOR_CONTIGUOUS(w)); |
191 | 3 | const int d_nd = ccv_nnc_tensor_nd(d->info.dim); |
192 | 3 | assert(d_nd == 3); |
193 | 3 | int ddim[CCV_NNC_MAX_DIM_ALLOC]; |
194 | 3 | int dstride[CCV_NNC_MAX_DIM_ALLOC]; |
195 | 3 | ccv_nnc_tensor_view_get_dim(d, ddim); |
196 | 3 | ccv_nnc_tensor_view_get_stride(d, dstride); |
197 | 3 | assert(ddim[2] == cdim[1]); |
198 | 3 | assert(ddim[3] == num_heads * cdim[3]); |
199 | 3 | assert(w->info.dim[1] == ddim[3]); |
200 | 3 | assert(w->info.dim[0] == ddim[3]); |
201 | 3 | float* const dp = d->data.f32; |
202 | 3 | const float* const wp = w->data.f32; |
203 | 3 | const float* const cp = c->data.f32; |
204 | 3 | if (bias) |
205 | 3 | { |
206 | 3 | assert(ccv_nnc_tensor_count(bias->info) == ddim[3]); |
207 | 3 | assert(CCV_IS_TENSOR_CONTIGUOUS(bias)); |
208 | 3 | const float* const biasp = bias->data.f32; |
209 | 99 | for (i[0] = 0; i[0] < ddim[1]; i[0]++96 ) |
210 | 96 | { |
211 | 96 | const float* const cp0 = cp + i[0] * cstride[0]; |
212 | 96 | float* const dp0 = dp + i[0] * dstride[1]; |
213 | 12.2k | parallel_for96 (y, ddim[2]) { |
214 | 12.2k | int x, j, k; |
215 | 12.2k | const float* const cp1 = cp0 + y * cstride[1]; |
216 | 12.2k | float* const dp1 = dp0 + y * dstride[2]; |
217 | 9.44M | for (x = 0; x < ddim[3]; x++9.43M ) |
218 | 9.43M | { |
219 | 9.43M | const float* const wp0 = wp + x * ddim[3]; |
220 | 9.43M | float v = biasp[x]; |
221 | 84.9M | for (j = 0; j < num_heads; j++75.4M ) |
222 | 75.4M | { |
223 | 75.4M | const float* const cp2 = cp1 + j * cstride[2]; |
224 | 7.32G | for (k = 0; k < cdim[3]; k++7.24G ) |
225 | 7.24G | v += wp0[j * cdim[3] + k] * cp2[k * cstride[3]]; |
226 | 75.4M | } |
227 | 9.43M | dp1[x * dstride[3]] = v; |
228 | 9.43M | } |
229 | 12.2k | } parallel_endfor |
230 | 96 | } |
231 | 3 | } else { |
232 | 0 | for (i[0] = 0; i[0] < ddim[1]; i[0]++) |
233 | 0 | { |
234 | 0 | const float* const cp0 = cp + i[0] * cstride[0]; |
235 | 0 | float* const dp0 = dp + i[0] * dstride[1]; |
236 | 0 | parallel_for(y, ddim[2]) { |
237 | 0 | int x, j, k; |
238 | 0 | const float* const cp1 = cp0 + y * cstride[1]; |
239 | 0 | float* const dp1 = dp0 + y * dstride[2]; |
240 | 0 | for (x = 0; x < ddim[3]; x++) |
241 | 0 | { |
242 | 0 | const float* const wp0 = wp + x * ddim[3]; |
243 | 0 | float v = 0; |
244 | 0 | for (j = 0; j < num_heads; j++) |
245 | 0 | { |
246 | 0 | const float* const cp2 = cp1 + j * cstride[2]; |
247 | 0 | for (k = 0; k < cdim[3]; k++) |
248 | 0 | v += wp0[j * cdim[3] + k] * cp2[k * cstride[3]]; |
249 | 0 | } |
250 | 0 | dp1[x * dstride[3]] = v; |
251 | 0 | } |
252 | 0 | } parallel_endfor |
253 | 0 | } |
254 | 0 | } |
255 | 3 | } |
256 | 9 | return CCV_NNC_EXEC_SUCCESS; |
257 | 9 | } |
258 | | |
259 | | static int _ccv_nnc_scaled_dot_product_attention_back(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint, const int flags, ccv_nnc_tensor_t* const* const inputs, const int input_size, ccv_nnc_tensor_t* const* const outputs, const int output_size, ccv_nnc_stream_context_t* const stream_context) |
260 | 1 | { |
261 | | // Assuming no saved_softmax, we need to recompute from q, k, v. |
262 | | // We cannot do this with masks (yet). |
263 | 1 | assert(input_size >= 6); |
264 | 1 | ccv_nnc_tensor_view_t* const g = (ccv_nnc_tensor_view_t*)inputs[0]; |
265 | 1 | ccv_nnc_tensor_view_t* const q = (ccv_nnc_tensor_view_t*)inputs[3]; |
266 | 1 | ccv_nnc_tensor_view_t* const k = (ccv_nnc_tensor_view_t*)inputs[4]; |
267 | 1 | ccv_nnc_tensor_view_t* const v = (ccv_nnc_tensor_view_t*)inputs[5]; |
268 | 1 | ccv_nnc_tensor_view_t* const dq = (ccv_nnc_tensor_view_t*)outputs[0]; |
269 | 1 | ccv_nnc_tensor_view_t* const dk = (ccv_nnc_tensor_view_t*)outputs[1]; |
270 | 1 | ccv_nnc_tensor_view_t* const dv = (ccv_nnc_tensor_view_t*)outputs[2]; |
271 | 1 | const int q_nd = ccv_nnc_tensor_nd(q->info.dim); |
272 | 1 | assert(q_nd == 3 || q_nd == 4); |
273 | 1 | const int k_nd = ccv_nnc_tensor_nd(k->info.dim); |
274 | 1 | assert(k_nd == 3 || k_nd == 4); |
275 | 1 | const int v_nd = ccv_nnc_tensor_nd(v->info.dim); |
276 | 1 | assert(v_nd == 3 || v_nd == 4); |
277 | 1 | const int g_nd = ccv_nnc_tensor_nd(g->info.dim); |
278 | 1 | assert(g_nd == 3 || g_nd == 4); |
279 | 1 | const int dq_nd = ccv_nnc_tensor_nd(dq->info.dim); |
280 | 1 | assert(dq_nd == 3 || dq_nd == 4); |
281 | 1 | assert(dq_nd == q_nd); |
282 | 1 | const int dk_nd = ccv_nnc_tensor_nd(dk->info.dim); |
283 | 1 | assert(dk_nd == 3 || dk_nd == 4); |
284 | 1 | assert(dk_nd == k_nd); |
285 | 1 | const int dv_nd = ccv_nnc_tensor_nd(dv->info.dim); |
286 | 1 | assert(dv_nd == 3 || dv_nd == 4); |
287 | 1 | assert(dv_nd == v_nd); |
288 | 1 | assert(q_nd == k_nd && k_nd == v_nd && v_nd == g_nd); |
289 | | // Assuming this is float 32. |
290 | 1 | int qdim[CCV_NNC_MAX_DIM_ALLOC]; |
291 | 1 | int kdim[CCV_NNC_MAX_DIM_ALLOC]; |
292 | 1 | int vdim[CCV_NNC_MAX_DIM_ALLOC]; |
293 | 1 | int gdim[CCV_NNC_MAX_DIM_ALLOC]; |
294 | 1 | int dqdim[CCV_NNC_MAX_DIM_ALLOC]; |
295 | 1 | int dkdim[CCV_NNC_MAX_DIM_ALLOC]; |
296 | 1 | int dvdim[CCV_NNC_MAX_DIM_ALLOC]; |
297 | 1 | ccv_nnc_tensor_view_get_dim(q, qdim); |
298 | 1 | ccv_nnc_tensor_view_get_dim(k, kdim); |
299 | 1 | ccv_nnc_tensor_view_get_dim(v, vdim); |
300 | 1 | ccv_nnc_tensor_view_get_dim(g, gdim); |
301 | 1 | ccv_nnc_tensor_view_get_dim(dq, dqdim); |
302 | 1 | ccv_nnc_tensor_view_get_dim(dk, dkdim); |
303 | 1 | ccv_nnc_tensor_view_get_dim(dv, dvdim); |
304 | 1 | if (q_nd == 3) |
305 | 0 | { |
306 | 0 | qdim[0] = qdim[1], qdim[1] = qdim[2], qdim[2] = 1; |
307 | 0 | kdim[0] = kdim[1], kdim[1] = kdim[2], kdim[2] = 1; |
308 | 0 | vdim[0] = vdim[1], vdim[1] = vdim[2], vdim[2] = 1; |
309 | 0 | gdim[0] = gdim[1], gdim[1] = gdim[2], gdim[2] = 1; |
310 | 0 | dqdim[0] = dqdim[1], dqdim[1] = dqdim[2], dqdim[2] = 1; |
311 | 0 | dkdim[0] = dkdim[1], dkdim[1] = dkdim[2], dkdim[2] = 1; |
312 | 0 | dvdim[0] = dvdim[1], dvdim[1] = dvdim[2], dvdim[2] = 1; |
313 | 0 | } |
314 | 1 | assert(qdim[0] == kdim[0] && kdim[0] == vdim[0] && vdim[0] == gdim[0]); |
315 | 1 | assert(qdim[2] == gdim[2]); |
316 | 1 | assert(kdim[2] == vdim[2]); |
317 | 1 | assert(qdim[2] % kdim[2] == 0); |
318 | 1 | assert(qdim[2] >= kdim[2]); |
319 | 1 | assert(qdim[3] == kdim[3]); |
320 | 1 | assert(kdim[1] == vdim[1]); |
321 | 1 | assert(gdim[1] == qdim[1]); |
322 | 1 | assert(gdim[3] == vdim[3]); |
323 | 1 | assert(CCV_NNC_MAX_DIM == 2); // Need to change this logic for CCV_NNC_MAX_DIM == other number. |
324 | 1 | int qstride[CCV_NNC_MAX_DIM_ALLOC]; |
325 | 1 | int kstride[CCV_NNC_MAX_DIM_ALLOC]; |
326 | 1 | int vstride[CCV_NNC_MAX_DIM_ALLOC]; |
327 | 1 | int gstride[CCV_NNC_MAX_DIM_ALLOC]; |
328 | 1 | int dqstride[CCV_NNC_MAX_DIM_ALLOC]; |
329 | 1 | int dkstride[CCV_NNC_MAX_DIM_ALLOC]; |
330 | 1 | int dvstride[CCV_NNC_MAX_DIM_ALLOC]; |
331 | 1 | ccv_nnc_tensor_view_get_stride(q, qstride); |
332 | 1 | ccv_nnc_tensor_view_get_stride(k, kstride); |
333 | 1 | ccv_nnc_tensor_view_get_stride(v, vstride); |
334 | 1 | ccv_nnc_tensor_view_get_stride(g, gstride); |
335 | 1 | ccv_nnc_tensor_view_get_stride(dq, dqstride); |
336 | 1 | ccv_nnc_tensor_view_get_stride(dk, dkstride); |
337 | 1 | ccv_nnc_tensor_view_get_stride(dv, dvstride); |
338 | 1 | if (q_nd == 3) |
339 | 0 | { |
340 | 0 | qstride[0] = qstride[1], qstride[1] = qstride[2], qstride[2] = qstride[3]; |
341 | 0 | kstride[0] = kstride[1], kstride[1] = kstride[2], kstride[2] = kstride[3]; |
342 | 0 | vstride[0] = vstride[1], vstride[1] = vstride[2], vstride[2] = vstride[3]; |
343 | 0 | gstride[0] = gstride[1], gstride[1] = gstride[2], gstride[2] = gstride[3]; |
344 | 0 | dqstride[0] = dqstride[1], dqstride[1] = dqstride[2], dqstride[2] = dqstride[3]; |
345 | 0 | dkstride[0] = dkstride[1], dkstride[1] = dkstride[2], dkstride[2] = dkstride[3]; |
346 | 0 | dvstride[0] = dvstride[1], dvstride[1] = dvstride[2], dvstride[2] = dvstride[3]; |
347 | 0 | } |
348 | 1 | int i[CCV_NNC_MAX_DIM + 2]; |
349 | 1 | float* qk = ccv_nnc_stream_context_get_workspace(stream_context, sizeof(float) * 2 * kdim[1], CCV_TENSOR_CPU_MEMORY); |
350 | 1 | const float* const qp = q->data.f32; |
351 | 1 | const float* const kp = k->data.f32; |
352 | 1 | const float* const vp = v->data.f32; |
353 | 1 | const float* const gp = g->data.f32; |
354 | 1 | float* const dqp = dq->data.f32; |
355 | 1 | float* const dkp = dk->data.f32; |
356 | 1 | float* const dvp = dv->data.f32; |
357 | 1 | const float scale = cmd.info.scaled_dot_product_attention.scale; |
358 | 1 | const int is_causal = cmd.info.scaled_dot_product_attention.is_causal; |
359 | 1 | const int h_h_k_ratio = qdim[2] / kdim[2]; |
360 | 33 | for (i[0] = 0; i[0] < qdim[0]; i[0]++32 ) |
361 | 32 | { |
362 | 32 | const float* const qp0 = qp + i[0] * qstride[0]; |
363 | 32 | const float* const kp0 = kp + i[0] * kstride[0]; |
364 | 32 | const float* const vp0 = vp + i[0] * vstride[0]; |
365 | 32 | const float* const gp0 = gp + i[0] * gstride[0]; |
366 | 32 | float* const dqp0 = dqp + i[0] * dqstride[0]; |
367 | 32 | float* const dkp0 = dkp + i[0] * dkstride[0]; |
368 | 32 | float* const dvp0 = dvp + i[0] * dvstride[0]; |
369 | 288 | for (i[1] = 0; i[1] < qdim[2]; i[1]++256 ) |
370 | 256 | { |
371 | 256 | const float* const qp1 = qp0 + i[1] * qstride[2]; |
372 | 256 | const float* const kp1 = kp0 + (i[1] / h_h_k_ratio) * kstride[2]; |
373 | 256 | const float* const vp1 = vp0 + (i[1] / h_h_k_ratio) * vstride[2]; |
374 | 256 | const float* const gp1 = gp0 + i[1] * gstride[2]; |
375 | 256 | float* const dqp1 = dqp0 + i[1] * dqstride[2]; |
376 | 256 | float* const dkp1 = dkp0 + (i[1] / h_h_k_ratio) * dkstride[2]; |
377 | 256 | float* const dvp1 = dvp0 + (i[1] / h_h_k_ratio) * dvstride[2]; |
378 | | // Compute Q @ K^T |
379 | 256 | int x, y, k; |
380 | 33.0k | for (x = 0; x < qdim[1]; x++32.7k ) |
381 | 32.7k | { |
382 | 32.7k | float* const dqp2 = dqp1 + x * dqstride[1]; |
383 | 2.12M | for (k = 0; k < qdim[3]; k++2.09M ) |
384 | 2.09M | dqp2[k * dqstride[3]] = 0; |
385 | 32.7k | } |
386 | | // Only zero out when it is at 0-index. |
387 | 256 | if (i[1] % h_h_k_ratio == 0) |
388 | 33.0k | for (y = 0; 256 y < kdim[1]; y++32.7k ) |
389 | 32.7k | { |
390 | 32.7k | float* const dkp2 = dkp1 + y * dkstride[1]; |
391 | 2.12M | for (k = 0; k < qdim[3]; k++2.09M ) |
392 | 2.09M | dkp2[k * dkstride[3]] = 0; |
393 | 32.7k | } |
394 | | // Only zero out when it is at 0-index. |
395 | 256 | if (i[1] % h_h_k_ratio == 0) |
396 | 33.0k | for (y = 0; 256 y < kdim[1]; y++32.7k ) |
397 | 32.7k | { |
398 | 32.7k | float* const dvp2 = dvp1 + y * dvstride[1]; |
399 | 3.17M | for (k = 0; k < vdim[3]; k++3.14M ) |
400 | 3.14M | dvp2[k * dvstride[3]] = 0; |
401 | 32.7k | } |
402 | 33.0k | for (x = 0; x < qdim[1]; x++32.7k ) |
403 | 32.7k | { |
404 | 32.7k | const float* const qp2 = qp1 + x * qstride[1]; |
405 | 32.7k | const float* const gp2 = gp1 + x * gstride[1]; |
406 | 32.7k | float* const qk0 = qk; |
407 | 32.7k | float* const qks0 = qk + kdim[1]; |
408 | 4.22M | for (y = 0; y < kdim[1]; y++4.19M ) |
409 | 4.19M | { |
410 | 4.19M | const float* const kp2 = kp1 + y * kstride[1]; |
411 | 4.19M | float v = 0; |
412 | 272M | for (k = 0; k < qdim[3]; k++268M ) |
413 | 268M | v += qp2[k * qstride[3]] * kp2[k * kstride[3]]; |
414 | 4.19M | qk0[y] = scale * v; |
415 | 4.19M | } |
416 | | // Compute softmax on qk. |
417 | 32.7k | if (is_causal) |
418 | 0 | { |
419 | 0 | const int x_end = ccv_max(x - qdim[1] + kdim[1] + 1, 0); |
420 | 0 | for (y = x_end; y < kdim[1]; y++) |
421 | 0 | qk0[y] = 0; |
422 | 0 | double maxval = qk0[0]; |
423 | 0 | for (y = 1; y < x_end; y++) |
424 | 0 | if (qk0[y] > maxval) |
425 | 0 | maxval = qk0[y]; |
426 | 0 | double sumval = 0; |
427 | 0 | for (y = 0; y < x_end; y++) |
428 | 0 | sumval += (qk0[y] = expf(qk0[y] - maxval)); |
429 | 0 | sumval = 1.0 / sumval; |
430 | 0 | for (y = 0; y < x_end; y++) |
431 | 0 | qk0[y] *= sumval; |
432 | 32.7k | } else { |
433 | 32.7k | double maxval = qk0[0]; |
434 | 4.19M | for (y = 1; y < kdim[1]; y++4.16M ) |
435 | 4.16M | if (qk0[y] > maxval) |
436 | 146k | maxval = qk0[y]; |
437 | 32.7k | double sumval = 0; |
438 | 4.22M | for (y = 0; y < kdim[1]; y++4.19M ) |
439 | 4.19M | sumval += (qk0[y] = expf(qk0[y] - maxval)); |
440 | 32.7k | sumval = 1.0 / sumval; |
441 | 4.22M | for (y = 0; y < kdim[1]; y++4.19M ) |
442 | 4.19M | qk0[y] *= sumval; |
443 | 32.7k | } |
444 | 4.22M | for (y = 0; y < kdim[1]; y++4.19M ) |
445 | 4.19M | { |
446 | 4.19M | float* const dvp2 = dvp1 + y * dvstride[1]; |
447 | 4.19M | const float v = qk0[y]; |
448 | 406M | for (k = 0; k < vdim[3]; k++402M ) |
449 | 402M | dvp2[k * dvstride[3]] += v * gp2[k * gstride[3]]; |
450 | 4.19M | } |
451 | 32.7k | double sumval = 0; |
452 | 4.22M | for (y = 0; y < kdim[1]; y++4.19M ) |
453 | 4.19M | { |
454 | 4.19M | const float* const vp2 = vp1 + y * vstride[1]; |
455 | 4.19M | float v = 0; |
456 | 406M | for (k = 0; k < vdim[3]; k++402M ) |
457 | 402M | v += gp2[k * gstride[3]] * vp2[k * vstride[3]]; |
458 | 4.19M | qks0[y] = v; |
459 | 4.19M | sumval += v * qk0[y]; |
460 | 4.19M | } |
461 | 4.22M | for (y = 0; y < kdim[1]; y++4.19M ) |
462 | 4.19M | qk0[y] = (qks0[y] - sumval) * qk0[y]; |
463 | 32.7k | float* const dqp2 = dqp1 + x * dqstride[1]; |
464 | 4.22M | for (y = 0; y < kdim[1]; y++4.19M ) |
465 | 4.19M | { |
466 | 4.19M | const float* const kp2 = kp1 + y * kstride[1]; |
467 | 4.19M | float* const dkp2 = dkp1 + y * dkstride[1]; |
468 | 4.19M | const float v = scale * qk0[y]; |
469 | 272M | for (k = 0; k < qdim[3]; k++268M ) |
470 | 268M | { |
471 | 268M | dqp2[k * dqstride[3]] += v * kp2[k * kstride[3]]; |
472 | 268M | dkp2[k * dkstride[3]] += v * qp2[k * qstride[3]]; |
473 | 268M | } |
474 | 4.19M | } |
475 | 32.7k | } |
476 | 256 | } |
477 | 32 | } |
478 | 1 | return CCV_NNC_EXEC_SUCCESS; |
479 | 1 | } |
480 | | |
481 | | REGISTER_COMMAND_BACKEND(CCV_NNC_SCALED_DOT_PRODUCT_ATTENTION_FORWARD, CCV_NNC_BACKEND_CPU_REF)(ccv_nnc_cmd_backend_registry_t* const registry) |
482 | 1 | { |
483 | 1 | registry->tensor_formats = CCV_TENSOR_FORMAT_NHWC; |
484 | 1 | registry->tensor_datatypes = CCV_32F; |
485 | 1 | registry->tensor_memory = CCV_TENSOR_CPU_MEMORY; |
486 | 1 | registry->algorithms = 1; |
487 | 1 | registry->exec = _ccv_nnc_scaled_dot_product_attention_forw; |
488 | 1 | } |
489 | | |
490 | | REGISTER_COMMAND_BACKEND(CCV_NNC_SCALED_DOT_PRODUCT_ATTENTION_BACKWARD, CCV_NNC_BACKEND_CPU_REF)(ccv_nnc_cmd_backend_registry_t* const registry) |
491 | 1 | { |
492 | 1 | registry->tensor_formats = CCV_TENSOR_FORMAT_NHWC; |
493 | 1 | registry->tensor_datatypes = CCV_32F; |
494 | 1 | registry->tensor_memory = CCV_TENSOR_CPU_MEMORY; |
495 | 1 | registry->algorithms = 1; |
496 | 1 | registry->exec = _ccv_nnc_scaled_dot_product_attention_back; |
497 | 1 | } |