File: | nnc/ccv_cnnp_model_gradient_checkpointing.c |
Warning: | line 48, column 1 Assigned value is garbage or undefined |
Press '?' to see keyboard shortcuts
Keyboard shortcuts:
1 | #include "ccv_nnc.h" |
2 | #include "ccv_nnc_easy.h" |
3 | #include "ccv_nnc_internal.h" |
4 | #include "ccv_internal.h" |
5 | #include "_ccv_cnnp_model.h" |
6 | // This can be removed once we organized ccv_cnnp_apply_gradient_checkpoints better. |
7 | #include "_ccv_nnc_symbolic_graph.h" |
8 | |
9 | typedef struct { |
10 | ccv_array_t* outgoings; |
11 | } ccv_nnc_graph_exec_symbol_reverse_t; |
12 | |
13 | typedef struct { |
14 | ccv_array_t* tensor_symbols; |
15 | void* old_tensor_symbol_new_hook_context; |
16 | ccv_nnc_tensor_symbol_new_hook_f old_tensor_symbol_new_hook; |
17 | void* old_tensor_symbol_alias_new_hook_context; |
18 | ccv_nnc_tensor_symbol_alias_new_hook_f old_tensor_symbol_alias_new_hook; |
19 | ccv_array_t* graph_exec_symbols; |
20 | ccv_nnc_graph_exec_symbol_new_hook_f old_graph_exec_symbol_new_hook; |
21 | void* old_graph_exec_symbol_new_hook_context; |
22 | } ccv_cnnp_gradient_checkpoint_build_t; |
23 | |
24 | static void _ccv_cnnp_gradient_checkpoint_tensor_symbol_new_hook(void* context, const ccv_nnc_tensor_symbol_t symbol, const ccv_nnc_tensor_param_t info, const char* const name) |
25 | { |
26 | ccv_cnnp_gradient_checkpoint_build_t* const build_context = (ccv_cnnp_gradient_checkpoint_build_t*)context; |
27 | ccv_array_push(build_context->tensor_symbols, &symbol); |
28 | if (build_context->old_tensor_symbol_new_hook) |
29 | build_context->old_tensor_symbol_new_hook(build_context->old_tensor_symbol_new_hook_context, symbol, info, name); |
30 | } |
31 | |
32 | static void _ccv_cnnp_gradient_checkpoint_tensor_symbol_alias_new_hook(void* context, const ccv_nnc_tensor_symbol_t symbol, const ccv_nnc_tensor_symbol_t from_symbol, const int ofs[CCV_NNC_MAX_DIM_ALLOC(12)], const int inc[CCV_NNC_MAX_DIM_ALLOC(12)], const ccv_nnc_tensor_param_t info, const char* const name) |
33 | { |
34 | ccv_cnnp_gradient_checkpoint_build_t* const build_context = (ccv_cnnp_gradient_checkpoint_build_t*)context; |
35 | ccv_array_push(build_context->tensor_symbols, &symbol); |
36 | if (build_context->old_tensor_symbol_alias_new_hook) |
37 | build_context->old_tensor_symbol_alias_new_hook(build_context->old_tensor_symbol_alias_new_hook_context, symbol, from_symbol, ofs, inc, info, name); |
38 | } |
39 | |
40 | static void _ccv_cnnp_model_gradient_checkpoint_graph_exec_symbol_new_hook(void* context, const ccv_nnc_graph_exec_symbol_t symbol, const ccv_nnc_cmd_t cmd, const ccv_nnc_tensor_symbol_t* const inputs, const int input_size, const ccv_nnc_tensor_symbol_t* const outputs, const int output_size, const char* const name) |
41 | { |
42 | ccv_cnnp_gradient_checkpoint_build_t* const build = (ccv_cnnp_gradient_checkpoint_build_t*)context; |
43 | ccv_array_push(build->graph_exec_symbols, &symbol); |
44 | if (build->old_graph_exec_symbol_new_hook) |
45 | build->old_graph_exec_symbol_new_hook(build->old_graph_exec_symbol_new_hook_context, symbol, cmd, inputs, input_size, outputs, output_size, name); |
46 | } |
47 | |
48 |