Coverage Report

Created: 2024-06-21 10:32

/home/liu/actions-runner/_work/ccv/ccv/lib/nnc/cmd/upsample/ccv_nnc_upsample.c
Line
Count
Source (jump to first uncovered line)
1
#include "ccv.h"
2
#include "nnc/ccv_nnc.h"
3
#include "nnc/ccv_nnc_internal.h"
4
#include "nnc/ccv_nnc_easy.h"
5
6
static void _ccv_nnc_upsample_tensor_auto_forw(const ccv_nnc_cmd_param_t cmd, const ccv_nnc_tensor_param_t* const inputs, const int input_size, const ccv_nnc_hint_t hint, ccv_nnc_tensor_param_t* const outputs, const int output_size)
7
12
{
8
12
  assert(input_size == 1);
9
12
  assert(output_size == 1);
10
12
  outputs[0] = inputs[0];
11
12
  const int nd = ccv_nnc_tensor_nd(inputs[0].dim);
12
12
  if (nd == 2)
13
0
  {
14
0
    outputs[0].dim[0] = (int)(inputs[0].dim[0] * cmd.upsample.height_scale + 0.5);
15
0
    outputs[0].dim[1] = (int)(inputs[0].dim[1] * cmd.upsample.width_scale + 0.5);
16
12
  } else if (nd == 3) {
17
0
    if (inputs[0].format == CCV_TENSOR_FORMAT_NCHW || inputs[0].format == CCV_TENSOR_FORMAT_CHWN)
18
0
    {
19
0
      outputs[0].dim[nd - 2] = (int)(inputs[0].dim[nd - 2] * cmd.upsample.height_scale + 0.5);
20
0
      outputs[0].dim[nd - 1] = (int)(inputs[0].dim[nd - 1] * cmd.upsample.width_scale + 0.5);
21
0
    } else {
22
0
      outputs[0].dim[0] = (int)(inputs[0].dim[0] * cmd.upsample.height_scale + 0.5);
23
0
      outputs[0].dim[1] = (int)(inputs[0].dim[1] * cmd.upsample.width_scale + 0.5);
24
0
    }
25
12
  } else if (nd == 4) {
26
12
    if (inputs[0].format == CCV_TENSOR_FORMAT_NCHW)
27
12
    {
28
12
      outputs[0].dim[nd - 2] = (int)(inputs[0].dim[nd - 2] * cmd.upsample.height_scale + 0.5);
29
12
      outputs[0].dim[nd - 1] = (int)(inputs[0].dim[nd - 1] * cmd.upsample.width_scale + 0.5);
30
12
    } else {
31
0
      outputs[0].dim[nd - 3] = (int)(inputs[0].dim[nd - 3] * cmd.upsample.height_scale + 0.5);
32
0
      outputs[0].dim[nd - 2] = (int)(inputs[0].dim[nd - 2] * cmd.upsample.width_scale + 0.5);
33
0
    }
34
12
  }
35
12
}
36
37
static int _ccv_nnc_upsample_forw_bitmask(const ccv_nnc_cmd_param_t cmd, const int input_size, const int output_size, const uint64_t* const input_bitmasks, const int input_bitmask_size, const uint64_t* const output_bitmasks, const int output_bitmask_size)
38
3
{
39
3
  if (input_bitmasks[0] == 1u && output_bitmasks[0] == 1u)
40
3
    return 1;
41
0
  return 0;
42
3
}
43
44
static int _ccv_nnc_upsample_back_bitmask(const ccv_nnc_cmd_param_t cmd, const int input_size, const int output_size, const uint64_t* const input_bitmasks, const int input_bitmask_size, const uint64_t* const output_bitmasks, const int output_bitmask_size)
45
0
{
46
  // Output the propagated error.
47
0
  if ((input_bitmasks[0] & 1u) == 1u && output_bitmasks[0] == 1u)
48
0
    return 1;
49
0
  return 0;
50
0
}
51
52
REGISTER_COMMAND(CCV_NNC_UPSAMPLE_FORWARD)(ccv_nnc_cmd_registry_t* const registry)
53
  FIND_BACKEND(ccv_nnc_upsample_cpu_ref.c, gpu/ccv_nnc_upsample_gpu_ref.cu, mps/ccv_nnc_upsample_mps.m)
54
1
{
55
1
  registry->bitmask = _ccv_nnc_upsample_forw_bitmask;
56
1
  registry->tensor_auto = _ccv_nnc_upsample_tensor_auto_forw;
57
1
}
58
59
REGISTER_COMMAND(CCV_NNC_UPSAMPLE_BACKWARD)(ccv_nnc_cmd_registry_t* const registry)
60
  FIND_BACKEND(ccv_nnc_upsample_cpu_ref.c, gpu/ccv_nnc_upsample_gpu_ref.cu, mps/ccv_nnc_upsample_mps.m)
61
1
{
62
1
  registry->bitmask = _ccv_nnc_upsample_back_bitmask;
63
1
  registry->tensor_auto = ccv_nnc_hint_tensor_auto_backward_from_gradient;
64
1
}
65
66
//@REGISTER_EASY_COMMAND_MACRO(CCV_NNC_UPSAMPLE_FORWARD)
67
#define CMD_UPSAMPLE_FORWARD(_type, _width_scale, _height_scale, _align_corners) ccv_nnc_cmd(CCV_NNC_UPSAMPLE_FORWARD, 0, ((ccv_nnc_cmd_param_t){.size={.dim={1,1,1}},.upsample={.type=_type,.width_scale=_width_scale,.height_scale=_height_scale,.align_corners=_align_corners}}), 0)
68
//@REGISTER_EASY_COMMAND_MACRO(CCV_NNC_UPSAMPLE_BACKWARD)
69
#define CMD_UPSAMPLE_BACKWARD(_type, _width_scale, _height_scale, _align_corners) ccv_nnc_cmd(CCV_NNC_UPSAMPLE_BACKWARD, 0, ((ccv_nnc_cmd_param_t){.size={.dim={1,1,1}},.upsample={.type=_type,.width_scale=_width_scale,.height_scale=_height_scale,.align_corners=_align_corners}}), 0)