File: | nnc/ccv_nnc_micro.c |
Warning: | line 10, column 1 Array access (via field 'flags') results in a null pointer dereference |
Press '?' to see keyboard shortcuts
Keyboard shortcuts:
1 | #include "ccv_nnc.h" | |||
2 | #include "ccv_nnc_easy.h" | |||
3 | #include "ccv_nnc_internal.h" | |||
4 | #include "ccv_internal.h" | |||
5 | #include "_ccv_nnc_micro.h" | |||
6 | #include "3rdparty/khash/khash.h" | |||
7 | ||||
8 | // MARK - Level-1 API | |||
9 | ||||
10 | KHASH_MAP_INIT_STR(ccv_nnc_micro_bind_scalar, uint32_t)typedef struct kh_ccv_nnc_micro_bind_scalar_s { khint_t n_buckets , size, n_occupied, upper_bound; khint32_t *flags; kh_cstr_t * keys; uint32_t *vals; } kh_ccv_nnc_micro_bind_scalar_t; static inline __attribute__ ((__unused__)) kh_ccv_nnc_micro_bind_scalar_t *kh_init_ccv_nnc_micro_bind_scalar(void) { return (kh_ccv_nnc_micro_bind_scalar_t *)calloc(1,sizeof(kh_ccv_nnc_micro_bind_scalar_t)); } static inline __attribute__ ((__unused__)) void kh_destroy_ccv_nnc_micro_bind_scalar (kh_ccv_nnc_micro_bind_scalar_t *h) { if (h) { free((void *)h ->keys); free(h->flags); free((void *)h->vals); free (h); } } static inline __attribute__ ((__unused__)) void kh_clear_ccv_nnc_micro_bind_scalar (kh_ccv_nnc_micro_bind_scalar_t *h) { if (h && h-> flags) { memset(h->flags, 0xaa, ((h->n_buckets) < 16 ? 1 : (h->n_buckets)>>4) * sizeof(khint32_t)); h-> size = h->n_occupied = 0; } } static inline __attribute__ ( (__unused__)) khint_t kh_get_ccv_nnc_micro_bind_scalar(const kh_ccv_nnc_micro_bind_scalar_t *h, kh_cstr_t key) { if (h->n_buckets) { khint_t k, i, last , mask, step = 0; mask = h->n_buckets - 1; k = __ac_X31_hash_string (key); i = k & mask; last = i; while (!((h->flags[i>> 4]>>((i&0xfU)<<1))&2) && (((h-> flags[i>>4]>>((i&0xfU)<<1))&1) || ! (strcmp(h->keys[i], key) == 0))) { i = (i + (++step)) & mask; if (i == last) return h->n_buckets; } return ((h-> flags[i>>4]>>((i&0xfU)<<1))&3)? h-> n_buckets : i; } else return 0; } static inline __attribute__ ((__unused__)) int kh_resize_ccv_nnc_micro_bind_scalar(kh_ccv_nnc_micro_bind_scalar_t *h, khint_t new_n_buckets) { khint32_t *new_flags = 0; khint_t j = 1; { (--(new_n_buckets), (new_n_buckets)|=(new_n_buckets )>>1, (new_n_buckets)|=(new_n_buckets)>>2, (new_n_buckets )|=(new_n_buckets)>>4, (new_n_buckets)|=(new_n_buckets) >>8, (new_n_buckets)|=(new_n_buckets)>>16, ++(new_n_buckets )); if (new_n_buckets < 4) new_n_buckets = 4; if (h->size >= (khint_t)(new_n_buckets * __ac_HASH_UPPER + 0.5)) j = 0 ; else { new_flags = (khint32_t*)malloc(((new_n_buckets) < 16? 1 : (new_n_buckets)>>4) * sizeof(khint32_t)); if ( !new_flags) return -1; memset(new_flags, 0xaa, ((new_n_buckets ) < 16? 1 : (new_n_buckets)>>4) * sizeof(khint32_t)) ; if (h->n_buckets < new_n_buckets) { kh_cstr_t *new_keys = (kh_cstr_t*)realloc((void *)h->keys,new_n_buckets * sizeof (kh_cstr_t)); if (!new_keys) { free(new_flags); return -1; } h ->keys = new_keys; if (1) { uint32_t *new_vals = (uint32_t *)realloc((void *)h->vals,new_n_buckets * sizeof(uint32_t) ); if (!new_vals) { free(new_flags); return -1; } h->vals = new_vals; } } } } if (j) { for (j = 0; j != h->n_buckets; ++j) { if (((h->flags[j>>4]>>((j&0xfU)<< 1))&3) == 0) { kh_cstr_t key = h->keys[j]; uint32_t val ; khint_t new_mask; new_mask = new_n_buckets - 1; if (1) val = h->vals[j]; (h->flags[j>>4]|=1ul<<((j& 0xfU)<<1)); while (1) { khint_t k, i, step = 0; k = __ac_X31_hash_string (key); i = k & new_mask; while (!((new_flags[i>>4]>> ((i&0xfU)<<1))&2)) i = (i + (++step)) & new_mask ; (new_flags[i>>4]&=~(2ul<<((i&0xfU)<< 1))); if (i < h->n_buckets && ((h->flags[i>> 4]>>((i&0xfU)<<1))&3) == 0) { { kh_cstr_t tmp = h->keys[i]; h->keys[i] = key; key = tmp; } if (1 ) { uint32_t tmp = h->vals[i]; h->vals[i] = val; val = tmp ; } (h->flags[i>>4]|=1ul<<((i&0xfU)<< 1)); } else { h->keys[i] = key; if (1) h->vals[i] = val ; break; } } } } if (h->n_buckets > new_n_buckets) { h-> keys = (kh_cstr_t*)realloc((void *)h->keys,new_n_buckets * sizeof(kh_cstr_t)); if (1) h->vals = (uint32_t*)realloc(( void *)h->vals,new_n_buckets * sizeof(uint32_t)); } free(h ->flags); h->flags = new_flags; h->n_buckets = new_n_buckets ; h->n_occupied = h->size; h->upper_bound = (khint_t )(h->n_buckets * __ac_HASH_UPPER + 0.5); } return 0; } static inline __attribute__ ((__unused__)) khint_t kh_put_ccv_nnc_micro_bind_scalar (kh_ccv_nnc_micro_bind_scalar_t *h, kh_cstr_t key, int *ret) { khint_t x; if (h->n_occupied >= h->upper_bound) { if (h->n_buckets > (h->size<<1)) { if (kh_resize_ccv_nnc_micro_bind_scalar (h, h->n_buckets - 1) < 0) { *ret = -1; return h->n_buckets ; } } else if (kh_resize_ccv_nnc_micro_bind_scalar(h, h->n_buckets + 1) < 0) { *ret = -1; return h->n_buckets; } } { khint_t k, i, site, last, mask = h->n_buckets - 1, step = 0; x = site = h->n_buckets; k = __ac_X31_hash_string(key); i = k & mask; if (((h->flags[i>>4]>>((i&0xfU)<< 1))&2)) x = i; else { last = i; while (!((h->flags[i>> 4]>>((i&0xfU)<<1))&2) && (((h-> flags[i>>4]>>((i&0xfU)<<1))&1) || ! (strcmp(h->keys[i], key) == 0))) { if (((h->flags[i>> 4]>>((i&0xfU)<<1))&1)) site = i; i = (i + (++step)) & mask; if (i == last) { x = site; break; } } if (x == h->n_buckets) { if (((h->flags[i>>4]>> ((i&0xfU)<<1))&2) && site != h->n_buckets ) x = site; else x = i; } } } if (((h->flags[x>>4]>> ((x&0xfU)<<1))&2)) { h->keys[x] = key; (h-> flags[x>>4]&=~(3ul<<((x&0xfU)<<1))) ; ++h->size; ++h->n_occupied; *ret = 1; } else if (((h-> flags[x>>4]>>((x&0xfU)<<1))&1)) { h ->keys[x] = key; (h->flags[x>>4]&=~(3ul<< ((x&0xfU)<<1))); ++h->size; *ret = 2; } else *ret = 0; return x; } static inline __attribute__ ((__unused__)) void kh_del_ccv_nnc_micro_bind_scalar(kh_ccv_nnc_micro_bind_scalar_t *h, khint_t x) { if (x != h->n_buckets && !((h-> flags[x>>4]>>((x&0xfU)<<1))&3)) { ( h->flags[x>>4]|=1ul<<((x&0xfU)<<1)); --h->size; } } | |||
| ||||
11 | ||||
12 | static uint32_t _scalars_lookup(const void* const context, const char* const name) | |||
13 | { | |||
14 | const khash_t(ccv_nnc_micro_bind_scalar)kh_ccv_nnc_micro_bind_scalar_t* const bind_scalars = (const khash_t(ccv_nnc_micro_bind_scalar)kh_ccv_nnc_micro_bind_scalar_t*)context; | |||
15 | khiter_t k = kh_get(ccv_nnc_micro_bind_scalar, bind_scalars, name)kh_get_ccv_nnc_micro_bind_scalar(bind_scalars, name); | |||
16 | assert(k != kh_end(bind_scalars))((void) sizeof ((k != ((bind_scalars)->n_buckets)) ? 1 : 0 ), __extension__ ({ if (k != ((bind_scalars)->n_buckets)) ; else __assert_fail ("k != kh_end(bind_scalars)", "ccv_nnc_micro.c" , 16, __extension__ __PRETTY_FUNCTION__); })); | |||
17 | return kh_val(bind_scalars, k)((bind_scalars)->vals[k]); | |||
18 | } | |||
19 | ||||
20 | KHASH_SET_INIT_INT64(ccv_nnc_ids)typedef struct kh_ccv_nnc_ids_s { khint_t n_buckets, size, n_occupied , upper_bound; khint32_t *flags; khint64_t *keys; char *vals; } kh_ccv_nnc_ids_t; static inline __attribute__ ((__unused__ )) kh_ccv_nnc_ids_t *kh_init_ccv_nnc_ids(void) { return (kh_ccv_nnc_ids_t *)calloc(1,sizeof(kh_ccv_nnc_ids_t)); } static inline __attribute__ ((__unused__)) void kh_destroy_ccv_nnc_ids(kh_ccv_nnc_ids_t * h) { if (h) { free((void *)h->keys); free(h->flags); free ((void *)h->vals); free(h); } } static inline __attribute__ ((__unused__)) void kh_clear_ccv_nnc_ids(kh_ccv_nnc_ids_t *h ) { if (h && h->flags) { memset(h->flags, 0xaa, ((h->n_buckets) < 16? 1 : (h->n_buckets)>>4) * sizeof(khint32_t)); h->size = h->n_occupied = 0; } } static inline __attribute__ ((__unused__)) khint_t kh_get_ccv_nnc_ids (const kh_ccv_nnc_ids_t *h, khint64_t key) { if (h->n_buckets ) { khint_t k, i, last, mask, step = 0; mask = h->n_buckets - 1; k = (khint32_t)((key)>>33^(key)^(key)<<11); i = k & mask; last = i; while (!((h->flags[i>>4 ]>>((i&0xfU)<<1))&2) && (((h-> flags[i>>4]>>((i&0xfU)<<1))&1) || ! ((h->keys[i]) == (key)))) { i = (i + (++step)) & mask; if (i == last) return h->n_buckets; } return ((h->flags [i>>4]>>((i&0xfU)<<1))&3)? h->n_buckets : i; } else return 0; } static inline __attribute__ ((__unused__ )) int kh_resize_ccv_nnc_ids(kh_ccv_nnc_ids_t *h, khint_t new_n_buckets ) { khint32_t *new_flags = 0; khint_t j = 1; { (--(new_n_buckets ), (new_n_buckets)|=(new_n_buckets)>>1, (new_n_buckets) |=(new_n_buckets)>>2, (new_n_buckets)|=(new_n_buckets)>> 4, (new_n_buckets)|=(new_n_buckets)>>8, (new_n_buckets) |=(new_n_buckets)>>16, ++(new_n_buckets)); if (new_n_buckets < 4) new_n_buckets = 4; if (h->size >= (khint_t)(new_n_buckets * __ac_HASH_UPPER + 0.5)) j = 0; else { new_flags = (khint32_t *)malloc(((new_n_buckets) < 16? 1 : (new_n_buckets)>> 4) * sizeof(khint32_t)); if (!new_flags) return -1; memset(new_flags , 0xaa, ((new_n_buckets) < 16? 1 : (new_n_buckets)>> 4) * sizeof(khint32_t)); if (h->n_buckets < new_n_buckets ) { khint64_t *new_keys = (khint64_t*)realloc((void *)h->keys ,new_n_buckets * sizeof(khint64_t)); if (!new_keys) { free(new_flags ); return -1; } h->keys = new_keys; if (0) { char *new_vals = (char*)realloc((void *)h->vals,new_n_buckets * sizeof(char )); if (!new_vals) { free(new_flags); return -1; } h->vals = new_vals; } } } } if (j) { for (j = 0; j != h->n_buckets ; ++j) { if (((h->flags[j>>4]>>((j&0xfU)<< 1))&3) == 0) { khint64_t key = h->keys[j]; char val; khint_t new_mask; new_mask = new_n_buckets - 1; if (0) val = h->vals [j]; (h->flags[j>>4]|=1ul<<((j&0xfU)<< 1)); while (1) { khint_t k, i, step = 0; k = (khint32_t)((key )>>33^(key)^(key)<<11); i = k & new_mask; while (!((new_flags[i>>4]>>((i&0xfU)<<1))& 2)) i = (i + (++step)) & new_mask; (new_flags[i>>4] &=~(2ul<<((i&0xfU)<<1))); if (i < h-> n_buckets && ((h->flags[i>>4]>>((i& 0xfU)<<1))&3) == 0) { { khint64_t tmp = h->keys[ i]; h->keys[i] = key; key = tmp; } if (0) { char tmp = h-> vals[i]; h->vals[i] = val; val = tmp; } (h->flags[i>> 4]|=1ul<<((i&0xfU)<<1)); } else { h->keys[ i] = key; if (0) h->vals[i] = val; break; } } } } if (h-> n_buckets > new_n_buckets) { h->keys = (khint64_t*)realloc ((void *)h->keys,new_n_buckets * sizeof(khint64_t)); if (0 ) h->vals = (char*)realloc((void *)h->vals,new_n_buckets * sizeof(char)); } free(h->flags); h->flags = new_flags ; h->n_buckets = new_n_buckets; h->n_occupied = h->size ; h->upper_bound = (khint_t)(h->n_buckets * __ac_HASH_UPPER + 0.5); } return 0; } static inline __attribute__ ((__unused__ )) khint_t kh_put_ccv_nnc_ids(kh_ccv_nnc_ids_t *h, khint64_t key , int *ret) { khint_t x; if (h->n_occupied >= h->upper_bound ) { if (h->n_buckets > (h->size<<1)) { if (kh_resize_ccv_nnc_ids (h, h->n_buckets - 1) < 0) { *ret = -1; return h->n_buckets ; } } else if (kh_resize_ccv_nnc_ids(h, h->n_buckets + 1) < 0) { *ret = -1; return h->n_buckets; } } { khint_t k, i, site , last, mask = h->n_buckets - 1, step = 0; x = site = h-> n_buckets; k = (khint32_t)((key)>>33^(key)^(key)<< 11); i = k & mask; if (((h->flags[i>>4]>>( (i&0xfU)<<1))&2)) x = i; else { last = i; while (!((h->flags[i>>4]>>((i&0xfU)<<1))& 2) && (((h->flags[i>>4]>>((i&0xfU) <<1))&1) || !((h->keys[i]) == (key)))) { if (((h ->flags[i>>4]>>((i&0xfU)<<1))&1) ) site = i; i = (i + (++step)) & mask; if (i == last) { x = site; break; } } if (x == h->n_buckets) { if (((h->flags [i>>4]>>((i&0xfU)<<1))&2) && site != h->n_buckets) x = site; else x = i; } } } if (((h ->flags[x>>4]>>((x&0xfU)<<1))&2) ) { h->keys[x] = key; (h->flags[x>>4]&=~(3ul<< ((x&0xfU)<<1))); ++h->size; ++h->n_occupied; * ret = 1; } else if (((h->flags[x>>4]>>((x& 0xfU)<<1))&1)) { h->keys[x] = key; (h->flags[ x>>4]&=~(3ul<<((x&0xfU)<<1))); ++h-> size; *ret = 2; } else *ret = 0; return x; } static inline __attribute__ ((__unused__)) void kh_del_ccv_nnc_ids(kh_ccv_nnc_ids_t *h, khint_t x) { if (x != h->n_buckets && !((h->flags[x>> 4]>>((x&0xfU)<<1))&3)) { (h->flags[x>> 4]|=1ul<<((x&0xfU)<<1)); --h->size; } } | |||
21 | ||||
22 | CCV_WARN_UNUSED(ccv_nnc_micro_combine_t*)ccv_nnc_micro_combine_t* __attribute__((warn_unused_result)) ccv_nnc_micro_combine_new(const ccv_nnc_micro_io_t* const inputs, const int input_size, const char* const* const parameters, const int parameter_size, const ccv_nnc_micro_io_t* const outputs, const int output_size, const ccv_nnc_micro_io_t* const ingrads, const int ingrad_size, const ccv_nnc_micro_io_t* const outgrads, const int outgrad_size) | |||
23 | { | |||
24 | assert(output_size > 0)((void) sizeof ((output_size > 0) ? 1 : 0), __extension__ ( { if (output_size > 0) ; else __assert_fail ("output_size > 0" , "ccv_nnc_micro.c", 24, __extension__ __PRETTY_FUNCTION__); } )); | |||
| ||||
25 | assert(input_size > 0)((void) sizeof ((input_size > 0) ? 1 : 0), __extension__ ( { if (input_size > 0) ; else __assert_fail ("input_size > 0" , "ccv_nnc_micro.c", 25, __extension__ __PRETTY_FUNCTION__); } )); | |||
26 | int i, j, k; | |||
27 | // First, do reverse topological sort (from output and then reverse the order). | |||
28 | // We can do this simple thing because there is no overlaps of the outputs, thus, no cases where | |||
29 | // output[0] is the input for output[1]. Otherwise we need to detect this, see ccv_cnnp_model_new | |||
30 | // for more details on why. | |||
31 | for (i = 0; i < output_size - 1; i++) | |||
32 | for (j = i + 1; j < output_size; j++) | |||
33 | { assert(outputs[i] != outputs[j])((void) sizeof ((outputs[i] != outputs[j]) ? 1 : 0), __extension__ ({ if (outputs[i] != outputs[j]) ; else __assert_fail ("outputs[i] != outputs[j]" , "ccv_nnc_micro.c", 33, __extension__ __PRETTY_FUNCTION__); } )); } | |||
34 | uint64_t input_bitmask[((input_size - 1) >> 6) + 1]; | |||
35 | memset(input_bitmask, 0, sizeof(uint64_t) * (((input_size - 1) >> 6) + 1)); | |||
36 | ccv_array_t* const reverse_top = ccv_array_new(sizeof(ccv_nnc_micro_io_t), output_size + input_size, 0); | |||
37 | ccv_array_resize(reverse_top, output_size); | |||
38 | memcpy(ccv_array_get(reverse_top, 0)((void*)(((char*)((reverse_top)->data)) + (size_t)(reverse_top )->rsize * (size_t)(0))), outputs, sizeof(ccv_nnc_micro_io_t) * output_size); | |||
39 | khash_t(ccv_nnc_ids)kh_ccv_nnc_ids_t* const ids = kh_init(ccv_nnc_ids)kh_init_ccv_nnc_ids(); | |||
40 | for (i = 0; i < reverse_top->rnum; i++) | |||
41 | { | |||
42 | const ccv_nnc_micro_io_t output = *(ccv_nnc_micro_io_t*)ccv_array_get(reverse_top, i)((void*)(((char*)((reverse_top)->data)) + (size_t)(reverse_top )->rsize * (size_t)(i))); | |||
43 | for (j = 0; j < output->input_size; j++) | |||
44 | if (!CCV_NNC_IS_MICRO_IO_INPUT(output->inputs[j])((output->inputs[j])->isa == &ccv_nnc_micro_io_input_isa )) | |||
45 | { | |||
46 | int ret; | |||
47 | kh_put(ccv_nnc_ids, ids, (int64_t)(intptr_t)output->inputs[j], &ret)kh_put_ccv_nnc_ids(ids, (int64_t)(intptr_t)output->inputs[ j], &ret); | |||
48 | if (ret != 0) | |||
49 | ccv_array_push(reverse_top, &output->inputs[j]); | |||
50 | } else { | |||
51 | // This is an input, it must be represented in inputs, try to find it. | |||
52 | for (k = 0; k
| |||
53 | if (inputs[k] == output->inputs[j]) | |||
54 | break; | |||
55 | assert(k < input_size)((void) sizeof ((k < input_size) ? 1 : 0), __extension__ ( { if (k < input_size) ; else __assert_fail ("k < input_size" , "ccv_nnc_micro.c", 55, __extension__ __PRETTY_FUNCTION__); } )); // Cannot find the inputs, error! | |||
56 | input_bitmask[k >> 6] |= ((uint64_t)1 << (k & 63)); | |||
57 | } | |||
58 | } | |||
59 | kh_destroy(ccv_nnc_ids, ids)kh_destroy_ccv_nnc_ids(ids); | |||
60 | for (i = 0; i
| |||
61 | { assert((input_bitmask[i >> 6] & ((uint64_t)1 << (i & 63))))((void) sizeof (((input_bitmask[i >> 6] & ((uint64_t )1 << (i & 63)))) ? 1 : 0), __extension__ ({ if ((input_bitmask [i >> 6] & ((uint64_t)1 << (i & 63)))) ; else __assert_fail ("(input_bitmask[i >> 6] & ((uint64_t)1 << (i & 63)))" , "ccv_nnc_micro.c", 61, __extension__ __PRETTY_FUNCTION__); } )); } // Assuming they all match. | |||
62 | // Second, binding parameters (bounded scalars). | |||
63 | khash_t(ccv_nnc_micro_bind_scalar)kh_ccv_nnc_micro_bind_scalar_t* const bind_scalars = kh_init(ccv_nnc_micro_bind_scalar)kh_init_ccv_nnc_micro_bind_scalar(); | |||
64 | for (i = 0; i < parameter_size; i++) | |||
65 | { | |||
66 | int ret; | |||
67 | khiter_t k = kh_put(ccv_nnc_micro_bind_scalar, bind_scalars, parameters[i], &ret)kh_put_ccv_nnc_micro_bind_scalar(bind_scalars, parameters[i], &ret); | |||
68 | assert(ret != 0)((void) sizeof ((ret != 0) ? 1 : 0), __extension__ ({ if (ret != 0) ; else __assert_fail ("ret != 0", "ccv_nnc_micro.c", 68 , __extension__ __PRETTY_FUNCTION__); })); | |||
69 | kh_val(bind_scalars, k)((bind_scalars)->vals[k]) = i; | |||
70 | } | |||
71 | for (i = 0; i < reverse_top->rnum; i++) | |||
72 | { | |||
73 | const ccv_nnc_micro_io_t output = *(ccv_nnc_micro_io_t*)ccv_array_get(reverse_top, reverse_top->rnum - 1 - i)((void*)(((char*)((reverse_top)->data)) + (size_t)(reverse_top )->rsize * (size_t)(reverse_top->rnum - 1 - i))); | |||
74 | ccv_nnc_micro_bind_scalars(output, _scalars_lookup, bind_scalars); | |||
75 | } | |||
76 | kh_destroy(ccv_nnc_micro_bind_scalar, bind_scalars)kh_destroy_ccv_nnc_micro_bind_scalar(bind_scalars); | |||
77 | const int var_count = reverse_top->rnum + input_size; | |||
78 | // Applying numbering for the inputs. Note that our variables are numbered in reverse topological order. | |||
79 | for (i = 0; i < input_size; i++) | |||
80 | ccv_nnc_micro_numbering(inputs[i], i, var_count); | |||
81 | ccv_array_t* const equal_assertions = ccv_array_new(sizeof(ccv_nnc_micro_id_equal_assertion_t), 0, 0); | |||
82 | // Applying numbering for the outputs and collect equal assertions. | |||
83 | for (i = reverse_top->rnum - 1; i >= 0; i--) | |||
84 | { | |||
85 | const ccv_nnc_micro_io_t output = *(ccv_nnc_micro_io_t*)ccv_array_get(reverse_top, reverse_top->rnum - 1 - i)((void*)(((char*)((reverse_top)->data)) + (size_t)(reverse_top )->rsize * (size_t)(reverse_top->rnum - 1 - i))); | |||
86 | ccv_nnc_micro_numbering(output, i + input_size, var_count); | |||
87 | ccv_nnc_micro_equal_assertions(output, equal_assertions); | |||
88 | } | |||
89 | for (i = 0; i < ingrad_size; i++) | |||
90 | ccv_nnc_micro_numbering(ingrads[i], -1, var_count); | |||
91 | for (i = 0; i < outgrad_size; i++) | |||
92 | ccv_nnc_micro_numbering(outgrads[i], -1, var_count); | |||
93 | // Fill in shapes for variables. | |||
94 | ccv_nnc_micro_tensor_t* const vars = (ccv_nnc_micro_tensor_t*)cccalloccalloc(var_count * 2, sizeof(ccv_nnc_micro_tensor_t)); | |||
95 | for (i = 0; i < input_size; i++) | |||
96 | { | |||
97 | vars[i].dimensions = inputs[i]->dimensions; | |||
98 | vars[i].input = -1; | |||
99 | } | |||
100 | for (i = 0; i < reverse_top->rnum; i++) | |||
101 | { | |||
102 | const ccv_nnc_micro_io_t output = *(ccv_nnc_micro_io_t*)ccv_array_get(reverse_top, reverse_top->rnum - 1 - i)((void*)(((char*)((reverse_top)->data)) + (size_t)(reverse_top )->rsize * (size_t)(reverse_top->rnum - 1 - i))); | |||
103 | vars[i + input_size] = ccv_nnc_micro_return_shape(output); | |||
104 | } | |||
105 | for (i = var_count; i < 2 * var_count; i++) | |||
106 | { | |||
107 | vars[i].dimensions = vars[2 * var_count - 1 - i].dimensions; | |||
108 | vars[i].input = 2 * var_count - 1 - i; | |||
109 | } | |||
110 | // Lower each ccv_nnc_micro_io_t (except the input) op into nested loops such that we can | |||
111 | // apply optimizations later. | |||
112 | int function_count = reverse_top->rnum; | |||
113 | ccv_nnc_micro_function_t* functions = (ccv_nnc_micro_function_t*)ccmallocmalloc(sizeof(ccv_nnc_micro_function_t) * function_count); | |||
114 | for (i = 0; i < function_count; i++) | |||
115 | { | |||
116 | const ccv_nnc_micro_io_t output = *(ccv_nnc_micro_io_t*)ccv_array_get(reverse_top, function_count - 1 - i)((void*)(((char*)((reverse_top)->data)) + (size_t)(reverse_top )->rsize * (size_t)(function_count - 1 - i))); | |||
117 | functions[i] = ccv_nnc_micro_emit(output); | |||
118 | } | |||
119 | ccv_nnc_micro_combine_t* const combine = (ccv_nnc_micro_combine_t*)ccmallocmalloc(sizeof(ccv_nnc_micro_combine_t)); | |||
120 | combine->parameter_size = parameter_size; | |||
121 | combine->forward.input_size = input_size; | |||
122 | combine->forward.inputs = (int*)ccmallocmalloc(sizeof(int) * (input_size + output_size)); | |||
123 | for (i = 0; i < input_size; i++) | |||
124 | combine->forward.inputs[i] = inputs[i]->id; | |||
125 | combine->forward.output_size = output_size; | |||
126 | combine->forward.outputs = combine->forward.inputs + input_size; | |||
127 | for (i = 0; i < output_size; i++) | |||
128 | combine->forward.outputs[i] = outputs[i]->id; | |||
129 | combine->forward.var_count = var_count; | |||
130 | // We copied forward.vars so backward.vars and forward.vars can maintain separate states. | |||
131 | // However, shape and related allocations are shared because these are not going to be mutated. | |||
132 | combine->forward.vars = (ccv_nnc_micro_tensor_t*)ccmallocmalloc(sizeof(ccv_nnc_micro_tensor_t) * var_count); | |||
133 | memcpy(combine->forward.vars, vars, sizeof(ccv_nnc_micro_tensor_t) * var_count); | |||
134 | combine->forward.function_count = function_count; | |||
135 | combine->forward.functions = functions; | |||
136 | ccv_nnc_micro_program_simplify(&combine->forward, inputs, input_size, outputs, output_size, equal_assertions); | |||
137 | function_count = reverse_top->rnum * 2; | |||
138 | functions = (ccv_nnc_micro_function_t*)ccmallocmalloc(sizeof(ccv_nnc_micro_function_t) * function_count); | |||
139 | for (i = 0; i < reverse_top->rnum; i++) | |||
140 | { | |||
141 | const ccv_nnc_micro_io_t output = *(ccv_nnc_micro_io_t*)ccv_array_get(reverse_top, reverse_top->rnum - 1 - i)((void*)(((char*)((reverse_top)->data)) + (size_t)(reverse_top )->rsize * (size_t)(reverse_top->rnum - 1 - i))); | |||
142 | functions[i] = ccv_nnc_micro_emit(output); | |||
143 | } | |||
144 | for (i = reverse_top->rnum; i < function_count; i++) | |||
145 | { | |||
146 | const ccv_nnc_micro_io_t output = *(ccv_nnc_micro_io_t*)ccv_array_get(reverse_top, i - reverse_top->rnum)((void*)(((char*)((reverse_top)->data)) + (size_t)(reverse_top )->rsize * (size_t)(i - reverse_top->rnum))); | |||
147 | functions[i] = ccv_nnc_micro_emit_grad(output, var_count); | |||
148 | } | |||
149 | combine->backward.input_size = ingrad_size; | |||
150 | combine->backward.inputs = ingrad_size + outgrad_size > 0 ? (int*)ccmallocmalloc(sizeof(int) * (ingrad_size + outgrad_size)) : 0; | |||
151 | for (i = 0; i < ingrad_size; i++) | |||
152 | combine->backward.inputs[i] = ingrads[i]->id; | |||
153 | combine->backward.output_size = outgrad_size; | |||
154 | combine->backward.outputs = outgrad_size > 0 ? combine->backward.inputs + ingrad_size : 0; | |||
155 | for (i = 0; i < outgrad_size; i++) | |||
156 | combine->backward.outputs[i] = outgrads[i]->id; | |||
157 | combine->backward.var_count = var_count * 2; | |||
158 | combine->backward.vars = vars; | |||
159 | combine->backward.function_count = function_count; | |||
160 | combine->backward.functions = functions; | |||
161 | ccv_nnc_micro_program_simplify(&combine->backward, ingrads, ingrad_size, outgrads, outgrad_size, equal_assertions); | |||
162 | combine->equal_assertions = equal_assertions; | |||
163 | for (i = 0; i < reverse_top->rnum; i++) | |||
164 | { | |||
165 | const ccv_nnc_micro_io_t output = *(ccv_nnc_micro_io_t*)ccv_array_get(reverse_top, i)((void*)(((char*)((reverse_top)->data)) + (size_t)(reverse_top )->rsize * (size_t)(i))); | |||
166 | ccv_nnc_micro_deinit(output); | |||
167 | ccfreefree(output); | |||
168 | } | |||
169 | ccv_array_free(reverse_top); | |||
170 | // It may overlap with inputs, in that case, skip. | |||
171 | for (i = 0; i < ingrad_size; i++) | |||
172 | { | |||
173 | int flag = 0; | |||
174 | for (j = 0; !flag && j < input_size; j++) | |||
175 | flag = (inputs[j] == ingrads[i]); | |||
176 | if (!flag) | |||
177 | { | |||
178 | ccv_nnc_micro_deinit(ingrads[i]); | |||
179 | ccfreefree(ingrads[i]); | |||
180 | } | |||
181 | } | |||
182 | for (i = 0; i < input_size; i++) | |||
183 | { | |||
184 | ccv_nnc_micro_deinit(inputs[i]); | |||
185 | ccfreefree(inputs[i]); | |||
186 | } | |||
187 | for (i = 0; i < outgrad_size; i++) // Should be no overlap on outgrads. | |||
188 | { | |||
189 | ccv_nnc_micro_deinit(outgrads[i]); | |||
190 | ccfreefree(outgrads[i]); | |||
191 | } | |||
192 | return combine; | |||
193 | } | |||
194 | ||||
195 | void ccv_nnc_micro_loop_index_free(ccv_nnc_micro_loop_index_term_t* const term) | |||
196 | { | |||
197 | if (term->type == CCV_NNC_MICRO_LOOP_INDEX_TYPE_BINARY) | |||
198 | { | |||
199 | ccv_nnc_micro_loop_index_free(&term->binary->left); | |||
200 | ccv_nnc_micro_loop_index_free(&term->binary->right); | |||
201 | ccfreefree(term->binary); | |||
202 | } | |||
203 | } | |||
204 | ||||
205 | void ccv_nnc_micro_loop_variable_free(ccv_nnc_micro_loop_variable_t* const var) | |||
206 | { | |||
207 | int i; | |||
208 | for (i = 0; i < var->index_count; i++) | |||
209 | ccv_nnc_micro_loop_index_free(&var->index[i]); | |||
210 | } | |||
211 | ||||
212 | static void _ccv_nnc_micro_loop_expression_free(ccv_nnc_micro_loop_expression_t* const expr) | |||
213 | { | |||
214 | switch (expr->type) | |||
215 | { | |||
216 | case CCV_NNC_MICRO_LOOP_EXPR_TYPE_VAR: { | |||
217 | ccv_nnc_micro_loop_variable_free(&expr->variable); | |||
218 | break; | |||
219 | } | |||
220 | case CCV_NNC_MICRO_LOOP_EXPR_TYPE_UNARY: { | |||
221 | _ccv_nnc_micro_loop_expression_free(expr->unary.x); | |||
222 | ccfreefree(expr->unary.x); | |||
223 | break; | |||
224 | } | |||
225 | case CCV_NNC_MICRO_LOOP_EXPR_TYPE_BINARY: { | |||
226 | _ccv_nnc_micro_loop_expression_free(expr->binary.left); | |||
227 | ccfreefree(expr->binary.left); | |||
228 | _ccv_nnc_micro_loop_expression_free(expr->binary.right); | |||
229 | ccfreefree(expr->binary.right); | |||
230 | break; | |||
231 | } | |||
232 | case CCV_NNC_MICRO_LOOP_EXPR_TYPE_TERNAY: { | |||
233 | _ccv_nnc_micro_loop_expression_free(expr->ternary.pivot); | |||
234 | ccfreefree(expr->ternary.pivot); | |||
235 | _ccv_nnc_micro_loop_expression_free(expr->ternary.left); | |||
236 | ccfreefree(expr->ternary.left); | |||
237 | _ccv_nnc_micro_loop_expression_free(expr->ternary.right); | |||
238 | ccfreefree(expr->ternary.right); | |||
239 | break; | |||
240 | } | |||
241 | } | |||
242 | } | |||
243 | ||||
244 | void ccv_nnc_micro_loop_statement_lvalue_free(ccv_nnc_micro_loop_statement_t* const statement) | |||
245 | { | |||
246 | switch (statement->type) | |||
247 | { | |||
248 | case CCV_NNC_MICRO_LOOP_STATEMENT_TYPE_COMPOUND_ASSIGNMENT: { | |||
249 | if (statement->compound_assignment.lvalue.type == CCV_NNC_MICRO_LOOP_EXPR_TYPE_VAR) | |||
250 | ccv_nnc_micro_loop_variable_free(&statement->compound_assignment.lvalue.variable); | |||
251 | break; | |||
252 | } | |||
253 | case CCV_NNC_MICRO_LOOP_STATEMENT_TYPE_ASSIGNMENT: { | |||
254 | ccv_nnc_micro_loop_variable_free(&statement->assignment.lvalue); | |||
255 | break; | |||
256 | } | |||
257 | } | |||
258 | } | |||
259 | ||||
260 | void ccv_nnc_micro_loop_statement_free(ccv_nnc_micro_loop_statement_t* const statement) | |||
261 | { | |||
262 | switch (statement->type) | |||
263 | { | |||
264 | case CCV_NNC_MICRO_LOOP_STATEMENT_TYPE_COMPOUND_ASSIGNMENT: { | |||
265 | if (statement->compound_assignment.lvalue.type == CCV_NNC_MICRO_LOOP_EXPR_TYPE_VAR) | |||
266 | ccv_nnc_micro_loop_variable_free(&statement->compound_assignment.lvalue.variable); | |||
267 | _ccv_nnc_micro_loop_expression_free(&statement->compound_assignment.rvalue); | |||
268 | break; | |||
269 | } | |||
270 | case CCV_NNC_MICRO_LOOP_STATEMENT_TYPE_ASSIGNMENT: { | |||
271 | ccv_nnc_micro_loop_variable_free(&statement->assignment.lvalue); | |||
272 | _ccv_nnc_micro_loop_expression_free(&statement->assignment.rvalue); | |||
273 | break; | |||
274 | } | |||
275 | } | |||
276 | } | |||
277 | ||||
278 | void ccv_nnc_micro_loops_free(ccv_nnc_micro_loop_t* const loops, const int loop_count) | |||
279 | { | |||
280 | int i, j; | |||
281 | for (i = 0; i < loop_count; i++) | |||
282 | { | |||
283 | for (j = 0; j < loops[i].statement_count; j++) | |||
284 | ccv_nnc_micro_loop_statement_free(&loops[i].statements[j]); | |||
285 | if (loops[i].statements) | |||
286 | ccfreefree(loops[i].statements); | |||
287 | if (loops[i].carrieds) | |||
288 | ccfreefree(loops[i].carrieds); | |||
289 | } | |||
290 | } | |||
291 | ||||
292 | void ccv_nnc_micro_combine_free(ccv_nnc_micro_combine_t* const combine) | |||
293 | { | |||
294 | int i, j; | |||
295 | const int var_count = combine->forward.var_count; | |||
296 | for (i = 0; i < var_count; i++) | |||
297 | if (combine->forward.vars[i].shape) | |||
298 | { | |||
299 | for (j = 0; j < combine->forward.vars[i].dimensions; j++) | |||
300 | ccv_nnc_micro_loop_index_free(&combine->forward.vars[i].shape[j]); | |||
301 | ccfreefree(combine->forward.vars[i].shape); | |||
302 | } | |||
303 | ccfreefree(combine->forward.vars); | |||
304 | ccfreefree(combine->backward.vars); | |||
305 | int function_count = combine->forward.function_count; | |||
306 | for (i = 0; i < function_count; i++) | |||
307 | { | |||
308 | const int block_count = combine->forward.functions[i].block_count; | |||
309 | ccv_nnc_micro_loop_block_t* const blocks = (block_count == 1) ? &combine->forward.functions[i].one_block : combine->forward.functions[i].blocks; | |||
310 | for (j = 0; j < block_count; j++) | |||
311 | { | |||
312 | ccv_nnc_micro_loop_block_t block = blocks[j]; | |||
313 | ccv_nnc_micro_loops_free(block.loops, block.loop_count); | |||
314 | ccfreefree(block.loops); | |||
315 | } | |||
316 | if (block_count > 1) | |||
317 | ccfreefree(combine->forward.functions[i].blocks); | |||
318 | } | |||
319 | ccfreefree(combine->forward.functions); | |||
320 | ccfreefree(combine->forward.inputs); | |||
321 | // Backward and forward share the same vars. | |||
322 | function_count = combine->backward.function_count; | |||
323 | for (i = 0; i < function_count; i++) | |||
324 | { | |||
325 | const int block_count = combine->backward.functions[i].block_count; | |||
326 | ccv_nnc_micro_loop_block_t* const blocks = (block_count == 1) ? &combine->backward.functions[i].one_block : combine->backward.functions[i].blocks; | |||
327 | for (j = 0; j < block_count; j++) | |||
328 | { | |||
329 | ccv_nnc_micro_loop_block_t block = blocks[j]; | |||
330 | ccv_nnc_micro_loops_free(block.loops, block.loop_count); | |||
331 | ccfreefree(block.loops); | |||
332 | } | |||
333 | if (block_count > 1) | |||
334 | ccfreefree(combine->backward.functions[i].blocks); | |||
335 | } | |||
336 | ccfreefree(combine->backward.functions); | |||
337 | if (combine->backward.inputs) | |||
338 | ccfreefree(combine->backward.inputs); | |||
339 | ccv_array_free(combine->equal_assertions); | |||
340 | ccfreefree(combine); | |||
341 | } | |||
342 | ||||
343 | char* ccv_nnc_micro_combine_c(ccv_nnc_micro_combine_t* const combine) | |||
344 | { | |||
345 | return 0; | |||
346 | } |