Coverage Report

Created: 2022-08-03 23:52

/home/liu/buildslave/linux-x64-runtests/build/lib/nnc/cmd/blas/cpu_sys/_ccv_nnc_gemm_cpu_sys.c
Line
Count
Source (jump to first uncovered line)
1
#include "ccv.h"
2
#include "ccv_internal.h"
3
#include "nnc/ccv_nnc.h"
4
#include "nnc/ccv_nnc_easy.h"
5
#include "nnc/ccv_nnc_internal.h"
6
#include "../_ccv_nnc_gemm_cpu_opt.h"
7
#if HAVE_ACCELERATE_FRAMEWORK
8
#include <Accelerate/Accelerate.h>
9
#elif HAVE_CBLAS
10
#include <cblas.h>
11
#endif
12
13
int _ccv_nnc_gemm_forw_cpu_sys(const int transpose_a[2], const int transpose_b[2], const ccv_nnc_tensor_view_t* const a, const ccv_nnc_tensor_view_t* const w, const ccv_nnc_tensor_view_t* const bias, ccv_nnc_tensor_view_t* const b)
14
9.47k
{
15
9.47k
#if (defined HAVE_CBLAS || defined HAVE_ACCELERATE_FRAMEWORK)
16
9.47k
  assert(!bias || (bias->info.dim[1] == 0 || bias->info.dim[2] == 0 || bias->info.dim[3] == 0)); // It is a 1-d array
17
9.47k
  int a_batch_size, a_rows, a_cols, a_batch_inc, a_rows_inc, a_cols_inc;
18
9.47k
  int w_batch_size, w_rows, w_cols, w_batch_inc, w_rows_inc, w_cols_inc;
19
9.47k
  int b_batch_size, b_rows, b_cols, b_batch_inc, b_rows_inc, b_cols_inc;
20
9.47k
  const static int no_transpose[2] = {};
21
9.47k
  ccv_nnc_tensor_get_matrix_params(a->info, CCV_IS_TENSOR_VIEW(a) ? 
a->inc0
: a->info.dim, transpose_a, &a_batch_size, &a_rows, &a_cols, &a_batch_inc, &a_rows_inc, &a_cols_inc);
22
9.47k
  ccv_nnc_tensor_get_matrix_params(w->info, CCV_IS_TENSOR_VIEW(w) ? 
w->inc0
: w->info.dim, transpose_b, &w_batch_size, &w_rows, &w_cols, &w_batch_inc, &w_rows_inc, &w_cols_inc);
23
9.47k
  ccv_nnc_tensor_get_matrix_params(b->info, CCV_IS_TENSOR_VIEW(b) ? 
b->inc10
:
b->info.dim9.46k
, no_transpose, &b_batch_size, &b_rows, &b_cols, &b_batch_inc, &b_rows_inc, &b_cols_inc);
24
9.47k
  assert(a_batch_size == b_batch_size);
25
9.47k
  assert(a_batch_size == b_batch_size || a_batch_size == 1);
26
9.47k
  if (a_batch_size == 1 && 
b_batch_size > 19.47k
)
27
0
    a_batch_inc = 0;
28
9.47k
  assert(w_batch_size == a_batch_size || w_batch_size == 1);
29
9.47k
  if (w_batch_size == 1 && 
b_batch_size > 19.47k
)
30
2
    w_batch_inc = 0;
31
9.47k
  assert(a_rows == b_rows);
32
9.47k
  assert(a_cols == w_rows);
33
9.47k
  assert(w_cols == b_cols);
34
9.47k
  const int is_transpose_a = ccv_nnc_is_matrix_transpose(a->info, transpose_a);
35
9.47k
  const int is_transpose_w = ccv_nnc_is_matrix_transpose(w->info, transpose_b);
36
9.47k
  if (bias)
37
6.42k
  {
38
6.42k
    float* const ones = (float*)ccmalloc(sizeof(float) * b_rows);
39
6.42k
    int i;
40
13.0k
    for (i = 0; i < b_rows; 
i++6.59k
)
41
6.59k
      ones[i] = 1;
42
6.42k
    int bias_batch_size, bias_rows, bias_cols, bias_batch_inc, bias_rows_inc, bias_cols_inc;
43
6.42k
    ccv_nnc_tensor_get_matrix_params(bias->info, CCV_IS_TENSOR_VIEW(bias) ? 
bias->inc0
: bias->info.dim, no_transpose, &bias_batch_size, &bias_rows, &bias_cols, &bias_batch_inc, &bias_rows_inc, &bias_cols_inc);
44
6.42k
    assert(bias_batch_size == b_batch_size || bias_batch_size == 1);
45
6.42k
    if (bias_batch_size == 1 && 
b_batch_size > 16.42k
)
46
1
      bias_batch_inc = 0;
47
6.42k
    assert(bias_cols == b_cols);
48
12.8k
    
for (i = 0; 6.42k
i < b_batch_size;
i++6.42k
)
49
6.42k
      cblas_sgemm(CblasColMajor, CblasNoTrans, CblasNoTrans, b_cols, b_rows, 1, 1.0, bias->data.f32 + i * bias_batch_inc, bias_rows_inc, ones, 1, 0.0, b->data.f32 + i * b_batch_inc, b_rows_inc);
50
6.42k
    ccfree(ones);
51
6.42k
    const int transa = is_transpose_w ? 
CblasTrans6.42k
:
CblasNoTrans2
;
52
6.42k
    const int transb = is_transpose_a ? 
CblasTrans1
:
CblasNoTrans6.42k
;
53
6.42k
    const int lda_inc = is_transpose_w ? 
w_cols_inc6.42k
:
w_rows_inc2
;
54
6.42k
    const int ldb_inc = is_transpose_a ? 
a_cols_inc1
:
a_rows_inc6.42k
;
55
12.8k
    for (i = 0; i < b_batch_size; 
i++6.42k
)
56
6.42k
      cblas_sgemm(CblasColMajor, transa, transb, b_cols, b_rows, a_cols, 1.0, w->data.f32 + i * w_batch_inc, lda_inc, a->data.f32 + i * a_batch_inc, ldb_inc, 1.0, b->data.f32 + i * b_batch_inc, b_rows_inc);
57
6.42k
  } else {
58
3.04k
    const int transa = is_transpose_w ? 
CblasTrans2.04k
:
CblasNoTrans1.00k
;
59
3.04k
    const int transb = is_transpose_a ? 
CblasTrans1
:
CblasNoTrans3.04k
;
60
3.04k
    const int lda_inc = is_transpose_w ? 
w_cols_inc2.04k
:
w_rows_inc1.00k
;
61
3.04k
    const int ldb_inc = is_transpose_a ? 
a_cols_inc1
:
a_rows_inc3.04k
;
62
3.04k
    int i;
63
6.09k
    for (i = 0; i < b_batch_size; 
i++3.05k
)
64
3.05k
      cblas_sgemm(CblasColMajor, transa, transb, b_cols, b_rows, a_cols, 1.0, w->data.f32 + i * w_batch_inc, lda_inc, a->data.f32 + i * a_batch_inc, ldb_inc, 0.0, b->data.f32 + i * b_batch_inc, b_rows_inc);
65
3.04k
  }
66
9.47k
  return CCV_NNC_EXEC_SUCCESS;
67
#else
68
  return CCV_NNC_EXEC_INVALID;
69
#endif
70
9.47k
}
71
72
int _ccv_nnc_gemm_back_cpu_sys(const int transpose_a[2], const int transpose_b[2], const ccv_nnc_tensor_view_t* const g, const ccv_nnc_tensor_view_t* const a, const ccv_nnc_tensor_view_t* const w, ccv_nnc_tensor_view_t* const dw, ccv_nnc_tensor_view_t* const bias, ccv_nnc_tensor_view_t* const h, const int flags)
73
9.46k
{
74
9.46k
#if (defined HAVE_CBLAS || defined HAVE_ACCELERATE_FRAMEWORK)
75
  // inputs: gradient, forw prop input, [w]
76
  // outputs: [output gradient], weight updates, bias updates
77
9.46k
  assert(!bias || (bias->info.dim[1] == 0 || bias->info.dim[2] == 0 || bias->info.dim[3] == 0)); // It is a 2-d or 3-d array.
78
9.46k
  int g_batch_size, g_rows, g_cols, g_batch_inc, g_rows_inc, g_cols_inc;
79
9.46k
  const static int no_transpose[2] = {};
80
9.46k
  ccv_nnc_tensor_get_matrix_params(g->info, CCV_IS_TENSOR_VIEW(g) ? 
g->inc0
: g->info.dim, no_transpose, &g_batch_size, &g_rows, &g_cols, &g_batch_inc, &g_rows_inc, &g_cols_inc);
81
9.46k
  int i;
82
9.46k
  if (bias)
83
6.45k
  {
84
6.45k
    int bias_batch_size, bias_rows, bias_cols, bias_batch_inc, bias_rows_inc, bias_cols_inc;
85
6.45k
    ccv_nnc_tensor_get_matrix_params(bias->info, CCV_IS_TENSOR_VIEW(bias) ? 
bias->inc0
: bias->info.dim, no_transpose, &bias_batch_size, &bias_rows, &bias_cols, &bias_batch_inc, &bias_rows_inc, &bias_cols_inc);
86
6.45k
    assert(bias_cols == g_cols);
87
6.45k
    assert(bias_batch_size == 1 || bias_batch_size == g_batch_size);
88
6.45k
    if (bias_batch_size == 1 && 
g_batch_size > 16.45k
)
89
3
      bias_batch_inc = 0;
90
6.45k
    float* const ones = (float*)ccmalloc(sizeof(float) * g_rows);
91
13.0k
    for (i = 0; i < g_rows; 
i++6.62k
)
92
6.62k
      ones[i] = 1;
93
6.45k
    if (g_batch_size > 1 && 
bias_batch_size == g_batch_size5
)
94
2
    {
95
6
      for (i = 0; i < g_batch_size; 
i++4
)
96
4
        cblas_sgemm(CblasColMajor, CblasNoTrans, CblasNoTrans, bias_cols, bias_rows, g_rows, 1.0, g->data.f32 + i * g_batch_inc, g_rows_inc, ones, g_rows, 0.0, bias->data.f32 + i * bias_batch_inc, bias_rows_inc);
97
6.45k
    } else {
98
6.45k
      cblas_sgemm(CblasColMajor, CblasNoTrans, CblasNoTrans, bias_cols, bias_rows, g_rows, 1.0, g->data.f32, g_rows_inc, ones, g_rows, 0.0, bias->data.f32, bias_rows_inc);
99
      // We cannot use strided batched alternative because on write, the data could race to the same position
100
6.45k
      for (i = 1; i < g_batch_size; 
i++3
)
101
3
        cblas_sgemm(CblasColMajor, CblasNoTrans, CblasNoTrans, bias_cols, bias_rows, g_rows, 1.0, g->data.f32 + i * g_batch_inc, g_rows_inc, ones, g_rows, 1.0, bias->data.f32, bias_rows_inc);
102
6.45k
    }
103
6.45k
    ccfree(ones);
104
6.45k
  }
105
9.46k
  if (dw)
106
9.46k
  {
107
9.46k
    const int is_transpose_a = ccv_nnc_is_matrix_transpose(a->info, transpose_a);
108
9.46k
    const int is_transpose_w = ccv_nnc_is_matrix_transpose(dw->info, transpose_b);
109
9.46k
    int a_batch_size, a_rows, a_cols, a_batch_inc, a_rows_inc, a_cols_inc;
110
9.46k
    int dw_batch_size, dw_rows, dw_cols, dw_batch_inc, dw_rows_inc, dw_cols_inc;
111
9.46k
    ccv_nnc_tensor_get_matrix_params(a->info, CCV_IS_TENSOR_VIEW(a) ? 
a->inc0
: a->info.dim, transpose_a, &a_batch_size, &a_rows, &a_cols, &a_batch_inc, &a_rows_inc, &a_cols_inc);
112
9.46k
    ccv_nnc_tensor_get_matrix_params(dw->info, CCV_IS_TENSOR_VIEW(dw) ? 
dw->inc0
: dw->info.dim, transpose_b, &dw_batch_size, &dw_rows, &dw_cols, &dw_batch_inc, &dw_rows_inc, &dw_cols_inc);
113
9.46k
    assert(a_rows == g_rows);
114
9.46k
    assert(a_cols == dw_rows);
115
9.46k
    assert(dw_cols == g_cols);
116
9.46k
    assert(a_batch_size == g_batch_size || a_batch_size == 1);
117
9.46k
    if (a_batch_size == 1 && 
g_batch_size > 19.46k
)
118
0
      a_batch_inc = 0;
119
9.46k
    assert(dw_batch_size == g_batch_size || dw_batch_size == 1);
120
9.46k
    if (dw_batch_size == 1 && 
g_batch_size > 19.46k
)
121
3
      dw_batch_inc = 0;
122
9.46k
    if (g_batch_size > 1 && 
g_batch_size == dw_batch_size5
)
123
2
    {
124
2
      if (is_transpose_w)
125
1
      {
126
1
        const int transa = is_transpose_a ? 
CblasTrans0
: CblasNoTrans;
127
1
        const int lda_inc = is_transpose_a ? 
a_cols_inc0
: a_rows_inc;
128
3
        for (i = 0; i < g_batch_size; 
i++2
)
129
2
          cblas_sgemm(CblasColMajor, transa, CblasTrans, dw_rows, dw_cols, a_rows, 1.0, a->data.f32 + i * a_batch_inc, lda_inc, g->data.f32 + i * g_batch_inc, g_rows_inc, 0.0, dw->data.f32 + i * dw_batch_inc, dw_cols_inc);
130
1
      } else {
131
1
        const int transb = is_transpose_a ? 
CblasNoTrans0
: CblasTrans;
132
1
        const int ldb_inc = is_transpose_a ? 
a_cols_inc0
: a_rows_inc;
133
3
        for (i = 0; i < g_batch_size; 
i++2
)
134
2
          cblas_sgemm(CblasColMajor, CblasNoTrans, transb, dw_cols, dw_rows, a_rows, 1.0, g->data.f32 + i * g_batch_inc, g_rows_inc, a->data.f32 + i * a_batch_inc, ldb_inc, 0.0, dw->data.f32 + i * dw_batch_inc, dw_rows_inc);
135
1
      }
136
9.46k
    } else {
137
9.46k
      if (is_transpose_w)
138
8.45k
      {
139
8.45k
        const int transa = is_transpose_a ? 
CblasTrans2
:
CblasNoTrans8.45k
;
140
8.45k
        const int lda_inc = is_transpose_a ? 
a_cols_inc2
:
a_rows_inc8.45k
;
141
8.45k
        cblas_sgemm(CblasColMajor, transa, CblasTrans, dw_rows, dw_cols, a_rows, 1.0, a->data.f32, lda_inc, g->data.f32, g_rows_inc, 0.0, dw->data.f32, dw_cols_inc);
142
8.45k
        for (i = 1; i < g_batch_size; 
i++1
)
143
1
          cblas_sgemm(CblasColMajor, transa, CblasTrans, dw_rows, dw_cols, a_rows, 1.0, a->data.f32 + i * a_batch_inc, lda_inc, g->data.f32 + i * g_batch_inc, g_rows_inc, 1.0, dw->data.f32, dw_cols_inc);
144
8.45k
      } else {
145
1.01k
        const int transb = is_transpose_a ? 
CblasNoTrans2
:
CblasTrans1.00k
;
146
1.01k
        const int ldb_inc = is_transpose_a ? 
a_cols_inc2
:
a_rows_inc1.00k
;
147
1.01k
        cblas_sgemm(CblasColMajor, CblasNoTrans, transb, dw_cols, dw_rows, a_rows, 1.0, g->data.f32, g_rows_inc, a->data.f32, ldb_inc, 0.0, dw->data.f32, dw_rows_inc);
148
1.01k
        for (i = 1; i < g_batch_size; 
i++2
)
149
2
          cblas_sgemm(CblasColMajor, CblasNoTrans, transb, dw_cols, dw_rows, a_rows, 1.0, g->data.f32 + i * g_batch_inc, g_rows_inc, a->data.f32 + i * a_batch_inc, ldb_inc, 1.0, dw->data.f32, dw_rows_inc);
150
1.01k
      }
151
9.46k
    }
152
9.46k
  }
153
9.46k
  if (h)
154
9.46k
  {
155
9.46k
    const int is_transpose_h = ccv_nnc_is_matrix_transpose(h->info, transpose_a);
156
9.46k
    const int is_transpose_w = ccv_nnc_is_matrix_transpose(w->info, transpose_b);
157
9.46k
    int h_batch_size, h_rows, h_cols, h_batch_inc, h_rows_inc, h_cols_inc;
158
9.46k
    int w_batch_size, w_rows, w_cols, w_batch_inc, w_rows_inc, w_cols_inc;
159
9.46k
    ccv_nnc_tensor_get_matrix_params(h->info, CCV_IS_TENSOR_VIEW(h) ? 
h->inc0
: h->info.dim, transpose_a, &h_batch_size, &h_rows, &h_cols, &h_batch_inc, &h_rows_inc, &h_cols_inc);
160
9.46k
    ccv_nnc_tensor_get_matrix_params(w->info, CCV_IS_TENSOR_VIEW(w) ? 
w->inc0
: w->info.dim, transpose_b, &w_batch_size, &w_rows, &w_cols, &w_batch_inc, &w_rows_inc, &w_cols_inc);
161
9.46k
    assert(h_rows == g_rows);
162
9.46k
    assert(h_cols == w_rows);
163
9.46k
    assert(w_cols == g_cols);
164
9.46k
    assert(h_batch_size == g_batch_size || h_batch_size == 1);
165
9.46k
    if (h_batch_size == 1 && 
g_batch_size > 19.46k
)
166
0
      h_batch_inc = 0;
167
9.46k
    assert(w_batch_size == g_batch_size || w_batch_size == 1);
168
9.46k
    if (w_batch_size == 1 && 
g_batch_size > 19.46k
)
169
3
      w_batch_inc = 0;
170
9.46k
    if (g_batch_size > 1 && 
g_batch_size == h_batch_size5
)
171
5
    {
172
5
      if (is_transpose_h)
173
2
      {
174
2
        const int transb = is_transpose_w ? 
CblasTrans1
:
CblasNoTrans1
;
175
2
        const int ldb_inc = is_transpose_w ? 
w_cols_inc1
:
w_rows_inc1
;
176
6
        for (i = 0; i < g_batch_size; 
i++4
)
177
4
          cblas_sgemm(CblasColMajor, CblasTrans, transb, h_rows, h_cols, g_cols, 1.0, g->data.f32 + i * g_batch_inc, g_rows_inc, w->data.f32 + i * w_batch_inc, ldb_inc, 0.0, h->data.f32 + i * h_batch_inc, h_cols_inc);
178
3
      } else {
179
3
        const int transa = is_transpose_w ? 
CblasNoTrans1
:
CblasTrans2
;
180
3
        const int lda_inc = is_transpose_w ? 
w_cols_inc1
:
w_rows_inc2
;
181
9
        for (i = 0; i < g_batch_size; 
i++6
)
182
6
          cblas_sgemm(CblasColMajor, transa, CblasNoTrans, h_cols, h_rows, g_cols, 1.0, w->data.f32 + i * w_batch_inc, lda_inc, g->data.f32 + i * g_batch_inc, g_rows_inc, 0.0, h->data.f32 + i * h_batch_inc, h_rows_inc);
183
3
      }
184
9.46k
    } else {
185
9.46k
      if (is_transpose_h)
186
2
      {
187
2
        const int transb = is_transpose_w ? 
CblasTrans1
:
CblasNoTrans1
;
188
2
        const int ldb_inc = is_transpose_w ? 
w_cols_inc1
:
w_rows_inc1
;
189
2
        cblas_sgemm(CblasColMajor, CblasTrans, transb, h_rows, h_cols, g_cols, 1.0, g->data.f32, g_rows_inc, w->data.f32, ldb_inc, 0.0, h->data.f32, h_cols_inc);
190
2
        for (i = 1; i < g_batch_size; 
i++0
)
191
0
          cblas_sgemm(CblasColMajor, CblasTrans, transb, h_rows, h_cols, g_cols, 1.0, g->data.f32 + i * g_batch_inc, g_rows_inc, w->data.f32 + i * w_batch_inc, ldb_inc, 1.0, h->data.f32, h_cols_inc);
192
9.45k
      } else {
193
9.45k
        const int transa = is_transpose_w ? 
CblasNoTrans8.45k
:
CblasTrans1.00k
;
194
9.45k
        const int lda_inc = is_transpose_w ? 
w_cols_inc8.45k
:
w_rows_inc1.00k
;
195
9.45k
        cblas_sgemm(CblasColMajor, transa, CblasNoTrans, h_cols, h_rows, g_cols, 1.0, w->data.f32, lda_inc, g->data.f32, g_rows_inc, 0.0, h->data.f32, h_rows_inc);
196
9.45k
        for (i = 1; i < g_batch_size; 
i++0
)
197
0
          cblas_sgemm(CblasColMajor, transa, CblasNoTrans, h_cols, h_rows, g_cols, 1.0, w->data.f32 + i * w_batch_inc, lda_inc, g->data.f32 + i * g_batch_inc, g_rows_inc, 1.0, h->data.f32, h_rows_inc);
198
9.45k
      }
199
9.46k
    }
200
9.46k
  }
201
9.46k
  return CCV_NNC_EXEC_SUCCESS;
202
#else
203
  return CCV_NNC_EXEC_INVALID;
204
#endif
205
9.46k
}