Coverage Report

Created: 2021-04-07 21:56

/home/liu/buildslave/linux-x64-runtests/build/lib/nnc/ccv_nnc_symbolic_graph_while.c
Line
Count
Source
1
#include "ccv_nnc.h"
2
#include "ccv_nnc_easy.h"
3
#include "ccv_nnc_internal.h"
4
#include "ccv_internal.h"
5
#include "_ccv_nnc_symbolic_graph.h"
6
7
// MARK - Level-3.5 API
8
9
ccv_nnc_graph_exec_symbol_t ccv_nnc_symbolic_graph_while(ccv_nnc_symbolic_graph_t* const graph, const uint32_t cmd, ccv_nnc_symbolic_graph_t* const while_graph, const char* const name)
10
23
{
11
23
  assert(cmd == CCV_NNC_GRAPH_FORWARD || cmd == CCV_NNC_GRAPH_BACKWARD);
12
23
  assert(while_graph->p == 0);
13
23
  assert(while_graph->p_idx == 0);
14
23
  // Added one more symbol.
15
23
  ccv_nnc_graph_exec_symbol_t symbol = ccv_nnc_graph_exec_symbol_new(graph, ccv_nnc_cmd(cmd, 0, CMD_GENERIC(), 0), 0, 0, 0, 0, name);
16
23
  // Assigning graph_ref to it.
17
23
  if (!graph->sub_graphs)
18
20
    graph->sub_graphs = ccv_array_new(sizeof(ccv_nnc_symbolic_graph_t*), 1, 0);
19
23
  ccv_array_push(graph->sub_graphs, &while_graph);
20
23
  ccv_nnc_graph_exec_symbol_info_t* symbol_info = (ccv_nnc_graph_exec_symbol_info_t*)ccv_array_get(graph->exec_symbol_info, symbol.d);
21
23
  symbol_info->flags |= CCV_NNC_GRAPH_EXEC_P_WHILE;
22
23
  symbol_info->graph_ref_size = 1;
23
23
  // Note the extra allocation (the ccv_array_t only holds a pointer to ccv_nnc_symbolic_graph_t*).
24
23
  // In this way, we can get the while graph and don't have to worry about it will be an invalid pointer once
25
23
  // the array expands (another while graph allocated).
26
23
  CCV_NNC_GRAPH_REF(symbol_info)[0] = graph->sub_graphs->rnum;
27
23
  while_graph->p_idx = graph->sub_graphs->rnum;
28
23
  while_graph->exec_idx = symbol.d + 1;
29
23
  while_graph->p = graph;
30
23
  return symbol;
31
23
}
32
33
void ccv_nnc_symbolic_graph_set_while_expr(ccv_nnc_symbolic_graph_t* const while_graph, const ccv_nnc_graph_while_f while_expr, const void* const while_data, const ccv_nnc_tensor_symbol_t* const inputs, const int input_size, const ccv_nnc_graph_exec_symbol_t* const breakpoints, const int breakpoint_size)
34
20
{
35
20
  const int exec_idx = while_graph->exec_idx - 1;
36
20
  assert(exec_idx >= 0 && exec_idx < while_graph->p->exec_symbol_info->rnum);
37
20
  ccv_nnc_graph_exec_symbol_info_t* const exec_info = (ccv_nnc_graph_exec_symbol_info_t*)ccv_array_get(while_graph->p->exec_symbol_info, exec_idx);
38
20
  exec_info->p_while.data = while_data;
39
20
  exec_info->p_while.expr = while_expr;
40
20
  int i;
41
20
  if (input_size > 0)
42
18
  {
43
18
    assert(inputs);
44
18
    exec_info->p_while.inputs = (int*)ccmalloc(sizeof(int) * input_size);
45
39
    for (i = 0; i < input_size; 
i++21
)
46
21
      exec_info->p_while.inputs[i] = ccv_nnc_tensor_symbol_map_raw(while_graph, inputs[i]);
47
18
    exec_info->p_while.input_size = input_size;
48
18
  }
49
20
  if (breakpoint_size > 0)
50
20
  {
51
20
    assert(breakpoints);
52
20
    while_graph->breakpoint_size = breakpoint_size;
53
20
    while_graph->breakpoints = (ccv_nnc_graph_exec_symbol_t*)ccmalloc(sizeof(ccv_nnc_graph_exec_symbol_t) * breakpoint_size);
54
20
    memcpy(while_graph->breakpoints, breakpoints, sizeof(ccv_nnc_graph_exec_symbol_t) * breakpoint_size);
55
20
  }
56
20
}
57
58
void ccv_nnc_symbolic_graph_set_carry_overs(ccv_nnc_symbolic_graph_t* const while_graph, const ccv_nnc_tensor_symbol_map_t* const symbol_map, const int symbol_map_size)
59
21
{
60
21
  int i;
61
45
  for (i = 0; i < symbol_map_size; 
i++24
)
62
24
  {
63
24
    const ccv_nnc_tensor_symbol_t source = ccv_nnc_tensor_symbol_resolve(while_graph, symbol_map[i].source);
64
24
    const ccv_nnc_tensor_symbol_t destination = ccv_nnc_tensor_symbol_resolve(while_graph, symbol_map[i].destination);
65
24
    assert(source.graph == while_graph);
66
24
    assert(destination.graph == while_graph);
67
24
    assert(source.d < while_graph->tensor_symbol_info->rnum);
68
24
    assert(destination.d < while_graph->tensor_symbol_info->rnum);
69
24
    ccv_nnc_tensor_symbol_info_t* source_tensor_symbol_info = (ccv_nnc_tensor_symbol_info_t*)ccv_array_get(while_graph->tensor_symbol_info, source.d);
70
24
    ccv_nnc_tensor_symbol_info_t* destination_tensor_symbol_info = (ccv_nnc_tensor_symbol_info_t*)ccv_array_get(while_graph->tensor_symbol_info, destination.d);
71
24
    // Don't support parameterize with alias. The reason is that to support parameterized loop (for SSA), I choose
72
24
    // to simply reuse the piece of memory (allocating the same memory region to both, therefore to enable parameter
73
24
    // passing). For alias, it is not possible because alias can pointing to the tensors with different sizes, thus,
74
24
    // these pointed tensors cannot share the same memory region. The best way for alias to be parameterized is to
75
24
    // create a new tensor of the same size, transfer value over, and parameterized on that tensor instead.
76
24
    assert(!destination_tensor_symbol_info->alias_ref);
77
24
    assert(!source_tensor_symbol_info->alias_ref);
78
24
    destination_tensor_symbol_info->assign_ref = source.d + 1;
79
24
    source_tensor_symbol_info->r_assign_ref = destination.d + 1;
80
24
  }
81
21
}
82
83
ccv_nnc_tensor_symbol_t ccv_nnc_tensor_symbol_for_while_count(const ccv_nnc_symbolic_graph_t* const while_graph)
84
19
{
85
19
  return (ccv_nnc_tensor_symbol_t){
86
19
    .d = CCV_NNC_WHILE_COUNT_TENSOR_SYMBOL,
87
19
    .graph = while_graph
88
19
  };
89
19
}
90
91
ccv_nnc_symbolic_graph_t* ccv_nnc_symbolic_graph_from_while_symbol(const ccv_nnc_symbolic_graph_t* const graph, const ccv_nnc_graph_exec_symbol_t while_symbol)
92
2
{
93
2
  assert(graph->sub_graphs);
94
2
  assert(while_symbol.graph == graph);
95
2
  assert(while_symbol.d < graph->exec_symbol_info->rnum);
96
2
  ccv_nnc_graph_exec_symbol_info_t* symbol_info = (ccv_nnc_graph_exec_symbol_info_t*)ccv_array_get(graph->exec_symbol_info, while_symbol.d);
97
2
  assert(CCV_NNC_GRAPH_REF(symbol_info)[0] <= graph->sub_graphs->rnum);
98
2
  return *(ccv_nnc_symbolic_graph_t**)ccv_array_get(graph->sub_graphs, CCV_NNC_GRAPH_REF(symbol_info)[0] - 1);
99
2
}