Coverage Report

Created: 2024-08-18 16:21

/home/liu/actions-runner/_work/ccv/ccv/lib/nnc/cmd/blas/ccv_nnc_gemm_cpu_ref.c
Line
Count
Source (jump to first uncovered line)
1
#include "ccv.h"
2
#include "ccv_internal.h"
3
#include "nnc/ccv_nnc.h"
4
#include "nnc/ccv_nnc_easy.h"
5
#include "nnc/ccv_nnc_internal.h"
6
#ifdef USE_OPENMP
7
#include <omp.h>
8
#endif
9
#ifdef USE_DISPATCH
10
#include <dispatch/dispatch.h>
11
#endif
12
13
static inline void _ccv_nnc_bmm_and_bias(const float* const a, const float* const w, const float* const bias, float* const b, const int b_batch_size, const int a_batch_inc, const int w_batch_inc, const int bias_batch_inc, const int b_batch_inc, const int b_rows, const int b_cols, const int a_cols, const int a_cols_inc, const int w_cols_inc, const int bias_cols_inc, const int b_cols_inc, const int a_rows_inc, const int w_rows_inc, const int bias_rows_inc, const int b_rows_inc)
14
11.8k
{
15
11.8k
  int n, i;
16
23.7k
  for (n = 0; n < b_batch_size; 
n++11.9k
)
17
11.9k
  {
18
11.9k
    const float* const ap = a + n * a_batch_inc;
19
11.9k
    const float* const wp = w + n * w_batch_inc;
20
11.9k
    const float* const biasp = bias + n * bias_batch_inc;
21
11.9k
    float* const bp = b + n * b_batch_inc;
22
34.6k
    for (i = 0; i < b_rows; 
i++22.7k
)
23
22.7k
    {
24
22.7k
      const float* const api = ap + i * a_rows_inc;
25
22.7k
      const float* const biaspi = biasp + i * bias_rows_inc;
26
22.7k
      float* const bpi = bp + i * b_rows_inc;
27
3.28M
      
parallel_for22.7k
(j, b_cols) {
28
3.28M
        float v = biaspi[j * bias_cols_inc];
29
3.28M
        const float* const wpj = wp + j * w_cols_inc;
30
3.28M
        int k;
31
3.10G
        for (k = 0; k < a_cols; 
k++3.09G
)
32
3.09G
          v += wpj[k * w_rows_inc] * api[k * a_cols_inc];
33
3.28M
        bpi[j * b_cols_inc] = v;
34
3.28M
      } parallel_endfor
35
22.7k
    }
36
11.9k
  }
37
11.8k
}
38
39
static inline void _ccv_nnc_gbmm_and_bias(const float* const a, const int a_nd, const int* const adim, const int* const astride, const float* const w, const int w_nd, const int* const wdim, const int* const wstride, const float* const bias, const int bias_nd, const int* const biasdim, const int* const biasstride, float* const b, const int b_nd, const int* const bdim, const int* const bstride, const int b_batch_size, const int a_batch_inc, const int w_batch_inc, const int bias_batch_inc, const int b_batch_inc, const int b_rows, const int b_cols, const int a_cols, const int a_cols_inc, const int w_cols_inc, const int bias_cols_inc, const int b_cols_inc, const int a_rows_inc, const int w_rows_inc, const int bias_rows_inc, const int b_rows_inc)
40
11.8k
{
41
11.8k
  if (b_nd <= 3)
42
11.8k
  {
43
11.8k
    _ccv_nnc_bmm_and_bias(a, w, bias, b, b_batch_size, a_batch_inc, w_batch_inc, bias_batch_inc, b_batch_inc, b_rows, b_cols, a_cols, a_cols_inc, w_cols_inc, bias_cols_inc, b_cols_inc, a_rows_inc, w_rows_inc, bias_rows_inc, b_rows_inc);
44
11.8k
    return;
45
11.8k
  }
46
6
  const int dim = bdim[0];
47
6
  if (a_nd > 3)
48
6
    { assert(adim[0] == 1 || dim == adim[0]); }
49
6
  if (w_nd > 3)
50
3
    { assert(wdim[0] == 1 || dim == wdim[0]); }
51
6
  if (bias_nd > 3)
52
0
    { assert(biasdim[0] == 1 || dim == biasdim[0]); }
53
6
  int i;
54
18
  for (i = 0; i < dim; 
i++12
)
55
12
  {
56
12
    _ccv_nnc_gbmm_and_bias(
57
12
      a_nd > 3 ? a + i * astride[0] : 
a0
, a_nd > 3 ? a_nd - 1 :
a_nd0
, a_nd > 3 ? adim + 1 :
adim0
, a_nd > 3 ? astride + 1 :
astride0
,
58
12
      w_nd > 3 ? 
w + i * wstride[0]6
:
w6
, w_nd > 3 ?
w_nd - 16
:
w_nd6
, w_nd > 3 ?
wdim + 16
:
wdim6
, w_nd > 3 ?
wstride + 16
:
wstride6
,
59
12
      bias_nd > 3 ? 
bias + i * biasstride[0]0
: bias, bias_nd > 3 ?
bias_nd - 10
: bias_nd, bias_nd > 3 ?
biasdim + 10
: biasdim, bias_nd > 3 ?
biasstride + 10
: biasstride,
60
12
      b + i * bstride[0], b_nd - 1, bdim + 1, bstride + 1, b_batch_size, a_batch_inc, w_batch_inc, bias_batch_inc, b_batch_inc, b_rows, b_cols, a_cols, a_cols_inc, w_cols_inc, bias_cols_inc, b_cols_inc, a_rows_inc, w_rows_inc, bias_rows_inc, b_rows_inc);
61
12
  }
62
6
}
63
64
static inline void _ccv_nnc_bmm(const float* const a, const float* const w, float* const b, const int b_batch_size, const int a_batch_inc, const int w_batch_inc, const int b_batch_inc, const int b_rows, const int b_cols, const int a_cols, const int a_cols_inc, const int w_cols_inc, const int b_cols_inc, const int a_rows_inc, const int w_rows_inc, const int b_rows_inc)
65
3.96k
{
66
3.96k
  int n, i;
67
9.32k
  for (n = 0; n < b_batch_size; 
n++5.36k
)
68
5.36k
  {
69
5.36k
    const float* const ap = a + n * a_batch_inc;
70
5.36k
    const float* const wp = w + n * w_batch_inc;
71
5.36k
    float* const bp = b + n * b_batch_inc;
72
206k
    for (i = 0; i < b_rows; 
i++201k
)
73
201k
    {
74
201k
      const float* const api = ap + i * a_rows_inc;
75
201k
      float* const bpi = bp + i * b_rows_inc;
76
22.0M
      
parallel_for201k
(j, b_cols) {
77
22.0M
        float v = 0;
78
22.0M
        const float* const wpj = wp + j * w_cols_inc;
79
22.0M
        int k;
80
2.04G
        for (k = 0; k < a_cols; 
k++2.01G
)
81
2.01G
          v += wpj[k * w_rows_inc] * api[k * a_cols_inc];
82
22.0M
        bpi[j * b_cols_inc] = v;
83
22.0M
      } parallel_endfor
84
201k
    }
85
5.36k
  }
86
3.96k
}
87
88
static inline void _ccv_nnc_gbmm(const float* const a, const int a_nd, const int* const adim, const int* const astride, const float* const w, const int w_nd, const int* const wdim, const int* const wstride, float* const b, const int b_nd, const int* const bdim, const int* const bstride, const int b_batch_size, const int a_batch_inc, const int w_batch_inc, const int b_batch_inc, const int b_rows, const int b_cols, const int a_cols, const int a_cols_inc, const int w_cols_inc, const int b_cols_inc, const int a_rows_inc, const int w_rows_inc, const int b_rows_inc)
89
3.97k
{
90
3.97k
  if (b_nd <= 3)
91
3.96k
  {
92
3.96k
    _ccv_nnc_bmm(a, w, b, b_batch_size, a_batch_inc, w_batch_inc, b_batch_inc, b_rows, b_cols, a_cols, a_cols_inc, w_cols_inc, b_cols_inc, a_rows_inc, w_rows_inc, b_rows_inc);
93
3.96k
    return;
94
3.96k
  }
95
12
  const int dim = bdim[0];
96
12
  if (a_nd > 3)
97
12
    { assert(adim[0] == 1 || dim == adim[0]); }
98
12
  if (w_nd > 3)
99
9
    { assert(wdim[0] == 1 || dim == wdim[0]); }
100
12
  int i;
101
216
  for (i = 0; i < dim; 
i++204
)
102
204
  {
103
204
    _ccv_nnc_gbmm(
104
204
      a_nd > 3 ? a + i * astride[0] : 
a0
, a_nd > 3 ? a_nd - 1 :
a_nd0
, a_nd > 3 ? adim + 1 :
adim0
, a_nd > 3 ? astride + 1 :
astride0
,
105
204
      w_nd > 3 ? 
w + i * wstride[0]198
:
w6
, w_nd > 3 ?
w_nd - 1198
:
w_nd6
, w_nd > 3 ?
wdim + 1198
:
wdim6
, w_nd > 3 ?
wstride + 1198
:
wstride6
,
106
204
      b + i * bstride[0], b_nd - 1, bdim + 1, bstride + 1, b_batch_size, a_batch_inc, w_batch_inc, b_batch_inc, b_rows, b_cols, a_cols, a_cols_inc, w_cols_inc, b_cols_inc, a_rows_inc, w_rows_inc, b_rows_inc);
107
204
  }
108
12
}
109
110
static int _ccv_nnc_gemm_forw(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint, const int flags, ccv_nnc_tensor_t* const* const inputs, const int input_size, ccv_nnc_tensor_t* const* const outputs, const int output_size, ccv_nnc_stream_context_t* const stream_context)
111
15.5k
{
112
15.5k
  assert(input_size >= 2);
113
15.5k
  const ccv_nnc_tensor_view_t* a = (const ccv_nnc_tensor_view_t*)inputs[0];
114
15.5k
  const ccv_nnc_tensor_view_t* w = (const ccv_nnc_tensor_view_t*)inputs[1];
115
15.5k
  const ccv_nnc_tensor_view_t* bias = input_size > 2 ? 
(const ccv_nnc_tensor_view_t*)inputs[2]11.8k
:
03.76k
;
116
  // Copy the most of parameters, but reshape the dimension of a to a vector.
117
15.5k
  assert(output_size == 1);
118
15.5k
  ccv_nnc_tensor_view_t* b = (ccv_nnc_tensor_view_t*)outputs[0];
119
15.5k
  assert(!bias || (bias->info.dim[1] == 0 || bias->info.dim[2] == 0 || bias->info.dim[3] == 0)); // It is a 1-d array
120
15.5k
  int a_batch_size, a_rows, a_cols, a_batch_inc, a_rows_inc, a_cols_inc;
121
15.5k
  int w_batch_size, w_rows, w_cols, w_batch_inc, w_rows_inc, w_cols_inc;
122
15.5k
  int b_batch_size, b_rows, b_cols, b_batch_inc, b_rows_inc, b_cols_inc;
123
15.5k
  const static int no_transpose[2] = {};
124
15.5k
  ccv_nnc_tensor_get_matrix_params(a->info, CCV_IS_TENSOR_VIEW(a) ? 
a->stride14
:
015.5k
, a->info.dim, cmd.info.blas.transpose_a, &a_batch_size, &a_rows, &a_cols, &a_batch_inc, &a_rows_inc, &a_cols_inc);
125
15.5k
  ccv_nnc_tensor_get_matrix_params(w->info, CCV_IS_TENSOR_VIEW(w) ? 
w->stride4
:
015.5k
, w->info.dim, cmd.info.blas.transpose_b, &w_batch_size, &w_rows, &w_cols, &w_batch_inc, &w_rows_inc, &w_cols_inc);
126
15.5k
  ccv_nnc_tensor_get_matrix_params(b->info, CCV_IS_TENSOR_VIEW(b) ? 
b->stride26
:
015.5k
, b->info.dim, no_transpose, &b_batch_size, &b_rows, &b_cols, &b_batch_inc, &b_rows_inc, &b_cols_inc);
127
15.5k
  assert(ccv_max(a_batch_size, w_batch_size) == b_batch_size);
128
15.5k
  assert(a_batch_size == b_batch_size || a_batch_size == 1);
129
15.5k
  if (a_batch_size == 1 && 
b_batch_size > 115.5k
)
130
0
    a_batch_inc = 0;
131
15.5k
  assert(w_batch_size == b_batch_size || w_batch_size == 1);
132
15.5k
  if (w_batch_size == 1 && 
b_batch_size > 115.5k
)
133
11
    w_batch_inc = 0;
134
15.5k
  assert(a_rows == b_rows);
135
15.5k
  assert(a_cols == w_rows);
136
15.5k
  assert(w_cols == b_cols);
137
15.5k
  int astride_from_dim[CCV_NNC_MAX_DIM_ALLOC];
138
15.5k
  int wstride_from_dim[CCV_NNC_MAX_DIM_ALLOC];
139
15.5k
  int bstride_from_dim[CCV_NNC_MAX_DIM_ALLOC];
140
15.5k
  const int* astride;
141
15.5k
  if (CCV_IS_TENSOR_VIEW(a))
142
14
    astride = a->stride;
143
15.5k
  else {
144
15.5k
    ccv_nnc_tensor_get_stride(a->info.dim, astride_from_dim);
145
15.5k
    astride = astride_from_dim;
146
15.5k
  }
147
15.5k
  const int* wstride;
148
15.5k
  if (CCV_IS_TENSOR_VIEW(w))
149
4
    wstride = w->stride;
150
15.5k
  else {
151
15.5k
    ccv_nnc_tensor_get_stride(w->info.dim, wstride_from_dim);
152
15.5k
    wstride = wstride_from_dim;
153
15.5k
  }
154
15.5k
  const int* bstride;
155
15.5k
  if (CCV_IS_TENSOR_VIEW(b))
156
26
    bstride = b->stride;
157
15.5k
  else {
158
15.5k
    ccv_nnc_tensor_get_stride(b->info.dim, bstride_from_dim);
159
15.5k
    bstride = bstride_from_dim;
160
15.5k
  }
161
15.5k
  if (bias)
162
11.8k
  {
163
11.8k
    int bias_batch_size, bias_rows, bias_cols, bias_batch_inc, bias_rows_inc, bias_cols_inc;
164
11.8k
    ccv_nnc_tensor_get_matrix_params(bias->info, CCV_IS_TENSOR_VIEW(bias) ? 
bias->stride0
: 0, bias->info.dim, no_transpose, &bias_batch_size, &bias_rows, &bias_cols, &bias_batch_inc, &bias_rows_inc, &bias_cols_inc);
165
11.8k
    assert(bias_batch_size == b_batch_size || bias_batch_size == 1);
166
11.8k
    if (bias_batch_size == 1 && 
b_batch_size > 111.8k
)
167
10
      bias_batch_inc = 0;
168
11.8k
    if (bias_rows == 1 && b_rows > 1)
169
6.12k
      bias_rows_inc = 0;
170
11.8k
    assert(bias_cols == b_cols);
171
11.8k
    const int* biasstride;
172
11.8k
    int biasstride_from_dim[CCV_NNC_MAX_DIM_ALLOC];
173
11.8k
    if (CCV_IS_TENSOR_VIEW(bias))
174
0
      biasstride = bias->stride;
175
11.8k
    else {
176
11.8k
      ccv_nnc_tensor_get_stride(bias->info.dim, biasstride_from_dim);
177
11.8k
      biasstride = biasstride_from_dim;
178
11.8k
    }
179
11.8k
    _ccv_nnc_gbmm_and_bias(a->data.f32, ccv_nnc_tensor_nd(a->info.dim), a->info.dim, astride, w->data.f32, ccv_nnc_tensor_nd(w->info.dim), w->info.dim, wstride, bias->data.f32, ccv_nnc_tensor_nd(bias->info.dim), bias->info.dim, biasstride, b->data.f32, ccv_nnc_tensor_nd(b->info.dim), b->info.dim, bstride, b_batch_size, a_batch_inc, w_batch_inc, bias_batch_inc, b_batch_inc, b_rows, b_cols, a_cols, a_cols_inc, w_cols_inc, bias_cols_inc, b_cols_inc, a_rows_inc, w_rows_inc, bias_rows_inc, b_rows_inc);
180
11.8k
  } else {
181
3.76k
    _ccv_nnc_gbmm(a->data.f32, ccv_nnc_tensor_nd(a->info.dim), a->info.dim, astride, w->data.f32, ccv_nnc_tensor_nd(w->info.dim), w->info.dim, wstride, b->data.f32, ccv_nnc_tensor_nd(b->info.dim), b->info.dim, bstride, b_batch_size, a_batch_inc, w_batch_inc, b_batch_inc, b_rows, b_cols, a_cols, a_cols_inc, w_cols_inc, b_cols_inc, a_rows_inc, w_rows_inc, b_rows_inc);
182
3.76k
  }
183
15.5k
  return CCV_NNC_EXEC_SUCCESS;
184
15.5k
}
185
186
static inline void _ccv_nnc_bmm_dbias(const float* const g, float* const dbias, const int g_batch_size, const int g_batch_inc, const int dbias_batch_inc, const int g_rows, const int g_cols, const int g_cols_inc, const int dbias_cols_inc, const int g_rows_inc, const int dbias_rows_inc)
187
3.49k
{
188
3.49k
  int n, i, j;
189
7.00k
  for (n = 0; n < g_batch_size; 
n++3.51k
)
190
3.51k
  {
191
3.51k
    const float* const gp = g + n * g_batch_inc;
192
3.51k
    float* const bp = dbias + n * dbias_batch_inc;
193
7.34k
    for (i = 0; i < g_rows; 
i++3.83k
)
194
3.83k
    {
195
3.83k
      const float* const gpi = gp + i * g_rows_inc;
196
3.83k
      float* const bpi = bp + i * dbias_rows_inc;
197
22.1k
      for (j = 0; j < g_cols; 
j++18.3k
)
198
18.3k
        bpi[j * dbias_cols_inc] += gpi[j * g_cols_inc];
199
3.83k
    }
200
3.51k
  }
201
3.49k
}
202
203
static inline void _ccv_nnc_gbmm_dbias(const float* const g, const int g_nd, const int* const gdim, const int* const gstride, float* const dbias, const int dbias_nd, const int* const dbiasdim, const int* const dbiasstride, const int g_batch_size, const int g_batch_inc, const int dbias_batch_inc, const int g_rows, const int g_cols, const int g_cols_inc, const int dbias_cols_inc, const int g_rows_inc, const int dbias_rows_inc)
204
3.49k
{
205
3.49k
  if (g_nd <= 3)
206
3.49k
  {
207
3.49k
    _ccv_nnc_bmm_dbias(g, dbias, g_batch_size, g_batch_inc, dbias_batch_inc, g_rows, g_cols, g_cols_inc, dbias_cols_inc, g_rows_inc, dbias_rows_inc);
208
3.49k
    return;
209
3.49k
  }
210
2
  const int dim = gdim[0];
211
2
  if (dbias_nd > 3)
212
0
    { assert(dbiasdim[0] == 1 || dim == dbiasdim[0]); }
213
2
  int i;
214
6
  for (i = 0; i < dim; 
i++4
)
215
4
  {
216
4
    _ccv_nnc_gbmm_dbias(
217
4
      g + i * gstride[0], g_nd - 1, gdim + 1, gstride + 1,
218
4
      dbias_nd > 3 ? 
dbias + i * dbiasstride[0]0
: dbias, dbias_nd > 3 ?
dbias_nd - 10
: dbias_nd, dbias_nd > 3 ?
dbiasdim + 10
: dbiasdim, dbias_nd > 3 ?
dbiasstride + 10
: dbiasstride,
219
4
      g_batch_size, g_batch_inc, dbias_batch_inc, g_rows, g_cols, g_cols_inc, dbias_cols_inc, g_rows_inc, dbias_rows_inc);
220
4
  }
221
2
}
222
223
static inline void _ccv_nnc_bmm_dw(const float* const g, const float* const a, float* const dw, const int g_batch_size, const int g_batch_inc, const int a_batch_inc, const int dw_batch_inc, const int a_rows, const int a_cols, const int g_cols, const int g_cols_inc, const int a_cols_inc, const int dw_cols_inc, const int g_rows_inc, const int a_rows_inc, const int dw_rows_inc)
224
10.7k
{
225
10.7k
  int n, i;
226
21.9k
  for (n = 0; n < g_batch_size; 
n++11.1k
)
227
11.1k
  {
228
11.1k
    const float* const gp = g + n * g_batch_inc;
229
11.1k
    const float* const ap = a + n * a_batch_inc;
230
11.1k
    float* const dwp = dw + n * dw_batch_inc;
231
93.9k
    for (i = 0; i < a_rows; 
i++82.7k
)
232
82.7k
    {
233
82.7k
      const float* const gpi = gp + i * g_rows_inc;
234
82.7k
      const float* const api = ap + i * a_rows_inc;
235
7.39M
      
parallel_for82.7k
(j, g_cols) {
236
7.39M
        const float v = gpi[j * g_cols_inc];
237
7.39M
        float* dwpj = dwp + j * dw_cols_inc;
238
7.39M
        int k;
239
683M
        for (k = 0; k < a_cols; 
k++675M
)
240
675M
          dwpj[k * dw_rows_inc] += api[k * a_cols_inc] * v;
241
7.39M
      } parallel_endfor
242
82.7k
    }
243
11.1k
  }
244
10.7k
}
245
246
static inline void _ccv_nnc_gbmm_dw(const float* const g, const int g_nd, const int* const gdim, const int* const gstride, const float* const a, const int a_nd, const int* const adim, const int* const astride, float* const dw, const int dw_nd, const int* const dwdim, const int* const dwstride, const int g_batch_size, const int g_batch_inc, const int a_batch_inc, const int dw_batch_inc, const int a_rows, const int a_cols, const int g_cols, const int g_cols_inc, const int a_cols_inc, const int dw_cols_inc, const int g_rows_inc, const int a_rows_inc, const int dw_rows_inc)
247
10.7k
{
248
10.7k
  if (g_nd <= 3)
249
10.7k
  {
250
10.7k
    _ccv_nnc_bmm_dw(g, a, dw, g_batch_size, g_batch_inc, a_batch_inc, dw_batch_inc, a_rows, a_cols, g_cols, g_cols_inc, a_cols_inc, dw_cols_inc, g_rows_inc, a_rows_inc, dw_rows_inc);
251
10.7k
    return;
252
10.7k
  }
253
6
  const int dim = gdim[0];
254
6
  if (a_nd > 3)
255
6
    { assert(adim[0] == 1 || dim == adim[0]); }
256
6
  if (dw_nd > 3)
257
4
    { assert(dwdim[0] == 1 || dim == dwdim[0]); }
258
6
  int i;
259
78
  for (i = 0; i < dim; 
i++72
)
260
72
  {
261
72
    _ccv_nnc_gbmm_dw(
262
72
      g + i * gstride[0], g_nd - 1, gdim + 1, gstride + 1,
263
72
      a_nd > 3 ? a + i * astride[0] : 
a0
, a_nd > 3 ? a_nd - 1 :
a_nd0
, a_nd > 3 ? adim + 1 :
adim0
, a_nd > 3 ? astride + 1 :
astride0
,
264
72
      dw_nd > 3 ? 
dw + i * dwstride[0]68
:
dw4
, dw_nd > 3 ?
dw_nd - 168
:
dw_nd4
, dw_nd > 3 ?
dwdim + 168
:
dwdim4
, dw_nd > 3 ?
dwstride + 168
:
dwstride4
,
265
72
      g_batch_size, g_batch_inc, a_batch_inc, dw_batch_inc, a_rows, a_cols, g_cols, g_cols_inc, a_cols_inc, dw_cols_inc, g_rows_inc, a_rows_inc, dw_rows_inc);
266
72
  }
267
6
}
268
269
static inline void _ccv_nnc_bmm_h(const float* const g, const float* const w, float* const h, const int zero_h, const int g_batch_size, const int g_batch_inc, const int w_batch_inc, const int h_batch_inc, const int h_rows, const int h_cols, const int g_cols, const int g_cols_inc, const int w_cols_inc, const int h_cols_inc, const int g_rows_inc, const int w_rows_inc, const int h_rows_inc)
270
2.46k
{
271
2.46k
  int n, i;
272
5.41k
  for (n = 0; n < g_batch_size; 
n++2.95k
)
273
2.95k
  {
274
2.95k
    const float* const gp = g + n * g_batch_inc;
275
2.95k
    const float* const wp = w + n * w_batch_inc;
276
2.95k
    float* const hp = h + n * h_batch_inc;
277
71.4k
    for (i = 0; i < h_rows; 
i++68.4k
)
278
68.4k
    {
279
68.4k
      const float* const gpi = gp + i * g_rows_inc;
280
68.4k
      float* const hpi = hp + i * h_rows_inc;
281
6.35M
      
parallel_for68.4k
(j, h_cols) {
282
6.35M
        const float* const wpj = wp + j * w_rows_inc;
283
6.35M
        float v = zero_h ? 0 : 
hpi[j * h_cols_inc]0
;
284
6.35M
        int k;
285
681M
        for (k = 0; k < g_cols; 
k++675M
)
286
675M
          v += wpj[k * w_cols_inc] * gpi[k * g_cols_inc];
287
6.35M
        hpi[j * h_cols_inc] = v;
288
6.35M
      } parallel_endfor
289
68.4k
    }
290
2.95k
  }
291
2.46k
}
292
293
static inline void _ccv_nnc_gbmm_h(const float* const g, const int g_nd, const int* const gdim, const int* const gstride, const float* const w, const int w_nd, const int* const wdim, const int* const wstride, float* const h, const int zero_h, const int h_nd, const int* const hdim, const int* const hstride, const int g_batch_size, const int g_batch_inc, const int w_batch_inc, const int h_batch_inc, const int h_rows, const int h_cols, const int g_cols, const int g_cols_inc, const int w_cols_inc, const int h_cols_inc, const int g_rows_inc, const int w_rows_inc, const int h_rows_inc)
294
2.47k
{
295
2.47k
  if (g_nd <= 3)
296
2.46k
  {
297
2.46k
    _ccv_nnc_bmm_h(g, w, h, zero_h, g_batch_size, g_batch_inc, w_batch_inc, h_batch_inc, h_rows, h_cols, g_cols, g_cols_inc, w_cols_inc, h_cols_inc, g_rows_inc, w_rows_inc, h_rows_inc);
298
2.46k
    return;
299
2.46k
  }
300
6
  const int dim = gdim[0];
301
6
  if (w_nd > 3)
302
4
    { assert(wdim[0] == 1 || dim == wdim[0]); }
303
6
  if (h_nd > 3)
304
6
    { assert(hdim[0] == 1 || dim == hdim[0]); }
305
6
  int i;
306
78
  for (i = 0; i < dim; 
i++72
)
307
72
  {
308
    // Only zero h if we are not doing h again.
309
72
    const int zero_h_override = (i == 0 || 
(66
i * hstride[0] > 066
&&
h_nd > 366
)) ? zero_h :
00
;
310
72
    _ccv_nnc_gbmm_h(
311
72
      g + i * gstride[0], g_nd - 1, gdim + 1, gstride + 1,
312
72
      w_nd > 3 ? 
w + i * wstride[0]68
:
w4
, w_nd > 3 ?
w_nd - 168
:
w_nd4
, w_nd > 3 ?
wdim + 168
:
wdim4
, w_nd > 3 ?
wstride + 168
:
wstride4
,
313
72
      h_nd > 3 ? h + i * hstride[0] : 
h0
, zero_h_override, h_nd > 3 ? h_nd - 1 :
h_nd0
, h_nd > 3 ? hdim + 1 :
hdim0
, h_nd > 3 ? hstride + 1 :
hstride0
,
314
72
      g_batch_size, g_batch_inc, w_batch_inc, h_batch_inc, h_rows, h_cols, g_cols, g_cols_inc, w_cols_inc, h_cols_inc, g_rows_inc, w_rows_inc, h_rows_inc);
315
72
  }
316
6
}
317
318
static int _ccv_nnc_gemm_back(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint, const int flags, ccv_nnc_tensor_t* const* const inputs, const int input_size, ccv_nnc_tensor_t* const* const outputs, const int output_size, ccv_nnc_stream_context_t* const stream_context)
319
10.6k
{
320
  // inputs: gradient, forw prop input, [w]
321
  // outputs: [output gradient], weight updates, bias updates
322
10.6k
  assert(input_size >= 2 && output_size >= 1);
323
10.6k
  const ccv_nnc_tensor_view_t* g = (const ccv_nnc_tensor_view_t*)inputs[0];
324
10.6k
  ccv_nnc_tensor_view_t* dw = output_size > 1 ? (ccv_nnc_tensor_view_t*)outputs[1] : 
00
;
325
10.6k
  ccv_nnc_tensor_view_t* bias = output_size > 2 ? 
(ccv_nnc_tensor_view_t*)outputs[2]9.50k
:
01.16k
;
326
10.6k
  assert(!bias || (bias->info.dim[1] == 0 || bias->info.dim[2] == 0 || bias->info.dim[3] == 0)); // It is a 2-d or 3-d array.
327
10.6k
  int g_batch_size, g_rows, g_cols, g_batch_inc, g_rows_inc, g_cols_inc;
328
10.6k
  const static int no_transpose[2] = {};
329
10.6k
  ccv_nnc_tensor_get_matrix_params(g->info, CCV_IS_TENSOR_VIEW(g) ? 
g->stride6
:
010.6k
, g->info.dim, no_transpose, &g_batch_size, &g_rows, &g_cols, &g_batch_inc, &g_rows_inc, &g_cols_inc);
330
10.6k
  if (bias)
331
3.49k
  {
332
3.49k
    if (!(flags & CCV_NNC_ACCUMULATE_OUTPUT)) // reset the gradients to 0
333
3.49k
      ccv_nnc_tensor_zero(bias);
334
3.49k
    int bias_batch_size, bias_rows, bias_cols, bias_batch_inc, bias_rows_inc, bias_cols_inc;
335
3.49k
    ccv_nnc_tensor_get_matrix_params(bias->info, CCV_IS_TENSOR_VIEW(bias) ? 
bias->stride0
: 0, bias->info.dim, no_transpose, &bias_batch_size, &bias_rows, &bias_cols, &bias_batch_inc, &bias_rows_inc, &bias_cols_inc);
336
3.49k
    assert(bias_cols == g_cols);
337
3.49k
    assert(bias_batch_size == 1 || bias_batch_size == g_batch_size);
338
3.49k
    if (bias_batch_size == 1 && 
g_batch_size > 13.49k
)
339
5
      bias_batch_inc = 0;
340
3.49k
    if (bias_rows == 1 && g_rows > 1)
341
124
      bias_rows_inc = 0;
342
3.49k
    int gstride_from_dim[CCV_NNC_MAX_DIM_ALLOC];
343
3.49k
    const int* gstride;
344
3.49k
    if (CCV_IS_TENSOR_VIEW(g))
345
0
      gstride = g->stride;
346
3.49k
    else {
347
3.49k
      ccv_nnc_tensor_get_stride(g->info.dim, gstride_from_dim);
348
3.49k
      gstride = gstride_from_dim;
349
3.49k
    }
350
3.49k
    int biasstride_from_dim[CCV_NNC_MAX_DIM_ALLOC];
351
3.49k
    const int* biasstride;
352
3.49k
    if (CCV_IS_TENSOR_VIEW(bias))
353
0
      biasstride = bias->stride;
354
3.49k
    else {
355
3.49k
      ccv_nnc_tensor_get_stride(bias->info.dim, biasstride_from_dim);
356
3.49k
      biasstride = biasstride_from_dim;
357
3.49k
    }
358
3.49k
    _ccv_nnc_gbmm_dbias(g->data.f32, ccv_nnc_tensor_nd(g->info.dim), g->info.dim, gstride, bias->data.f32, ccv_nnc_tensor_nd(bias->info.dim), bias->info.dim, biasstride, g_batch_size, g_batch_inc, bias_batch_inc, g_rows, g_cols, g_cols_inc, bias_cols_inc, g_rows_inc, bias_rows_inc);
359
3.49k
  }
360
10.6k
  if (dw)
361
10.6k
  {
362
10.6k
    if (!(flags & CCV_NNC_ACCUMULATE_OUTPUT)) // reset the gradients to 0
363
10.6k
      ccv_nnc_tensor_zero(dw);
364
10.6k
    const ccv_nnc_tensor_view_t* a = (const ccv_nnc_tensor_view_t*)inputs[1];
365
10.6k
    int a_batch_size, a_rows, a_cols, a_batch_inc, a_rows_inc, a_cols_inc;
366
10.6k
    int dw_batch_size, dw_rows, dw_cols, dw_batch_inc, dw_rows_inc, dw_cols_inc;
367
10.6k
    ccv_nnc_tensor_get_matrix_params(a->info, CCV_IS_TENSOR_VIEW(a) ? 
a->stride2
:
010.6k
, a->info.dim, cmd.info.blas.transpose_a, &a_batch_size, &a_rows, &a_cols, &a_batch_inc, &a_rows_inc, &a_cols_inc);
368
10.6k
    ccv_nnc_tensor_get_matrix_params(dw->info, CCV_IS_TENSOR_VIEW(dw) ? 
dw->stride0
: 0, dw->info.dim, cmd.info.blas.transpose_b, &dw_batch_size, &dw_rows, &dw_cols, &dw_batch_inc, &dw_rows_inc, &dw_cols_inc);
369
10.6k
    assert(a_rows == g_rows);
370
10.6k
    assert(a_cols == dw_rows);
371
10.6k
    assert(dw_cols == g_cols);
372
10.6k
    assert(a_batch_size == g_batch_size || a_batch_size == 1);
373
10.6k
    if (a_batch_size == 1 && 
g_batch_size > 110.6k
)
374
0
      a_batch_inc = 0;
375
10.6k
    assert(dw_batch_size == g_batch_size || dw_batch_size == 1);
376
10.6k
    if (dw_batch_size == 1 && 
g_batch_size > 110.6k
)
377
5
      dw_batch_inc = 0;
378
10.6k
    int gstride_from_dim[CCV_NNC_MAX_DIM_ALLOC];
379
10.6k
    int astride_from_dim[CCV_NNC_MAX_DIM_ALLOC];
380
10.6k
    int dwstride_from_dim[CCV_NNC_MAX_DIM_ALLOC];
381
10.6k
    const int* gstride;
382
10.6k
    if (CCV_IS_TENSOR_VIEW(g))
383
2
      gstride = g->stride;
384
10.6k
    else {
385
10.6k
      ccv_nnc_tensor_get_stride(g->info.dim, gstride_from_dim);
386
10.6k
      gstride = gstride_from_dim;
387
10.6k
    }
388
10.6k
    const int* astride;
389
10.6k
    if (CCV_IS_TENSOR_VIEW(a))
390
2
      astride = a->stride;
391
10.6k
    else {
392
10.6k
      ccv_nnc_tensor_get_stride(a->info.dim, astride_from_dim);
393
10.6k
      astride = astride_from_dim;
394
10.6k
    }
395
10.6k
    const int* dwstride;
396
10.6k
    if (CCV_IS_TENSOR_VIEW(dw))
397
0
      dwstride = dw->stride;
398
10.6k
    else {
399
10.6k
      ccv_nnc_tensor_get_stride(dw->info.dim, dwstride_from_dim);
400
10.6k
      dwstride = dwstride_from_dim;
401
10.6k
    }
402
10.6k
    _ccv_nnc_gbmm_dw(g->data.f32, ccv_nnc_tensor_nd(g->info.dim), g->info.dim, gstride, a->data.f32, ccv_nnc_tensor_nd(a->info.dim), a->info.dim, astride, dw->data.f32, ccv_nnc_tensor_nd(dw->info.dim), dw->info.dim, dwstride, g_batch_size, g_batch_inc, a_batch_inc, dw_batch_inc, a_rows, a_cols, g_cols, g_cols_inc, a_cols_inc, dw_cols_inc, g_rows_inc, a_rows_inc, dw_rows_inc);
403
10.6k
  }
404
10.6k
  ccv_nnc_tensor_view_t* h = (ccv_nnc_tensor_view_t*)outputs[0];
405
10.6k
  if (h)
406
2.40k
  {
407
2.40k
    const int zero_h = !(flags & CCV_NNC_ACCUMULATE_OUTPUT); // reset the gradients to 0
408
2.40k
    const ccv_nnc_tensor_view_t* w = (const ccv_nnc_tensor_view_t*)inputs[2];
409
2.40k
    int h_batch_size, h_rows, h_cols, h_batch_inc, h_rows_inc, h_cols_inc;
410
2.40k
    int w_batch_size, w_rows, w_cols, w_batch_inc, w_rows_inc, w_cols_inc;
411
2.40k
    ccv_nnc_tensor_get_matrix_params(h->info, CCV_IS_TENSOR_VIEW(h) ? 
h->stride4
:
02.39k
, h->info.dim, cmd.info.blas.transpose_a, &h_batch_size, &h_rows, &h_cols, &h_batch_inc, &h_rows_inc, &h_cols_inc);
412
2.40k
    ccv_nnc_tensor_get_matrix_params(w->info, CCV_IS_TENSOR_VIEW(w) ? 
w->stride0
: 0, w->info.dim, cmd.info.blas.transpose_b, &w_batch_size, &w_rows, &w_cols, &w_batch_inc, &w_rows_inc, &w_cols_inc);
413
2.40k
    assert(h_cols == w_rows);
414
2.40k
    assert(w_cols == g_cols);
415
2.40k
    assert(h_batch_size == g_batch_size || h_batch_size == 1);
416
2.40k
    if (h_batch_size == 1 && 
g_batch_size > 12.39k
)
417
0
      h_batch_inc = 0;
418
2.40k
    assert(w_batch_size == g_batch_size || w_batch_size == 1);
419
2.40k
    if (w_batch_size == 1 && 
g_batch_size > 12.39k
)
420
5
      w_batch_inc = 0;
421
2.40k
    int gstride_from_dim[CCV_NNC_MAX_DIM_ALLOC];
422
2.40k
    int wstride_from_dim[CCV_NNC_MAX_DIM_ALLOC];
423
2.40k
    int hstride_from_dim[CCV_NNC_MAX_DIM_ALLOC];
424
2.40k
    const int* gstride;
425
2.40k
    if (CCV_IS_TENSOR_VIEW(g))
426
4
      gstride = g->stride;
427
2.39k
    else {
428
2.39k
      ccv_nnc_tensor_get_stride(g->info.dim, gstride_from_dim);
429
2.39k
      gstride = gstride_from_dim;
430
2.39k
    }
431
2.40k
    const int* wstride;
432
2.40k
    if (CCV_IS_TENSOR_VIEW(w))
433
0
      wstride = w->stride;
434
2.40k
    else {
435
2.40k
      ccv_nnc_tensor_get_stride(w->info.dim, wstride_from_dim);
436
2.40k
      wstride = wstride_from_dim;
437
2.40k
    }
438
2.40k
    const int* hstride;
439
2.40k
    if (CCV_IS_TENSOR_VIEW(h))
440
4
      hstride = h->stride;
441
2.39k
    else {
442
2.39k
      ccv_nnc_tensor_get_stride(h->info.dim, hstride_from_dim);
443
2.39k
      hstride = hstride_from_dim;
444
2.39k
    }
445
2.40k
    _ccv_nnc_gbmm_h(g->data.f32, ccv_nnc_tensor_nd(g->info.dim), g->info.dim, gstride, w->data.f32, ccv_nnc_tensor_nd(w->info.dim), w->info.dim, wstride, h->data.f32, zero_h, ccv_nnc_tensor_nd(h->info.dim), h->info.dim, hstride, g_batch_size, g_batch_inc, w_batch_inc, h_batch_inc, h_rows, h_cols, g_cols, g_cols_inc, w_cols_inc, h_cols_inc, g_rows_inc, w_rows_inc, h_rows_inc);
446
2.40k
  }
447
10.6k
  return CCV_NNC_EXEC_SUCCESS;
448
10.6k
}
449
450
REGISTER_COMMAND_BACKEND(CCV_NNC_GEMM_FORWARD, CCV_NNC_BACKEND_CPU_REF)(ccv_nnc_cmd_backend_registry_t* const registry)
451
1
{
452
1
  registry->tensor_formats = CCV_TENSOR_FORMAT_NHWC | CCV_TENSOR_FORMAT_NCHW;
453
1
  registry->tensor_datatypes = CCV_32F;
454
1
  registry->tensor_memory = CCV_TENSOR_CPU_MEMORY;
455
1
  registry->algorithms = 1;
456
1
  registry->exec = _ccv_nnc_gemm_forw;
457
1
}
458
459
REGISTER_COMMAND_BACKEND(CCV_NNC_GEMM_BACKWARD, CCV_NNC_BACKEND_CPU_REF)(ccv_nnc_cmd_backend_registry_t* const registry)
460
1
{
461
1
  registry->tensor_formats = CCV_TENSOR_FORMAT_NHWC | CCV_TENSOR_FORMAT_NCHW;
462
1
  registry->tensor_datatypes = CCV_32F;
463
1
  registry->tensor_memory = CCV_TENSOR_CPU_MEMORY;
464
1
  registry->algorithms = 1;
465
1
  registry->exec = _ccv_nnc_gemm_back;
466
1
}