Coverage Report

Created: 2021-04-05 01:08

/home/liu/buildslave/linux-x64-runtests/build/lib/nnc/cmd/blas/cpu_sys/_ccv_nnc_gemm_cpu_sys.c
Line
Count
Source (jump to first uncovered line)
1
#include "ccv.h"
2
#include "ccv_internal.h"
3
#include "nnc/ccv_nnc.h"
4
#include "nnc/ccv_nnc_easy.h"
5
#include "nnc/ccv_nnc_internal.h"
6
#include "../_ccv_nnc_gemm_cpu_opt.h"
7
#if HAVE_ACCELERATE_FRAMEWORK
8
#include <Accelerate/Accelerate.h>
9
#elif HAVE_CBLAS
10
#include <cblas.h>
11
#endif
12
13
int _ccv_nnc_gemm_forw_cpu_sys(const int transpose_a[2], const int transpose_b[2], const ccv_nnc_tensor_view_t* const a, const ccv_nnc_tensor_view_t* const w, const ccv_nnc_tensor_view_t* const bias, ccv_nnc_tensor_view_t* const b)
14
7.27k
{
15
7.27k
#if (defined HAVE_CBLAS || defined HAVE_ACCELERATE_FRAMEWORK)
16
7.27k
  assert(!bias || (bias->info.dim[1] == 0 || bias->info.dim[2] == 0 || bias->info.dim[3] == 0)); // It is a 1-d array
17
7.27k
  int a_batch_size, a_rows, a_cols, a_batch_inc, a_rows_inc, a_cols_inc;
18
7.27k
  int w_batch_size, w_rows, w_cols, w_batch_inc, w_rows_inc, w_cols_inc;
19
7.27k
  int b_batch_size, b_rows, b_cols, b_batch_inc, b_rows_inc, b_cols_inc;
20
7.27k
  const static int no_transpose[2] = {};
21
7.27k
  ccv_nnc_tensor_get_matrix_params(a->info, CCV_IS_TENSOR_VIEW(a) ? 
a->inc0
: a->info.dim, transpose_a, &a_batch_size, &a_rows, &a_cols, &a_batch_inc, &a_rows_inc, &a_cols_inc);
22
7.27k
  ccv_nnc_tensor_get_matrix_params(w->info, CCV_IS_TENSOR_VIEW(w) ? 
w->inc0
: w->info.dim, transpose_b, &w_batch_size, &w_rows, &w_cols, &w_batch_inc, &w_rows_inc, &w_cols_inc);
23
7.27k
  ccv_nnc_tensor_get_matrix_params(b->info, CCV_IS_TENSOR_VIEW(b) ? 
b->inc10
:
b->info.dim7.26k
, no_transpose, &b_batch_size, &b_rows, &b_cols, &b_batch_inc, &b_rows_inc, &b_cols_inc);
24
7.27k
  assert(a_batch_size == b_batch_size);
25
7.27k
  assert(a_batch_size == b_batch_size || a_batch_size == 1);
26
7.27k
  if (a_batch_size == 1 && 
b_batch_size > 17.26k
)
27
0
    a_batch_inc = 0;
28
7.27k
  assert(w_batch_size == a_batch_size || w_batch_size == 1);
29
7.27k
  if (w_batch_size == 1 && 
b_batch_size > 17.26k
)
30
2
    w_batch_inc = 0;
31
7.27k
  assert(a_rows == b_rows);
32
7.27k
  assert(a_cols == w_rows);
33
7.27k
  assert(w_cols == b_cols);
34
7.27k
  const int is_transpose_a = ccv_nnc_is_matrix_transpose(a->info, transpose_a);
35
7.27k
  const int is_transpose_w = ccv_nnc_is_matrix_transpose(w->info, transpose_b);
36
7.27k
  if (bias)
37
6.22k
  {
38
6.22k
    float* const ones = (float*)ccmalloc(sizeof(float) * b_rows);
39
6.22k
    int i;
40
12.5k
    for (i = 0; i < b_rows; 
i++6.29k
)
41
6.29k
      ones[i] = 1;
42
6.22k
    int bias_batch_size, bias_rows, bias_cols, bias_batch_inc, bias_rows_inc, bias_cols_inc;
43
6.22k
    ccv_nnc_tensor_get_matrix_params(bias->info, CCV_IS_TENSOR_VIEW(bias) ? 
bias->inc0
: bias->info.dim, no_transpose, &bias_batch_size, &bias_rows, &bias_cols, &bias_batch_inc, &bias_rows_inc, &bias_cols_inc);
44
6.22k
    assert(bias_batch_size == b_batch_size || bias_batch_size == 1);
45
6.22k
    if (bias_batch_size == 1 && 
b_batch_size > 16.22k
)
46
1
      bias_batch_inc = 0;
47
6.22k
    assert(bias_cols == b_cols);
48
12.4k
    
for (i = 0; 6.22k
i < b_batch_size;
i++6.22k
)
49
6.22k
      cblas_sgemm(CblasColMajor, CblasNoTrans, CblasNoTrans, b_cols, b_rows, 1, 1.0, bias->data.f32 + i * bias_batch_inc, bias_rows_inc, ones, 1, 0.0, b->data.f32 + i * b_batch_inc, b_rows_inc);
50
6.22k
    ccfree(ones);
51
6.22k
    const int transa = is_transpose_w ? 
CblasTrans6.21k
:
CblasNoTrans2
;
52
6.22k
    const int transb = is_transpose_a ? 
CblasTrans1
:
CblasNoTrans6.22k
;
53
6.22k
    const int lda_inc = is_transpose_w ? 
w_cols_inc6.21k
:
w_rows_inc2
;
54
6.22k
    const int ldb_inc = is_transpose_a ? 
a_cols_inc1
:
a_rows_inc6.22k
;
55
12.4k
    for (i = 0; i < b_batch_size; 
i++6.22k
)
56
6.22k
      cblas_sgemm(CblasColMajor, transa, transb, b_cols, b_rows, a_cols, 1.0, w->data.f32 + i * w_batch_inc, lda_inc, a->data.f32 + i * a_batch_inc, ldb_inc, 1.0, b->data.f32 + i * b_batch_inc, b_rows_inc);
57
6.22k
  } else {
58
1.04k
    const int transa = is_transpose_w ? 
CblasTrans40
:
CblasNoTrans1.00k
;
59
1.04k
    const int transb = is_transpose_a ? 
CblasTrans1
:
CblasNoTrans1.04k
;
60
1.04k
    const int lda_inc = is_transpose_w ? 
w_cols_inc40
:
w_rows_inc1.00k
;
61
1.04k
    const int ldb_inc = is_transpose_a ? 
a_cols_inc1
:
a_rows_inc1.04k
;
62
1.04k
    int i;
63
2.09k
    for (i = 0; i < b_batch_size; 
i++1.05k
)
64
1.05k
      cblas_sgemm(CblasColMajor, transa, transb, b_cols, b_rows, a_cols, 1.0, w->data.f32 + i * w_batch_inc, lda_inc, a->data.f32 + i * a_batch_inc, ldb_inc, 0.0, b->data.f32 + i * b_batch_inc, b_rows_inc);
65
1.04k
  }
66
7.27k
  return CCV_NNC_EXEC_SUCCESS;
67
#else
68
  return CCV_NNC_EXEC_INVALID;
69
#endif
70
7.27k
}
71
72
int _ccv_nnc_gemm_back_cpu_sys(const int transpose_a[2], const int transpose_b[2], const ccv_nnc_tensor_view_t* const g, const ccv_nnc_tensor_view_t* const a, const ccv_nnc_tensor_view_t* const w, ccv_nnc_tensor_view_t* const dw, ccv_nnc_tensor_view_t* const bias, ccv_nnc_tensor_view_t* const h, const int flags)
73
7.26k
{
74
7.26k
#if (defined HAVE_CBLAS || defined HAVE_ACCELERATE_FRAMEWORK)
75
7.26k
  // inputs: gradient, forw prop input, [w]
76
7.26k
  // outputs: [output gradient], weight updates, bias updates
77
7.26k
  assert(!bias || (bias->info.dim[1] == 0 || bias->info.dim[2] == 0 || bias->info.dim[3] == 0)); // It is a 2-d or 3-d array.
78
7.26k
  int g_batch_size, g_rows, g_cols, g_batch_inc, g_rows_inc, g_cols_inc;
79
7.26k
  const static int no_transpose[2] = {};
80
7.26k
  ccv_nnc_tensor_get_matrix_params(g->info, CCV_IS_TENSOR_VIEW(g) ? 
g->inc0
: g->info.dim, no_transpose, &g_batch_size, &g_rows, &g_cols, &g_batch_inc, &g_rows_inc, &g_cols_inc);
81
7.26k
  int i;
82
7.26k
  if (bias)
83
6.25k
  {
84
6.25k
    int bias_batch_size, bias_rows, bias_cols, bias_batch_inc, bias_rows_inc, bias_cols_inc;
85
6.25k
    ccv_nnc_tensor_get_matrix_params(bias->info, CCV_IS_TENSOR_VIEW(bias) ? 
bias->inc0
: bias->info.dim, no_transpose, &bias_batch_size, &bias_rows, &bias_cols, &bias_batch_inc, &bias_rows_inc, &bias_cols_inc);
86
6.25k
    assert(bias_cols == g_cols);
87
6.25k
    assert(bias_batch_size == 1 || bias_batch_size == g_batch_size);
88
6.25k
    if (bias_batch_size == 1 && 
g_batch_size > 16.25k
)
89
3
      bias_batch_inc = 0;
90
6.25k
    float* const ones = (float*)ccmalloc(sizeof(float) * g_rows);
91
12.5k
    for (i = 0; i < g_rows; 
i++6.32k
)
92
6.32k
      ones[i] = 1;
93
6.25k
    if (g_batch_size > 1 && 
bias_batch_size == g_batch_size5
)
94
2
    {
95
6
      for (i = 0; i < g_batch_size; 
i++4
)
96
4
        cblas_sgemm(CblasColMajor, CblasNoTrans, CblasNoTrans, bias_cols, bias_rows, g_rows, 1.0, g->data.f32 + i * g_batch_inc, g_rows_inc, ones, g_rows, 0.0, bias->data.f32 + i * bias_batch_inc, bias_rows_inc);
97
6.25k
    } else {
98
6.25k
      cblas_sgemm(CblasColMajor, CblasNoTrans, CblasNoTrans, bias_cols, bias_rows, g_rows, 1.0, g->data.f32, g_rows_inc, ones, g_rows, 0.0, bias->data.f32, bias_rows_inc);
99
6.25k
      // We cannot use strided batched alternative because on write, the data could race to the same position
100
6.25k
      for (i = 1; i < g_batch_size; 
i++3
)
101
3
        cblas_sgemm(CblasColMajor, CblasNoTrans, CblasNoTrans, bias_cols, bias_rows, g_rows, 1.0, g->data.f32 + i * g_batch_inc, g_rows_inc, ones, g_rows, 1.0, bias->data.f32, bias_rows_inc);
102
6.25k
    }
103
6.25k
    ccfree(ones);
104
6.25k
  }
105
7.26k
  if (dw)
106
7.26k
  {
107
7.26k
    const int is_transpose_a = ccv_nnc_is_matrix_transpose(a->info, transpose_a);
108
7.26k
    const int is_transpose_w = ccv_nnc_is_matrix_transpose(dw->info, transpose_b);
109
7.26k
    int a_batch_size, a_rows, a_cols, a_batch_inc, a_rows_inc, a_cols_inc;
110
7.26k
    int dw_batch_size, dw_rows, dw_cols, dw_batch_inc, dw_rows_inc, dw_cols_inc;
111
7.26k
    ccv_nnc_tensor_get_matrix_params(a->info, CCV_IS_TENSOR_VIEW(a) ? 
a->inc0
: a->info.dim, transpose_a, &a_batch_size, &a_rows, &a_cols, &a_batch_inc, &a_rows_inc, &a_cols_inc);
112
7.26k
    ccv_nnc_tensor_get_matrix_params(dw->info, CCV_IS_TENSOR_VIEW(dw) ? 
dw->inc0
: dw->info.dim, transpose_b, &dw_batch_size, &dw_rows, &dw_cols, &dw_batch_inc, &dw_rows_inc, &dw_cols_inc);
113
7.26k
    assert(a_rows == g_rows);
114
7.26k
    assert(a_cols == dw_rows);
115
7.26k
    assert(dw_cols == g_cols);
116
7.26k
    assert(a_batch_size == g_batch_size || a_batch_size == 1);
117
7.26k
    if (a_batch_size == 1 && 
g_batch_size > 17.26k
)
118
0
      a_batch_inc = 0;
119
7.26k
    assert(dw_batch_size == g_batch_size || dw_batch_size == 1);
120
7.26k
    if (dw_batch_size == 1 && 
g_batch_size > 17.26k
)
121
3
      dw_batch_inc = 0;
122
7.26k
    if (g_batch_size > 1 && 
g_batch_size == dw_batch_size5
)
123
2
    {
124
2
      if (is_transpose_w)
125
1
      {
126
1
        const int transa = is_transpose_a ? 
CblasTrans0
: CblasNoTrans;
127
1
        const int lda_inc = is_transpose_a ? 
a_cols_inc0
: a_rows_inc;
128
3
        for (i = 0; i < g_batch_size; 
i++2
)
129
2
          cblas_sgemm(CblasColMajor, transa, CblasTrans, dw_rows, dw_cols, a_rows, 1.0, a->data.f32 + i * a_batch_inc, lda_inc, g->data.f32 + i * g_batch_inc, g_rows_inc, 0.0, dw->data.f32 + i * dw_batch_inc, dw_cols_inc);
130
1
      } else {
131
1
        const int transb = is_transpose_a ? 
CblasNoTrans0
: CblasTrans;
132
1
        const int ldb_inc = is_transpose_a ? 
a_cols_inc0
: a_rows_inc;
133
3
        for (i = 0; i < g_batch_size; 
i++2
)
134
2
          cblas_sgemm(CblasColMajor, CblasNoTrans, transb, dw_cols, dw_rows, a_rows, 1.0, g->data.f32 + i * g_batch_inc, g_rows_inc, a->data.f32 + i * a_batch_inc, ldb_inc, 0.0, dw->data.f32 + i * dw_batch_inc, dw_rows_inc);
135
1
      }
136
7.26k
    } else {
137
7.26k
      if (is_transpose_w)
138
6.25k
      {
139
6.25k
        const int transa = is_transpose_a ? 
CblasTrans2
:
CblasNoTrans6.25k
;
140
6.25k
        const int lda_inc = is_transpose_a ? 
a_cols_inc2
:
a_rows_inc6.25k
;
141
6.25k
        cblas_sgemm(CblasColMajor, transa, CblasTrans, dw_rows, dw_cols, a_rows, 1.0, a->data.f32, lda_inc, g->data.f32, g_rows_inc, 0.0, dw->data.f32, dw_cols_inc);
142
6.25k
        for (i = 1; i < g_batch_size; 
i++1
)
143
1
          cblas_sgemm(CblasColMajor, transa, CblasTrans, dw_rows, dw_cols, a_rows, 1.0, a->data.f32 + i * a_batch_inc, lda_inc, g->data.f32 + i * g_batch_inc, g_rows_inc, 1.0, dw->data.f32, dw_cols_inc);
144
6.25k
      } else {
145
1.01k
        const int transb = is_transpose_a ? 
CblasNoTrans2
:
CblasTrans1.00k
;
146
1.01k
        const int ldb_inc = is_transpose_a ? 
a_cols_inc2
:
a_rows_inc1.00k
;
147
1.01k
        cblas_sgemm(CblasColMajor, CblasNoTrans, transb, dw_cols, dw_rows, a_rows, 1.0, g->data.f32, g_rows_inc, a->data.f32, ldb_inc, 0.0, dw->data.f32, dw_rows_inc);
148
1.01k
        for (i = 1; i < g_batch_size; 
i++2
)
149
2
          cblas_sgemm(CblasColMajor, CblasNoTrans, transb, dw_cols, dw_rows, a_rows, 1.0, g->data.f32 + i * g_batch_inc, g_rows_inc, a->data.f32 + i * a_batch_inc, ldb_inc, 1.0, dw->data.f32, dw_rows_inc);
150
1.01k
      }
151
7.26k
    }
152
7.26k
  }
153
7.26k
  if (h)
154
7.26k
  {
155
7.26k
    const int is_transpose_h = ccv_nnc_is_matrix_transpose(h->info, transpose_a);
156
7.26k
    const int is_transpose_w = ccv_nnc_is_matrix_transpose(w->info, transpose_b);
157
7.26k
    int h_batch_size, h_rows, h_cols, h_batch_inc, h_rows_inc, h_cols_inc;
158
7.26k
    int w_batch_size, w_rows, w_cols, w_batch_inc, w_rows_inc, w_cols_inc;
159
7.26k
    ccv_nnc_tensor_get_matrix_params(h->info, CCV_IS_TENSOR_VIEW(h) ? 
h->inc0
: h->info.dim, transpose_a, &h_batch_size, &h_rows, &h_cols, &h_batch_inc, &h_rows_inc, &h_cols_inc);
160
7.26k
    ccv_nnc_tensor_get_matrix_params(w->info, CCV_IS_TENSOR_VIEW(w) ? 
w->inc0
: w->info.dim, transpose_b, &w_batch_size, &w_rows, &w_cols, &w_batch_inc, &w_rows_inc, &w_cols_inc);
161
7.26k
    assert(h_rows == g_rows);
162
7.26k
    assert(h_cols == w_rows);
163
7.26k
    assert(w_cols == g_cols);
164
7.26k
    assert(h_batch_size == g_batch_size || h_batch_size == 1);
165
7.26k
    if (h_batch_size == 1 && 
g_batch_size > 17.26k
)
166
0
      h_batch_inc = 0;
167
7.26k
    assert(w_batch_size == g_batch_size || w_batch_size == 1);
168
7.26k
    if (w_batch_size == 1 && 
g_batch_size > 17.26k
)
169
3
      w_batch_inc = 0;
170
7.26k
    if (g_batch_size > 1 && 
g_batch_size == h_batch_size5
)
171
5
    {
172
5
      if (is_transpose_h)
173
2
      {
174
2
        const int transb = is_transpose_w ? 
CblasTrans1
:
CblasNoTrans1
;
175
2
        const int ldb_inc = is_transpose_w ? 
w_cols_inc1
:
w_rows_inc1
;
176
6
        for (i = 0; i < g_batch_size; 
i++4
)
177
4
          cblas_sgemm(CblasColMajor, CblasTrans, transb, h_rows, h_cols, g_cols, 1.0, g->data.f32 + i * g_batch_inc, g_rows_inc, w->data.f32 + i * w_batch_inc, ldb_inc, 0.0, h->data.f32 + i * h_batch_inc, h_cols_inc);
178
3
      } else {
179
3
        const int transa = is_transpose_w ? 
CblasNoTrans1
:
CblasTrans2
;
180
3
        const int lda_inc = is_transpose_w ? 
w_cols_inc1
:
w_rows_inc2
;
181
9
        for (i = 0; i < g_batch_size; 
i++6
)
182
6
          cblas_sgemm(CblasColMajor, transa, CblasNoTrans, h_cols, h_rows, g_cols, 1.0, w->data.f32 + i * w_batch_inc, lda_inc, g->data.f32 + i * g_batch_inc, g_rows_inc, 0.0, h->data.f32 + i * h_batch_inc, h_rows_inc);
183
3
      }
184
7.26k
    } else {
185
7.26k
      if (is_transpose_h)
186
2
      {
187
2
        const int transb = is_transpose_w ? 
CblasTrans1
:
CblasNoTrans1
;
188
2
        const int ldb_inc = is_transpose_w ? 
w_cols_inc1
:
w_rows_inc1
;
189
2
        cblas_sgemm(CblasColMajor, CblasTrans, transb, h_rows, h_cols, g_cols, 1.0, g->data.f32, g_rows_inc, w->data.f32, ldb_inc, 0.0, h->data.f32, h_cols_inc);
190
2
        for (i = 1; i < g_batch_size; 
i++0
)
191
0
          cblas_sgemm(CblasColMajor, CblasTrans, transb, h_rows, h_cols, g_cols, 1.0, g->data.f32 + i * g_batch_inc, g_rows_inc, w->data.f32 + i * w_batch_inc, ldb_inc, 1.0, h->data.f32, h_cols_inc);
192
7.26k
      } else {
193
7.26k
        const int transa = is_transpose_w ? 
CblasNoTrans6.25k
:
CblasTrans1.00k
;
194
7.26k
        const int lda_inc = is_transpose_w ? 
w_cols_inc6.25k
:
w_rows_inc1.00k
;
195
7.26k
        cblas_sgemm(CblasColMajor, transa, CblasNoTrans, h_cols, h_rows, g_cols, 1.0, w->data.f32, lda_inc, g->data.f32, g_rows_inc, 0.0, h->data.f32, h_rows_inc);
196
7.26k
        for (i = 1; i < g_batch_size; 
i++0
)
197
0
          cblas_sgemm(CblasColMajor, transa, CblasNoTrans, h_cols, h_rows, g_cols, 1.0, w->data.f32 + i * w_batch_inc, lda_inc, g->data.f32 + i * g_batch_inc, g_rows_inc, 1.0, h->data.f32, h_rows_inc);
198
7.26k
      }
199
7.26k
    }
200
7.26k
  }
201
7.26k
  return CCV_NNC_EXEC_SUCCESS;
202
#else
203
  return CCV_NNC_EXEC_INVALID;
204
#endif
205
7.26k
}