| File: | nnc/ccv_nnc_micro.c |
| Warning: | line 152, column 31 Array access (via field 'inputs') results in a null pointer dereference |
Press '?' to see keyboard shortcuts
Keyboard shortcuts:
| 1 | #include "ccv_nnc.h" | |||
| 2 | #include "ccv_nnc_easy.h" | |||
| 3 | #include "ccv_nnc_internal.h" | |||
| 4 | #include "ccv_internal.h" | |||
| 5 | #include "_ccv_nnc_micro.h" | |||
| 6 | #include "3rdparty/khash/khash.h" | |||
| 7 | ||||
| 8 | // MARK - Level-1 API | |||
| 9 | ||||
| 10 | KHASH_MAP_INIT_STR(ccv_nnc_micro_bind_scalar, uint32_t)typedef struct kh_ccv_nnc_micro_bind_scalar_s { khint_t n_buckets , size, n_occupied, upper_bound; khint32_t *flags; kh_cstr_t * keys; uint32_t *vals; } kh_ccv_nnc_micro_bind_scalar_t; static inline __attribute__ ((__unused__)) kh_ccv_nnc_micro_bind_scalar_t *kh_init_ccv_nnc_micro_bind_scalar(void) { return (kh_ccv_nnc_micro_bind_scalar_t *)calloc(1,sizeof(kh_ccv_nnc_micro_bind_scalar_t)); } static inline __attribute__ ((__unused__)) void kh_destroy_ccv_nnc_micro_bind_scalar (kh_ccv_nnc_micro_bind_scalar_t *h) { if (h) { free((void *)h ->keys); free(h->flags); free((void *)h->vals); free (h); } } static inline __attribute__ ((__unused__)) void kh_clear_ccv_nnc_micro_bind_scalar (kh_ccv_nnc_micro_bind_scalar_t *h) { if (h && h-> flags) { memset(h->flags, 0xaa, ((h->n_buckets) < 16 ? 1 : (h->n_buckets)>>4) * sizeof(khint32_t)); h-> size = h->n_occupied = 0; } } static inline __attribute__ ( (__unused__)) khint_t kh_get_ccv_nnc_micro_bind_scalar(const kh_ccv_nnc_micro_bind_scalar_t *h, kh_cstr_t key) { if (h->n_buckets) { khint_t k, i, last , mask, step = 0; mask = h->n_buckets - 1; k = __ac_X31_hash_string (key); i = k & mask; last = i; while (!((h->flags[i>> 4]>>((i&0xfU)<<1))&2) && (((h-> flags[i>>4]>>((i&0xfU)<<1))&1) || ! (strcmp(h->keys[i], key) == 0))) { i = (i + (++step)) & mask; if (i == last) return h->n_buckets; } return ((h-> flags[i>>4]>>((i&0xfU)<<1))&3)? h-> n_buckets : i; } else return 0; } static inline __attribute__ ((__unused__)) int kh_resize_ccv_nnc_micro_bind_scalar(kh_ccv_nnc_micro_bind_scalar_t *h, khint_t new_n_buckets) { khint32_t *new_flags = 0; khint_t j = 1; { (--(new_n_buckets), (new_n_buckets)|=(new_n_buckets )>>1, (new_n_buckets)|=(new_n_buckets)>>2, (new_n_buckets )|=(new_n_buckets)>>4, (new_n_buckets)|=(new_n_buckets) >>8, (new_n_buckets)|=(new_n_buckets)>>16, ++(new_n_buckets )); if (new_n_buckets < 4) new_n_buckets = 4; if (h->size >= (khint_t)(new_n_buckets * __ac_HASH_UPPER + 0.5)) j = 0 ; else { new_flags = (khint32_t*)malloc(((new_n_buckets) < 16? 1 : (new_n_buckets)>>4) * sizeof(khint32_t)); if ( !new_flags) return -1; memset(new_flags, 0xaa, ((new_n_buckets ) < 16? 1 : (new_n_buckets)>>4) * sizeof(khint32_t)) ; if (h->n_buckets < new_n_buckets) { kh_cstr_t *new_keys = (kh_cstr_t*)realloc((void *)h->keys,new_n_buckets * sizeof (kh_cstr_t)); if (!new_keys) { free(new_flags); return -1; } h ->keys = new_keys; if (1) { uint32_t *new_vals = (uint32_t *)realloc((void *)h->vals,new_n_buckets * sizeof(uint32_t) ); if (!new_vals) { free(new_flags); return -1; } h->vals = new_vals; } } } } if (j) { for (j = 0; j != h->n_buckets; ++j) { if (((h->flags[j>>4]>>((j&0xfU)<< 1))&3) == 0) { kh_cstr_t key = h->keys[j]; uint32_t val ; khint_t new_mask; new_mask = new_n_buckets - 1; if (1) val = h->vals[j]; (h->flags[j>>4]|=1ul<<((j& 0xfU)<<1)); while (1) { khint_t k, i, step = 0; k = __ac_X31_hash_string (key); i = k & new_mask; while (!((new_flags[i>>4]>> ((i&0xfU)<<1))&2)) i = (i + (++step)) & new_mask ; (new_flags[i>>4]&=~(2ul<<((i&0xfU)<< 1))); if (i < h->n_buckets && ((h->flags[i>> 4]>>((i&0xfU)<<1))&3) == 0) { { kh_cstr_t tmp = h->keys[i]; h->keys[i] = key; key = tmp; } if (1 ) { uint32_t tmp = h->vals[i]; h->vals[i] = val; val = tmp ; } (h->flags[i>>4]|=1ul<<((i&0xfU)<< 1)); } else { h->keys[i] = key; if (1) h->vals[i] = val ; break; } } } } if (h->n_buckets > new_n_buckets) { h-> keys = (kh_cstr_t*)realloc((void *)h->keys,new_n_buckets * sizeof(kh_cstr_t)); if (1) h->vals = (uint32_t*)realloc(( void *)h->vals,new_n_buckets * sizeof(uint32_t)); } free(h ->flags); h->flags = new_flags; h->n_buckets = new_n_buckets ; h->n_occupied = h->size; h->upper_bound = (khint_t )(h->n_buckets * __ac_HASH_UPPER + 0.5); } return 0; } static inline __attribute__ ((__unused__)) khint_t kh_put_ccv_nnc_micro_bind_scalar (kh_ccv_nnc_micro_bind_scalar_t *h, kh_cstr_t key, int *ret) { khint_t x; if (h->n_occupied >= h->upper_bound) { if (h->n_buckets > (h->size<<1)) { if (kh_resize_ccv_nnc_micro_bind_scalar (h, h->n_buckets - 1) < 0) { *ret = -1; return h->n_buckets ; } } else if (kh_resize_ccv_nnc_micro_bind_scalar(h, h->n_buckets + 1) < 0) { *ret = -1; return h->n_buckets; } } { khint_t k, i, site, last, mask = h->n_buckets - 1, step = 0; x = site = h->n_buckets; k = __ac_X31_hash_string(key); i = k & mask; if (((h->flags[i>>4]>>((i&0xfU)<< 1))&2)) x = i; else { last = i; while (!((h->flags[i>> 4]>>((i&0xfU)<<1))&2) && (((h-> flags[i>>4]>>((i&0xfU)<<1))&1) || ! (strcmp(h->keys[i], key) == 0))) { if (((h->flags[i>> 4]>>((i&0xfU)<<1))&1)) site = i; i = (i + (++step)) & mask; if (i == last) { x = site; break; } } if (x == h->n_buckets) { if (((h->flags[i>>4]>> ((i&0xfU)<<1))&2) && site != h->n_buckets ) x = site; else x = i; } } } if (((h->flags[x>>4]>> ((x&0xfU)<<1))&2)) { h->keys[x] = key; (h-> flags[x>>4]&=~(3ul<<((x&0xfU)<<1))) ; ++h->size; ++h->n_occupied; *ret = 1; } else if (((h-> flags[x>>4]>>((x&0xfU)<<1))&1)) { h ->keys[x] = key; (h->flags[x>>4]&=~(3ul<< ((x&0xfU)<<1))); ++h->size; *ret = 2; } else *ret = 0; return x; } static inline __attribute__ ((__unused__)) void kh_del_ccv_nnc_micro_bind_scalar(kh_ccv_nnc_micro_bind_scalar_t *h, khint_t x) { if (x != h->n_buckets && !((h-> flags[x>>4]>>((x&0xfU)<<1))&3)) { ( h->flags[x>>4]|=1ul<<((x&0xfU)<<1)); --h->size; } } | |||
| 11 | ||||
| 12 | static uint32_t _scalars_lookup(const void* const context, const char* const name) | |||
| 13 | { | |||
| 14 | const khash_t(ccv_nnc_micro_bind_scalar)kh_ccv_nnc_micro_bind_scalar_t* const bind_scalars = (const khash_t(ccv_nnc_micro_bind_scalar)kh_ccv_nnc_micro_bind_scalar_t*)context; | |||
| 15 | khiter_t k = kh_get(ccv_nnc_micro_bind_scalar, bind_scalars, name)kh_get_ccv_nnc_micro_bind_scalar(bind_scalars, name); | |||
| 16 | assert(k != kh_end(bind_scalars))((void) sizeof ((k != ((bind_scalars)->n_buckets)) ? 1 : 0 ), __extension__ ({ if (k != ((bind_scalars)->n_buckets)) ; else __assert_fail ("k != kh_end(bind_scalars)", "ccv_nnc_micro.c" , 16, __extension__ __PRETTY_FUNCTION__); })); | |||
| 17 | return kh_val(bind_scalars, k)((bind_scalars)->vals[k]); | |||
| 18 | } | |||
| 19 | ||||
| 20 | KHASH_SET_INIT_INT64(ccv_nnc_ids)typedef struct kh_ccv_nnc_ids_s { khint_t n_buckets, size, n_occupied , upper_bound; khint32_t *flags; khint64_t *keys; char *vals; } kh_ccv_nnc_ids_t; static inline __attribute__ ((__unused__ )) kh_ccv_nnc_ids_t *kh_init_ccv_nnc_ids(void) { return (kh_ccv_nnc_ids_t *)calloc(1,sizeof(kh_ccv_nnc_ids_t)); } static inline __attribute__ ((__unused__)) void kh_destroy_ccv_nnc_ids(kh_ccv_nnc_ids_t * h) { if (h) { free((void *)h->keys); free(h->flags); free ((void *)h->vals); free(h); } } static inline __attribute__ ((__unused__)) void kh_clear_ccv_nnc_ids(kh_ccv_nnc_ids_t *h ) { if (h && h->flags) { memset(h->flags, 0xaa, ((h->n_buckets) < 16? 1 : (h->n_buckets)>>4) * sizeof(khint32_t)); h->size = h->n_occupied = 0; } } static inline __attribute__ ((__unused__)) khint_t kh_get_ccv_nnc_ids (const kh_ccv_nnc_ids_t *h, khint64_t key) { if (h->n_buckets ) { khint_t k, i, last, mask, step = 0; mask = h->n_buckets - 1; k = (khint32_t)((key)>>33^(key)^(key)<<11); i = k & mask; last = i; while (!((h->flags[i>>4 ]>>((i&0xfU)<<1))&2) && (((h-> flags[i>>4]>>((i&0xfU)<<1))&1) || ! ((h->keys[i]) == (key)))) { i = (i + (++step)) & mask; if (i == last) return h->n_buckets; } return ((h->flags [i>>4]>>((i&0xfU)<<1))&3)? h->n_buckets : i; } else return 0; } static inline __attribute__ ((__unused__ )) int kh_resize_ccv_nnc_ids(kh_ccv_nnc_ids_t *h, khint_t new_n_buckets ) { khint32_t *new_flags = 0; khint_t j = 1; { (--(new_n_buckets ), (new_n_buckets)|=(new_n_buckets)>>1, (new_n_buckets) |=(new_n_buckets)>>2, (new_n_buckets)|=(new_n_buckets)>> 4, (new_n_buckets)|=(new_n_buckets)>>8, (new_n_buckets) |=(new_n_buckets)>>16, ++(new_n_buckets)); if (new_n_buckets < 4) new_n_buckets = 4; if (h->size >= (khint_t)(new_n_buckets * __ac_HASH_UPPER + 0.5)) j = 0; else { new_flags = (khint32_t *)malloc(((new_n_buckets) < 16? 1 : (new_n_buckets)>> 4) * sizeof(khint32_t)); if (!new_flags) return -1; memset(new_flags , 0xaa, ((new_n_buckets) < 16? 1 : (new_n_buckets)>> 4) * sizeof(khint32_t)); if (h->n_buckets < new_n_buckets ) { khint64_t *new_keys = (khint64_t*)realloc((void *)h->keys ,new_n_buckets * sizeof(khint64_t)); if (!new_keys) { free(new_flags ); return -1; } h->keys = new_keys; if (0) { char *new_vals = (char*)realloc((void *)h->vals,new_n_buckets * sizeof(char )); if (!new_vals) { free(new_flags); return -1; } h->vals = new_vals; } } } } if (j) { for (j = 0; j != h->n_buckets ; ++j) { if (((h->flags[j>>4]>>((j&0xfU)<< 1))&3) == 0) { khint64_t key = h->keys[j]; char val; khint_t new_mask; new_mask = new_n_buckets - 1; if (0) val = h->vals [j]; (h->flags[j>>4]|=1ul<<((j&0xfU)<< 1)); while (1) { khint_t k, i, step = 0; k = (khint32_t)((key )>>33^(key)^(key)<<11); i = k & new_mask; while (!((new_flags[i>>4]>>((i&0xfU)<<1))& 2)) i = (i + (++step)) & new_mask; (new_flags[i>>4] &=~(2ul<<((i&0xfU)<<1))); if (i < h-> n_buckets && ((h->flags[i>>4]>>((i& 0xfU)<<1))&3) == 0) { { khint64_t tmp = h->keys[ i]; h->keys[i] = key; key = tmp; } if (0) { char tmp = h-> vals[i]; h->vals[i] = val; val = tmp; } (h->flags[i>> 4]|=1ul<<((i&0xfU)<<1)); } else { h->keys[ i] = key; if (0) h->vals[i] = val; break; } } } } if (h-> n_buckets > new_n_buckets) { h->keys = (khint64_t*)realloc ((void *)h->keys,new_n_buckets * sizeof(khint64_t)); if (0 ) h->vals = (char*)realloc((void *)h->vals,new_n_buckets * sizeof(char)); } free(h->flags); h->flags = new_flags ; h->n_buckets = new_n_buckets; h->n_occupied = h->size ; h->upper_bound = (khint_t)(h->n_buckets * __ac_HASH_UPPER + 0.5); } return 0; } static inline __attribute__ ((__unused__ )) khint_t kh_put_ccv_nnc_ids(kh_ccv_nnc_ids_t *h, khint64_t key , int *ret) { khint_t x; if (h->n_occupied >= h->upper_bound ) { if (h->n_buckets > (h->size<<1)) { if (kh_resize_ccv_nnc_ids (h, h->n_buckets - 1) < 0) { *ret = -1; return h->n_buckets ; } } else if (kh_resize_ccv_nnc_ids(h, h->n_buckets + 1) < 0) { *ret = -1; return h->n_buckets; } } { khint_t k, i, site , last, mask = h->n_buckets - 1, step = 0; x = site = h-> n_buckets; k = (khint32_t)((key)>>33^(key)^(key)<< 11); i = k & mask; if (((h->flags[i>>4]>>( (i&0xfU)<<1))&2)) x = i; else { last = i; while (!((h->flags[i>>4]>>((i&0xfU)<<1))& 2) && (((h->flags[i>>4]>>((i&0xfU) <<1))&1) || !((h->keys[i]) == (key)))) { if (((h ->flags[i>>4]>>((i&0xfU)<<1))&1) ) site = i; i = (i + (++step)) & mask; if (i == last) { x = site; break; } } if (x == h->n_buckets) { if (((h->flags [i>>4]>>((i&0xfU)<<1))&2) && site != h->n_buckets) x = site; else x = i; } } } if (((h ->flags[x>>4]>>((x&0xfU)<<1))&2) ) { h->keys[x] = key; (h->flags[x>>4]&=~(3ul<< ((x&0xfU)<<1))); ++h->size; ++h->n_occupied; * ret = 1; } else if (((h->flags[x>>4]>>((x& 0xfU)<<1))&1)) { h->keys[x] = key; (h->flags[ x>>4]&=~(3ul<<((x&0xfU)<<1))); ++h-> size; *ret = 2; } else *ret = 0; return x; } static inline __attribute__ ((__unused__)) void kh_del_ccv_nnc_ids(kh_ccv_nnc_ids_t *h, khint_t x) { if (x != h->n_buckets && !((h->flags[x>> 4]>>((x&0xfU)<<1))&3)) { (h->flags[x>> 4]|=1ul<<((x&0xfU)<<1)); --h->size; } } | |||
| 21 | ||||
| 22 | CCV_WARN_UNUSED(ccv_nnc_micro_combine_t*)ccv_nnc_micro_combine_t* __attribute__((warn_unused_result)) ccv_nnc_micro_combine_new(const ccv_nnc_micro_io_t* const inputs, const int input_size, const char* const* const parameters, const int parameter_size, const ccv_nnc_micro_io_t* const outputs, const int output_size, const ccv_nnc_micro_io_t* const ingrads, const int ingrad_size, const ccv_nnc_micro_io_t* const outgrads, const int outgrad_size) | |||
| 23 | { | |||
| 24 | assert(output_size > 0)((void) sizeof ((output_size > 0) ? 1 : 0), __extension__ ( { if (output_size > 0) ; else __assert_fail ("output_size > 0" , "ccv_nnc_micro.c", 24, __extension__ __PRETTY_FUNCTION__); } )); | |||
| ||||
| 25 | assert(input_size > 0)((void) sizeof ((input_size > 0) ? 1 : 0), __extension__ ( { if (input_size > 0) ; else __assert_fail ("input_size > 0" , "ccv_nnc_micro.c", 25, __extension__ __PRETTY_FUNCTION__); } )); | |||
| 26 | int i, j, k; | |||
| 27 | // First, do reverse topological sort (from output and then reverse the order). | |||
| 28 | // We can do this simple thing because there is no overlaps of the outputs, thus, no cases where | |||
| 29 | // output[0] is the input for output[1]. Otherwise we need to detect this, see ccv_cnnp_model_new | |||
| 30 | // for more details on why. | |||
| 31 | for (i = 0; i < output_size - 1; i++) | |||
| 32 | for (j = i + 1; j < output_size; j++) | |||
| 33 | { assert(outputs[i] != outputs[j])((void) sizeof ((outputs[i] != outputs[j]) ? 1 : 0), __extension__ ({ if (outputs[i] != outputs[j]) ; else __assert_fail ("outputs[i] != outputs[j]" , "ccv_nnc_micro.c", 33, __extension__ __PRETTY_FUNCTION__); } )); } | |||
| 34 | uint64_t input_bitmask[((input_size - 1) >> 6) + 1]; | |||
| 35 | memset(input_bitmask, 0, sizeof(uint64_t) * (((input_size - 1) >> 6) + 1)); | |||
| 36 | ccv_array_t* const reverse_top = ccv_array_new(sizeof(ccv_nnc_micro_io_t), output_size + input_size, 0); | |||
| 37 | ccv_array_resize(reverse_top, output_size); | |||
| 38 | memcpy(ccv_array_get(reverse_top, 0)((void*)(((char*)((reverse_top)->data)) + (size_t)(reverse_top )->rsize * (size_t)(0))), outputs, sizeof(ccv_nnc_micro_io_t) * output_size); | |||
| 39 | khash_t(ccv_nnc_ids)kh_ccv_nnc_ids_t* const ids = kh_init(ccv_nnc_ids)kh_init_ccv_nnc_ids(); | |||
| 40 | for (i = 0; i < reverse_top->rnum; i++) | |||
| 41 | { | |||
| 42 | const ccv_nnc_micro_io_t output = *(ccv_nnc_micro_io_t*)ccv_array_get(reverse_top, i)((void*)(((char*)((reverse_top)->data)) + (size_t)(reverse_top )->rsize * (size_t)(i))); | |||
| 43 | for (j = 0; j < output->input_size; j++) | |||
| 44 | if (!CCV_NNC_IS_MICRO_IO_INPUT(output->inputs[j])((output->inputs[j])->isa == &ccv_nnc_micro_io_input_isa )) | |||
| 45 | { | |||
| 46 | int ret; | |||
| 47 | kh_put(ccv_nnc_ids, ids, (int64_t)(intptr_t)output->inputs[j], &ret)kh_put_ccv_nnc_ids(ids, (int64_t)(intptr_t)output->inputs[ j], &ret); | |||
| 48 | if (ret != 0) | |||
| 49 | ccv_array_push(reverse_top, &output->inputs[j]); | |||
| 50 | } else { | |||
| 51 | // This is an input, it must be represented in inputs, try to find it. | |||
| 52 | for (k = 0; k
| |||
| 53 | if (inputs[k] == output->inputs[j]) | |||
| 54 | break; | |||
| 55 | assert(k < input_size)((void) sizeof ((k < input_size) ? 1 : 0), __extension__ ( { if (k < input_size) ; else __assert_fail ("k < input_size" , "ccv_nnc_micro.c", 55, __extension__ __PRETTY_FUNCTION__); } )); // Cannot find the inputs, error! | |||
| 56 | input_bitmask[k >> 6] |= ((uint64_t)1 << (k & 63)); | |||
| 57 | } | |||
| 58 | } | |||
| 59 | kh_destroy(ccv_nnc_ids, ids)kh_destroy_ccv_nnc_ids(ids); | |||
| 60 | for (i = 0; i
| |||
| 61 | { assert((input_bitmask[i >> 6] & ((uint64_t)1 << (i & 63))))((void) sizeof (((input_bitmask[i >> 6] & ((uint64_t )1 << (i & 63)))) ? 1 : 0), __extension__ ({ if ((input_bitmask [i >> 6] & ((uint64_t)1 << (i & 63)))) ; else __assert_fail ("(input_bitmask[i >> 6] & ((uint64_t)1 << (i & 63)))" , "ccv_nnc_micro.c", 61, __extension__ __PRETTY_FUNCTION__); } )); } // Assuming they all match. | |||
| 62 | // Second, binding parameters (bounded scalars). | |||
| 63 | khash_t(ccv_nnc_micro_bind_scalar)kh_ccv_nnc_micro_bind_scalar_t* const bind_scalars = kh_init(ccv_nnc_micro_bind_scalar)kh_init_ccv_nnc_micro_bind_scalar(); | |||
| 64 | for (i = 0; i < parameter_size; i++) | |||
| 65 | { | |||
| 66 | int ret; | |||
| 67 | khiter_t k = kh_put(ccv_nnc_micro_bind_scalar, bind_scalars, parameters[i], &ret)kh_put_ccv_nnc_micro_bind_scalar(bind_scalars, parameters[i], &ret); | |||
| 68 | assert(ret != 0)((void) sizeof ((ret != 0) ? 1 : 0), __extension__ ({ if (ret != 0) ; else __assert_fail ("ret != 0", "ccv_nnc_micro.c", 68 , __extension__ __PRETTY_FUNCTION__); })); | |||
| 69 | kh_val(bind_scalars, k)((bind_scalars)->vals[k]) = i; | |||
| 70 | } | |||
| 71 | for (i = 0; i < reverse_top->rnum; i++) | |||
| 72 | { | |||
| 73 | const ccv_nnc_micro_io_t output = *(ccv_nnc_micro_io_t*)ccv_array_get(reverse_top, reverse_top->rnum - 1 - i)((void*)(((char*)((reverse_top)->data)) + (size_t)(reverse_top )->rsize * (size_t)(reverse_top->rnum - 1 - i))); | |||
| 74 | ccv_nnc_micro_bind_scalars(output, _scalars_lookup, bind_scalars); | |||
| 75 | } | |||
| 76 | kh_destroy(ccv_nnc_micro_bind_scalar, bind_scalars)kh_destroy_ccv_nnc_micro_bind_scalar(bind_scalars); | |||
| 77 | const int var_count = reverse_top->rnum + input_size; | |||
| 78 | // Applying numbering for the inputs. Note that our variables are numbered in reverse topological order. | |||
| 79 | for (i = 0; i < input_size; i++) | |||
| 80 | ccv_nnc_micro_numbering(inputs[i], i, var_count); | |||
| 81 | ccv_array_t* const equal_assertions = ccv_array_new(sizeof(ccv_nnc_micro_id_equal_assertion_t), 0, 0); | |||
| 82 | // Applying numbering for the outputs and collect equal assertions. | |||
| 83 | for (i = reverse_top->rnum - 1; i >= 0; i--) | |||
| 84 | { | |||
| 85 | const ccv_nnc_micro_io_t output = *(ccv_nnc_micro_io_t*)ccv_array_get(reverse_top, reverse_top->rnum - 1 - i)((void*)(((char*)((reverse_top)->data)) + (size_t)(reverse_top )->rsize * (size_t)(reverse_top->rnum - 1 - i))); | |||
| 86 | ccv_nnc_micro_numbering(output, i + input_size, var_count); | |||
| 87 | ccv_nnc_micro_equal_assertions(output, equal_assertions); | |||
| 88 | } | |||
| 89 | for (i = 0; i < ingrad_size; i++) | |||
| 90 | ccv_nnc_micro_numbering(ingrads[i], -1, var_count); | |||
| 91 | for (i = 0; i < outgrad_size; i++) | |||
| 92 | ccv_nnc_micro_numbering(outgrads[i], -1, var_count); | |||
| 93 | // Fill in shapes for variables. | |||
| 94 | ccv_nnc_micro_tensor_t* const vars = (ccv_nnc_micro_tensor_t*)cccalloccalloc(var_count * 2, sizeof(ccv_nnc_micro_tensor_t)); | |||
| 95 | for (i = 0; i < input_size; i++) | |||
| 96 | { | |||
| 97 | vars[i].dimensions = inputs[i]->dimensions; | |||
| 98 | vars[i].input = -1; | |||
| 99 | } | |||
| 100 | for (i = 0; i < reverse_top->rnum; i++) | |||
| 101 | { | |||
| 102 | const ccv_nnc_micro_io_t output = *(ccv_nnc_micro_io_t*)ccv_array_get(reverse_top, reverse_top->rnum - 1 - i)((void*)(((char*)((reverse_top)->data)) + (size_t)(reverse_top )->rsize * (size_t)(reverse_top->rnum - 1 - i))); | |||
| 103 | vars[i + input_size] = ccv_nnc_micro_return_shape(output); | |||
| 104 | } | |||
| 105 | for (i = var_count; i < 2 * var_count; i++) | |||
| 106 | { | |||
| 107 | vars[i].dimensions = vars[2 * var_count - 1 - i].dimensions; | |||
| 108 | vars[i].input = 2 * var_count - 1 - i; | |||
| 109 | } | |||
| 110 | // Lower each ccv_nnc_micro_io_t (except the input) op into nested loops such that we can | |||
| 111 | // apply optimizations later. | |||
| 112 | int function_count = reverse_top->rnum; | |||
| 113 | ccv_nnc_micro_function_t* functions = (ccv_nnc_micro_function_t*)ccmallocmalloc(sizeof(ccv_nnc_micro_function_t) * function_count); | |||
| 114 | for (i = 0; i < function_count; i++) | |||
| 115 | { | |||
| 116 | const ccv_nnc_micro_io_t output = *(ccv_nnc_micro_io_t*)ccv_array_get(reverse_top, function_count - 1 - i)((void*)(((char*)((reverse_top)->data)) + (size_t)(reverse_top )->rsize * (size_t)(function_count - 1 - i))); | |||
| 117 | functions[i] = ccv_nnc_micro_emit(output); | |||
| 118 | } | |||
| 119 | ccv_nnc_micro_combine_t* const combine = (ccv_nnc_micro_combine_t*)ccmallocmalloc(sizeof(ccv_nnc_micro_combine_t)); | |||
| 120 | combine->parameter_size = parameter_size; | |||
| 121 | combine->forward.input_size = input_size; | |||
| 122 | combine->forward.inputs = (int*)ccmallocmalloc(sizeof(int) * (input_size + output_size)); | |||
| 123 | for (i = 0; i < input_size; i++) | |||
| 124 | combine->forward.inputs[i] = inputs[i]->id; | |||
| 125 | combine->forward.output_size = output_size; | |||
| 126 | combine->forward.outputs = combine->forward.inputs + input_size; | |||
| 127 | for (i = 0; i < output_size; i++) | |||
| 128 | combine->forward.outputs[i] = outputs[i]->id; | |||
| 129 | combine->forward.var_count = var_count; | |||
| 130 | // We copied forward.vars so backward.vars and forward.vars can maintain separate states. | |||
| 131 | // However, shape and related allocations are shared because these are not going to be mutated. | |||
| 132 | combine->forward.vars = (ccv_nnc_micro_tensor_t*)ccmallocmalloc(sizeof(ccv_nnc_micro_tensor_t) * var_count); | |||
| 133 | memcpy(combine->forward.vars, vars, sizeof(ccv_nnc_micro_tensor_t) * var_count); | |||
| 134 | combine->forward.function_count = function_count; | |||
| 135 | combine->forward.functions = functions; | |||
| 136 | ccv_nnc_micro_program_simplify(&combine->forward, inputs, input_size, outputs, output_size, equal_assertions); | |||
| 137 | function_count = reverse_top->rnum * 2; | |||
| 138 | functions = (ccv_nnc_micro_function_t*)ccmallocmalloc(sizeof(ccv_nnc_micro_function_t) * function_count); | |||
| 139 | for (i = 0; i < reverse_top->rnum; i++) | |||
| 140 | { | |||
| 141 | const ccv_nnc_micro_io_t output = *(ccv_nnc_micro_io_t*)ccv_array_get(reverse_top, reverse_top->rnum - 1 - i)((void*)(((char*)((reverse_top)->data)) + (size_t)(reverse_top )->rsize * (size_t)(reverse_top->rnum - 1 - i))); | |||
| 142 | functions[i] = ccv_nnc_micro_emit(output); | |||
| 143 | } | |||
| 144 | for (i = reverse_top->rnum; i < function_count; i++) | |||
| 145 | { | |||
| 146 | const ccv_nnc_micro_io_t output = *(ccv_nnc_micro_io_t*)ccv_array_get(reverse_top, i - reverse_top->rnum)((void*)(((char*)((reverse_top)->data)) + (size_t)(reverse_top )->rsize * (size_t)(i - reverse_top->rnum))); | |||
| 147 | functions[i] = ccv_nnc_micro_emit_grad(output, var_count); | |||
| 148 | } | |||
| 149 | combine->backward.input_size = ingrad_size; | |||
| 150 | combine->backward.inputs = ingrad_size + outgrad_size > 0 ? (int*)ccmallocmalloc(sizeof(int) * (ingrad_size + outgrad_size)) : 0; | |||
| 151 | for (i = 0; i < ingrad_size; i++) | |||
| 152 | combine->backward.inputs[i] = ingrads[i]->id; | |||
| ||||
| 153 | combine->backward.output_size = outgrad_size; | |||
| 154 | combine->backward.outputs = outgrad_size > 0 ? combine->backward.inputs + ingrad_size : 0; | |||
| 155 | for (i = 0; i < outgrad_size; i++) | |||
| 156 | combine->backward.outputs[i] = outgrads[i]->id; | |||
| 157 | combine->backward.var_count = var_count * 2; | |||
| 158 | combine->backward.vars = vars; | |||
| 159 | combine->backward.function_count = function_count; | |||
| 160 | combine->backward.functions = functions; | |||
| 161 | ccv_nnc_micro_program_simplify(&combine->backward, ingrads, ingrad_size, outgrads, outgrad_size, equal_assertions); | |||
| 162 | combine->equal_assertions = equal_assertions; | |||
| 163 | for (i = 0; i < reverse_top->rnum; i++) | |||
| 164 | { | |||
| 165 | const ccv_nnc_micro_io_t output = *(ccv_nnc_micro_io_t*)ccv_array_get(reverse_top, i)((void*)(((char*)((reverse_top)->data)) + (size_t)(reverse_top )->rsize * (size_t)(i))); | |||
| 166 | ccv_nnc_micro_deinit(output); | |||
| 167 | ccfreefree(output); | |||
| 168 | } | |||
| 169 | ccv_array_free(reverse_top); | |||
| 170 | // It may overlap with inputs, in that case, skip. | |||
| 171 | for (i = 0; i < ingrad_size; i++) | |||
| 172 | { | |||
| 173 | int flag = 0; | |||
| 174 | for (j = 0; !flag && j < input_size; j++) | |||
| 175 | flag = (inputs[j] == ingrads[i]); | |||
| 176 | if (!flag) | |||
| 177 | { | |||
| 178 | ccv_nnc_micro_deinit(ingrads[i]); | |||
| 179 | ccfreefree(ingrads[i]); | |||
| 180 | } | |||
| 181 | } | |||
| 182 | for (i = 0; i < input_size; i++) | |||
| 183 | { | |||
| 184 | ccv_nnc_micro_deinit(inputs[i]); | |||
| 185 | ccfreefree(inputs[i]); | |||
| 186 | } | |||
| 187 | for (i = 0; i < outgrad_size; i++) // Should be no overlap on outgrads. | |||
| 188 | { | |||
| 189 | ccv_nnc_micro_deinit(outgrads[i]); | |||
| 190 | ccfreefree(outgrads[i]); | |||
| 191 | } | |||
| 192 | return combine; | |||
| 193 | } | |||
| 194 | ||||
| 195 | void ccv_nnc_micro_loop_index_free(ccv_nnc_micro_loop_index_term_t* const term) | |||
| 196 | { | |||
| 197 | if (term->type == CCV_NNC_MICRO_LOOP_INDEX_TYPE_BINARY) | |||
| 198 | { | |||
| 199 | ccv_nnc_micro_loop_index_free(&term->binary->left); | |||
| 200 | ccv_nnc_micro_loop_index_free(&term->binary->right); | |||
| 201 | ccfreefree(term->binary); | |||
| 202 | } | |||
| 203 | } | |||
| 204 | ||||
| 205 | void ccv_nnc_micro_loop_variable_free(ccv_nnc_micro_loop_variable_t* const var) | |||
| 206 | { | |||
| 207 | int i; | |||
| 208 | for (i = 0; i < var->index_count; i++) | |||
| 209 | ccv_nnc_micro_loop_index_free(&var->index[i]); | |||
| 210 | } | |||
| 211 | ||||
| 212 | static void _ccv_nnc_micro_loop_expression_free(ccv_nnc_micro_loop_expression_t* const expr) | |||
| 213 | { | |||
| 214 | switch (expr->type) | |||
| 215 | { | |||
| 216 | case CCV_NNC_MICRO_LOOP_EXPR_TYPE_VAR: { | |||
| 217 | ccv_nnc_micro_loop_variable_free(&expr->variable); | |||
| 218 | break; | |||
| 219 | } | |||
| 220 | case CCV_NNC_MICRO_LOOP_EXPR_TYPE_UNARY: { | |||
| 221 | _ccv_nnc_micro_loop_expression_free(expr->unary.x); | |||
| 222 | ccfreefree(expr->unary.x); | |||
| 223 | break; | |||
| 224 | } | |||
| 225 | case CCV_NNC_MICRO_LOOP_EXPR_TYPE_BINARY: { | |||
| 226 | _ccv_nnc_micro_loop_expression_free(expr->binary.left); | |||
| 227 | ccfreefree(expr->binary.left); | |||
| 228 | _ccv_nnc_micro_loop_expression_free(expr->binary.right); | |||
| 229 | ccfreefree(expr->binary.right); | |||
| 230 | break; | |||
| 231 | } | |||
| 232 | case CCV_NNC_MICRO_LOOP_EXPR_TYPE_TERNAY: { | |||
| 233 | _ccv_nnc_micro_loop_expression_free(expr->ternary.pivot); | |||
| 234 | ccfreefree(expr->ternary.pivot); | |||
| 235 | _ccv_nnc_micro_loop_expression_free(expr->ternary.left); | |||
| 236 | ccfreefree(expr->ternary.left); | |||
| 237 | _ccv_nnc_micro_loop_expression_free(expr->ternary.right); | |||
| 238 | ccfreefree(expr->ternary.right); | |||
| 239 | break; | |||
| 240 | } | |||
| 241 | } | |||
| 242 | } | |||
| 243 | ||||
| 244 | void ccv_nnc_micro_loop_statement_lvalue_free(ccv_nnc_micro_loop_statement_t* const statement) | |||
| 245 | { | |||
| 246 | switch (statement->type) | |||
| 247 | { | |||
| 248 | case CCV_NNC_MICRO_LOOP_STATEMENT_TYPE_COMPOUND_ASSIGNMENT: { | |||
| 249 | if (statement->compound_assignment.lvalue.type == CCV_NNC_MICRO_LOOP_EXPR_TYPE_VAR) | |||
| 250 | ccv_nnc_micro_loop_variable_free(&statement->compound_assignment.lvalue.variable); | |||
| 251 | break; | |||
| 252 | } | |||
| 253 | case CCV_NNC_MICRO_LOOP_STATEMENT_TYPE_ASSIGNMENT: { | |||
| 254 | ccv_nnc_micro_loop_variable_free(&statement->assignment.lvalue); | |||
| 255 | break; | |||
| 256 | } | |||
| 257 | } | |||
| 258 | } | |||
| 259 | ||||
| 260 | void ccv_nnc_micro_loop_statement_free(ccv_nnc_micro_loop_statement_t* const statement) | |||
| 261 | { | |||
| 262 | switch (statement->type) | |||
| 263 | { | |||
| 264 | case CCV_NNC_MICRO_LOOP_STATEMENT_TYPE_COMPOUND_ASSIGNMENT: { | |||
| 265 | if (statement->compound_assignment.lvalue.type == CCV_NNC_MICRO_LOOP_EXPR_TYPE_VAR) | |||
| 266 | ccv_nnc_micro_loop_variable_free(&statement->compound_assignment.lvalue.variable); | |||
| 267 | _ccv_nnc_micro_loop_expression_free(&statement->compound_assignment.rvalue); | |||
| 268 | break; | |||
| 269 | } | |||
| 270 | case CCV_NNC_MICRO_LOOP_STATEMENT_TYPE_ASSIGNMENT: { | |||
| 271 | ccv_nnc_micro_loop_variable_free(&statement->assignment.lvalue); | |||
| 272 | _ccv_nnc_micro_loop_expression_free(&statement->assignment.rvalue); | |||
| 273 | break; | |||
| 274 | } | |||
| 275 | } | |||
| 276 | } | |||
| 277 | ||||
| 278 | void ccv_nnc_micro_loops_free(ccv_nnc_micro_loop_t* const loops, const int loop_count) | |||
| 279 | { | |||
| 280 | int i, j; | |||
| 281 | for (i = 0; i < loop_count; i++) | |||
| 282 | { | |||
| 283 | for (j = 0; j < loops[i].statement_count; j++) | |||
| 284 | ccv_nnc_micro_loop_statement_free(&loops[i].statements[j]); | |||
| 285 | if (loops[i].statements) | |||
| 286 | ccfreefree(loops[i].statements); | |||
| 287 | if (loops[i].carrieds) | |||
| 288 | ccfreefree(loops[i].carrieds); | |||
| 289 | } | |||
| 290 | } | |||
| 291 | ||||
| 292 | void ccv_nnc_micro_combine_free(ccv_nnc_micro_combine_t* const combine) | |||
| 293 | { | |||
| 294 | int i, j; | |||
| 295 | const int var_count = combine->forward.var_count; | |||
| 296 | for (i = 0; i < var_count; i++) | |||
| 297 | if (combine->forward.vars[i].shape) | |||
| 298 | { | |||
| 299 | for (j = 0; j < combine->forward.vars[i].dimensions; j++) | |||
| 300 | ccv_nnc_micro_loop_index_free(&combine->forward.vars[i].shape[j]); | |||
| 301 | ccfreefree(combine->forward.vars[i].shape); | |||
| 302 | } | |||
| 303 | ccfreefree(combine->forward.vars); | |||
| 304 | ccfreefree(combine->backward.vars); | |||
| 305 | int function_count = combine->forward.function_count; | |||
| 306 | for (i = 0; i < function_count; i++) | |||
| 307 | { | |||
| 308 | const int block_count = combine->forward.functions[i].block_count; | |||
| 309 | ccv_nnc_micro_loop_block_t* const blocks = (block_count == 1) ? &combine->forward.functions[i].one_block : combine->forward.functions[i].blocks; | |||
| 310 | for (j = 0; j < block_count; j++) | |||
| 311 | { | |||
| 312 | ccv_nnc_micro_loop_block_t block = blocks[j]; | |||
| 313 | ccv_nnc_micro_loops_free(block.loops, block.loop_count); | |||
| 314 | ccfreefree(block.loops); | |||
| 315 | } | |||
| 316 | if (block_count > 1) | |||
| 317 | ccfreefree(combine->forward.functions[i].blocks); | |||
| 318 | } | |||
| 319 | ccfreefree(combine->forward.functions); | |||
| 320 | ccfreefree(combine->forward.inputs); | |||
| 321 | // Backward and forward share the same vars. | |||
| 322 | function_count = combine->backward.function_count; | |||
| 323 | for (i = 0; i < function_count; i++) | |||
| 324 | { | |||
| 325 | const int block_count = combine->backward.functions[i].block_count; | |||
| 326 | ccv_nnc_micro_loop_block_t* const blocks = (block_count == 1) ? &combine->backward.functions[i].one_block : combine->backward.functions[i].blocks; | |||
| 327 | for (j = 0; j < block_count; j++) | |||
| 328 | { | |||
| 329 | ccv_nnc_micro_loop_block_t block = blocks[j]; | |||
| 330 | ccv_nnc_micro_loops_free(block.loops, block.loop_count); | |||
| 331 | ccfreefree(block.loops); | |||
| 332 | } | |||
| 333 | if (block_count > 1) | |||
| 334 | ccfreefree(combine->backward.functions[i].blocks); | |||
| 335 | } | |||
| 336 | ccfreefree(combine->backward.functions); | |||
| 337 | if (combine->backward.inputs) | |||
| 338 | ccfreefree(combine->backward.inputs); | |||
| 339 | ccv_array_free(combine->equal_assertions); | |||
| 340 | ccfreefree(combine); | |||
| 341 | } | |||
| 342 | ||||
| 343 | char* ccv_nnc_micro_combine_c(ccv_nnc_micro_combine_t* const combine) | |||
| 344 | { | |||
| 345 | return 0; | |||
| 346 | } |