Coverage Report

Created: 2025-02-24 17:43

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/home/liu/actions-runner/_work/ccv/ccv/lib/nnc/cmd/scaled_dot_product_attention/ccv_nnc_scaled_dot_product_attention_cpu_ref.c
Line
Count
Source
1
#include "ccv.h"
2
#include "ccv_internal.h"
3
#include "nnc/ccv_nnc.h"
4
#include "nnc/ccv_nnc_easy.h"
5
#include "nnc/ccv_nnc_internal.h"
6
#ifdef USE_OPENMP
7
#include <omp.h>
8
#endif
9
#ifdef USE_DISPATCH
10
#include <dispatch/dispatch.h>
11
#endif
12
13
// Shared methods.
14
#include "../_ccv_nnc_cpu_ref.h"
15
16
static int _ccv_nnc_scaled_dot_product_attention_forw(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint, const int flags, ccv_nnc_tensor_t* const* const inputs, const int input_size, ccv_nnc_tensor_t* const* const outputs, const int output_size, ccv_nnc_stream_context_t* const stream_context)
17
9
{
18
9
  assert(input_size >= 3);
19
9
  assert(output_size >= 1);
20
9
  ccv_nnc_tensor_view_t* const q = (ccv_nnc_tensor_view_t*)inputs[0];
21
9
  ccv_nnc_tensor_view_t* const k = (ccv_nnc_tensor_view_t*)inputs[1];
22
9
  ccv_nnc_tensor_view_t* const v = (ccv_nnc_tensor_view_t*)inputs[2];
23
9
  ccv_nnc_tensor_view_t* const attn_mask = input_size > 3 ? 
(ccv_nnc_tensor_view_t*)inputs[3]6
:
03
;
24
9
  ccv_nnc_tensor_view_t* const w = input_size > 4 ? 
(ccv_nnc_tensor_view_t*)inputs[4]5
:
04
;
25
9
  ccv_nnc_tensor_view_t* const bias = input_size > 5 ? 
(ccv_nnc_tensor_view_t*)inputs[5]5
:
04
;
26
9
  if (bias) // bias always requires a weight matrix.
27
3
    { assert(w); }
28
9
  ccv_nnc_tensor_view_t* const c = (w) ? 
(ccv_nnc_tensor_view_t*)outputs[2]3
:
(ccv_nnc_tensor_view_t*)outputs[0]6
;
29
9
  const int q_nd = ccv_nnc_tensor_nd(q->info.dim);
30
9
  assert(q_nd == 3 || q_nd == 4);
31
9
  const int k_nd = ccv_nnc_tensor_nd(k->info.dim);
32
9
  assert(k_nd == 3 || k_nd == 4);
33
9
  const int v_nd = ccv_nnc_tensor_nd(v->info.dim);
34
9
  assert(v_nd == 3 || v_nd == 4);
35
9
  const int c_nd = ccv_nnc_tensor_nd(c->info.dim);
36
9
  assert(c_nd == 3 || c_nd == 4);
37
9
  assert(q_nd == k_nd && k_nd == v_nd && v_nd == c_nd);
38
  // Assuming this is float 32.
39
9
  int qdim[CCV_NNC_MAX_DIM_ALLOC];
40
9
  int kdim[CCV_NNC_MAX_DIM_ALLOC];
41
9
  int vdim[CCV_NNC_MAX_DIM_ALLOC];
42
9
  int cdim[CCV_NNC_MAX_DIM_ALLOC];
43
9
  int amdim[CCV_NNC_MAX_DIM_ALLOC];
44
9
  ccv_nnc_tensor_view_get_dim(q, qdim);
45
9
  ccv_nnc_tensor_view_get_dim(k, kdim);
46
9
  ccv_nnc_tensor_view_get_dim(v, vdim);
47
9
  ccv_nnc_tensor_view_get_dim(c, cdim);
48
9
  if (q_nd == 3)
49
0
  {
50
0
    qdim[0] = qdim[1], qdim[1] = qdim[2], qdim[2] = 1;
51
0
    kdim[0] = kdim[1], kdim[1] = kdim[2], kdim[2] = 1;
52
0
    vdim[0] = vdim[1], vdim[1] = vdim[2], vdim[2] = 1;
53
0
    cdim[0] = cdim[1], cdim[1] = cdim[2], cdim[2] = 1;
54
0
  }
55
9
  assert(qdim[0] == kdim[0] && kdim[0] == vdim[0] && vdim[0] == cdim[0]);
56
9
  assert(qdim[2] == cdim[2]);
57
9
  assert(kdim[2] == vdim[2]);
58
9
  assert(qdim[2] % kdim[2] == 0);
59
9
  assert(qdim[2] >= kdim[2]);
60
9
  assert(qdim[3] == kdim[3]);
61
9
  assert(kdim[1] == vdim[1]);
62
9
  assert(cdim[1] == qdim[1]);
63
9
  assert(cdim[3] == vdim[3]);
64
9
  assert(CCV_NNC_MAX_DIM == 2); // Need to change this logic for CCV_NNC_MAX_DIM == other number.
65
9
  int qstride[CCV_NNC_MAX_DIM_ALLOC];
66
9
  int kstride[CCV_NNC_MAX_DIM_ALLOC];
67
9
  int vstride[CCV_NNC_MAX_DIM_ALLOC];
68
9
  int cstride[CCV_NNC_MAX_DIM_ALLOC];
69
9
  int amstride[CCV_NNC_MAX_DIM_ALLOC];
70
9
  ccv_nnc_tensor_view_get_stride(q, qstride);
71
9
  ccv_nnc_tensor_view_get_stride(k, kstride);
72
9
  ccv_nnc_tensor_view_get_stride(v, vstride);
73
9
  ccv_nnc_tensor_view_get_stride(c, cstride);
74
9
  if (q_nd == 3)
75
0
  {
76
0
    qstride[0] = qstride[1], qstride[1] = qstride[2], qstride[2] = qstride[3];
77
0
    kstride[0] = kstride[1], kstride[1] = kstride[2], kstride[2] = kstride[3];
78
0
    vstride[0] = vstride[1], vstride[1] = vstride[2], vstride[2] = vstride[3];
79
0
    cstride[0] = cstride[1], cstride[1] = cstride[2], cstride[2] = cstride[3];
80
0
  }
81
9
  if (attn_mask)
82
2
  {
83
2
    ccv_nnc_tensor_view_get_dim(attn_mask, amdim);
84
2
    ccv_nnc_tensor_view_get_stride(attn_mask, amstride);
85
2
    assert(amdim[0] == qdim[0] || amdim[0] == 1);
86
2
    assert(amdim[1] == qdim[2] || amdim[1] == 1);
87
2
    assert(amdim[2] == qdim[1]);
88
2
    assert(amdim[3] == kdim[1]);
89
2
  }
90
9
  int i[CCV_NNC_MAX_DIM + 2];
91
9
  float* qk = ccv_nnc_stream_context_get_workspace(stream_context, sizeof(float) * qdim[1] * kdim[1], CCV_TENSOR_CPU_MEMORY);
92
9
  const float* const qp = q->data.f32;
93
9
  const float* const kp = k->data.f32;
94
9
  const float* const vp = v->data.f32;
95
9
  const float* const amp = attn_mask ? 
attn_mask->data.f322
:
07
;
96
9
  float* const cp = c->data.f32;
97
9
  const float scale = cmd.info.scaled_dot_product_attention.scale;
98
9
  const int is_causal = cmd.info.scaled_dot_product_attention.is_causal;
99
9
  const int h_h_k_ratio = qdim[2] / kdim[2];
100
9
  assert(kdim[2] == vdim[2]);
101
9
  assert(qdim[2] >= kdim[2]);
102
9
  assert(qdim[2] % kdim[2] == 0);
103
297
  
for (i[0] = 0; 9
i[0] < qdim[0];
i[0]++288
)
104
288
  {
105
288
    const float* const qp0 = qp + i[0] * qstride[0];
106
288
    const float* const kp0 = kp + i[0] * kstride[0];
107
288
    const float* const vp0 = vp + i[0] * vstride[0];
108
288
    const float* const amp0 = amp && 
amdim[0] > 164
?
amp + i[0] * amstride[0]0
: amp;
109
288
    float* const cp0 = cp + i[0] * cstride[0];
110
2.59k
    for (i[1] = 0; i[1] < qdim[2]; 
i[1]++2.30k
)
111
2.30k
    {
112
2.30k
      const float* const qp1 = qp0 + i[1] * qstride[2];
113
2.30k
      const float* const kp1 = kp0 + (i[1] / h_h_k_ratio) * kstride[2];
114
2.30k
      const float* const vp1 = vp0 + (i[1] / h_h_k_ratio) * vstride[2];
115
2.30k
      const float* const amp1 = amp && 
amdim[1] > 1512
?
amp0 + i[1] * amstride[1]0
: amp0;
116
2.30k
      float* const cp1 = cp0 + i[1] * cstride[2];
117
      // Compute Q @ K^T
118
294k
      
parallel_for2.30k
(x, qdim[1]) {
119
294k
        int y, k;
120
294k
        const float* const qp2 = qp1 + x * qstride[1];
121
294k
        float* const cp2 = cp1 + x * cstride[1];
122
294k
        float* const qk0 = qk + x * kdim[1];
123
294k
        const float* const amp2 = amp1 ? 
amp1 + x * amstride[2]65.5k
:
0229k
;
124
294k
        if (attn_mask)
125
65.5k
        {
126
8.45M
          for (y = 0; y < kdim[1]; 
y++8.38M
)
127
8.38M
          {
128
8.38M
            const float* const kp2 = kp1 + y * kstride[1];
129
8.38M
            float v = 0;
130
545M
            for (k = 0; k < qdim[3]; 
k++536M
)
131
536M
              v += qp2[k * qstride[3]] * kp2[k * kstride[3]];
132
8.38M
            qk0[y] = scale * v + amp2[y * amstride[3]];
133
8.38M
          }
134
229k
        } else {
135
29.5M
          for (y = 0; y < kdim[1]; 
y++29.3M
)
136
29.3M
          {
137
29.3M
            const float* const kp2 = kp1 + y * kstride[1];
138
29.3M
            float v = 0;
139
1.90G
            for (k = 0; k < qdim[3]; 
k++1.87G
)
140
1.87G
              v += qp2[k * qstride[3]] * kp2[k * kstride[3]];
141
29.3M
            qk0[y] = scale * v;
142
29.3M
          }
143
229k
        }
144
        // Compute softmax on qk.
145
294k
        if (is_causal)
146
0
        {
147
0
          const int x_end = ccv_max(x - qdim[1] + kdim[1] + 1, 0);
148
0
          for (y = x_end; y < kdim[1]; y++)
149
0
            qk0[y] = 0;
150
0
          double maxval = qk0[0];
151
0
          for (y = 1; y < x_end; y++)
152
0
            if (qk0[y] > maxval)
153
0
              maxval = qk0[y];
154
0
          double sumval = 0;
155
0
          for (y = 0; y < x_end; y++)
156
0
            sumval += (qk0[y] = expf(qk0[y] - maxval));
157
0
          sumval = 1.0 / sumval;
158
0
          for (y = 0; y < x_end; y++)
159
0
            qk0[y] *= sumval;
160
294k
        } else {
161
294k
          double maxval = qk0[0];
162
37.7M
          for (y = 1; y < kdim[1]; 
y++37.4M
)
163
37.4M
            if (qk0[y] > maxval)
164
1.25M
              maxval = qk0[y];
165
294k
          double sumval = 0;
166
38.0M
          for (y = 0; y < kdim[1]; 
y++37.7M
)
167
37.7M
            sumval += (qk0[y] = expf(qk0[y] - maxval));
168
294k
          sumval = 1.0 / sumval;
169
38.0M
          for (y = 0; y < kdim[1]; 
y++37.7M
)
170
37.7M
            qk0[y] *= sumval;
171
294k
        }
172
28.6M
        for (k = 0; k < vdim[3]; 
k++28.3M
)
173
28.3M
          cp2[k * cstride[3]] = 0;
174
38.0M
        for (y = 0; y < kdim[1]; 
y++37.7M
)
175
37.7M
        {
176
37.7M
          const float* const vp2 = vp1 + y * vstride[1];
177
37.7M
          const float v = qk0[y];
178
3.66G
          for (k = 0; k < vdim[3]; 
k++3.62G
)
179
3.62G
            cp2[k * cstride[3]] += v * vp2[k * vstride[3]];
180
37.7M
        }
181
294k
      } parallel_endfor
182
2.30k
    }
183
288
  }
184
9
  if (w)
185
3
  {
186
3
    const int num_heads = cdim[2];
187
3
    ccv_nnc_tensor_view_t* const d = (ccv_nnc_tensor_view_t*)outputs[0];
188
3
    const int w_nd = ccv_nnc_tensor_nd(w->info.dim);
189
3
    assert(w_nd == 2);
190
3
    assert(CCV_IS_TENSOR_CONTIGUOUS(w));
191
3
    const int d_nd = ccv_nnc_tensor_nd(d->info.dim);
192
3
    assert(d_nd == 3);
193
3
    int ddim[CCV_NNC_MAX_DIM_ALLOC];
194
3
    int dstride[CCV_NNC_MAX_DIM_ALLOC];
195
3
    ccv_nnc_tensor_view_get_dim(d, ddim);
196
3
    ccv_nnc_tensor_view_get_stride(d, dstride);
197
3
    assert(ddim[2] == cdim[1]);
198
3
    assert(ddim[3] == num_heads * cdim[3]);
199
3
    assert(w->info.dim[1] == ddim[3]);
200
3
    assert(w->info.dim[0] == ddim[3]);
201
3
    float* const dp = d->data.f32;
202
3
    const float* const wp = w->data.f32;
203
3
    const float* const cp = c->data.f32;
204
3
    if (bias)
205
3
    {
206
3
      assert(ccv_nnc_tensor_count(bias->info) == ddim[3]);
207
3
      assert(CCV_IS_TENSOR_CONTIGUOUS(bias));
208
3
      const float* const biasp = bias->data.f32;
209
99
      for (i[0] = 0; i[0] < ddim[1]; 
i[0]++96
)
210
96
      {
211
96
        const float* const cp0 = cp + i[0] * cstride[0];
212
96
        float* const dp0 = dp + i[0] * dstride[1];
213
12.2k
        
parallel_for96
(y, ddim[2]) {
214
12.2k
          int x, j, k;
215
12.2k
          const float* const cp1 = cp0 + y * cstride[1];
216
12.2k
          float* const dp1 = dp0 + y * dstride[2];
217
9.44M
          for (x = 0; x < ddim[3]; 
x++9.43M
)
218
9.43M
          {
219
9.43M
            const float* const wp0 = wp + x * ddim[3];
220
9.43M
            float v = biasp[x];
221
84.9M
            for (j = 0; j < num_heads; 
j++75.4M
)
222
75.4M
            {
223
75.4M
              const float* const cp2 = cp1 + j * cstride[2];
224
7.32G
              for (k = 0; k < cdim[3]; 
k++7.24G
)
225
7.24G
                v += wp0[j * cdim[3] + k] * cp2[k * cstride[3]];
226
75.4M
            }
227
9.43M
            dp1[x * dstride[3]] = v;
228
9.43M
          }
229
12.2k
        } parallel_endfor
230
96
      }
231
3
    } else {
232
0
      for (i[0] = 0; i[0] < ddim[1]; i[0]++)
233
0
      {
234
0
        const float* const cp0 = cp + i[0] * cstride[0];
235
0
        float* const dp0 = dp + i[0] * dstride[1];
236
0
        parallel_for(y, ddim[2]) {
237
0
          int x, j, k;
238
0
          const float* const cp1 = cp0 + y * cstride[1];
239
0
          float* const dp1 = dp0 + y * dstride[2];
240
0
          for (x = 0; x < ddim[3]; x++)
241
0
          {
242
0
            const float* const wp0 = wp + x * ddim[3];
243
0
            float v = 0;
244
0
            for (j = 0; j < num_heads; j++)
245
0
            {
246
0
              const float* const cp2 = cp1 + j * cstride[2];
247
0
              for (k = 0; k < cdim[3]; k++)
248
0
                v += wp0[j * cdim[3] + k] * cp2[k * cstride[3]];
249
0
            }
250
0
            dp1[x * dstride[3]] = v;
251
0
          }
252
0
        } parallel_endfor
253
0
      }
254
0
    }
255
3
  }
256
9
  return CCV_NNC_EXEC_SUCCESS;
257
9
}
258
259
static int _ccv_nnc_scaled_dot_product_attention_back(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint, const int flags, ccv_nnc_tensor_t* const* const inputs, const int input_size, ccv_nnc_tensor_t* const* const outputs, const int output_size, ccv_nnc_stream_context_t* const stream_context)
260
1
{
261
  // Assuming no saved_softmax, we need to recompute from q, k, v.
262
  // We cannot do this with masks (yet).
263
1
  assert(input_size >= 6);
264
1
  ccv_nnc_tensor_view_t* const g = (ccv_nnc_tensor_view_t*)inputs[0];
265
1
  ccv_nnc_tensor_view_t* const q = (ccv_nnc_tensor_view_t*)inputs[3];
266
1
  ccv_nnc_tensor_view_t* const k = (ccv_nnc_tensor_view_t*)inputs[4];
267
1
  ccv_nnc_tensor_view_t* const v = (ccv_nnc_tensor_view_t*)inputs[5];
268
1
  ccv_nnc_tensor_view_t* const dq = (ccv_nnc_tensor_view_t*)outputs[0];
269
1
  ccv_nnc_tensor_view_t* const dk = (ccv_nnc_tensor_view_t*)outputs[1];
270
1
  ccv_nnc_tensor_view_t* const dv = (ccv_nnc_tensor_view_t*)outputs[2];
271
1
  const int q_nd = ccv_nnc_tensor_nd(q->info.dim);
272
1
  assert(q_nd == 3 || q_nd == 4);
273
1
  const int k_nd = ccv_nnc_tensor_nd(k->info.dim);
274
1
  assert(k_nd == 3 || k_nd == 4);
275
1
  const int v_nd = ccv_nnc_tensor_nd(v->info.dim);
276
1
  assert(v_nd == 3 || v_nd == 4);
277
1
  const int g_nd = ccv_nnc_tensor_nd(g->info.dim);
278
1
  assert(g_nd == 3 || g_nd == 4);
279
1
  const int dq_nd = ccv_nnc_tensor_nd(dq->info.dim);
280
1
  assert(dq_nd == 3 || dq_nd == 4);
281
1
  assert(dq_nd == q_nd);
282
1
  const int dk_nd = ccv_nnc_tensor_nd(dk->info.dim);
283
1
  assert(dk_nd == 3 || dk_nd == 4);
284
1
  assert(dk_nd == k_nd);
285
1
  const int dv_nd = ccv_nnc_tensor_nd(dv->info.dim);
286
1
  assert(dv_nd == 3 || dv_nd == 4);
287
1
  assert(dv_nd == v_nd);
288
1
  assert(q_nd == k_nd && k_nd == v_nd && v_nd == g_nd);
289
  // Assuming this is float 32.
290
1
  int qdim[CCV_NNC_MAX_DIM_ALLOC];
291
1
  int kdim[CCV_NNC_MAX_DIM_ALLOC];
292
1
  int vdim[CCV_NNC_MAX_DIM_ALLOC];
293
1
  int gdim[CCV_NNC_MAX_DIM_ALLOC];
294
1
  int dqdim[CCV_NNC_MAX_DIM_ALLOC];
295
1
  int dkdim[CCV_NNC_MAX_DIM_ALLOC];
296
1
  int dvdim[CCV_NNC_MAX_DIM_ALLOC];
297
1
  ccv_nnc_tensor_view_get_dim(q, qdim);
298
1
  ccv_nnc_tensor_view_get_dim(k, kdim);
299
1
  ccv_nnc_tensor_view_get_dim(v, vdim);
300
1
  ccv_nnc_tensor_view_get_dim(g, gdim);
301
1
  ccv_nnc_tensor_view_get_dim(dq, dqdim);
302
1
  ccv_nnc_tensor_view_get_dim(dk, dkdim);
303
1
  ccv_nnc_tensor_view_get_dim(dv, dvdim);
304
1
  if (q_nd == 3)
305
0
  {
306
0
    qdim[0] = qdim[1], qdim[1] = qdim[2], qdim[2] = 1;
307
0
    kdim[0] = kdim[1], kdim[1] = kdim[2], kdim[2] = 1;
308
0
    vdim[0] = vdim[1], vdim[1] = vdim[2], vdim[2] = 1;
309
0
    gdim[0] = gdim[1], gdim[1] = gdim[2], gdim[2] = 1;
310
0
    dqdim[0] = dqdim[1], dqdim[1] = dqdim[2], dqdim[2] = 1;
311
0
    dkdim[0] = dkdim[1], dkdim[1] = dkdim[2], dkdim[2] = 1;
312
0
    dvdim[0] = dvdim[1], dvdim[1] = dvdim[2], dvdim[2] = 1;
313
0
  }
314
1
  assert(qdim[0] == kdim[0] && kdim[0] == vdim[0] && vdim[0] == gdim[0]);
315
1
  assert(qdim[2] == gdim[2]);
316
1
  assert(kdim[2] == vdim[2]);
317
1
  assert(qdim[2] % kdim[2] == 0);
318
1
  assert(qdim[2] >= kdim[2]);
319
1
  assert(qdim[3] == kdim[3]);
320
1
  assert(kdim[1] == vdim[1]);
321
1
  assert(gdim[1] == qdim[1]);
322
1
  assert(gdim[3] == vdim[3]);
323
1
  assert(CCV_NNC_MAX_DIM == 2); // Need to change this logic for CCV_NNC_MAX_DIM == other number.
324
1
  int qstride[CCV_NNC_MAX_DIM_ALLOC];
325
1
  int kstride[CCV_NNC_MAX_DIM_ALLOC];
326
1
  int vstride[CCV_NNC_MAX_DIM_ALLOC];
327
1
  int gstride[CCV_NNC_MAX_DIM_ALLOC];
328
1
  int dqstride[CCV_NNC_MAX_DIM_ALLOC];
329
1
  int dkstride[CCV_NNC_MAX_DIM_ALLOC];
330
1
  int dvstride[CCV_NNC_MAX_DIM_ALLOC];
331
1
  ccv_nnc_tensor_view_get_stride(q, qstride);
332
1
  ccv_nnc_tensor_view_get_stride(k, kstride);
333
1
  ccv_nnc_tensor_view_get_stride(v, vstride);
334
1
  ccv_nnc_tensor_view_get_stride(g, gstride);
335
1
  ccv_nnc_tensor_view_get_stride(dq, dqstride);
336
1
  ccv_nnc_tensor_view_get_stride(dk, dkstride);
337
1
  ccv_nnc_tensor_view_get_stride(dv, dvstride);
338
1
  if (q_nd == 3)
339
0
  {
340
0
    qstride[0] = qstride[1], qstride[1] = qstride[2], qstride[2] = qstride[3];
341
0
    kstride[0] = kstride[1], kstride[1] = kstride[2], kstride[2] = kstride[3];
342
0
    vstride[0] = vstride[1], vstride[1] = vstride[2], vstride[2] = vstride[3];
343
0
    gstride[0] = gstride[1], gstride[1] = gstride[2], gstride[2] = gstride[3];
344
0
    dqstride[0] = dqstride[1], dqstride[1] = dqstride[2], dqstride[2] = dqstride[3];
345
0
    dkstride[0] = dkstride[1], dkstride[1] = dkstride[2], dkstride[2] = dkstride[3];
346
0
    dvstride[0] = dvstride[1], dvstride[1] = dvstride[2], dvstride[2] = dvstride[3];
347
0
  }
348
1
  int i[CCV_NNC_MAX_DIM + 2];
349
1
  float* qk = ccv_nnc_stream_context_get_workspace(stream_context, sizeof(float) * 2 * kdim[1], CCV_TENSOR_CPU_MEMORY);
350
1
  const float* const qp = q->data.f32;
351
1
  const float* const kp = k->data.f32;
352
1
  const float* const vp = v->data.f32;
353
1
  const float* const gp = g->data.f32;
354
1
  float* const dqp = dq->data.f32;
355
1
  float* const dkp = dk->data.f32;
356
1
  float* const dvp = dv->data.f32;
357
1
  const float scale = cmd.info.scaled_dot_product_attention.scale;
358
1
  const int is_causal = cmd.info.scaled_dot_product_attention.is_causal;
359
1
  const int h_h_k_ratio = qdim[2] / kdim[2];
360
33
  for (i[0] = 0; i[0] < qdim[0]; 
i[0]++32
)
361
32
  {
362
32
    const float* const qp0 = qp + i[0] * qstride[0];
363
32
    const float* const kp0 = kp + i[0] * kstride[0];
364
32
    const float* const vp0 = vp + i[0] * vstride[0];
365
32
    const float* const gp0 = gp + i[0] * gstride[0];
366
32
    float* const dqp0 = dqp + i[0] * dqstride[0];
367
32
    float* const dkp0 = dkp + i[0] * dkstride[0];
368
32
    float* const dvp0 = dvp + i[0] * dvstride[0];
369
288
    for (i[1] = 0; i[1] < qdim[2]; 
i[1]++256
)
370
256
    {
371
256
      const float* const qp1 = qp0 + i[1] * qstride[2];
372
256
      const float* const kp1 = kp0 + (i[1] / h_h_k_ratio) * kstride[2];
373
256
      const float* const vp1 = vp0 + (i[1] / h_h_k_ratio) * vstride[2];
374
256
      const float* const gp1 = gp0 + i[1] * gstride[2];
375
256
      float* const dqp1 = dqp0 + i[1] * dqstride[2];
376
256
      float* const dkp1 = dkp0 + (i[1] / h_h_k_ratio) * dkstride[2];
377
256
      float* const dvp1 = dvp0 + (i[1] / h_h_k_ratio) * dvstride[2];
378
      // Compute Q @ K^T
379
256
      int x, y, k;
380
33.0k
      for (x = 0; x < qdim[1]; 
x++32.7k
)
381
32.7k
      {
382
32.7k
        float* const dqp2 = dqp1 + x * dqstride[1];
383
2.12M
        for (k = 0; k < qdim[3]; 
k++2.09M
)
384
2.09M
          dqp2[k * dqstride[3]] = 0;
385
32.7k
      }
386
      // Only zero out when it is at 0-index.
387
256
      if (i[1] % h_h_k_ratio == 0)
388
33.0k
        
for (y = 0; 256
y < kdim[1];
y++32.7k
)
389
32.7k
        {
390
32.7k
          float* const dkp2 = dkp1 + y * dkstride[1];
391
2.12M
          for (k = 0; k < qdim[3]; 
k++2.09M
)
392
2.09M
            dkp2[k * dkstride[3]] = 0;
393
32.7k
        }
394
      // Only zero out when it is at 0-index.
395
256
      if (i[1] % h_h_k_ratio == 0)
396
33.0k
        
for (y = 0; 256
y < kdim[1];
y++32.7k
)
397
32.7k
        {
398
32.7k
          float* const dvp2 = dvp1 + y * dvstride[1];
399
3.17M
          for (k = 0; k < vdim[3]; 
k++3.14M
)
400
3.14M
            dvp2[k * dvstride[3]] = 0;
401
32.7k
        }
402
33.0k
      for (x = 0; x < qdim[1]; 
x++32.7k
)
403
32.7k
      {
404
32.7k
        const float* const qp2 = qp1 + x * qstride[1];
405
32.7k
        const float* const gp2 = gp1 + x * gstride[1];
406
32.7k
        float* const qk0 = qk;
407
32.7k
        float* const qks0 = qk + kdim[1];
408
4.22M
        for (y = 0; y < kdim[1]; 
y++4.19M
)
409
4.19M
        {
410
4.19M
          const float* const kp2 = kp1 + y * kstride[1];
411
4.19M
          float v = 0;
412
272M
          for (k = 0; k < qdim[3]; 
k++268M
)
413
268M
            v += qp2[k * qstride[3]] * kp2[k * kstride[3]];
414
4.19M
          qk0[y] = scale * v;
415
4.19M
        }
416
        // Compute softmax on qk.
417
32.7k
        if (is_causal)
418
0
        {
419
0
          const int x_end = ccv_max(x - qdim[1] + kdim[1] + 1, 0);
420
0
          for (y = x_end; y < kdim[1]; y++)
421
0
            qk0[y] = 0;
422
0
          double maxval = qk0[0];
423
0
          for (y = 1; y < x_end; y++)
424
0
            if (qk0[y] > maxval)
425
0
              maxval = qk0[y];
426
0
          double sumval = 0;
427
0
          for (y = 0; y < x_end; y++)
428
0
            sumval += (qk0[y] = expf(qk0[y] - maxval));
429
0
          sumval = 1.0 / sumval;
430
0
          for (y = 0; y < x_end; y++)
431
0
            qk0[y] *= sumval;
432
32.7k
        } else {
433
32.7k
          double maxval = qk0[0];
434
4.19M
          for (y = 1; y < kdim[1]; 
y++4.16M
)
435
4.16M
            if (qk0[y] > maxval)
436
146k
              maxval = qk0[y];
437
32.7k
          double sumval = 0;
438
4.22M
          for (y = 0; y < kdim[1]; 
y++4.19M
)
439
4.19M
            sumval += (qk0[y] = expf(qk0[y] - maxval));
440
32.7k
          sumval = 1.0 / sumval;
441
4.22M
          for (y = 0; y < kdim[1]; 
y++4.19M
)
442
4.19M
            qk0[y] *= sumval;
443
32.7k
        }
444
4.22M
        for (y = 0; y < kdim[1]; 
y++4.19M
)
445
4.19M
        {
446
4.19M
          float* const dvp2 = dvp1 + y * dvstride[1];
447
4.19M
          const float v = qk0[y];
448
406M
          for (k = 0; k < vdim[3]; 
k++402M
)
449
402M
            dvp2[k * dvstride[3]] += v * gp2[k * gstride[3]];
450
4.19M
        }
451
32.7k
        double sumval = 0;
452
4.22M
        for (y = 0; y < kdim[1]; 
y++4.19M
)
453
4.19M
        {
454
4.19M
          const float* const vp2 = vp1 + y * vstride[1];
455
4.19M
          float v = 0;
456
406M
          for (k = 0; k < vdim[3]; 
k++402M
)
457
402M
            v += gp2[k * gstride[3]] * vp2[k * vstride[3]];
458
4.19M
          qks0[y] = v;
459
4.19M
          sumval += v * qk0[y];
460
4.19M
        }
461
4.22M
        for (y = 0; y < kdim[1]; 
y++4.19M
)
462
4.19M
          qk0[y] = (qks0[y] - sumval) * qk0[y];
463
32.7k
        float* const dqp2 = dqp1 + x * dqstride[1];
464
4.22M
        for (y = 0; y < kdim[1]; 
y++4.19M
)
465
4.19M
        {
466
4.19M
          const float* const kp2 = kp1 + y * kstride[1];
467
4.19M
          float* const dkp2 = dkp1 + y * dkstride[1];
468
4.19M
          const float v = scale * qk0[y];
469
272M
          for (k = 0; k < qdim[3]; 
k++268M
)
470
268M
          {
471
268M
            dqp2[k * dqstride[3]] += v * kp2[k * kstride[3]];
472
268M
            dkp2[k * dkstride[3]] += v * qp2[k * qstride[3]];
473
268M
          }
474
4.19M
        }
475
32.7k
      }
476
256
    }
477
32
  }
478
1
  return CCV_NNC_EXEC_SUCCESS;
479
1
}
480
481
REGISTER_COMMAND_BACKEND(CCV_NNC_SCALED_DOT_PRODUCT_ATTENTION_FORWARD, CCV_NNC_BACKEND_CPU_REF)(ccv_nnc_cmd_backend_registry_t* const registry)
482
1
{
483
1
  registry->tensor_formats = CCV_TENSOR_FORMAT_NHWC;
484
1
  registry->tensor_datatypes = CCV_32F;
485
1
  registry->tensor_memory = CCV_TENSOR_CPU_MEMORY;
486
1
  registry->algorithms = 1;
487
1
  registry->exec = _ccv_nnc_scaled_dot_product_attention_forw;
488
1
}
489
490
REGISTER_COMMAND_BACKEND(CCV_NNC_SCALED_DOT_PRODUCT_ATTENTION_BACKWARD, CCV_NNC_BACKEND_CPU_REF)(ccv_nnc_cmd_backend_registry_t* const registry)
491
1
{
492
1
  registry->tensor_formats = CCV_TENSOR_FORMAT_NHWC;
493
1
  registry->tensor_datatypes = CCV_32F;
494
1
  registry->tensor_memory = CCV_TENSOR_CPU_MEMORY;
495
1
  registry->algorithms = 1;
496
1
  registry->exec = _ccv_nnc_scaled_dot_product_attention_back;
497
1
}