File: | nnc/ccv_cnnp_model_gradient_checkpointing.c |
Warning: | line 75, column 1 Assigned value is garbage or undefined |
Press '?' to see keyboard shortcuts
Keyboard shortcuts:
1 | #include "ccv_nnc.h" |
2 | #include "ccv_nnc_easy.h" |
3 | #include "ccv_nnc_internal.h" |
4 | #include "ccv_internal.h" |
5 | #include "_ccv_cnnp_model.h" |
6 | // This can be removed once we organized ccv_cnnp_apply_gradient_checkpoints better. |
7 | #include "_ccv_nnc_symbolic_graph.h" |
8 | |
9 | void ccv_cnnp_model_gradient_checkpoints_cleanup_after_build(ccv_cnnp_compiled_data_t* const compiled_data, ccv_nnc_symbolic_graph_t* const graph) |
10 | { |
11 | ccv_array_t* const gradient_checkpoints = compiled_data->gradient_checkpoints; |
12 | if (!gradient_checkpoints || gradient_checkpoints->rnum == 0) // No saved gradient checkpoints, this is an easy way out. |
13 | return; |
14 | int i, j; |
15 | const ccv_nnc_tensor_symbol_info_t* const tensor_symbol_info = (const ccv_nnc_tensor_symbol_info_t*)ccv_array_get(graph->tensor_symbol_info, 0)((void*)(((char*)((graph->tensor_symbol_info)->data)) + (size_t)(graph->tensor_symbol_info)->rsize * (size_t)( 0))); |
16 | // Go through to check if any tensors that supposes in this map is removed. |
17 | for (i = 0; i < gradient_checkpoints->rnum; i++) |
18 | { |
19 | ccv_cnnp_model_gradient_checkpoint_t* const checkpoint = (ccv_cnnp_model_gradient_checkpoint_t*)ccv_array_get(gradient_checkpoints, i)((void*)(((char*)((gradient_checkpoints)->data)) + (size_t )(gradient_checkpoints)->rsize * (size_t)(i))); |
20 | for (j = 0; j < checkpoint->tensor_symbols->rnum; j++) |
21 | { |
22 | ccv_nnc_tensor_symbol_t* const symbol = ((ccv_nnc_tensor_symbol_t*)ccv_array_get(checkpoint->tensor_symbols, j)((void*)(((char*)((checkpoint->tensor_symbols)->data)) + (size_t)(checkpoint->tensor_symbols)->rsize * (size_t) (j)))); |
23 | if (symbol->d >= 0 && symbol->d < graph->tensor_symbol_info->rnum) |
24 | // If it is dead, we need to remove this symbol. |
25 | if (CCV_NNC_TENSOR_SYMBOL_IS_DEAD(tensor_symbol_info[symbol->d].flags)((tensor_symbol_info[symbol->d].flags) & CCV_NNC_TENSOR_SYMBOL_DEAD )) |
26 | { |
27 | symbol->d = -1; |
28 | symbol->graph = 0; |
29 | } |
30 | } |
31 | } |
32 | } |
33 | |
34 | typedef struct { |
35 | ccv_array_t* outgoings; |
36 | } ccv_nnc_graph_exec_symbol_reverse_t; |
37 | |
38 | typedef struct { |
39 | ccv_cnnp_model_gradient_checkpoint_build_context_t tensor_context; |
40 | ccv_array_t* graph_exec_symbols; |
41 | ccv_nnc_graph_exec_symbol_new_hook_f old_graph_exec_symbol_new_hook; |
42 | void* old_graph_exec_symbol_new_hook_context; |
43 | ccv_array_t* all_tensor_symbols; |
44 | } ccv_cnnp_gradient_checkpoint_build_t; |
45 | |
46 | static void _ccv_cnnp_gradient_checkpoint_tensor_symbol_new_hook(void* context, const ccv_nnc_tensor_symbol_t symbol, const ccv_nnc_tensor_param_t info, const char* const name) |
47 | { |
48 | ccv_cnnp_gradient_checkpoint_build_t* const build_context = (ccv_cnnp_gradient_checkpoint_build_t*)context; |
49 | if (build_context->tensor_context.record) |
50 | ccv_array_push(build_context->tensor_context.tensor_symbols, &symbol); |
51 | ccv_array_push(build_context->all_tensor_symbols, &symbol); |
52 | if (build_context->tensor_context.old_tensor_symbol_new_hook) |
53 | build_context->tensor_context.old_tensor_symbol_new_hook(build_context->tensor_context.old_tensor_symbol_new_hook_context, symbol, info, name); |
54 | } |
55 | |
56 | static void _ccv_cnnp_gradient_checkpoint_tensor_symbol_alias_new_hook(void* context, const ccv_nnc_tensor_symbol_t symbol, const ccv_nnc_tensor_symbol_t from_symbol, const int ofs[CCV_NNC_MAX_DIM_ALLOC(12)], const int inc[CCV_NNC_MAX_DIM_ALLOC(12)], const ccv_nnc_tensor_param_t info, const char* const name) |
57 | { |
58 | ccv_cnnp_gradient_checkpoint_build_t* const build_context = (ccv_cnnp_gradient_checkpoint_build_t*)context; |
59 | if (build_context->tensor_context.record) |
60 | ccv_array_push(build_context->tensor_context.tensor_symbols, &symbol); |
61 | ccv_array_push(build_context->all_tensor_symbols, &symbol); |
62 | if (build_context->tensor_context.old_tensor_symbol_alias_new_hook) |
63 | build_context->tensor_context.old_tensor_symbol_alias_new_hook(build_context->tensor_context.old_tensor_symbol_alias_new_hook_context, symbol, from_symbol, ofs, inc, info, name); |
64 | } |
65 | |
66 | static void _ccv_cnnp_model_gradient_checkpoint_graph_exec_symbol_new_hook(void* context, const ccv_nnc_graph_exec_symbol_t symbol, const ccv_nnc_cmd_t cmd, const ccv_nnc_tensor_symbol_t* const inputs, const int input_size, const ccv_nnc_tensor_symbol_t* const outputs, const int output_size, const char* const name) |
67 | { |
68 | ccv_cnnp_gradient_checkpoint_build_t* const build = (ccv_cnnp_gradient_checkpoint_build_t*)context; |
69 | ccv_array_push(build->graph_exec_symbols, &symbol); |
70 | if (build->old_graph_exec_symbol_new_hook) |
71 | build->old_graph_exec_symbol_new_hook(build->old_graph_exec_symbol_new_hook_context, symbol, cmd, inputs, input_size, outputs, output_size, name); |
72 | } |
73 | |
74 | KHASH_MAP_INIT_INT(ccv_cnnp_tensor_symbol_map, int)typedef struct kh_ccv_cnnp_tensor_symbol_map_s { khint_t n_buckets , size, n_occupied, upper_bound; khint32_t *flags; khint32_t * keys; int *vals; } kh_ccv_cnnp_tensor_symbol_map_t; static inline __attribute__ ((__unused__)) kh_ccv_cnnp_tensor_symbol_map_t *kh_init_ccv_cnnp_tensor_symbol_map(void) { return (kh_ccv_cnnp_tensor_symbol_map_t *)calloc(1,sizeof(kh_ccv_cnnp_tensor_symbol_map_t)); } static inline __attribute__ ((__unused__)) void kh_destroy_ccv_cnnp_tensor_symbol_map (kh_ccv_cnnp_tensor_symbol_map_t *h) { if (h) { free((void *) h->keys); free(h->flags); free((void *)h->vals); free (h); } } static inline __attribute__ ((__unused__)) void kh_clear_ccv_cnnp_tensor_symbol_map (kh_ccv_cnnp_tensor_symbol_map_t *h) { if (h && h-> flags) { memset(h->flags, 0xaa, ((h->n_buckets) < 16 ? 1 : (h->n_buckets)>>4) * sizeof(khint32_t)); h-> size = h->n_occupied = 0; } } static inline __attribute__ ( (__unused__)) khint_t kh_get_ccv_cnnp_tensor_symbol_map(const kh_ccv_cnnp_tensor_symbol_map_t *h, khint32_t key) { if (h-> n_buckets) { khint_t k, i, last, mask, step = 0; mask = h-> n_buckets - 1; k = (khint32_t)(key); i = k & mask; last = i; while (!((h->flags[i>>4]>>((i&0xfU)<< 1))&2) && (((h->flags[i>>4]>>((i& 0xfU)<<1))&1) || !((h->keys[i]) == (key)))) { i = (i + (++step)) & mask; if (i == last) return h->n_buckets ; } return ((h->flags[i>>4]>>((i&0xfU)<< 1))&3)? h->n_buckets : i; } else return 0; } static inline __attribute__ ((__unused__)) int kh_resize_ccv_cnnp_tensor_symbol_map (kh_ccv_cnnp_tensor_symbol_map_t *h, khint_t new_n_buckets) { khint32_t *new_flags = 0; khint_t j = 1; { (--(new_n_buckets ), (new_n_buckets)|=(new_n_buckets)>>1, (new_n_buckets) |=(new_n_buckets)>>2, (new_n_buckets)|=(new_n_buckets)>> 4, (new_n_buckets)|=(new_n_buckets)>>8, (new_n_buckets) |=(new_n_buckets)>>16, ++(new_n_buckets)); if (new_n_buckets < 4) new_n_buckets = 4; if (h->size >= (khint_t)(new_n_buckets * __ac_HASH_UPPER + 0.5)) j = 0; else { new_flags = (khint32_t *)malloc(((new_n_buckets) < 16? 1 : (new_n_buckets)>> 4) * sizeof(khint32_t)); if (!new_flags) return -1; memset(new_flags , 0xaa, ((new_n_buckets) < 16? 1 : (new_n_buckets)>> 4) * sizeof(khint32_t)); if (h->n_buckets < new_n_buckets ) { khint32_t *new_keys = (khint32_t*)realloc((void *)h->keys ,new_n_buckets * sizeof(khint32_t)); if (!new_keys) { free(new_flags ); return -1; } h->keys = new_keys; if (1) { int *new_vals = (int*)realloc((void *)h->vals,new_n_buckets * sizeof(int )); if (!new_vals) { free(new_flags); return -1; } h->vals = new_vals; } } } } if (j) { for (j = 0; j != h->n_buckets ; ++j) { if (((h->flags[j>>4]>>((j&0xfU)<< 1))&3) == 0) { khint32_t key = h->keys[j]; int val; khint_t new_mask; new_mask = new_n_buckets - 1; if (1) val = h->vals [j]; (h->flags[j>>4]|=1ul<<((j&0xfU)<< 1)); while (1) { khint_t k, i, step = 0; k = (khint32_t)(key) ; i = k & new_mask; while (!((new_flags[i>>4]>> ((i&0xfU)<<1))&2)) i = (i + (++step)) & new_mask ; (new_flags[i>>4]&=~(2ul<<((i&0xfU)<< 1))); if (i < h->n_buckets && ((h->flags[i>> 4]>>((i&0xfU)<<1))&3) == 0) { { khint32_t tmp = h->keys[i]; h->keys[i] = key; key = tmp; } if (1 ) { int tmp = h->vals[i]; h->vals[i] = val; val = tmp; } (h->flags[i>>4]|=1ul<<((i&0xfU)<<1) ); } else { h->keys[i] = key; if (1) h->vals[i] = val; break ; } } } } if (h->n_buckets > new_n_buckets) { h->keys = (khint32_t*)realloc((void *)h->keys,new_n_buckets * sizeof (khint32_t)); if (1) h->vals = (int*)realloc((void *)h-> vals,new_n_buckets * sizeof(int)); } free(h->flags); h-> flags = new_flags; h->n_buckets = new_n_buckets; h->n_occupied = h->size; h->upper_bound = (khint_t)(h->n_buckets * __ac_HASH_UPPER + 0.5); } return 0; } static inline __attribute__ ((__unused__)) khint_t kh_put_ccv_cnnp_tensor_symbol_map(kh_ccv_cnnp_tensor_symbol_map_t *h, khint32_t key, int *ret) { khint_t x; if (h->n_occupied >= h->upper_bound) { if (h->n_buckets > (h->size <<1)) { if (kh_resize_ccv_cnnp_tensor_symbol_map(h, h-> n_buckets - 1) < 0) { *ret = -1; return h->n_buckets; } } else if (kh_resize_ccv_cnnp_tensor_symbol_map(h, h->n_buckets + 1) < 0) { *ret = -1; return h->n_buckets; } } { khint_t k, i, site, last, mask = h->n_buckets - 1, step = 0; x = site = h->n_buckets; k = (khint32_t)(key); i = k & mask; if (((h->flags[i>>4]>>((i&0xfU)<<1))& 2)) x = i; else { last = i; while (!((h->flags[i>>4] >>((i&0xfU)<<1))&2) && (((h->flags [i>>4]>>((i&0xfU)<<1))&1) || !((h-> keys[i]) == (key)))) { if (((h->flags[i>>4]>>( (i&0xfU)<<1))&1)) site = i; i = (i + (++step)) & mask; if (i == last) { x = site; break; } } if (x == h->n_buckets ) { if (((h->flags[i>>4]>>((i&0xfU)<< 1))&2) && site != h->n_buckets) x = site; else x = i; } } } if (((h->flags[x>>4]>>((x&0xfU )<<1))&2)) { h->keys[x] = key; (h->flags[x>> 4]&=~(3ul<<((x&0xfU)<<1))); ++h->size; ++h->n_occupied; *ret = 1; } else if (((h->flags[x>> 4]>>((x&0xfU)<<1))&1)) { h->keys[x] = key ; (h->flags[x>>4]&=~(3ul<<((x&0xfU)<< 1))); ++h->size; *ret = 2; } else *ret = 0; return x; } static inline __attribute__ ((__unused__)) void kh_del_ccv_cnnp_tensor_symbol_map (kh_ccv_cnnp_tensor_symbol_map_t *h, khint_t x) { if (x != h-> n_buckets && !((h->flags[x>>4]>>((x& 0xfU)<<1))&3)) { (h->flags[x>>4]|=1ul<< ((x&0xfU)<<1)); --h->size; } } |
75 |