/home/liu/actions-runner/_work/ccv/ccv/lib/nnc/cmd/blas/cpu_sys/_ccv_nnc_gemm_cpu_sys.c
Line | Count | Source |
1 | | #include "ccv.h" |
2 | | #include "ccv_internal.h" |
3 | | #include "nnc/ccv_nnc.h" |
4 | | #include "nnc/ccv_nnc_easy.h" |
5 | | #include "nnc/ccv_nnc_internal.h" |
6 | | #include "../_ccv_nnc_gemm_cpu_opt.h" |
7 | | #if HAVE_ACCELERATE_FRAMEWORK |
8 | | #include <Accelerate/Accelerate.h> |
9 | | #elif HAVE_CBLAS |
10 | | #include <cblas.h> |
11 | | #endif |
12 | | |
13 | | static inline void _ccv_nnc_gbmm_and_bias(const float* const ones, const float* const a, const int a_nd, const int* const adim, const int* const astride, const float* const w, const int w_nd, const int* const wdim, const int* const wstride, const float* const bias, const int bias_nd, const int* const biasdim, const int* const biasstride, float* const b, const int b_nd, const int* const bdim, const int* const bstride, const int b_batch_size, const int transa, const int transb, const int lda_inc, const int ldb_inc, const int a_batch_inc, const int w_batch_inc, const int bias_batch_inc, const int b_batch_inc, const int b_rows, const int b_cols, const int a_cols, const int bias_rows_inc, const int b_rows_inc) |
14 | 752 | { |
15 | 752 | int i; |
16 | 752 | if (b_nd <= 3) |
17 | 750 | { |
18 | 1.51k | for (i = 0; i < b_batch_size; i++764 ) |
19 | 764 | { |
20 | 764 | cblas_sgemm(CblasColMajor, CblasNoTrans, CblasNoTrans, b_cols, b_rows, 1, 1.0, bias + i * bias_batch_inc, bias_rows_inc, ones, 1, 0.0, b + i * b_batch_inc, b_rows_inc); |
21 | 764 | cblas_sgemm(CblasColMajor, transa, transb, b_cols, b_rows, a_cols, 1.0, w + i * w_batch_inc, lda_inc, a + i * a_batch_inc, ldb_inc, 1.0, b + i * b_batch_inc, b_rows_inc); |
22 | 764 | } |
23 | 750 | return; |
24 | 750 | } |
25 | 2 | const int dim = bdim[0]; |
26 | 2 | if (a_nd > 3) |
27 | 2 | { assert(adim[0] == 1 || dim == adim[0]); } |
28 | 2 | if (w_nd > 3) |
29 | 1 | { assert(wdim[0] == 1 || dim == wdim[0]); } |
30 | 2 | if (bias_nd > 3) |
31 | 0 | { assert(biasdim[0] == 1 || dim == biasdim[0]); } |
32 | 6 | for (i = 0; 2 i < dim; i++4 ) |
33 | 4 | { |
34 | 4 | _ccv_nnc_gbmm_and_bias(ones, |
35 | 4 | a_nd > 3 ? a + i * astride[0] : a0 , a_nd > 3 ? a_nd - 1 : a_nd0 , a_nd > 3 ? adim + 1 : adim0 , a_nd > 3 ? astride + 1 : astride0 , |
36 | 4 | w_nd > 3 ? w + i * wstride[0]2 : w2 , w_nd > 3 ? w_nd - 12 : w_nd2 , w_nd > 3 ? wdim + 12 : wdim2 , w_nd > 3 ? wstride + 12 : wstride2 , |
37 | 4 | bias_nd > 3 ? bias + i * biasstride[0]0 : bias, bias_nd > 3 ? bias_nd - 10 : bias_nd, bias_nd > 3 ? biasdim + 10 : biasdim, bias_nd > 3 ? biasstride + 10 : biasstride, |
38 | 4 | b + i * bstride[0], b_nd - 1, bdim + 1, bstride + 1, b_batch_size, transa, transb, lda_inc, ldb_inc, a_batch_inc, w_batch_inc, bias_batch_inc, b_batch_inc, b_rows, b_cols, a_cols, bias_rows_inc, b_rows_inc); |
39 | 4 | } |
40 | 2 | } |
41 | | |
42 | | static inline void _ccv_nnc_gbmm(const float* const a, const int a_nd, const int* const adim, const int* const astride, const float* const w, const int w_nd, const int* const wdim, const int* const wstride, float* const b, const int b_nd, const int* const bdim, const int* const bstride, const int b_batch_size, const int transa, const int transb, const int lda_inc, const int ldb_inc, const int a_batch_inc, const int w_batch_inc, const int b_batch_inc, const int b_rows, const int b_cols, const int a_cols, const int b_rows_inc) |
43 | 1.53k | { |
44 | 1.53k | int i; |
45 | 1.53k | if (b_nd <= 3) |
46 | 1.53k | { |
47 | 3.08k | for (i = 0; i < b_batch_size; i++1.55k ) |
48 | 1.55k | cblas_sgemm(CblasColMajor, transa, transb, b_cols, b_rows, a_cols, 1.0, w + i * w_batch_inc, lda_inc, a + i * a_batch_inc, ldb_inc, 0.0, b + i * b_batch_inc, b_rows_inc); |
49 | 1.53k | return; |
50 | 1.53k | } |
51 | 2 | const int dim = bdim[0]; |
52 | 2 | if (a_nd > 3) |
53 | 2 | { assert(adim[0] == 1 || dim == adim[0]); } |
54 | 2 | if (w_nd > 3) |
55 | 1 | { assert(wdim[0] == 1 || dim == wdim[0]); } |
56 | 6 | for (i = 0; 2 i < dim; i++4 ) |
57 | 4 | { |
58 | 4 | _ccv_nnc_gbmm( |
59 | 4 | a_nd > 3 ? a + i * astride[0] : a0 , a_nd > 3 ? a_nd - 1 : a_nd0 , a_nd > 3 ? adim + 1 : adim0 , a_nd > 3 ? astride + 1 : astride0 , |
60 | 4 | w_nd > 3 ? w + i * wstride[0]2 : w2 , w_nd > 3 ? w_nd - 12 : w_nd2 , w_nd > 3 ? wdim + 12 : wdim2 , w_nd > 3 ? wstride + 12 : wstride2 , |
61 | 4 | b + i * bstride[0], b_nd - 1, bdim + 1, bstride + 1, b_batch_size, transa, transb, lda_inc, ldb_inc, a_batch_inc, w_batch_inc, b_batch_inc, b_rows, b_cols, a_cols, b_rows_inc); |
62 | 4 | } |
63 | 2 | } |
64 | | |
65 | | int _ccv_nnc_gemm_forw_cpu_sys(const int transpose_a[2], const int transpose_b[2], const ccv_nnc_tensor_view_t* const a, const ccv_nnc_tensor_view_t* const w, const ccv_nnc_tensor_view_t* const bias, ccv_nnc_tensor_view_t* const b) |
66 | 2.28k | { |
67 | 2.28k | #if (defined HAVE_CBLAS || defined HAVE_ACCELERATE_FRAMEWORK) |
68 | 2.28k | assert(!bias || (bias->info.dim[1] == 0 || bias->info.dim[2] == 0 || bias->info.dim[3] == 0)); // It is a 1-d array |
69 | 2.28k | int a_batch_size, a_rows, a_cols, a_batch_inc, a_rows_inc, a_cols_inc; |
70 | 2.28k | int w_batch_size, w_rows, w_cols, w_batch_inc, w_rows_inc, w_cols_inc; |
71 | 2.28k | int b_batch_size, b_rows, b_cols, b_batch_inc, b_rows_inc, b_cols_inc; |
72 | 2.28k | const static int no_transpose[2] = {}; |
73 | 2.28k | ccv_nnc_tensor_get_matrix_params(a->info, CCV_IS_TENSOR_VIEW(a) ? a->stride5 : 02.27k , a->info.dim, transpose_a, &a_batch_size, &a_rows, &a_cols, &a_batch_inc, &a_rows_inc, &a_cols_inc); |
74 | 2.28k | ccv_nnc_tensor_get_matrix_params(w->info, CCV_IS_TENSOR_VIEW(w) ? w->stride3 : 02.28k , w->info.dim, transpose_b, &w_batch_size, &w_rows, &w_cols, &w_batch_inc, &w_rows_inc, &w_cols_inc); |
75 | 2.28k | ccv_nnc_tensor_get_matrix_params(b->info, CCV_IS_TENSOR_VIEW(b) ? b->stride9 : 02.27k , b->info.dim, no_transpose, &b_batch_size, &b_rows, &b_cols, &b_batch_inc, &b_rows_inc, &b_cols_inc); |
76 | 2.28k | assert(a_batch_size == b_batch_size); |
77 | 2.28k | assert(a_batch_size == b_batch_size || a_batch_size == 1); |
78 | 2.28k | if (a_batch_size == 1 && b_batch_size > 12.27k ) |
79 | 0 | a_batch_inc = 0; |
80 | 2.28k | assert(w_batch_size == a_batch_size || w_batch_size == 1); |
81 | 2.28k | if (w_batch_size == 1 && b_batch_size > 12.27k ) |
82 | 4 | w_batch_inc = 0; |
83 | 2.28k | assert(a_rows == b_rows); |
84 | 2.28k | assert(a_cols == w_rows); |
85 | 2.28k | assert(w_cols == b_cols); |
86 | 2.28k | const int is_transpose_a = ccv_nnc_is_matrix_transpose(a->info, transpose_a); |
87 | 2.28k | const int is_transpose_w = ccv_nnc_is_matrix_transpose(w->info, transpose_b); |
88 | 2.28k | int astride_from_dim[CCV_NNC_MAX_DIM_ALLOC]; |
89 | 2.28k | int wstride_from_dim[CCV_NNC_MAX_DIM_ALLOC]; |
90 | 2.28k | int bstride_from_dim[CCV_NNC_MAX_DIM_ALLOC]; |
91 | 2.28k | const int* astride; |
92 | 2.28k | if (CCV_IS_TENSOR_VIEW(a)) |
93 | 5 | astride = a->stride; |
94 | 2.27k | else { |
95 | 2.27k | ccv_nnc_tensor_get_stride(a->info.dim, astride_from_dim); |
96 | 2.27k | astride = astride_from_dim; |
97 | 2.27k | } |
98 | 2.28k | const int* wstride; |
99 | 2.28k | if (CCV_IS_TENSOR_VIEW(w)) |
100 | 3 | wstride = w->stride; |
101 | 2.28k | else { |
102 | 2.28k | ccv_nnc_tensor_get_stride(w->info.dim, wstride_from_dim); |
103 | 2.28k | wstride = wstride_from_dim; |
104 | 2.28k | } |
105 | 2.28k | const int* bstride; |
106 | 2.28k | if (CCV_IS_TENSOR_VIEW(b)) |
107 | 9 | bstride = b->stride; |
108 | 2.27k | else { |
109 | 2.27k | ccv_nnc_tensor_get_stride(b->info.dim, bstride_from_dim); |
110 | 2.27k | bstride = bstride_from_dim; |
111 | 2.27k | } |
112 | 2.28k | const int transa = is_transpose_w ? CblasTrans2.27k : CblasNoTrans10 ; |
113 | 2.28k | const int transb = is_transpose_a ? CblasTrans2 : CblasNoTrans2.28k ; |
114 | 2.28k | const int lda_inc = is_transpose_w ? w_cols_inc2.27k : w_rows_inc10 ; |
115 | 2.28k | const int ldb_inc = is_transpose_a ? a_cols_inc2 : a_rows_inc2.28k ; |
116 | 2.28k | if (bias) |
117 | 748 | { |
118 | 748 | float* const ones = (float*)ccmalloc(sizeof(float) * b_rows); |
119 | 748 | int i; |
120 | 1.55k | for (i = 0; i < b_rows; i++802 ) |
121 | 802 | ones[i] = 1; |
122 | 748 | int bias_batch_size, bias_rows, bias_cols, bias_batch_inc, bias_rows_inc, bias_cols_inc; |
123 | 748 | ccv_nnc_tensor_get_matrix_params(bias->info, CCV_IS_TENSOR_VIEW(bias) ? bias->stride0 : 0, bias->info.dim, no_transpose, &bias_batch_size, &bias_rows, &bias_cols, &bias_batch_inc, &bias_rows_inc, &bias_cols_inc); |
124 | 748 | assert(bias_batch_size == b_batch_size || bias_batch_size == 1); |
125 | 748 | if (bias_batch_size == 1 && b_batch_size > 1747 ) |
126 | 3 | bias_batch_inc = 0; |
127 | 748 | assert(bias_cols == b_cols); |
128 | 748 | const int* biasstride; |
129 | 748 | int biasstride_from_dim[CCV_NNC_MAX_DIM_ALLOC]; |
130 | 748 | if (CCV_IS_TENSOR_VIEW(bias)) |
131 | 0 | biasstride = bias->stride; |
132 | 748 | else { |
133 | 748 | ccv_nnc_tensor_get_stride(bias->info.dim, biasstride_from_dim); |
134 | 748 | biasstride = biasstride_from_dim; |
135 | 748 | } |
136 | 748 | _ccv_nnc_gbmm_and_bias(ones, a->data.f32, ccv_nnc_tensor_nd(a->info.dim), a->info.dim, astride, w->data.f32, ccv_nnc_tensor_nd(w->info.dim), w->info.dim, wstride, bias->data.f32, ccv_nnc_tensor_nd(bias->info.dim), bias->info.dim, biasstride, b->data.f32, ccv_nnc_tensor_nd(b->info.dim), b->info.dim, bstride, b_batch_size, transa, transb, lda_inc, ldb_inc, a_batch_inc, w_batch_inc, bias_batch_inc, b_batch_inc, b_rows, b_cols, a_cols, bias_rows_inc, b_rows_inc); |
137 | 748 | ccfree(ones); |
138 | 1.53k | } else { |
139 | 1.53k | _ccv_nnc_gbmm(a->data.f32, ccv_nnc_tensor_nd(a->info.dim), a->info.dim, astride, w->data.f32, ccv_nnc_tensor_nd(w->info.dim), w->info.dim, wstride, b->data.f32, ccv_nnc_tensor_nd(b->info.dim), b->info.dim, bstride, b_batch_size, transa, transb, lda_inc, ldb_inc, a_batch_inc, w_batch_inc, b_batch_inc, b_rows, b_cols, a_cols, b_rows_inc); |
140 | 1.53k | } |
141 | 2.28k | return CCV_NNC_EXEC_SUCCESS; |
142 | | #else |
143 | | return CCV_NNC_EXEC_INVALID; |
144 | | #endif |
145 | 2.28k | } |
146 | | |
147 | | int _ccv_nnc_gemm_back_cpu_sys(const int transpose_a[2], const int transpose_b[2], const ccv_nnc_tensor_view_t* const g, const ccv_nnc_tensor_view_t* const a, const ccv_nnc_tensor_view_t* const w, ccv_nnc_tensor_view_t* const dw, ccv_nnc_tensor_view_t* const bias, ccv_nnc_tensor_view_t* const h, const int flags) |
148 | 4.81k | { |
149 | 4.81k | #if (defined HAVE_CBLAS || defined HAVE_ACCELERATE_FRAMEWORK) |
150 | | // inputs: gradient, forw prop input, [w] |
151 | | // outputs: [output gradient], weight updates, bias updates |
152 | 4.81k | assert(!bias || (bias->info.dim[1] == 0 || bias->info.dim[2] == 0 || bias->info.dim[3] == 0)); // It is a 2-d or 3-d array. |
153 | 4.81k | int g_batch_size, g_rows, g_cols, g_batch_inc, g_rows_inc, g_cols_inc; |
154 | 4.81k | const static int no_transpose[2] = {}; |
155 | 4.81k | ccv_nnc_tensor_get_matrix_params(g->info, CCV_IS_TENSOR_VIEW(g) ? g->stride0 : 0, g->info.dim, no_transpose, &g_batch_size, &g_rows, &g_cols, &g_batch_inc, &g_rows_inc, &g_cols_inc); |
156 | 4.81k | int i; |
157 | 4.81k | if (bias) |
158 | 3.01k | { |
159 | 3.01k | int bias_batch_size, bias_rows, bias_cols, bias_batch_inc, bias_rows_inc, bias_cols_inc; |
160 | 3.01k | ccv_nnc_tensor_get_matrix_params(bias->info, CCV_IS_TENSOR_VIEW(bias) ? bias->stride0 : 0, bias->info.dim, no_transpose, &bias_batch_size, &bias_rows, &bias_cols, &bias_batch_inc, &bias_rows_inc, &bias_cols_inc); |
161 | 3.01k | assert(bias_cols == g_cols); |
162 | 3.01k | assert(bias_batch_size == 1 || bias_batch_size == g_batch_size); |
163 | 3.01k | if (bias_batch_size == 1 && g_batch_size > 13.01k ) |
164 | 3 | bias_batch_inc = 0; |
165 | 3.01k | float* const ones = (float*)ccmalloc(sizeof(float) * g_rows); |
166 | 6.12k | for (i = 0; i < g_rows; i++3.10k ) |
167 | 3.10k | ones[i] = 1; |
168 | 3.01k | if (g_batch_size > 1 && bias_batch_size == g_batch_size5 ) |
169 | 2 | { |
170 | 6 | for (i = 0; i < g_batch_size; i++4 ) |
171 | 4 | cblas_sgemm(CblasColMajor, CblasNoTrans, CblasNoTrans, bias_cols, bias_rows, g_rows, 1.0, g->data.f32 + i * g_batch_inc, g_rows_inc, ones, g_rows, 0.0, bias->data.f32 + i * bias_batch_inc, bias_rows_inc); |
172 | 3.01k | } else { |
173 | 3.01k | cblas_sgemm(CblasColMajor, CblasNoTrans, CblasNoTrans, bias_cols, bias_rows, g_rows, 1.0, g->data.f32, g_rows_inc, ones, g_rows, 0.0, bias->data.f32, bias_rows_inc); |
174 | | // We cannot use strided batched alternative because on write, the data could race to the same position |
175 | 3.01k | for (i = 1; i < g_batch_size; i++3 ) |
176 | 3 | cblas_sgemm(CblasColMajor, CblasNoTrans, CblasNoTrans, bias_cols, bias_rows, g_rows, 1.0, g->data.f32 + i * g_batch_inc, g_rows_inc, ones, g_rows, 1.0, bias->data.f32, bias_rows_inc); |
177 | 3.01k | } |
178 | 3.01k | ccfree(ones); |
179 | 3.01k | } |
180 | 4.81k | if (dw) |
181 | 4.70k | { |
182 | 4.70k | const int is_transpose_a = ccv_nnc_is_matrix_transpose(a->info, transpose_a); |
183 | 4.70k | const int is_transpose_w = ccv_nnc_is_matrix_transpose(dw->info, transpose_b); |
184 | 4.70k | int a_batch_size, a_rows, a_cols, a_batch_inc, a_rows_inc, a_cols_inc; |
185 | 4.70k | int dw_batch_size, dw_rows, dw_cols, dw_batch_inc, dw_rows_inc, dw_cols_inc; |
186 | 4.70k | ccv_nnc_tensor_get_matrix_params(a->info, CCV_IS_TENSOR_VIEW(a) ? a->stride0 : 0, a->info.dim, transpose_a, &a_batch_size, &a_rows, &a_cols, &a_batch_inc, &a_rows_inc, &a_cols_inc); |
187 | 4.70k | ccv_nnc_tensor_get_matrix_params(dw->info, CCV_IS_TENSOR_VIEW(dw) ? dw->stride0 : 0, dw->info.dim, transpose_b, &dw_batch_size, &dw_rows, &dw_cols, &dw_batch_inc, &dw_rows_inc, &dw_cols_inc); |
188 | 4.70k | assert(a_rows == g_rows); |
189 | 4.70k | assert(a_cols == dw_rows); |
190 | 4.70k | assert(dw_cols == g_cols); |
191 | 4.70k | assert(a_batch_size == g_batch_size || a_batch_size == 1); |
192 | 4.70k | if (a_batch_size == 1 && g_batch_size > 14.70k ) |
193 | 0 | a_batch_inc = 0; |
194 | 4.70k | assert(dw_batch_size == g_batch_size || dw_batch_size == 1); |
195 | 4.70k | if (dw_batch_size == 1 && g_batch_size > 14.70k ) |
196 | 3 | dw_batch_inc = 0; |
197 | 4.70k | if (g_batch_size > 1 && g_batch_size == dw_batch_size5 ) |
198 | 2 | { |
199 | 2 | if (is_transpose_w) |
200 | 1 | { |
201 | 1 | const int transa = is_transpose_a ? CblasTrans0 : CblasNoTrans; |
202 | 1 | const int lda_inc = is_transpose_a ? a_cols_inc0 : a_rows_inc; |
203 | 3 | for (i = 0; i < g_batch_size; i++2 ) |
204 | 2 | cblas_sgemm(CblasColMajor, transa, CblasTrans, dw_rows, dw_cols, a_rows, 1.0, a->data.f32 + i * a_batch_inc, lda_inc, g->data.f32 + i * g_batch_inc, g_rows_inc, 0.0, dw->data.f32 + i * dw_batch_inc, dw_cols_inc); |
205 | 1 | } else { |
206 | 1 | const int transb = is_transpose_a ? CblasNoTrans0 : CblasTrans; |
207 | 1 | const int ldb_inc = is_transpose_a ? a_cols_inc0 : a_rows_inc; |
208 | 3 | for (i = 0; i < g_batch_size; i++2 ) |
209 | 2 | cblas_sgemm(CblasColMajor, CblasNoTrans, transb, dw_cols, dw_rows, a_rows, 1.0, g->data.f32 + i * g_batch_inc, g_rows_inc, a->data.f32 + i * a_batch_inc, ldb_inc, 0.0, dw->data.f32 + i * dw_batch_inc, dw_rows_inc); |
210 | 1 | } |
211 | 4.70k | } else { |
212 | 4.70k | if (is_transpose_w) |
213 | 3.69k | { |
214 | 3.69k | const int transa = is_transpose_a ? CblasTrans2 : CblasNoTrans3.69k ; |
215 | 3.69k | const int lda_inc = is_transpose_a ? a_cols_inc2 : a_rows_inc3.69k ; |
216 | 3.69k | cblas_sgemm(CblasColMajor, transa, CblasTrans, dw_rows, dw_cols, a_rows, 1.0, a->data.f32, lda_inc, g->data.f32, g_rows_inc, 0.0, dw->data.f32, dw_cols_inc); |
217 | 3.69k | for (i = 1; i < g_batch_size; i++1 ) |
218 | 1 | cblas_sgemm(CblasColMajor, transa, CblasTrans, dw_rows, dw_cols, a_rows, 1.0, a->data.f32 + i * a_batch_inc, lda_inc, g->data.f32 + i * g_batch_inc, g_rows_inc, 1.0, dw->data.f32, dw_cols_inc); |
219 | 3.69k | } else { |
220 | 1.01k | const int transb = is_transpose_a ? CblasNoTrans2 : CblasTrans1.00k ; |
221 | 1.01k | const int ldb_inc = is_transpose_a ? a_cols_inc2 : a_rows_inc1.00k ; |
222 | 1.01k | cblas_sgemm(CblasColMajor, CblasNoTrans, transb, dw_cols, dw_rows, a_rows, 1.0, g->data.f32, g_rows_inc, a->data.f32, ldb_inc, 0.0, dw->data.f32, dw_rows_inc); |
223 | 1.01k | for (i = 1; i < g_batch_size; i++2 ) |
224 | 2 | cblas_sgemm(CblasColMajor, CblasNoTrans, transb, dw_cols, dw_rows, a_rows, 1.0, g->data.f32 + i * g_batch_inc, g_rows_inc, a->data.f32 + i * a_batch_inc, ldb_inc, 1.0, dw->data.f32, dw_rows_inc); |
225 | 1.01k | } |
226 | 4.70k | } |
227 | 4.70k | } |
228 | 4.81k | if (h) |
229 | 3.45k | { |
230 | 3.45k | const int is_transpose_h = ccv_nnc_is_matrix_transpose(h->info, transpose_a); |
231 | 3.45k | const int is_transpose_w = ccv_nnc_is_matrix_transpose(w->info, transpose_b); |
232 | 3.45k | int h_batch_size, h_rows, h_cols, h_batch_inc, h_rows_inc, h_cols_inc; |
233 | 3.45k | int w_batch_size, w_rows, w_cols, w_batch_inc, w_rows_inc, w_cols_inc; |
234 | 3.45k | ccv_nnc_tensor_get_matrix_params(h->info, CCV_IS_TENSOR_VIEW(h) ? h->stride0 : 0, h->info.dim, transpose_a, &h_batch_size, &h_rows, &h_cols, &h_batch_inc, &h_rows_inc, &h_cols_inc); |
235 | 3.45k | ccv_nnc_tensor_get_matrix_params(w->info, CCV_IS_TENSOR_VIEW(w) ? w->stride0 : 0, w->info.dim, transpose_b, &w_batch_size, &w_rows, &w_cols, &w_batch_inc, &w_rows_inc, &w_cols_inc); |
236 | 3.45k | assert(h_rows == g_rows); |
237 | 3.45k | assert(h_cols == w_rows); |
238 | 3.45k | assert(w_cols == g_cols); |
239 | 3.45k | assert(h_batch_size == g_batch_size || h_batch_size == 1); |
240 | 3.45k | if (h_batch_size == 1 && g_batch_size > 13.44k ) |
241 | 0 | h_batch_inc = 0; |
242 | 3.45k | assert(w_batch_size == g_batch_size || w_batch_size == 1); |
243 | 3.45k | if (w_batch_size == 1 && g_batch_size > 13.45k ) |
244 | 3 | w_batch_inc = 0; |
245 | 3.45k | if (g_batch_size > 1 && g_batch_size == h_batch_size5 ) |
246 | 5 | { |
247 | 5 | if (is_transpose_h) |
248 | 2 | { |
249 | 2 | const int transb = is_transpose_w ? CblasTrans1 : CblasNoTrans1 ; |
250 | 2 | const int ldb_inc = is_transpose_w ? w_cols_inc1 : w_rows_inc1 ; |
251 | 6 | for (i = 0; i < g_batch_size; i++4 ) |
252 | 4 | cblas_sgemm(CblasColMajor, CblasTrans, transb, h_rows, h_cols, g_cols, 1.0, g->data.f32 + i * g_batch_inc, g_rows_inc, w->data.f32 + i * w_batch_inc, ldb_inc, 0.0, h->data.f32 + i * h_batch_inc, h_cols_inc); |
253 | 3 | } else { |
254 | 3 | const int transa = is_transpose_w ? CblasNoTrans1 : CblasTrans2 ; |
255 | 3 | const int lda_inc = is_transpose_w ? w_cols_inc1 : w_rows_inc2 ; |
256 | 9 | for (i = 0; i < g_batch_size; i++6 ) |
257 | 6 | cblas_sgemm(CblasColMajor, transa, CblasNoTrans, h_cols, h_rows, g_cols, 1.0, w->data.f32 + i * w_batch_inc, lda_inc, g->data.f32 + i * g_batch_inc, g_rows_inc, 0.0, h->data.f32 + i * h_batch_inc, h_rows_inc); |
258 | 3 | } |
259 | 3.44k | } else { |
260 | 3.44k | if (is_transpose_h) |
261 | 2 | { |
262 | 2 | const int transb = is_transpose_w ? CblasTrans1 : CblasNoTrans1 ; |
263 | 2 | const int ldb_inc = is_transpose_w ? w_cols_inc1 : w_rows_inc1 ; |
264 | 2 | cblas_sgemm(CblasColMajor, CblasTrans, transb, h_rows, h_cols, g_cols, 1.0, g->data.f32, g_rows_inc, w->data.f32, ldb_inc, 0.0, h->data.f32, h_cols_inc); |
265 | 2 | for (i = 1; i < g_batch_size; i++0 ) |
266 | 0 | cblas_sgemm(CblasColMajor, CblasTrans, transb, h_rows, h_cols, g_cols, 1.0, g->data.f32 + i * g_batch_inc, g_rows_inc, w->data.f32 + i * w_batch_inc, ldb_inc, 1.0, h->data.f32, h_cols_inc); |
267 | 3.44k | } else { |
268 | 3.44k | const int transa = is_transpose_w ? CblasNoTrans2.43k : CblasTrans1.00k ; |
269 | 3.44k | const int lda_inc = is_transpose_w ? w_cols_inc2.43k : w_rows_inc1.00k ; |
270 | 3.44k | cblas_sgemm(CblasColMajor, transa, CblasNoTrans, h_cols, h_rows, g_cols, 1.0, w->data.f32, lda_inc, g->data.f32, g_rows_inc, 0.0, h->data.f32, h_rows_inc); |
271 | 3.44k | for (i = 1; i < g_batch_size; i++0 ) |
272 | 0 | cblas_sgemm(CblasColMajor, transa, CblasNoTrans, h_cols, h_rows, g_cols, 1.0, w->data.f32 + i * w_batch_inc, lda_inc, g->data.f32 + i * g_batch_inc, g_rows_inc, 1.0, h->data.f32, h_rows_inc); |
273 | 3.44k | } |
274 | 3.44k | } |
275 | 3.45k | } |
276 | 4.81k | return CCV_NNC_EXEC_SUCCESS; |
277 | | #else |
278 | | return CCV_NNC_EXEC_INVALID; |
279 | | #endif |
280 | 4.81k | } |