/home/liu/actions-runner/_work/ccv/ccv/lib/nnc/cmd/blas/ccv_nnc_segmented_gemm_cpu_ref.c
Line | Count | Source |
1 | | #include "ccv.h" |
2 | | #include "ccv_internal.h" |
3 | | #include "nnc/ccv_nnc.h" |
4 | | #include "nnc/ccv_nnc_easy.h" |
5 | | #include "nnc/ccv_nnc_internal.h" |
6 | | #ifdef USE_OPENMP |
7 | | #include <omp.h> |
8 | | #endif |
9 | | #ifdef USE_DISPATCH |
10 | | #include <dispatch/dispatch.h> |
11 | | #endif |
12 | | |
13 | | static inline void _ccv_nnc_segmented_bmm_and_bias(const float* const a, const float* const w, const float* const bias, float* const b, const int b_batch_size, const int a_batch_inc, const int w_batch_inc, const int bias_batch_inc, const int b_batch_inc, const int b_rows, const int b_cols, const int a_cols, const int a_cols_inc, const int w_cols_inc, const int bias_cols_inc, const int b_cols_inc, const int a_rows_inc, const int w_rows_inc, const int bias_rows_inc, const int b_rows_inc, const int* const indices, const int* const counts, const int bincount) |
14 | 4 | { |
15 | 4 | assert(b_batch_size == 1); |
16 | 4 | int n, i; |
17 | 4 | int off = 0; |
18 | 16 | for (n = 0; n < bincount; n++12 ) |
19 | 12 | { |
20 | 12 | if (indices[n] < 0) |
21 | 0 | continue; |
22 | 12 | const float* const ap = a + off * a_rows_inc; |
23 | 12 | const float* const wp = w + indices[n] * w_batch_inc; |
24 | 12 | const float* const biasp = bias + indices[n] * bias_batch_inc; |
25 | 12 | float* const bp = b + off * b_rows_inc; |
26 | 12 | const int rowcount = counts[n]; |
27 | 12 | off += rowcount; |
28 | 332 | for (i = 0; i < rowcount; i++320 ) |
29 | 320 | { |
30 | 320 | const float* const api = ap + i * a_rows_inc; |
31 | 320 | const float* const biaspi = biasp; |
32 | 320 | float* const bpi = bp + i * b_rows_inc; |
33 | 20.4k | parallel_for320 (j, b_cols) { |
34 | 20.4k | float v = biaspi[j * bias_cols_inc]; |
35 | 20.4k | const float* const wpj = wp + j * w_cols_inc; |
36 | 20.4k | int k; |
37 | 2.64M | for (k = 0; k < a_cols; k++2.62M ) |
38 | 2.62M | v += wpj[k * w_rows_inc] * api[k * a_cols_inc]; |
39 | 20.4k | bpi[j * b_cols_inc] = v; |
40 | 20.4k | } parallel_endfor |
41 | 320 | } |
42 | 12 | } |
43 | 4 | } |
44 | | |
45 | | static inline void _ccv_nnc_segmented_gbmm_and_bias(const float* const a, const int a_nd, const int* const adim, const int* const astride, const int* const indices, const int* const counts, const int bincount, const float* const w, const int w_nd, const int* const wdim, const int* const wstride, const float* const bias, const int bias_nd, const int* const biasdim, const int* const biasstride, float* const b, const int b_nd, const int* const bdim, const int* const bstride, const int b_batch_size, const int a_batch_inc, const int w_batch_inc, const int bias_batch_inc, const int b_batch_inc, const int b_rows, const int b_cols, const int a_cols, const int a_cols_inc, const int w_cols_inc, const int bias_cols_inc, const int b_cols_inc, const int a_rows_inc, const int w_rows_inc, const int bias_rows_inc, const int b_rows_inc) |
46 | 4 | { |
47 | 4 | if (b_nd <= 3) |
48 | 4 | { |
49 | 4 | _ccv_nnc_segmented_bmm_and_bias(a, w, bias, b, b_batch_size, a_batch_inc, w_batch_inc, bias_batch_inc, b_batch_inc, b_rows, b_cols, a_cols, a_cols_inc, w_cols_inc, bias_cols_inc, b_cols_inc, a_rows_inc, w_rows_inc, bias_rows_inc, b_rows_inc, indices, counts, bincount); |
50 | 4 | return; |
51 | 4 | } |
52 | 0 | const int dim = bdim[0]; |
53 | 0 | if (a_nd > 3) |
54 | 0 | { assert(adim[0] == 1 || dim == adim[0]); } |
55 | 0 | if (w_nd > 3) |
56 | 0 | { assert(wdim[0] == 1 || dim == wdim[0]); } |
57 | 0 | if (bias_nd > 3) |
58 | 0 | { assert(biasdim[0] == 1 || dim == biasdim[0]); } |
59 | 0 | int i; |
60 | 0 | for (i = 0; i < dim; i++) |
61 | 0 | { |
62 | 0 | _ccv_nnc_segmented_gbmm_and_bias( |
63 | 0 | a_nd > 3 ? a + i * astride[0] : a, a_nd > 3 ? a_nd - 1 : a_nd, a_nd > 3 ? adim + 1 : adim, a_nd > 3 ? astride + 1 : astride, |
64 | 0 | indices, counts, bincount, |
65 | 0 | w_nd > 3 ? w + i * wstride[0] : w, w_nd > 3 ? w_nd - 1 : w_nd, w_nd > 3 ? wdim + 1 : wdim, w_nd > 3 ? wstride + 1 : wstride, |
66 | 0 | bias_nd > 3 ? bias + i * biasstride[0] : bias, bias_nd > 3 ? bias_nd - 1 : bias_nd, bias_nd > 3 ? biasdim + 1 : biasdim, bias_nd > 3 ? biasstride + 1 : biasstride, |
67 | 0 | b + i * bstride[0], b_nd - 1, bdim + 1, bstride + 1, b_batch_size, a_batch_inc, w_batch_inc, bias_batch_inc, b_batch_inc, b_rows, b_cols, a_cols, a_cols_inc, w_cols_inc, bias_cols_inc, b_cols_inc, a_rows_inc, w_rows_inc, bias_rows_inc, b_rows_inc); |
68 | 0 | } |
69 | 0 | } |
70 | | |
71 | | static inline void _ccv_nnc_segmented_bmm(const float* const a, const float* const w, float* const b, const int b_batch_size, const int a_batch_inc, const int w_batch_inc, const int b_batch_inc, const int b_rows, const int b_cols, const int a_cols, const int a_cols_inc, const int w_cols_inc, const int b_cols_inc, const int a_rows_inc, const int w_rows_inc, const int b_rows_inc, const int* const indices, const int* const counts, const int bincount) |
72 | 5 | { |
73 | 5 | assert(b_batch_size == 1); |
74 | 5 | int n, i; |
75 | 5 | int off = 0; |
76 | 20 | for (n = 0; n < bincount; n++15 ) |
77 | 15 | { |
78 | 15 | if (indices[n] < 0) |
79 | 0 | continue; |
80 | 15 | const float* const ap = a + off * a_rows_inc; |
81 | 15 | const float* const wp = w + indices[n] * w_batch_inc; |
82 | 15 | float* const bp = b + off * b_rows_inc; |
83 | 15 | const int rowcount = counts[n]; |
84 | 15 | off += rowcount; |
85 | 415 | for (i = 0; i < rowcount; i++400 ) |
86 | 400 | { |
87 | 400 | const float* const api = ap + i * a_rows_inc; |
88 | 400 | float* const bpi = bp + i * b_rows_inc; |
89 | 25.6k | parallel_for400 (j, b_cols) { |
90 | 25.6k | float v = 0; |
91 | 25.6k | const float* const wpj = wp + j * w_cols_inc; |
92 | 25.6k | int k; |
93 | 3.30M | for (k = 0; k < a_cols; k++3.27M ) |
94 | 3.27M | v += wpj[k * w_rows_inc] * api[k * a_cols_inc]; |
95 | 25.6k | bpi[j * b_cols_inc] = v; |
96 | 25.6k | } parallel_endfor |
97 | 400 | } |
98 | 15 | } |
99 | 5 | } |
100 | | |
101 | | static inline void _ccv_nnc_segmented_gbmm(const float* const a, const int a_nd, const int* const adim, const int* const astride, const int* const indices, const int* const counts, const int bincount, const float* const w, const int w_nd, const int* const wdim, const int* const wstride, float* const b, const int b_nd, const int* const bdim, const int* const bstride, const int b_batch_size, const int a_batch_inc, const int w_batch_inc, const int b_batch_inc, const int b_rows, const int b_cols, const int a_cols, const int a_cols_inc, const int w_cols_inc, const int b_cols_inc, const int a_rows_inc, const int w_rows_inc, const int b_rows_inc) |
102 | 5 | { |
103 | 5 | if (b_nd <= 3) |
104 | 5 | { |
105 | 5 | _ccv_nnc_segmented_bmm(a, w, b, b_batch_size, a_batch_inc, w_batch_inc, b_batch_inc, b_rows, b_cols, a_cols, a_cols_inc, w_cols_inc, b_cols_inc, a_rows_inc, w_rows_inc, b_rows_inc, indices, counts, bincount); |
106 | 5 | return; |
107 | 5 | } |
108 | 0 | const int dim = bdim[0]; |
109 | 0 | if (a_nd > 3) |
110 | 0 | { assert(adim[0] == 1 || dim == adim[0]); } |
111 | 0 | if (w_nd > 3) |
112 | 0 | { assert(wdim[0] == 1 || dim == wdim[0]); } |
113 | 0 | int i; |
114 | 0 | for (i = 0; i < dim; i++) |
115 | 0 | { |
116 | 0 | _ccv_nnc_segmented_gbmm( |
117 | 0 | a_nd > 3 ? a + i * astride[0] : a, a_nd > 3 ? a_nd - 1 : a_nd, a_nd > 3 ? adim + 1 : adim, a_nd > 3 ? astride + 1 : astride, |
118 | 0 | indices, counts, bincount, |
119 | 0 | w_nd > 3 ? w + i * wstride[0] : w, w_nd > 3 ? w_nd - 1 : w_nd, w_nd > 3 ? wdim + 1 : wdim, w_nd > 3 ? wstride + 1 : wstride, |
120 | 0 | b + i * bstride[0], b_nd - 1, bdim + 1, bstride + 1, b_batch_size, a_batch_inc, w_batch_inc, b_batch_inc, b_rows, b_cols, a_cols, a_cols_inc, w_cols_inc, b_cols_inc, a_rows_inc, w_rows_inc, b_rows_inc); |
121 | 0 | } |
122 | 0 | } |
123 | | |
124 | | static int _ccv_nnc_segmented_gemm_forw(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint, const int flags, ccv_nnc_tensor_t* const* const inputs, const int input_size, ccv_nnc_tensor_t* const* const outputs, const int output_size, ccv_nnc_stream_context_t* const stream_context) |
125 | 9 | { |
126 | 9 | assert(input_size >= 4); |
127 | 9 | const ccv_nnc_tensor_view_t* a = (const ccv_nnc_tensor_view_t*)inputs[0]; |
128 | 9 | const ccv_nnc_tensor_view_t* indices = (const ccv_nnc_tensor_view_t*)inputs[1]; |
129 | 9 | const ccv_nnc_tensor_view_t* counts = (const ccv_nnc_tensor_view_t*)inputs[2]; |
130 | 9 | const ccv_nnc_tensor_view_t* w = (const ccv_nnc_tensor_view_t*)inputs[3]; |
131 | 9 | const ccv_nnc_tensor_view_t* bias = input_size > 4 ? (const ccv_nnc_tensor_view_t*)inputs[4]4 : 05 ; |
132 | | // Copy the most of parameters, but reshape the dimension of a to a vector. |
133 | 9 | assert(output_size == 1); |
134 | 9 | ccv_nnc_tensor_view_t* b = (ccv_nnc_tensor_view_t*)outputs[0]; |
135 | 9 | assert(!bias || (bias->info.dim[1] == 0 || bias->info.dim[2] == 0 || bias->info.dim[3] == 0)); // It is a 1-d array |
136 | 9 | int a_batch_size, a_rows, a_cols, a_batch_inc, a_rows_inc, a_cols_inc; |
137 | 9 | int w_batch_size, w_rows, w_cols, w_batch_inc, w_rows_inc, w_cols_inc; |
138 | 9 | int b_batch_size, b_rows, b_cols, b_batch_inc, b_rows_inc, b_cols_inc; |
139 | 9 | const static int no_transpose[2] = {}; |
140 | 9 | ccv_nnc_tensor_get_matrix_params(a->info, CCV_IS_TENSOR_VIEW(a) ? a->stride0 : 0, a->info.dim, cmd.info.blas.transpose_a, &a_batch_size, &a_rows, &a_cols, &a_batch_inc, &a_rows_inc, &a_cols_inc); |
141 | 9 | ccv_nnc_tensor_get_matrix_params(w->info, CCV_IS_TENSOR_VIEW(w) ? w->stride0 : 0, w->info.dim, cmd.info.blas.transpose_b, &w_batch_size, &w_rows, &w_cols, &w_batch_inc, &w_rows_inc, &w_cols_inc); |
142 | 9 | ccv_nnc_tensor_get_matrix_params(b->info, CCV_IS_TENSOR_VIEW(b) ? b->stride0 : 0, b->info.dim, no_transpose, &b_batch_size, &b_rows, &b_cols, &b_batch_inc, &b_rows_inc, &b_cols_inc); |
143 | 9 | assert(a_batch_size == 1); // Currently, a cannot be batched (no broadcast support too). |
144 | 9 | assert(a_batch_size == b_batch_size); |
145 | 9 | assert(a_rows == b_rows); |
146 | 9 | assert(a_cols == w_rows); |
147 | 9 | assert(w_cols == b_cols); |
148 | 9 | int astride_from_dim[CCV_NNC_MAX_DIM_ALLOC]; |
149 | 9 | int wstride_from_dim[CCV_NNC_MAX_DIM_ALLOC]; |
150 | 9 | int bstride_from_dim[CCV_NNC_MAX_DIM_ALLOC]; |
151 | 9 | const int* astride; |
152 | 9 | if (CCV_IS_TENSOR_VIEW(a)) |
153 | 0 | astride = a->stride; |
154 | 9 | else { |
155 | 9 | ccv_nnc_tensor_get_stride(a->info.dim, astride_from_dim); |
156 | 9 | astride = astride_from_dim; |
157 | 9 | } |
158 | 9 | const int* wstride; |
159 | 9 | if (CCV_IS_TENSOR_VIEW(w)) |
160 | 0 | wstride = w->stride; |
161 | 9 | else { |
162 | 9 | ccv_nnc_tensor_get_stride(w->info.dim, wstride_from_dim); |
163 | 9 | wstride = wstride_from_dim; |
164 | 9 | } |
165 | 9 | const int* bstride; |
166 | 9 | if (CCV_IS_TENSOR_VIEW(b)) |
167 | 0 | bstride = b->stride; |
168 | 9 | else { |
169 | 9 | ccv_nnc_tensor_get_stride(b->info.dim, bstride_from_dim); |
170 | 9 | bstride = bstride_from_dim; |
171 | 9 | } |
172 | 9 | const int bincount = ccv_nnc_tensor_count(indices->info); |
173 | 9 | assert(ccv_nnc_tensor_count(counts->info) == bincount); |
174 | 9 | assert(CCV_IS_TENSOR_CONTIGUOUS(indices)); |
175 | 9 | assert(CCV_IS_TENSOR_CONTIGUOUS(counts)); |
176 | 9 | if (bias) |
177 | 4 | { |
178 | 4 | int bias_batch_size, bias_rows, bias_cols, bias_batch_inc, bias_rows_inc, bias_cols_inc; |
179 | 4 | const int bias_nd = ccv_nnc_tensor_nd(bias->info.dim); |
180 | 4 | ccv_nnc_tensor_get_matrix_params(bias->info, CCV_IS_TENSOR_VIEW(bias) ? bias->stride0 : 0, bias->info.dim, no_transpose, &bias_batch_size, &bias_rows, &bias_cols, &bias_batch_inc, &bias_rows_inc, &bias_cols_inc); |
181 | 4 | if (bias_nd == 2) // For nd == 2, we expand rows to 1 and assign that to batch. |
182 | 4 | { |
183 | 4 | bias_batch_size = bias_rows; |
184 | 4 | bias_rows = 1; |
185 | 4 | bias_batch_inc = bias_rows_inc; |
186 | 4 | } |
187 | 4 | assert(bias_batch_size == w_batch_size); |
188 | 4 | assert(bias_cols == b_cols); |
189 | 4 | const int* biasstride; |
190 | 4 | int biasstride_from_dim[CCV_NNC_MAX_DIM_ALLOC]; |
191 | 4 | if (CCV_IS_TENSOR_VIEW(bias)) |
192 | 0 | biasstride = bias->stride; |
193 | 4 | else { |
194 | 4 | ccv_nnc_tensor_get_stride(bias->info.dim, biasstride_from_dim); |
195 | 4 | biasstride = biasstride_from_dim; |
196 | 4 | } |
197 | 4 | _ccv_nnc_segmented_gbmm_and_bias(a->data.f32, ccv_nnc_tensor_nd(a->info.dim), a->info.dim, astride, indices->data.i32, counts->data.i32, bincount, w->data.f32, ccv_nnc_tensor_nd(w->info.dim), w->info.dim, wstride, bias->data.f32, ccv_nnc_tensor_nd(bias->info.dim), bias->info.dim, biasstride, b->data.f32, ccv_nnc_tensor_nd(b->info.dim), b->info.dim, bstride, b_batch_size, a_batch_inc, w_batch_inc, bias_batch_inc, b_batch_inc, b_rows, b_cols, a_cols, a_cols_inc, w_cols_inc, bias_cols_inc, b_cols_inc, a_rows_inc, w_rows_inc, bias_rows_inc, b_rows_inc); |
198 | 5 | } else { |
199 | 5 | _ccv_nnc_segmented_gbmm(a->data.f32, ccv_nnc_tensor_nd(a->info.dim), a->info.dim, astride, indices->data.i32, counts->data.i32, bincount, w->data.f32, ccv_nnc_tensor_nd(w->info.dim), w->info.dim, wstride, b->data.f32, ccv_nnc_tensor_nd(b->info.dim), b->info.dim, bstride, b_batch_size, a_batch_inc, w_batch_inc, b_batch_inc, b_rows, b_cols, a_cols, a_cols_inc, w_cols_inc, b_cols_inc, a_rows_inc, w_rows_inc, b_rows_inc); |
200 | 5 | } |
201 | 9 | return CCV_NNC_EXEC_SUCCESS; |
202 | 9 | } |
203 | | |
204 | | static int _ccv_nnc_segmented_gemm_back(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint, const int flags, ccv_nnc_tensor_t* const* const inputs, const int input_size, ccv_nnc_tensor_t* const* const outputs, const int output_size, ccv_nnc_stream_context_t* const stream_context) |
205 | 0 | { |
206 | 0 | return CCV_NNC_EXEC_INVALID; |
207 | 0 | } |
208 | | |
209 | | REGISTER_COMMAND_BACKEND(CCV_NNC_SEGMENTED_GEMM_FORWARD, CCV_NNC_BACKEND_CPU_REF)(ccv_nnc_cmd_backend_registry_t* const registry) |
210 | 1 | { |
211 | 1 | registry->tensor_formats = CCV_TENSOR_FORMAT_NHWC | CCV_TENSOR_FORMAT_NCHW; |
212 | 1 | registry->tensor_datatypes = CCV_32F | CCV_32S; |
213 | 1 | registry->tensor_memory = CCV_TENSOR_CPU_MEMORY; |
214 | 1 | registry->algorithms = 1; |
215 | 1 | registry->exec = _ccv_nnc_segmented_gemm_forw; |
216 | 1 | } |
217 | | |
218 | | REGISTER_COMMAND_BACKEND(CCV_NNC_SEGMENTED_GEMM_BACKWARD, CCV_NNC_BACKEND_CPU_REF)(ccv_nnc_cmd_backend_registry_t* const registry) |
219 | 1 | { |
220 | 1 | registry->tensor_formats = CCV_TENSOR_FORMAT_NHWC | CCV_TENSOR_FORMAT_NCHW; |
221 | 1 | registry->tensor_datatypes = CCV_32F | CCV_32S; |
222 | 1 | registry->tensor_memory = CCV_TENSOR_CPU_MEMORY; |
223 | 1 | registry->algorithms = 1; |
224 | 1 | registry->exec = _ccv_nnc_segmented_gemm_back; |
225 | 1 | } |