| File: | nnc/cmd/scaled_dot_product_attention/ccv_nnc_scaled_dot_product_attention_cpu_ref.c |
| Warning: | line 132, column 28 Array access (from variable 'amp2') results in a null pointer dereference |
Press '?' to see keyboard shortcuts
Keyboard shortcuts:
| 1 | #include "ccv.h" | |||
| 2 | #include "ccv_internal.h" | |||
| 3 | #include "nnc/ccv_nnc.h" | |||
| 4 | #include "nnc/ccv_nnc_easy.h" | |||
| 5 | #include "nnc/ccv_nnc_internal.h" | |||
| 6 | #ifdef USE_OPENMP | |||
| 7 | #include <omp.h> | |||
| 8 | #endif | |||
| 9 | #ifdef USE_DISPATCH | |||
| 10 | #include <dispatch/dispatch.h> | |||
| 11 | #endif | |||
| 12 | ||||
| 13 | // Shared methods. | |||
| 14 | #include "../_ccv_nnc_cpu_ref.h" | |||
| 15 | ||||
| 16 | static int _ccv_nnc_scaled_dot_product_attention_forw(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint, const int flags, ccv_nnc_tensor_t* const* const inputs, const int input_size, ccv_nnc_tensor_t* const* const outputs, const int output_size, ccv_nnc_stream_context_t* const stream_context) | |||
| 17 | { | |||
| 18 | assert(input_size >= 3)((void) sizeof ((input_size >= 3) ? 1 : 0), __extension__ ( { if (input_size >= 3) ; else __assert_fail ("input_size >= 3" , "scaled_dot_product_attention/ccv_nnc_scaled_dot_product_attention_cpu_ref.c" , 18, __extension__ __PRETTY_FUNCTION__); })); | |||
| ||||
| 19 | assert(output_size >= 1)((void) sizeof ((output_size >= 1) ? 1 : 0), __extension__ ({ if (output_size >= 1) ; else __assert_fail ("output_size >= 1" , "scaled_dot_product_attention/ccv_nnc_scaled_dot_product_attention_cpu_ref.c" , 19, __extension__ __PRETTY_FUNCTION__); })); | |||
| 20 | ccv_nnc_tensor_view_t* const q = (ccv_nnc_tensor_view_t*)inputs[0]; | |||
| 21 | ccv_nnc_tensor_view_t* const k = (ccv_nnc_tensor_view_t*)inputs[1]; | |||
| 22 | ccv_nnc_tensor_view_t* const v = (ccv_nnc_tensor_view_t*)inputs[2]; | |||
| 23 | ccv_nnc_tensor_view_t* const attn_mask = input_size > 3 ? (ccv_nnc_tensor_view_t*)inputs[3] : 0; | |||
| 24 | ccv_nnc_tensor_view_t* const w = input_size > 4 ? (ccv_nnc_tensor_view_t*)inputs[4] : 0; | |||
| 25 | ccv_nnc_tensor_view_t* const bias = input_size > 5 ? (ccv_nnc_tensor_view_t*)inputs[5] : 0; | |||
| 26 | if (bias) // bias always requires a weight matrix. | |||
| 27 | { assert(w)((void) sizeof ((w) ? 1 : 0), __extension__ ({ if (w) ; else __assert_fail ("w", "scaled_dot_product_attention/ccv_nnc_scaled_dot_product_attention_cpu_ref.c" , 27, __extension__ __PRETTY_FUNCTION__); })); } | |||
| 28 | ccv_nnc_tensor_view_t* const c = (w) ? (ccv_nnc_tensor_view_t*)outputs[2] : (ccv_nnc_tensor_view_t*)outputs[0]; | |||
| 29 | const int q_nd = ccv_nnc_tensor_nd(q->info.dim); | |||
| 30 | assert(q_nd == 3 || q_nd == 4)((void) sizeof ((q_nd == 3 || q_nd == 4) ? 1 : 0), __extension__ ({ if (q_nd == 3 || q_nd == 4) ; else __assert_fail ("q_nd == 3 || q_nd == 4" , "scaled_dot_product_attention/ccv_nnc_scaled_dot_product_attention_cpu_ref.c" , 30, __extension__ __PRETTY_FUNCTION__); })); | |||
| 31 | const int k_nd = ccv_nnc_tensor_nd(k->info.dim); | |||
| 32 | assert(k_nd == 3 || k_nd == 4)((void) sizeof ((k_nd == 3 || k_nd == 4) ? 1 : 0), __extension__ ({ if (k_nd == 3 || k_nd == 4) ; else __assert_fail ("k_nd == 3 || k_nd == 4" , "scaled_dot_product_attention/ccv_nnc_scaled_dot_product_attention_cpu_ref.c" , 32, __extension__ __PRETTY_FUNCTION__); })); | |||
| 33 | const int v_nd = ccv_nnc_tensor_nd(v->info.dim); | |||
| 34 | assert(v_nd == 3 || v_nd == 4)((void) sizeof ((v_nd == 3 || v_nd == 4) ? 1 : 0), __extension__ ({ if (v_nd == 3 || v_nd == 4) ; else __assert_fail ("v_nd == 3 || v_nd == 4" , "scaled_dot_product_attention/ccv_nnc_scaled_dot_product_attention_cpu_ref.c" , 34, __extension__ __PRETTY_FUNCTION__); })); | |||
| 35 | const int c_nd = ccv_nnc_tensor_nd(c->info.dim); | |||
| 36 | assert(c_nd == 3 || c_nd == 4)((void) sizeof ((c_nd == 3 || c_nd == 4) ? 1 : 0), __extension__ ({ if (c_nd == 3 || c_nd == 4) ; else __assert_fail ("c_nd == 3 || c_nd == 4" , "scaled_dot_product_attention/ccv_nnc_scaled_dot_product_attention_cpu_ref.c" , 36, __extension__ __PRETTY_FUNCTION__); })); | |||
| 37 | assert(q_nd == k_nd && k_nd == v_nd && v_nd == c_nd)((void) sizeof ((q_nd == k_nd && k_nd == v_nd && v_nd == c_nd) ? 1 : 0), __extension__ ({ if (q_nd == k_nd && k_nd == v_nd && v_nd == c_nd) ; else __assert_fail ( "q_nd == k_nd && k_nd == v_nd && v_nd == c_nd" , "scaled_dot_product_attention/ccv_nnc_scaled_dot_product_attention_cpu_ref.c" , 37, __extension__ __PRETTY_FUNCTION__); })); | |||
| 38 | // Assuming this is float 32. | |||
| 39 | int qdim[CCV_NNC_MAX_DIM_ALLOC(12)]; | |||
| 40 | int kdim[CCV_NNC_MAX_DIM_ALLOC(12)]; | |||
| 41 | int vdim[CCV_NNC_MAX_DIM_ALLOC(12)]; | |||
| 42 | int cdim[CCV_NNC_MAX_DIM_ALLOC(12)]; | |||
| 43 | int amdim[CCV_NNC_MAX_DIM_ALLOC(12)]; | |||
| 44 | ccv_nnc_tensor_view_get_dim(q, qdim); | |||
| 45 | ccv_nnc_tensor_view_get_dim(k, kdim); | |||
| 46 | ccv_nnc_tensor_view_get_dim(v, vdim); | |||
| 47 | ccv_nnc_tensor_view_get_dim(c, cdim); | |||
| 48 | if (q_nd
| |||
| 49 | { | |||
| 50 | qdim[0] = qdim[1], qdim[1] = qdim[2], qdim[2] = 1; | |||
| 51 | kdim[0] = kdim[1], kdim[1] = kdim[2], kdim[2] = 1; | |||
| 52 | vdim[0] = vdim[1], vdim[1] = vdim[2], vdim[2] = 1; | |||
| 53 | cdim[0] = cdim[1], cdim[1] = cdim[2], cdim[2] = 1; | |||
| 54 | } | |||
| 55 | assert(qdim[0] == kdim[0] && kdim[0] == vdim[0] && vdim[0] == cdim[0])((void) sizeof ((qdim[0] == kdim[0] && kdim[0] == vdim [0] && vdim[0] == cdim[0]) ? 1 : 0), __extension__ ({ if (qdim[0] == kdim[0] && kdim[0] == vdim[0] && vdim[0] == cdim[0]) ; else __assert_fail ("qdim[0] == kdim[0] && kdim[0] == vdim[0] && vdim[0] == cdim[0]" , "scaled_dot_product_attention/ccv_nnc_scaled_dot_product_attention_cpu_ref.c" , 55, __extension__ __PRETTY_FUNCTION__); })); | |||
| 56 | assert(qdim[2] == cdim[2])((void) sizeof ((qdim[2] == cdim[2]) ? 1 : 0), __extension__ ( { if (qdim[2] == cdim[2]) ; else __assert_fail ("qdim[2] == cdim[2]" , "scaled_dot_product_attention/ccv_nnc_scaled_dot_product_attention_cpu_ref.c" , 56, __extension__ __PRETTY_FUNCTION__); })); | |||
| 57 | assert(kdim[2] == vdim[2])((void) sizeof ((kdim[2] == vdim[2]) ? 1 : 0), __extension__ ( { if (kdim[2] == vdim[2]) ; else __assert_fail ("kdim[2] == vdim[2]" , "scaled_dot_product_attention/ccv_nnc_scaled_dot_product_attention_cpu_ref.c" , 57, __extension__ __PRETTY_FUNCTION__); })); | |||
| 58 | assert(qdim[2] % kdim[2] == 0)((void) sizeof ((qdim[2] % kdim[2] == 0) ? 1 : 0), __extension__ ({ if (qdim[2] % kdim[2] == 0) ; else __assert_fail ("qdim[2] % kdim[2] == 0" , "scaled_dot_product_attention/ccv_nnc_scaled_dot_product_attention_cpu_ref.c" , 58, __extension__ __PRETTY_FUNCTION__); })); | |||
| 59 | assert(qdim[2] >= kdim[2])((void) sizeof ((qdim[2] >= kdim[2]) ? 1 : 0), __extension__ ({ if (qdim[2] >= kdim[2]) ; else __assert_fail ("qdim[2] >= kdim[2]" , "scaled_dot_product_attention/ccv_nnc_scaled_dot_product_attention_cpu_ref.c" , 59, __extension__ __PRETTY_FUNCTION__); })); | |||
| 60 | assert(qdim[3] == kdim[3])((void) sizeof ((qdim[3] == kdim[3]) ? 1 : 0), __extension__ ( { if (qdim[3] == kdim[3]) ; else __assert_fail ("qdim[3] == kdim[3]" , "scaled_dot_product_attention/ccv_nnc_scaled_dot_product_attention_cpu_ref.c" , 60, __extension__ __PRETTY_FUNCTION__); })); | |||
| 61 | assert(kdim[1] == vdim[1])((void) sizeof ((kdim[1] == vdim[1]) ? 1 : 0), __extension__ ( { if (kdim[1] == vdim[1]) ; else __assert_fail ("kdim[1] == vdim[1]" , "scaled_dot_product_attention/ccv_nnc_scaled_dot_product_attention_cpu_ref.c" , 61, __extension__ __PRETTY_FUNCTION__); })); | |||
| 62 | assert(cdim[1] == qdim[1])((void) sizeof ((cdim[1] == qdim[1]) ? 1 : 0), __extension__ ( { if (cdim[1] == qdim[1]) ; else __assert_fail ("cdim[1] == qdim[1]" , "scaled_dot_product_attention/ccv_nnc_scaled_dot_product_attention_cpu_ref.c" , 62, __extension__ __PRETTY_FUNCTION__); })); | |||
| 63 | assert(cdim[3] == vdim[3])((void) sizeof ((cdim[3] == vdim[3]) ? 1 : 0), __extension__ ( { if (cdim[3] == vdim[3]) ; else __assert_fail ("cdim[3] == vdim[3]" , "scaled_dot_product_attention/ccv_nnc_scaled_dot_product_attention_cpu_ref.c" , 63, __extension__ __PRETTY_FUNCTION__); })); | |||
| 64 | assert(CCV_NNC_MAX_DIM == 2)((void) sizeof (((2) == 2) ? 1 : 0), __extension__ ({ if ((2) == 2) ; else __assert_fail ("CCV_NNC_MAX_DIM == 2", "scaled_dot_product_attention/ccv_nnc_scaled_dot_product_attention_cpu_ref.c" , 64, __extension__ __PRETTY_FUNCTION__); })); // Need to change this logic for CCV_NNC_MAX_DIM == other number. | |||
| 65 | int qstride[CCV_NNC_MAX_DIM_ALLOC(12)]; | |||
| 66 | int kstride[CCV_NNC_MAX_DIM_ALLOC(12)]; | |||
| 67 | int vstride[CCV_NNC_MAX_DIM_ALLOC(12)]; | |||
| 68 | int cstride[CCV_NNC_MAX_DIM_ALLOC(12)]; | |||
| 69 | int amstride[CCV_NNC_MAX_DIM_ALLOC(12)]; | |||
| 70 | ccv_nnc_tensor_view_get_stride(q, qstride); | |||
| 71 | ccv_nnc_tensor_view_get_stride(k, kstride); | |||
| 72 | ccv_nnc_tensor_view_get_stride(v, vstride); | |||
| 73 | ccv_nnc_tensor_view_get_stride(c, cstride); | |||
| 74 | if (q_nd
| |||
| 75 | { | |||
| 76 | qstride[0] = qstride[1], qstride[1] = qstride[2], qstride[2] = qstride[3]; | |||
| 77 | kstride[0] = kstride[1], kstride[1] = kstride[2], kstride[2] = kstride[3]; | |||
| 78 | vstride[0] = vstride[1], vstride[1] = vstride[2], vstride[2] = vstride[3]; | |||
| 79 | cstride[0] = cstride[1], cstride[1] = cstride[2], cstride[2] = cstride[3]; | |||
| 80 | } | |||
| 81 | if (attn_mask) | |||
| 82 | { | |||
| 83 | ccv_nnc_tensor_view_get_dim(attn_mask, amdim); | |||
| 84 | ccv_nnc_tensor_view_get_stride(attn_mask, amstride); | |||
| 85 | assert(amdim[0] == qdim[0] || amdim[0] == 1)((void) sizeof ((amdim[0] == qdim[0] || amdim[0] == 1) ? 1 : 0 ), __extension__ ({ if (amdim[0] == qdim[0] || amdim[0] == 1) ; else __assert_fail ("amdim[0] == qdim[0] || amdim[0] == 1" , "scaled_dot_product_attention/ccv_nnc_scaled_dot_product_attention_cpu_ref.c" , 85, __extension__ __PRETTY_FUNCTION__); })); | |||
| 86 | assert(amdim[1] == qdim[2] || amdim[1] == 1)((void) sizeof ((amdim[1] == qdim[2] || amdim[1] == 1) ? 1 : 0 ), __extension__ ({ if (amdim[1] == qdim[2] || amdim[1] == 1) ; else __assert_fail ("amdim[1] == qdim[2] || amdim[1] == 1" , "scaled_dot_product_attention/ccv_nnc_scaled_dot_product_attention_cpu_ref.c" , 86, __extension__ __PRETTY_FUNCTION__); })); | |||
| 87 | assert(amdim[2] == qdim[1])((void) sizeof ((amdim[2] == qdim[1]) ? 1 : 0), __extension__ ({ if (amdim[2] == qdim[1]) ; else __assert_fail ("amdim[2] == qdim[1]" , "scaled_dot_product_attention/ccv_nnc_scaled_dot_product_attention_cpu_ref.c" , 87, __extension__ __PRETTY_FUNCTION__); })); | |||
| 88 | assert(amdim[3] == kdim[1])((void) sizeof ((amdim[3] == kdim[1]) ? 1 : 0), __extension__ ({ if (amdim[3] == kdim[1]) ; else __assert_fail ("amdim[3] == kdim[1]" , "scaled_dot_product_attention/ccv_nnc_scaled_dot_product_attention_cpu_ref.c" , 88, __extension__ __PRETTY_FUNCTION__); })); | |||
| 89 | } | |||
| 90 | int i[CCV_NNC_MAX_DIM(2) + 2]; | |||
| 91 | float* qk = ccv_nnc_stream_context_get_workspace(stream_context, sizeof(float) * qdim[1] * kdim[1], CCV_TENSOR_CPU_MEMORY); | |||
| 92 | const float* const qp = q->data.f32; | |||
| 93 | const float* const kp = k->data.f32; | |||
| 94 | const float* const vp = v->data.f32; | |||
| 95 | const float* const amp = attn_mask
| |||
| 96 | float* const cp = c->data.f32; | |||
| 97 | const float scale = cmd.info.scaled_dot_product_attention.scale; | |||
| 98 | const int is_causal = cmd.info.scaled_dot_product_attention.is_causal; | |||
| 99 | const int h_h_k_ratio = qdim[2] / kdim[2]; | |||
| 100 | assert(kdim[2] == vdim[2])((void) sizeof ((kdim[2] == vdim[2]) ? 1 : 0), __extension__ ( { if (kdim[2] == vdim[2]) ; else __assert_fail ("kdim[2] == vdim[2]" , "scaled_dot_product_attention/ccv_nnc_scaled_dot_product_attention_cpu_ref.c" , 100, __extension__ __PRETTY_FUNCTION__); })); | |||
| 101 | assert(qdim[2] >= kdim[2])((void) sizeof ((qdim[2] >= kdim[2]) ? 1 : 0), __extension__ ({ if (qdim[2] >= kdim[2]) ; else __assert_fail ("qdim[2] >= kdim[2]" , "scaled_dot_product_attention/ccv_nnc_scaled_dot_product_attention_cpu_ref.c" , 101, __extension__ __PRETTY_FUNCTION__); })); | |||
| 102 | assert(qdim[2] % kdim[2] == 0)((void) sizeof ((qdim[2] % kdim[2] == 0) ? 1 : 0), __extension__ ({ if (qdim[2] % kdim[2] == 0) ; else __assert_fail ("qdim[2] % kdim[2] == 0" , "scaled_dot_product_attention/ccv_nnc_scaled_dot_product_attention_cpu_ref.c" , 102, __extension__ __PRETTY_FUNCTION__); })); | |||
| 103 | for (i[0] = 0; i[0] < qdim[0]; i[0]++) | |||
| 104 | { | |||
| 105 | const float* const qp0 = qp + i[0] * qstride[0]; | |||
| 106 | const float* const kp0 = kp + i[0] * kstride[0]; | |||
| 107 | const float* const vp0 = vp + i[0] * vstride[0]; | |||
| 108 | const float* const amp0 = amp && amdim[0] > 1 ? amp + i[0] * amstride[0] : amp; | |||
| 109 | float* const cp0 = cp + i[0] * cstride[0]; | |||
| 110 | for (i[1] = 0; i[1] < qdim[2]; i[1]++) | |||
| 111 | { | |||
| 112 | const float* const qp1 = qp0 + i[1] * qstride[2]; | |||
| 113 | const float* const kp1 = kp0 + (i[1] / h_h_k_ratio) * kstride[2]; | |||
| 114 | const float* const vp1 = vp0 + (i[1] / h_h_k_ratio) * vstride[2]; | |||
| 115 | const float* const amp1 = amp
| |||
| 116 | float* const cp1 = cp0 + i[1] * cstride[2]; | |||
| 117 | // Compute Q @ K^T | |||
| 118 | parallel_for(x, qdim[1]){ int x; for ((x) = 0; (x) < (qdim[1]); (x)++) { { | |||
| 119 | int y, k; | |||
| 120 | const float* const qp2 = qp1 + x * qstride[1]; | |||
| 121 | float* const cp2 = cp1 + x * cstride[1]; | |||
| 122 | float* const qk0 = qk + x * kdim[1]; | |||
| 123 | const float* const amp2 = amp1
| |||
| 124 | if (attn_mask
| |||
| 125 | { | |||
| 126 | for (y = 0; y < kdim[1]; y++) | |||
| 127 | { | |||
| 128 | const float* const kp2 = kp1 + y * kstride[1]; | |||
| 129 | float v = 0; | |||
| 130 | for (k = 0; k < qdim[3]; k++) | |||
| 131 | v += qp2[k * qstride[3]] * kp2[k * kstride[3]]; | |||
| 132 | qk0[y] = scale * v + amp2[y * amstride[3]]; | |||
| ||||
| 133 | } | |||
| 134 | } else { | |||
| 135 | for (y = 0; y < kdim[1]; y++) | |||
| 136 | { | |||
| 137 | const float* const kp2 = kp1 + y * kstride[1]; | |||
| 138 | float v = 0; | |||
| 139 | for (k = 0; k < qdim[3]; k++) | |||
| 140 | v += qp2[k * qstride[3]] * kp2[k * kstride[3]]; | |||
| 141 | qk0[y] = scale * v; | |||
| 142 | } | |||
| 143 | } | |||
| 144 | // Compute softmax on qk. | |||
| 145 | if (is_causal) | |||
| 146 | { | |||
| 147 | const int x_end = ccv_max(x - qdim[1] + kdim[1] + 1, 0)({ typeof (x - qdim[1] + kdim[1] + 1) _a = (x - qdim[1] + kdim [1] + 1); typeof (0) _b = (0); (_a > _b) ? _a : _b; }); | |||
| 148 | for (y = x_end; y < kdim[1]; y++) | |||
| 149 | qk0[y] = 0; | |||
| 150 | double maxval = qk0[0]; | |||
| 151 | for (y = 1; y < x_end; y++) | |||
| 152 | if (qk0[y] > maxval) | |||
| 153 | maxval = qk0[y]; | |||
| 154 | double sumval = 0; | |||
| 155 | for (y = 0; y < x_end; y++) | |||
| 156 | sumval += (qk0[y] = expf(qk0[y] - maxval)); | |||
| 157 | sumval = 1.0 / sumval; | |||
| 158 | for (y = 0; y < x_end; y++) | |||
| 159 | qk0[y] *= sumval; | |||
| 160 | } else { | |||
| 161 | double maxval = qk0[0]; | |||
| 162 | for (y = 1; y < kdim[1]; y++) | |||
| 163 | if (qk0[y] > maxval) | |||
| 164 | maxval = qk0[y]; | |||
| 165 | double sumval = 0; | |||
| 166 | for (y = 0; y < kdim[1]; y++) | |||
| 167 | sumval += (qk0[y] = expf(qk0[y] - maxval)); | |||
| 168 | sumval = 1.0 / sumval; | |||
| 169 | for (y = 0; y < kdim[1]; y++) | |||
| 170 | qk0[y] *= sumval; | |||
| 171 | } | |||
| 172 | for (k = 0; k < vdim[3]; k++) | |||
| 173 | cp2[k * cstride[3]] = 0; | |||
| 174 | for (y = 0; y < kdim[1]; y++) | |||
| 175 | { | |||
| 176 | const float* const vp2 = vp1 + y * vstride[1]; | |||
| 177 | const float v = qk0[y]; | |||
| 178 | for (k = 0; k < vdim[3]; k++) | |||
| 179 | cp2[k * cstride[3]] += v * vp2[k * vstride[3]]; | |||
| 180 | } | |||
| 181 | } parallel_endfor} } | |||
| 182 | } | |||
| 183 | } | |||
| 184 | if (w) | |||
| 185 | { | |||
| 186 | const int num_heads = cdim[2]; | |||
| 187 | ccv_nnc_tensor_view_t* const d = (ccv_nnc_tensor_view_t*)outputs[0]; | |||
| 188 | const int w_nd = ccv_nnc_tensor_nd(w->info.dim); | |||
| 189 | assert(w_nd == 2)((void) sizeof ((w_nd == 2) ? 1 : 0), __extension__ ({ if (w_nd == 2) ; else __assert_fail ("w_nd == 2", "scaled_dot_product_attention/ccv_nnc_scaled_dot_product_attention_cpu_ref.c" , 189, __extension__ __PRETTY_FUNCTION__); })); | |||
| 190 | assert(CCV_IS_TENSOR_CONTIGUOUS(w))((void) sizeof (((!((*(int*)(w)) & CCV_TENSOR_VIEW) || (( (ccv_nnc_tensor_view_t*)w)->contiguous == 1))) ? 1 : 0), __extension__ ({ if ((!((*(int*)(w)) & CCV_TENSOR_VIEW) || (((ccv_nnc_tensor_view_t *)w)->contiguous == 1))) ; else __assert_fail ("CCV_IS_TENSOR_CONTIGUOUS(w)" , "scaled_dot_product_attention/ccv_nnc_scaled_dot_product_attention_cpu_ref.c" , 190, __extension__ __PRETTY_FUNCTION__); })); | |||
| 191 | const int d_nd = ccv_nnc_tensor_nd(d->info.dim); | |||
| 192 | assert(d_nd == 3)((void) sizeof ((d_nd == 3) ? 1 : 0), __extension__ ({ if (d_nd == 3) ; else __assert_fail ("d_nd == 3", "scaled_dot_product_attention/ccv_nnc_scaled_dot_product_attention_cpu_ref.c" , 192, __extension__ __PRETTY_FUNCTION__); })); | |||
| 193 | int ddim[CCV_NNC_MAX_DIM_ALLOC(12)]; | |||
| 194 | int dstride[CCV_NNC_MAX_DIM_ALLOC(12)]; | |||
| 195 | ccv_nnc_tensor_view_get_dim(d, ddim); | |||
| 196 | ccv_nnc_tensor_view_get_stride(d, dstride); | |||
| 197 | assert(ddim[2] == cdim[1])((void) sizeof ((ddim[2] == cdim[1]) ? 1 : 0), __extension__ ( { if (ddim[2] == cdim[1]) ; else __assert_fail ("ddim[2] == cdim[1]" , "scaled_dot_product_attention/ccv_nnc_scaled_dot_product_attention_cpu_ref.c" , 197, __extension__ __PRETTY_FUNCTION__); })); | |||
| 198 | assert(ddim[3] == num_heads * cdim[3])((void) sizeof ((ddim[3] == num_heads * cdim[3]) ? 1 : 0), __extension__ ({ if (ddim[3] == num_heads * cdim[3]) ; else __assert_fail ( "ddim[3] == num_heads * cdim[3]", "scaled_dot_product_attention/ccv_nnc_scaled_dot_product_attention_cpu_ref.c" , 198, __extension__ __PRETTY_FUNCTION__); })); | |||
| 199 | assert(w->info.dim[1] == ddim[3])((void) sizeof ((w->info.dim[1] == ddim[3]) ? 1 : 0), __extension__ ({ if (w->info.dim[1] == ddim[3]) ; else __assert_fail ("w->info.dim[1] == ddim[3]" , "scaled_dot_product_attention/ccv_nnc_scaled_dot_product_attention_cpu_ref.c" , 199, __extension__ __PRETTY_FUNCTION__); })); | |||
| 200 | assert(w->info.dim[0] == ddim[3])((void) sizeof ((w->info.dim[0] == ddim[3]) ? 1 : 0), __extension__ ({ if (w->info.dim[0] == ddim[3]) ; else __assert_fail ("w->info.dim[0] == ddim[3]" , "scaled_dot_product_attention/ccv_nnc_scaled_dot_product_attention_cpu_ref.c" , 200, __extension__ __PRETTY_FUNCTION__); })); | |||
| 201 | float* const dp = d->data.f32; | |||
| 202 | const float* const wp = w->data.f32; | |||
| 203 | const float* const cp = c->data.f32; | |||
| 204 | if (bias) | |||
| 205 | { | |||
| 206 | assert(ccv_nnc_tensor_count(bias->info) == ddim[3])((void) sizeof ((ccv_nnc_tensor_count(bias->info) == ddim[ 3]) ? 1 : 0), __extension__ ({ if (ccv_nnc_tensor_count(bias-> info) == ddim[3]) ; else __assert_fail ("ccv_nnc_tensor_count(bias->info) == ddim[3]" , "scaled_dot_product_attention/ccv_nnc_scaled_dot_product_attention_cpu_ref.c" , 206, __extension__ __PRETTY_FUNCTION__); })); | |||
| 207 | assert(CCV_IS_TENSOR_CONTIGUOUS(bias))((void) sizeof (((!((*(int*)(bias)) & CCV_TENSOR_VIEW) || (((ccv_nnc_tensor_view_t*)bias)->contiguous == 1))) ? 1 : 0), __extension__ ({ if ((!((*(int*)(bias)) & CCV_TENSOR_VIEW ) || (((ccv_nnc_tensor_view_t*)bias)->contiguous == 1))) ; else __assert_fail ("CCV_IS_TENSOR_CONTIGUOUS(bias)", "scaled_dot_product_attention/ccv_nnc_scaled_dot_product_attention_cpu_ref.c" , 207, __extension__ __PRETTY_FUNCTION__); })); | |||
| 208 | const float* const biasp = bias->data.f32; | |||
| 209 | for (i[0] = 0; i[0] < ddim[1]; i[0]++) | |||
| 210 | { | |||
| 211 | const float* const cp0 = cp + i[0] * cstride[0]; | |||
| 212 | float* const dp0 = dp + i[0] * dstride[1]; | |||
| 213 | parallel_for(y, ddim[2]){ int y; for ((y) = 0; (y) < (ddim[2]); (y)++) { { | |||
| 214 | int x, j, k; | |||
| 215 | const float* const cp1 = cp0 + y * cstride[1]; | |||
| 216 | float* const dp1 = dp0 + y * dstride[2]; | |||
| 217 | for (x = 0; x < ddim[3]; x++) | |||
| 218 | { | |||
| 219 | const float* const wp0 = wp + x * ddim[3]; | |||
| 220 | float v = biasp[x]; | |||
| 221 | for (j = 0; j < num_heads; j++) | |||
| 222 | { | |||
| 223 | const float* const cp2 = cp1 + j * cstride[2]; | |||
| 224 | for (k = 0; k < cdim[3]; k++) | |||
| 225 | v += wp0[j * cdim[3] + k] * cp2[k * cstride[3]]; | |||
| 226 | } | |||
| 227 | dp1[x * dstride[3]] = v; | |||
| 228 | } | |||
| 229 | } parallel_endfor} } | |||
| 230 | } | |||
| 231 | } else { | |||
| 232 | for (i[0] = 0; i[0] < ddim[1]; i[0]++) | |||
| 233 | { | |||
| 234 | const float* const cp0 = cp + i[0] * cstride[0]; | |||
| 235 | float* const dp0 = dp + i[0] * dstride[1]; | |||
| 236 | parallel_for(y, ddim[2]){ int y; for ((y) = 0; (y) < (ddim[2]); (y)++) { { | |||
| 237 | int x, j, k; | |||
| 238 | const float* const cp1 = cp0 + y * cstride[1]; | |||
| 239 | float* const dp1 = dp0 + y * dstride[2]; | |||
| 240 | for (x = 0; x < ddim[3]; x++) | |||
| 241 | { | |||
| 242 | const float* const wp0 = wp + x * ddim[3]; | |||
| 243 | float v = 0; | |||
| 244 | for (j = 0; j < num_heads; j++) | |||
| 245 | { | |||
| 246 | const float* const cp2 = cp1 + j * cstride[2]; | |||
| 247 | for (k = 0; k < cdim[3]; k++) | |||
| 248 | v += wp0[j * cdim[3] + k] * cp2[k * cstride[3]]; | |||
| 249 | } | |||
| 250 | dp1[x * dstride[3]] = v; | |||
| 251 | } | |||
| 252 | } parallel_endfor} } | |||
| 253 | } | |||
| 254 | } | |||
| 255 | } | |||
| 256 | return CCV_NNC_EXEC_SUCCESS; | |||
| 257 | } | |||
| 258 | ||||
| 259 | static int _ccv_nnc_scaled_dot_product_attention_back(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint, const int flags, ccv_nnc_tensor_t* const* const inputs, const int input_size, ccv_nnc_tensor_t* const* const outputs, const int output_size, ccv_nnc_stream_context_t* const stream_context) | |||
| 260 | { | |||
| 261 | // Assuming no saved_softmax, we need to recompute from q, k, v. | |||
| 262 | // We cannot do this with masks (yet). | |||
| 263 | assert(input_size >= 6)((void) sizeof ((input_size >= 6) ? 1 : 0), __extension__ ( { if (input_size >= 6) ; else __assert_fail ("input_size >= 6" , "scaled_dot_product_attention/ccv_nnc_scaled_dot_product_attention_cpu_ref.c" , 263, __extension__ __PRETTY_FUNCTION__); })); | |||
| 264 | ccv_nnc_tensor_view_t* const g = (ccv_nnc_tensor_view_t*)inputs[0]; | |||
| 265 | ccv_nnc_tensor_view_t* const q = (ccv_nnc_tensor_view_t*)inputs[3]; | |||
| 266 | ccv_nnc_tensor_view_t* const k = (ccv_nnc_tensor_view_t*)inputs[4]; | |||
| 267 | ccv_nnc_tensor_view_t* const v = (ccv_nnc_tensor_view_t*)inputs[5]; | |||
| 268 | ccv_nnc_tensor_view_t* const dq = (ccv_nnc_tensor_view_t*)outputs[0]; | |||
| 269 | ccv_nnc_tensor_view_t* const dk = (ccv_nnc_tensor_view_t*)outputs[1]; | |||
| 270 | ccv_nnc_tensor_view_t* const dv = (ccv_nnc_tensor_view_t*)outputs[2]; | |||
| 271 | const int q_nd = ccv_nnc_tensor_nd(q->info.dim); | |||
| 272 | assert(q_nd == 3 || q_nd == 4)((void) sizeof ((q_nd == 3 || q_nd == 4) ? 1 : 0), __extension__ ({ if (q_nd == 3 || q_nd == 4) ; else __assert_fail ("q_nd == 3 || q_nd == 4" , "scaled_dot_product_attention/ccv_nnc_scaled_dot_product_attention_cpu_ref.c" , 272, __extension__ __PRETTY_FUNCTION__); })); | |||
| 273 | const int k_nd = ccv_nnc_tensor_nd(k->info.dim); | |||
| 274 | assert(k_nd == 3 || k_nd == 4)((void) sizeof ((k_nd == 3 || k_nd == 4) ? 1 : 0), __extension__ ({ if (k_nd == 3 || k_nd == 4) ; else __assert_fail ("k_nd == 3 || k_nd == 4" , "scaled_dot_product_attention/ccv_nnc_scaled_dot_product_attention_cpu_ref.c" , 274, __extension__ __PRETTY_FUNCTION__); })); | |||
| 275 | const int v_nd = ccv_nnc_tensor_nd(v->info.dim); | |||
| 276 | assert(v_nd == 3 || v_nd == 4)((void) sizeof ((v_nd == 3 || v_nd == 4) ? 1 : 0), __extension__ ({ if (v_nd == 3 || v_nd == 4) ; else __assert_fail ("v_nd == 3 || v_nd == 4" , "scaled_dot_product_attention/ccv_nnc_scaled_dot_product_attention_cpu_ref.c" , 276, __extension__ __PRETTY_FUNCTION__); })); | |||
| 277 | const int g_nd = ccv_nnc_tensor_nd(g->info.dim); | |||
| 278 | assert(g_nd == 3 || g_nd == 4)((void) sizeof ((g_nd == 3 || g_nd == 4) ? 1 : 0), __extension__ ({ if (g_nd == 3 || g_nd == 4) ; else __assert_fail ("g_nd == 3 || g_nd == 4" , "scaled_dot_product_attention/ccv_nnc_scaled_dot_product_attention_cpu_ref.c" , 278, __extension__ __PRETTY_FUNCTION__); })); | |||
| 279 | const int dq_nd = ccv_nnc_tensor_nd(dq->info.dim); | |||
| 280 | assert(dq_nd == 3 || dq_nd == 4)((void) sizeof ((dq_nd == 3 || dq_nd == 4) ? 1 : 0), __extension__ ({ if (dq_nd == 3 || dq_nd == 4) ; else __assert_fail ("dq_nd == 3 || dq_nd == 4" , "scaled_dot_product_attention/ccv_nnc_scaled_dot_product_attention_cpu_ref.c" , 280, __extension__ __PRETTY_FUNCTION__); })); | |||
| 281 | assert(dq_nd == q_nd)((void) sizeof ((dq_nd == q_nd) ? 1 : 0), __extension__ ({ if (dq_nd == q_nd) ; else __assert_fail ("dq_nd == q_nd", "scaled_dot_product_attention/ccv_nnc_scaled_dot_product_attention_cpu_ref.c" , 281, __extension__ __PRETTY_FUNCTION__); })); | |||
| 282 | const int dk_nd = ccv_nnc_tensor_nd(dk->info.dim); | |||
| 283 | assert(dk_nd == 3 || dk_nd == 4)((void) sizeof ((dk_nd == 3 || dk_nd == 4) ? 1 : 0), __extension__ ({ if (dk_nd == 3 || dk_nd == 4) ; else __assert_fail ("dk_nd == 3 || dk_nd == 4" , "scaled_dot_product_attention/ccv_nnc_scaled_dot_product_attention_cpu_ref.c" , 283, __extension__ __PRETTY_FUNCTION__); })); | |||
| 284 | assert(dk_nd == k_nd)((void) sizeof ((dk_nd == k_nd) ? 1 : 0), __extension__ ({ if (dk_nd == k_nd) ; else __assert_fail ("dk_nd == k_nd", "scaled_dot_product_attention/ccv_nnc_scaled_dot_product_attention_cpu_ref.c" , 284, __extension__ __PRETTY_FUNCTION__); })); | |||
| 285 | const int dv_nd = ccv_nnc_tensor_nd(dv->info.dim); | |||
| 286 | assert(dv_nd == 3 || dv_nd == 4)((void) sizeof ((dv_nd == 3 || dv_nd == 4) ? 1 : 0), __extension__ ({ if (dv_nd == 3 || dv_nd == 4) ; else __assert_fail ("dv_nd == 3 || dv_nd == 4" , "scaled_dot_product_attention/ccv_nnc_scaled_dot_product_attention_cpu_ref.c" , 286, __extension__ __PRETTY_FUNCTION__); })); | |||
| 287 | assert(dv_nd == v_nd)((void) sizeof ((dv_nd == v_nd) ? 1 : 0), __extension__ ({ if (dv_nd == v_nd) ; else __assert_fail ("dv_nd == v_nd", "scaled_dot_product_attention/ccv_nnc_scaled_dot_product_attention_cpu_ref.c" , 287, __extension__ __PRETTY_FUNCTION__); })); | |||
| 288 | assert(q_nd == k_nd && k_nd == v_nd && v_nd == g_nd)((void) sizeof ((q_nd == k_nd && k_nd == v_nd && v_nd == g_nd) ? 1 : 0), __extension__ ({ if (q_nd == k_nd && k_nd == v_nd && v_nd == g_nd) ; else __assert_fail ( "q_nd == k_nd && k_nd == v_nd && v_nd == g_nd" , "scaled_dot_product_attention/ccv_nnc_scaled_dot_product_attention_cpu_ref.c" , 288, __extension__ __PRETTY_FUNCTION__); })); | |||
| 289 | // Assuming this is float 32. | |||
| 290 | int qdim[CCV_NNC_MAX_DIM_ALLOC(12)]; | |||
| 291 | int kdim[CCV_NNC_MAX_DIM_ALLOC(12)]; | |||
| 292 | int vdim[CCV_NNC_MAX_DIM_ALLOC(12)]; | |||
| 293 | int gdim[CCV_NNC_MAX_DIM_ALLOC(12)]; | |||
| 294 | int dqdim[CCV_NNC_MAX_DIM_ALLOC(12)]; | |||
| 295 | int dkdim[CCV_NNC_MAX_DIM_ALLOC(12)]; | |||
| 296 | int dvdim[CCV_NNC_MAX_DIM_ALLOC(12)]; | |||
| 297 | ccv_nnc_tensor_view_get_dim(q, qdim); | |||
| 298 | ccv_nnc_tensor_view_get_dim(k, kdim); | |||
| 299 | ccv_nnc_tensor_view_get_dim(v, vdim); | |||
| 300 | ccv_nnc_tensor_view_get_dim(g, gdim); | |||
| 301 | ccv_nnc_tensor_view_get_dim(dq, dqdim); | |||
| 302 | ccv_nnc_tensor_view_get_dim(dk, dkdim); | |||
| 303 | ccv_nnc_tensor_view_get_dim(dv, dvdim); | |||
| 304 | if (q_nd == 3) | |||
| 305 | { | |||
| 306 | qdim[0] = qdim[1], qdim[1] = qdim[2], qdim[2] = 1; | |||
| 307 | kdim[0] = kdim[1], kdim[1] = kdim[2], kdim[2] = 1; | |||
| 308 | vdim[0] = vdim[1], vdim[1] = vdim[2], vdim[2] = 1; | |||
| 309 | gdim[0] = gdim[1], gdim[1] = gdim[2], gdim[2] = 1; | |||
| 310 | dqdim[0] = dqdim[1], dqdim[1] = dqdim[2], dqdim[2] = 1; | |||
| 311 | dkdim[0] = dkdim[1], dkdim[1] = dkdim[2], dkdim[2] = 1; | |||
| 312 | dvdim[0] = dvdim[1], dvdim[1] = dvdim[2], dvdim[2] = 1; | |||
| 313 | } | |||
| 314 | assert(qdim[0] == kdim[0] && kdim[0] == vdim[0] && vdim[0] == gdim[0])((void) sizeof ((qdim[0] == kdim[0] && kdim[0] == vdim [0] && vdim[0] == gdim[0]) ? 1 : 0), __extension__ ({ if (qdim[0] == kdim[0] && kdim[0] == vdim[0] && vdim[0] == gdim[0]) ; else __assert_fail ("qdim[0] == kdim[0] && kdim[0] == vdim[0] && vdim[0] == gdim[0]" , "scaled_dot_product_attention/ccv_nnc_scaled_dot_product_attention_cpu_ref.c" , 314, __extension__ __PRETTY_FUNCTION__); })); | |||
| 315 | assert(qdim[2] == gdim[2])((void) sizeof ((qdim[2] == gdim[2]) ? 1 : 0), __extension__ ( { if (qdim[2] == gdim[2]) ; else __assert_fail ("qdim[2] == gdim[2]" , "scaled_dot_product_attention/ccv_nnc_scaled_dot_product_attention_cpu_ref.c" , 315, __extension__ __PRETTY_FUNCTION__); })); | |||
| 316 | assert(kdim[2] == vdim[2])((void) sizeof ((kdim[2] == vdim[2]) ? 1 : 0), __extension__ ( { if (kdim[2] == vdim[2]) ; else __assert_fail ("kdim[2] == vdim[2]" , "scaled_dot_product_attention/ccv_nnc_scaled_dot_product_attention_cpu_ref.c" , 316, __extension__ __PRETTY_FUNCTION__); })); | |||
| 317 | assert(qdim[2] % kdim[2] == 0)((void) sizeof ((qdim[2] % kdim[2] == 0) ? 1 : 0), __extension__ ({ if (qdim[2] % kdim[2] == 0) ; else __assert_fail ("qdim[2] % kdim[2] == 0" , "scaled_dot_product_attention/ccv_nnc_scaled_dot_product_attention_cpu_ref.c" , 317, __extension__ __PRETTY_FUNCTION__); })); | |||
| 318 | assert(qdim[2] >= kdim[2])((void) sizeof ((qdim[2] >= kdim[2]) ? 1 : 0), __extension__ ({ if (qdim[2] >= kdim[2]) ; else __assert_fail ("qdim[2] >= kdim[2]" , "scaled_dot_product_attention/ccv_nnc_scaled_dot_product_attention_cpu_ref.c" , 318, __extension__ __PRETTY_FUNCTION__); })); | |||
| 319 | assert(qdim[3] == kdim[3])((void) sizeof ((qdim[3] == kdim[3]) ? 1 : 0), __extension__ ( { if (qdim[3] == kdim[3]) ; else __assert_fail ("qdim[3] == kdim[3]" , "scaled_dot_product_attention/ccv_nnc_scaled_dot_product_attention_cpu_ref.c" , 319, __extension__ __PRETTY_FUNCTION__); })); | |||
| 320 | assert(kdim[1] == vdim[1])((void) sizeof ((kdim[1] == vdim[1]) ? 1 : 0), __extension__ ( { if (kdim[1] == vdim[1]) ; else __assert_fail ("kdim[1] == vdim[1]" , "scaled_dot_product_attention/ccv_nnc_scaled_dot_product_attention_cpu_ref.c" , 320, __extension__ __PRETTY_FUNCTION__); })); | |||
| 321 | assert(gdim[1] == qdim[1])((void) sizeof ((gdim[1] == qdim[1]) ? 1 : 0), __extension__ ( { if (gdim[1] == qdim[1]) ; else __assert_fail ("gdim[1] == qdim[1]" , "scaled_dot_product_attention/ccv_nnc_scaled_dot_product_attention_cpu_ref.c" , 321, __extension__ __PRETTY_FUNCTION__); })); | |||
| 322 | assert(gdim[3] == vdim[3])((void) sizeof ((gdim[3] == vdim[3]) ? 1 : 0), __extension__ ( { if (gdim[3] == vdim[3]) ; else __assert_fail ("gdim[3] == vdim[3]" , "scaled_dot_product_attention/ccv_nnc_scaled_dot_product_attention_cpu_ref.c" , 322, __extension__ __PRETTY_FUNCTION__); })); | |||
| 323 | assert(CCV_NNC_MAX_DIM == 2)((void) sizeof (((2) == 2) ? 1 : 0), __extension__ ({ if ((2) == 2) ; else __assert_fail ("CCV_NNC_MAX_DIM == 2", "scaled_dot_product_attention/ccv_nnc_scaled_dot_product_attention_cpu_ref.c" , 323, __extension__ __PRETTY_FUNCTION__); })); // Need to change this logic for CCV_NNC_MAX_DIM == other number. | |||
| 324 | int qstride[CCV_NNC_MAX_DIM_ALLOC(12)]; | |||
| 325 | int kstride[CCV_NNC_MAX_DIM_ALLOC(12)]; | |||
| 326 | int vstride[CCV_NNC_MAX_DIM_ALLOC(12)]; | |||
| 327 | int gstride[CCV_NNC_MAX_DIM_ALLOC(12)]; | |||
| 328 | int dqstride[CCV_NNC_MAX_DIM_ALLOC(12)]; | |||
| 329 | int dkstride[CCV_NNC_MAX_DIM_ALLOC(12)]; | |||
| 330 | int dvstride[CCV_NNC_MAX_DIM_ALLOC(12)]; | |||
| 331 | ccv_nnc_tensor_view_get_stride(q, qstride); | |||
| 332 | ccv_nnc_tensor_view_get_stride(k, kstride); | |||
| 333 | ccv_nnc_tensor_view_get_stride(v, vstride); | |||
| 334 | ccv_nnc_tensor_view_get_stride(g, gstride); | |||
| 335 | ccv_nnc_tensor_view_get_stride(dq, dqstride); | |||
| 336 | ccv_nnc_tensor_view_get_stride(dk, dkstride); | |||
| 337 | ccv_nnc_tensor_view_get_stride(dv, dvstride); | |||
| 338 | if (q_nd == 3) | |||
| 339 | { | |||
| 340 | qstride[0] = qstride[1], qstride[1] = qstride[2], qstride[2] = qstride[3]; | |||
| 341 | kstride[0] = kstride[1], kstride[1] = kstride[2], kstride[2] = kstride[3]; | |||
| 342 | vstride[0] = vstride[1], vstride[1] = vstride[2], vstride[2] = vstride[3]; | |||
| 343 | gstride[0] = gstride[1], gstride[1] = gstride[2], gstride[2] = gstride[3]; | |||
| 344 | dqstride[0] = dqstride[1], dqstride[1] = dqstride[2], dqstride[2] = dqstride[3]; | |||
| 345 | dkstride[0] = dkstride[1], dkstride[1] = dkstride[2], dkstride[2] = dkstride[3]; | |||
| 346 | dvstride[0] = dvstride[1], dvstride[1] = dvstride[2], dvstride[2] = dvstride[3]; | |||
| 347 | } | |||
| 348 | int i[CCV_NNC_MAX_DIM(2) + 2]; | |||
| 349 | float* qk = ccv_nnc_stream_context_get_workspace(stream_context, sizeof(float) * 2 * kdim[1], CCV_TENSOR_CPU_MEMORY); | |||
| 350 | const float* const qp = q->data.f32; | |||
| 351 | const float* const kp = k->data.f32; | |||
| 352 | const float* const vp = v->data.f32; | |||
| 353 | const float* const gp = g->data.f32; | |||
| 354 | float* const dqp = dq->data.f32; | |||
| 355 | float* const dkp = dk->data.f32; | |||
| 356 | float* const dvp = dv->data.f32; | |||
| 357 | const float scale = cmd.info.scaled_dot_product_attention.scale; | |||
| 358 | const int is_causal = cmd.info.scaled_dot_product_attention.is_causal; | |||
| 359 | const int h_h_k_ratio = qdim[2] / kdim[2]; | |||
| 360 | for (i[0] = 0; i[0] < qdim[0]; i[0]++) | |||
| 361 | { | |||
| 362 | const float* const qp0 = qp + i[0] * qstride[0]; | |||
| 363 | const float* const kp0 = kp + i[0] * kstride[0]; | |||
| 364 | const float* const vp0 = vp + i[0] * vstride[0]; | |||
| 365 | const float* const gp0 = gp + i[0] * gstride[0]; | |||
| 366 | float* const dqp0 = dqp + i[0] * dqstride[0]; | |||
| 367 | float* const dkp0 = dkp + i[0] * dkstride[0]; | |||
| 368 | float* const dvp0 = dvp + i[0] * dvstride[0]; | |||
| 369 | for (i[1] = 0; i[1] < qdim[2]; i[1]++) | |||
| 370 | { | |||
| 371 | const float* const qp1 = qp0 + i[1] * qstride[2]; | |||
| 372 | const float* const kp1 = kp0 + (i[1] / h_h_k_ratio) * kstride[2]; | |||
| 373 | const float* const vp1 = vp0 + (i[1] / h_h_k_ratio) * vstride[2]; | |||
| 374 | const float* const gp1 = gp0 + i[1] * gstride[2]; | |||
| 375 | float* const dqp1 = dqp0 + i[1] * dqstride[2]; | |||
| 376 | float* const dkp1 = dkp0 + (i[1] / h_h_k_ratio) * dkstride[2]; | |||
| 377 | float* const dvp1 = dvp0 + (i[1] / h_h_k_ratio) * dvstride[2]; | |||
| 378 | // Compute Q @ K^T | |||
| 379 | int x, y, k; | |||
| 380 | for (x = 0; x < qdim[1]; x++) | |||
| 381 | { | |||
| 382 | float* const dqp2 = dqp1 + x * dqstride[1]; | |||
| 383 | for (k = 0; k < qdim[3]; k++) | |||
| 384 | dqp2[k * dqstride[3]] = 0; | |||
| 385 | } | |||
| 386 | // Only zero out when it is at 0-index. | |||
| 387 | if (i[1] % h_h_k_ratio == 0) | |||
| 388 | for (y = 0; y < kdim[1]; y++) | |||
| 389 | { | |||
| 390 | float* const dkp2 = dkp1 + y * dkstride[1]; | |||
| 391 | for (k = 0; k < qdim[3]; k++) | |||
| 392 | dkp2[k * dkstride[3]] = 0; | |||
| 393 | } | |||
| 394 | // Only zero out when it is at 0-index. | |||
| 395 | if (i[1] % h_h_k_ratio == 0) | |||
| 396 | for (y = 0; y < kdim[1]; y++) | |||
| 397 | { | |||
| 398 | float* const dvp2 = dvp1 + y * dvstride[1]; | |||
| 399 | for (k = 0; k < vdim[3]; k++) | |||
| 400 | dvp2[k * dvstride[3]] = 0; | |||
| 401 | } | |||
| 402 | for (x = 0; x < qdim[1]; x++) | |||
| 403 | { | |||
| 404 | const float* const qp2 = qp1 + x * qstride[1]; | |||
| 405 | const float* const gp2 = gp1 + x * gstride[1]; | |||
| 406 | float* const qk0 = qk; | |||
| 407 | float* const qks0 = qk + kdim[1]; | |||
| 408 | for (y = 0; y < kdim[1]; y++) | |||
| 409 | { | |||
| 410 | const float* const kp2 = kp1 + y * kstride[1]; | |||
| 411 | float v = 0; | |||
| 412 | for (k = 0; k < qdim[3]; k++) | |||
| 413 | v += qp2[k * qstride[3]] * kp2[k * kstride[3]]; | |||
| 414 | qk0[y] = scale * v; | |||
| 415 | } | |||
| 416 | // Compute softmax on qk. | |||
| 417 | if (is_causal) | |||
| 418 | { | |||
| 419 | const int x_end = ccv_max(x - qdim[1] + kdim[1] + 1, 0)({ typeof (x - qdim[1] + kdim[1] + 1) _a = (x - qdim[1] + kdim [1] + 1); typeof (0) _b = (0); (_a > _b) ? _a : _b; }); | |||
| 420 | for (y = x_end; y < kdim[1]; y++) | |||
| 421 | qk0[y] = 0; | |||
| 422 | double maxval = qk0[0]; | |||
| 423 | for (y = 1; y < x_end; y++) | |||
| 424 | if (qk0[y] > maxval) | |||
| 425 | maxval = qk0[y]; | |||
| 426 | double sumval = 0; | |||
| 427 | for (y = 0; y < x_end; y++) | |||
| 428 | sumval += (qk0[y] = expf(qk0[y] - maxval)); | |||
| 429 | sumval = 1.0 / sumval; | |||
| 430 | for (y = 0; y < x_end; y++) | |||
| 431 | qk0[y] *= sumval; | |||
| 432 | } else { | |||
| 433 | double maxval = qk0[0]; | |||
| 434 | for (y = 1; y < kdim[1]; y++) | |||
| 435 | if (qk0[y] > maxval) | |||
| 436 | maxval = qk0[y]; | |||
| 437 | double sumval = 0; | |||
| 438 | for (y = 0; y < kdim[1]; y++) | |||
| 439 | sumval += (qk0[y] = expf(qk0[y] - maxval)); | |||
| 440 | sumval = 1.0 / sumval; | |||
| 441 | for (y = 0; y < kdim[1]; y++) | |||
| 442 | qk0[y] *= sumval; | |||
| 443 | } | |||
| 444 | for (y = 0; y < kdim[1]; y++) | |||
| 445 | { | |||
| 446 | float* const dvp2 = dvp1 + y * dvstride[1]; | |||
| 447 | const float v = qk0[y]; | |||
| 448 | for (k = 0; k < vdim[3]; k++) | |||
| 449 | dvp2[k * dvstride[3]] += v * gp2[k * gstride[3]]; | |||
| 450 | } | |||
| 451 | double sumval = 0; | |||
| 452 | for (y = 0; y < kdim[1]; y++) | |||
| 453 | { | |||
| 454 | const float* const vp2 = vp1 + y * vstride[1]; | |||
| 455 | float v = 0; | |||
| 456 | for (k = 0; k < vdim[3]; k++) | |||
| 457 | v += gp2[k * gstride[3]] * vp2[k * vstride[3]]; | |||
| 458 | qks0[y] = v; | |||
| 459 | sumval += v * qk0[y]; | |||
| 460 | } | |||
| 461 | for (y = 0; y < kdim[1]; y++) | |||
| 462 | qk0[y] = (qks0[y] - sumval) * qk0[y]; | |||
| 463 | float* const dqp2 = dqp1 + x * dqstride[1]; | |||
| 464 | for (y = 0; y < kdim[1]; y++) | |||
| 465 | { | |||
| 466 | const float* const kp2 = kp1 + y * kstride[1]; | |||
| 467 | float* const dkp2 = dkp1 + y * dkstride[1]; | |||
| 468 | const float v = scale * qk0[y]; | |||
| 469 | for (k = 0; k < qdim[3]; k++) | |||
| 470 | { | |||
| 471 | dqp2[k * dqstride[3]] += v * kp2[k * kstride[3]]; | |||
| 472 | dkp2[k * dkstride[3]] += v * qp2[k * qstride[3]]; | |||
| 473 | } | |||
| 474 | } | |||
| 475 | } | |||
| 476 | } | |||
| 477 | } | |||
| 478 | return CCV_NNC_EXEC_SUCCESS; | |||
| 479 | } | |||
| 480 | ||||
| 481 | REGISTER_COMMAND_BACKEND(CCV_NNC_SCALED_DOT_PRODUCT_ATTENTION_FORWARD, CCV_NNC_BACKEND_CPU_REF)void _register_command_CCV_NNC_SCALED_DOT_PRODUCT_ATTENTION_FORWARD_backend_CCV_NNC_BACKEND_CPU_REF(ccv_nnc_cmd_backend_registry_t* const registry) | |||
| 482 | { | |||
| 483 | registry->tensor_formats = CCV_TENSOR_FORMAT_NHWC; | |||
| 484 | registry->tensor_datatypes = CCV_32F; | |||
| 485 | registry->tensor_memory = CCV_TENSOR_CPU_MEMORY; | |||
| 486 | registry->algorithms = 1; | |||
| 487 | registry->exec = _ccv_nnc_scaled_dot_product_attention_forw; | |||
| 488 | } | |||
| 489 | ||||
| 490 | REGISTER_COMMAND_BACKEND(CCV_NNC_SCALED_DOT_PRODUCT_ATTENTION_BACKWARD, CCV_NNC_BACKEND_CPU_REF)void _register_command_CCV_NNC_SCALED_DOT_PRODUCT_ATTENTION_BACKWARD_backend_CCV_NNC_BACKEND_CPU_REF(ccv_nnc_cmd_backend_registry_t* const registry) | |||
| 491 | { | |||
| 492 | registry->tensor_formats = CCV_TENSOR_FORMAT_NHWC; | |||
| 493 | registry->tensor_datatypes = CCV_32F; | |||
| 494 | registry->tensor_memory = CCV_TENSOR_CPU_MEMORY; | |||
| 495 | registry->algorithms = 1; | |||
| 496 | registry->exec = _ccv_nnc_scaled_dot_product_attention_back; | |||
| 497 | } |