| File: | nnc/ccv_nnc_micro_simplify.c |
| Warning: | line 1366, column 17 Array access (via field 'vals') results in a null pointer dereference |
Press '?' to see keyboard shortcuts
Keyboard shortcuts:
| 1 | #include "ccv_nnc.h" | ||||
| 2 | #include "ccv_nnc_easy.h" | ||||
| 3 | #include "ccv_nnc_internal.h" | ||||
| 4 | #include "ccv_internal.h" | ||||
| 5 | #include "_ccv_nnc_micro.h" | ||||
| 6 | #include "3rdparty/khash/khash.h" | ||||
| 7 | |||||
| 8 | #define MICRO_ID_TO_INT(x)(((x).id << 8) | ((x).d)) (((x).id << 8) | ((x).d)) | ||||
| 9 | KHASH_MAP_INIT_INT(ccv_nnc_axis_id_group, int)typedef struct kh_ccv_nnc_axis_id_group_s { khint_t n_buckets , size, n_occupied, upper_bound; khint32_t *flags; khint32_t * keys; int *vals; } kh_ccv_nnc_axis_id_group_t; static inline __attribute__ ((__unused__)) kh_ccv_nnc_axis_id_group_t *kh_init_ccv_nnc_axis_id_group (void) { return (kh_ccv_nnc_axis_id_group_t*)calloc(1,sizeof( kh_ccv_nnc_axis_id_group_t)); } static inline __attribute__ ( (__unused__)) void kh_destroy_ccv_nnc_axis_id_group(kh_ccv_nnc_axis_id_group_t *h) { if (h) { free((void *)h->keys); free(h->flags); free ((void *)h->vals); free(h); } } static inline __attribute__ ((__unused__)) void kh_clear_ccv_nnc_axis_id_group(kh_ccv_nnc_axis_id_group_t *h) { if (h && h->flags) { memset(h->flags, 0xaa , ((h->n_buckets) < 16? 1 : (h->n_buckets)>>4) * sizeof(khint32_t)); h->size = h->n_occupied = 0; } } static inline __attribute__ ((__unused__)) khint_t kh_get_ccv_nnc_axis_id_group (const kh_ccv_nnc_axis_id_group_t *h, khint32_t key) { if (h-> n_buckets) { khint_t k, i, last, mask, step = 0; mask = h-> n_buckets - 1; k = (khint32_t)(key); i = k & mask; last = i; while (!((h->flags[i>>4]>>((i&0xfU)<< 1))&2) && (((h->flags[i>>4]>>((i& 0xfU)<<1))&1) || !((h->keys[i]) == (key)))) { i = (i + (++step)) & mask; if (i == last) return h->n_buckets ; } return ((h->flags[i>>4]>>((i&0xfU)<< 1))&3)? h->n_buckets : i; } else return 0; } static inline __attribute__ ((__unused__)) int kh_resize_ccv_nnc_axis_id_group (kh_ccv_nnc_axis_id_group_t *h, khint_t new_n_buckets) { khint32_t *new_flags = 0; khint_t j = 1; { (--(new_n_buckets), (new_n_buckets )|=(new_n_buckets)>>1, (new_n_buckets)|=(new_n_buckets) >>2, (new_n_buckets)|=(new_n_buckets)>>4, (new_n_buckets )|=(new_n_buckets)>>8, (new_n_buckets)|=(new_n_buckets) >>16, ++(new_n_buckets)); if (new_n_buckets < 4) new_n_buckets = 4; if (h->size >= (khint_t)(new_n_buckets * __ac_HASH_UPPER + 0.5)) j = 0; else { new_flags = (khint32_t*)malloc(((new_n_buckets ) < 16? 1 : (new_n_buckets)>>4) * sizeof(khint32_t)) ; if (!new_flags) return -1; memset(new_flags, 0xaa, ((new_n_buckets ) < 16? 1 : (new_n_buckets)>>4) * sizeof(khint32_t)) ; if (h->n_buckets < new_n_buckets) { khint32_t *new_keys = (khint32_t*)realloc((void *)h->keys,new_n_buckets * sizeof (khint32_t)); if (!new_keys) { free(new_flags); return -1; } h ->keys = new_keys; if (1) { int *new_vals = (int*)realloc( (void *)h->vals,new_n_buckets * sizeof(int)); if (!new_vals ) { free(new_flags); return -1; } h->vals = new_vals; } } } } if (j) { for (j = 0; j != h->n_buckets; ++j) { if (((h-> flags[j>>4]>>((j&0xfU)<<1))&3) == 0 ) { khint32_t key = h->keys[j]; int val; khint_t new_mask; new_mask = new_n_buckets - 1; if (1) val = h->vals[j]; (h ->flags[j>>4]|=1ul<<((j&0xfU)<<1)); while (1) { khint_t k, i, step = 0; k = (khint32_t)(key); i = k & new_mask; while (!((new_flags[i>>4]>>((i&0xfU )<<1))&2)) i = (i + (++step)) & new_mask; (new_flags [i>>4]&=~(2ul<<((i&0xfU)<<1))); if ( i < h->n_buckets && ((h->flags[i>>4]>> ((i&0xfU)<<1))&3) == 0) { { khint32_t tmp = h-> keys[i]; h->keys[i] = key; key = tmp; } if (1) { int tmp = h->vals[i]; h->vals[i] = val; val = tmp; } (h->flags [i>>4]|=1ul<<((i&0xfU)<<1)); } else { h ->keys[i] = key; if (1) h->vals[i] = val; break; } } } } if (h->n_buckets > new_n_buckets) { h->keys = (khint32_t *)realloc((void *)h->keys,new_n_buckets * sizeof(khint32_t )); if (1) h->vals = (int*)realloc((void *)h->vals,new_n_buckets * sizeof(int)); } free(h->flags); h->flags = new_flags ; h->n_buckets = new_n_buckets; h->n_occupied = h->size ; h->upper_bound = (khint_t)(h->n_buckets * __ac_HASH_UPPER + 0.5); } return 0; } static inline __attribute__ ((__unused__ )) khint_t kh_put_ccv_nnc_axis_id_group(kh_ccv_nnc_axis_id_group_t *h, khint32_t key, int *ret) { khint_t x; if (h->n_occupied >= h->upper_bound) { if (h->n_buckets > (h->size <<1)) { if (kh_resize_ccv_nnc_axis_id_group(h, h->n_buckets - 1) < 0) { *ret = -1; return h->n_buckets; } } else if (kh_resize_ccv_nnc_axis_id_group(h, h->n_buckets + 1) < 0) { *ret = -1; return h->n_buckets; } } { khint_t k, i, site , last, mask = h->n_buckets - 1, step = 0; x = site = h-> n_buckets; k = (khint32_t)(key); i = k & mask; if (((h-> flags[i>>4]>>((i&0xfU)<<1))&2)) x = i; else { last = i; while (!((h->flags[i>>4]>> ((i&0xfU)<<1))&2) && (((h->flags[i>> 4]>>((i&0xfU)<<1))&1) || !((h->keys[i] ) == (key)))) { if (((h->flags[i>>4]>>((i& 0xfU)<<1))&1)) site = i; i = (i + (++step)) & mask ; if (i == last) { x = site; break; } } if (x == h->n_buckets ) { if (((h->flags[i>>4]>>((i&0xfU)<< 1))&2) && site != h->n_buckets) x = site; else x = i; } } } if (((h->flags[x>>4]>>((x&0xfU )<<1))&2)) { h->keys[x] = key; (h->flags[x>> 4]&=~(3ul<<((x&0xfU)<<1))); ++h->size; ++h->n_occupied; *ret = 1; } else if (((h->flags[x>> 4]>>((x&0xfU)<<1))&1)) { h->keys[x] = key ; (h->flags[x>>4]&=~(3ul<<((x&0xfU)<< 1))); ++h->size; *ret = 2; } else *ret = 0; return x; } static inline __attribute__ ((__unused__)) void kh_del_ccv_nnc_axis_id_group (kh_ccv_nnc_axis_id_group_t *h, khint_t x) { if (x != h->n_buckets && !((h->flags[x>>4]>>((x&0xfU)<< 1))&3)) { (h->flags[x>>4]|=1ul<<((x&0xfU )<<1)); --h->size; } } | ||||
| 10 | |||||
| 11 | static int _ccv_nnc_same_index_term(const ccv_nnc_micro_loop_index_term_t a_index, const ccv_nnc_micro_loop_index_term_t b_index, const int* const groups, khash_t(ccv_nnc_axis_id_group)kh_ccv_nnc_axis_id_group_t* const axis_id_groups) | ||||
| 12 | { | ||||
| 13 | if (a_index.type != b_index.type) | ||||
| 14 | return 0; | ||||
| 15 | const int type = a_index.type; | ||||
| 16 | switch (type) | ||||
| 17 | { | ||||
| 18 | case CCV_NNC_MICRO_LOOP_INDEX_TYPE_VAL: | ||||
| 19 | return a_index.immediate_value == b_index.immediate_value; | ||||
| 20 | case CCV_NNC_MICRO_LOOP_INDEX_TYPE_ID: | ||||
| 21 | if (a_index.id.type != b_index.id.type) | ||||
| 22 | return 0; | ||||
| 23 | // Check within the axis_id_groups to see if there is a match, if there is no match, we can proceed (to use the group table again to check). | ||||
| 24 | if (axis_id_groups && a_index.id.type == CCV_NNC_MICRO_AXIS_SIZE_ID) | ||||
| 25 | { | ||||
| 26 | ccv_nnc_micro_id_t a_id = a_index.id; | ||||
| 27 | while (groups && groups[a_id.id] != a_id.id) | ||||
| 28 | a_id.id = groups[a_id.id]; | ||||
| 29 | int a_root = MICRO_ID_TO_INT(a_id)(((a_id).id << 8) | ((a_id).d)); | ||||
| 30 | khiter_t k; | ||||
| 31 | for (;;) { | ||||
| 32 | k = kh_get(ccv_nnc_axis_id_group, axis_id_groups, a_root)kh_get_ccv_nnc_axis_id_group(axis_id_groups, a_root); | ||||
| 33 | if (k == kh_end(axis_id_groups)((axis_id_groups)->n_buckets)) | ||||
| 34 | break; | ||||
| 35 | a_root = kh_val(axis_id_groups, k)((axis_id_groups)->vals[k]); | ||||
| 36 | } | ||||
| 37 | ccv_nnc_micro_id_t b_id = b_index.id; | ||||
| 38 | while (groups && groups[b_id.id] != b_id.id) | ||||
| 39 | b_id.id = groups[b_id.id]; | ||||
| 40 | int b_root = MICRO_ID_TO_INT(b_id)(((b_id).id << 8) | ((b_id).d)); | ||||
| 41 | for (;;) { | ||||
| 42 | k = kh_get(ccv_nnc_axis_id_group, axis_id_groups, b_root)kh_get_ccv_nnc_axis_id_group(axis_id_groups, b_root); | ||||
| 43 | if (k == kh_end(axis_id_groups)((axis_id_groups)->n_buckets)) | ||||
| 44 | break; | ||||
| 45 | b_root = kh_val(axis_id_groups, k)((axis_id_groups)->vals[k]); | ||||
| 46 | } | ||||
| 47 | if (a_root == b_root) | ||||
| 48 | return 1; | ||||
| 49 | } | ||||
| 50 | if (groups && (a_index.id.type == CCV_NNC_MICRO_AXIS_SIZE_ID || a_index.id.type == CCV_NNC_MICRO_TENSOR_ID)) | ||||
| 51 | { | ||||
| 52 | if (a_index.id.d != b_index.id.d) | ||||
| 53 | return 0; | ||||
| 54 | switch (a_index.id.type) | ||||
| 55 | { | ||||
| 56 | case CCV_NNC_MICRO_TENSOR_ID: | ||||
| 57 | case CCV_NNC_MICRO_AXIS_SIZE_ID: { | ||||
| 58 | // Find their group identifier and then compare. | ||||
| 59 | int a_root = groups[a_index.id.id]; | ||||
| 60 | while (groups[a_root] != a_root) | ||||
| 61 | a_root = groups[a_root]; | ||||
| 62 | int b_root = groups[b_index.id.id]; | ||||
| 63 | while (groups[b_root] != b_root) | ||||
| 64 | b_root = groups[b_root]; | ||||
| 65 | return a_root == b_root; | ||||
| 66 | } | ||||
| 67 | } | ||||
| 68 | return a_index.id.id == b_index.id.id; | ||||
| 69 | } else | ||||
| 70 | return (a_index.id.d == b_index.id.d && a_index.id.id == b_index.id.id); | ||||
| 71 | case CCV_NNC_MICRO_LOOP_INDEX_TYPE_BINARY: { | ||||
| 72 | return a_index.binary->op == b_index.binary->op && _ccv_nnc_same_index_term(a_index.binary->left, b_index.binary->left, groups, axis_id_groups) && _ccv_nnc_same_index_term(a_index.binary->right, b_index.binary->right, groups, axis_id_groups); | ||||
| 73 | } | ||||
| 74 | } | ||||
| 75 | return 0; | ||||
| 76 | } | ||||
| 77 | |||||
| 78 | static int _ccv_nnc_same_shape(const ccv_nnc_micro_loop_index_term_t* const a_shape, const ccv_nnc_micro_loop_index_term_t* const b_shape, const int dimensions) | ||||
| 79 | { | ||||
| 80 | int i; | ||||
| 81 | for (i = 0; i < dimensions; i++) | ||||
| 82 | if (!_ccv_nnc_same_index_term(a_shape[i], b_shape[i], 0, 0)) | ||||
| 83 | return 0; | ||||
| 84 | return 1; | ||||
| 85 | } | ||||
| 86 | |||||
| 87 | static int _ccv_nnc_same_loop(const ccv_nnc_micro_loop_block_t* const left_block, const ccv_nnc_micro_loop_block_t* const right_block, const int* const groups, khash_t(ccv_nnc_axis_id_group)kh_ccv_nnc_axis_id_group_t* const axis_id_groups, int* const left_loop_idx, int* const right_loop_idx) | ||||
| 88 | { | ||||
| 89 | assert(left_block->loop_count > 0)((void) sizeof ((left_block->loop_count > 0) ? 1 : 0), __extension__ ({ if (left_block->loop_count > 0) ; else __assert_fail ("left_block->loop_count > 0", "ccv_nnc_micro_simplify.c" , 89, __extension__ __PRETTY_FUNCTION__); })); | ||||
| 90 | assert(right_block->loop_count > 0)((void) sizeof ((right_block->loop_count > 0) ? 1 : 0), __extension__ ({ if (right_block->loop_count > 0) ; else __assert_fail ("right_block->loop_count > 0", "ccv_nnc_micro_simplify.c" , 90, __extension__ __PRETTY_FUNCTION__); })); | ||||
| 91 | int i, j; | ||||
| 92 | int left_right_link[left_block->loop_count]; | ||||
| 93 | int right_left_link[right_block->loop_count]; | ||||
| 94 | enum { | ||||
| 95 | ONE = -2, | ||||
| 96 | UNASSIGNED = -1, | ||||
| 97 | }; | ||||
| 98 | for (i = 0; i < left_block->loop_count; i++) | ||||
| 99 | if (left_block->loops[i].start_index.type == CCV_NNC_MICRO_LOOP_INDEX_TYPE_VAL && left_block->loops[i].start_index.immediate_value == 0 && | ||||
| 100 | left_block->loops[i].end_index.type == CCV_NNC_MICRO_LOOP_INDEX_TYPE_VAL && left_block->loops[i].end_index.immediate_value == 1) | ||||
| 101 | left_right_link[i] = ONE; | ||||
| 102 | else | ||||
| 103 | left_right_link[i] = UNASSIGNED; | ||||
| 104 | for (i = 0; i < right_block->loop_count; i++) | ||||
| 105 | if (right_block->loops[i].start_index.type == CCV_NNC_MICRO_LOOP_INDEX_TYPE_VAL && right_block->loops[i].start_index.immediate_value == 0 && | ||||
| 106 | right_block->loops[i].end_index.type == CCV_NNC_MICRO_LOOP_INDEX_TYPE_VAL && right_block->loops[i].end_index.immediate_value == 1) | ||||
| 107 | right_left_link[i] = ONE; | ||||
| 108 | else | ||||
| 109 | right_left_link[i] = UNASSIGNED; | ||||
| 110 | for (i = 0; i < left_block->loop_count; i++) // Find corresponding loop on the right. | ||||
| 111 | { | ||||
| 112 | if (left_right_link[i] != UNASSIGNED) | ||||
| 113 | continue; | ||||
| 114 | int flag = UNASSIGNED; | ||||
| 115 | for (j = 0; j < right_block->loop_count && flag == UNASSIGNED; j++) | ||||
| 116 | { | ||||
| 117 | if (right_left_link[j] != UNASSIGNED) | ||||
| 118 | continue; | ||||
| 119 | if (_ccv_nnc_same_index_term(left_block->loops[i].start_index, right_block->loops[j].start_index, groups, axis_id_groups) && | ||||
| 120 | _ccv_nnc_same_index_term(left_block->loops[i].end_index, right_block->loops[j].end_index, groups, axis_id_groups)) | ||||
| 121 | flag = j; | ||||
| 122 | } | ||||
| 123 | if (flag != UNASSIGNED) | ||||
| 124 | { | ||||
| 125 | left_right_link[i] = flag; | ||||
| 126 | right_left_link[flag] = i; | ||||
| 127 | } | ||||
| 128 | } | ||||
| 129 | // If still have unmatched, they don't share the same loop. | ||||
| 130 | for (i = 0; i < left_block->loop_count; i++) | ||||
| 131 | if (left_right_link[i] == UNASSIGNED) | ||||
| 132 | return 0; | ||||
| 133 | for (i = 0; i < right_block->loop_count; i++) | ||||
| 134 | if (right_left_link[i] == UNASSIGNED) | ||||
| 135 | return 0; | ||||
| 136 | // I don't want to deal with constant loop, hence, if other than the outer-most is a constant loop (0..<1), | ||||
| 137 | // we cannot merge. | ||||
| 138 | for (i = 1; i < left_block->loop_count; i++) | ||||
| 139 | if (left_right_link[i] == ONE) | ||||
| 140 | return 0; | ||||
| 141 | for (i = 1; i < right_block->loop_count; i++) | ||||
| 142 | if (right_left_link[i] == ONE) | ||||
| 143 | return 0; | ||||
| 144 | assert((left_block->loop_count == right_block->loop_count) ||((void) sizeof (((left_block->loop_count == right_block-> loop_count) || (left_block->loop_count == right_block-> loop_count + 1) || (left_block->loop_count + 1 == right_block ->loop_count)) ? 1 : 0), __extension__ ({ if ((left_block-> loop_count == right_block->loop_count) || (left_block-> loop_count == right_block->loop_count + 1) || (left_block-> loop_count + 1 == right_block->loop_count)) ; else __assert_fail ("(left_block->loop_count == right_block->loop_count) || (left_block->loop_count == right_block->loop_count + 1) || (left_block->loop_count + 1 == right_block->loop_count)" , "ccv_nnc_micro_simplify.c", 146, __extension__ __PRETTY_FUNCTION__ ); })) | ||||
| 145 | (left_block->loop_count == right_block->loop_count + 1) ||((void) sizeof (((left_block->loop_count == right_block-> loop_count) || (left_block->loop_count == right_block-> loop_count + 1) || (left_block->loop_count + 1 == right_block ->loop_count)) ? 1 : 0), __extension__ ({ if ((left_block-> loop_count == right_block->loop_count) || (left_block-> loop_count == right_block->loop_count + 1) || (left_block-> loop_count + 1 == right_block->loop_count)) ; else __assert_fail ("(left_block->loop_count == right_block->loop_count) || (left_block->loop_count == right_block->loop_count + 1) || (left_block->loop_count + 1 == right_block->loop_count)" , "ccv_nnc_micro_simplify.c", 146, __extension__ __PRETTY_FUNCTION__ ); })) | ||||
| 146 | (left_block->loop_count + 1 == right_block->loop_count))((void) sizeof (((left_block->loop_count == right_block-> loop_count) || (left_block->loop_count == right_block-> loop_count + 1) || (left_block->loop_count + 1 == right_block ->loop_count)) ? 1 : 0), __extension__ ({ if ((left_block-> loop_count == right_block->loop_count) || (left_block-> loop_count == right_block->loop_count + 1) || (left_block-> loop_count + 1 == right_block->loop_count)) ; else __assert_fail ("(left_block->loop_count == right_block->loop_count) || (left_block->loop_count == right_block->loop_count + 1) || (left_block->loop_count + 1 == right_block->loop_count)" , "ccv_nnc_micro_simplify.c", 146, __extension__ __PRETTY_FUNCTION__ ); })); | ||||
| 147 | // The loop matches, but the ordering probably doesn't. We reorder loop based on statements. | ||||
| 148 | // Hence, two loops can only merge if using the statements as a pivot point and they can still | ||||
| 149 | // match things before / after the statement. | ||||
| 150 | // If both have statements, check if order preserving within the statement loop (we can be fancier | ||||
| 151 | // and recursively call this while using statement as pivoting point, but that is too much to my taste). | ||||
| 152 | const int left_start_idx = left_right_link[0] == ONE ? 1 : 0; | ||||
| 153 | const int right_start_idx = right_left_link[0] == ONE ? 1 : 0; | ||||
| 154 | for (i = 0; i < left_block->loop_count; i++) | ||||
| 155 | left_loop_idx[i] = UNASSIGNED; | ||||
| 156 | for (i = 0; i < right_block->loop_count; i++) | ||||
| 157 | right_loop_idx[i] = UNASSIGNED; | ||||
| 158 | if (left_start_idx == 1) | ||||
| 159 | left_loop_idx[0] = 0; // Assign their index. | ||||
| 160 | if (right_start_idx == 0) | ||||
| 161 | right_loop_idx[0] = 0; // Assign their index. | ||||
| 162 | const int end_idx = left_right_link[0] == ONE && right_left_link[0] == ONE ? left_block->loop_count - 1 : ccv_min(left_block->loop_count, right_block->loop_count)({ typeof (left_block->loop_count) _a = (left_block->loop_count ); typeof (right_block->loop_count) _b = (right_block-> loop_count); (_a < _b) ? _a : _b; }); | ||||
| 163 | int pivot_idx = end_idx; | ||||
| 164 | int k; | ||||
| 165 | for (i = end_idx - 1; i >= 0; i--) | ||||
| 166 | { | ||||
| 167 | if (left_block->loops[i + left_start_idx].statement_count > 0) | ||||
| 168 | { | ||||
| 169 | for (j = i + 1, k = i + 1; j < end_idx; j++) | ||||
| 170 | if (left_loop_idx[j + left_start_idx] == UNASSIGNED) | ||||
| 171 | { | ||||
| 172 | left_loop_idx[j + left_start_idx] = k + left_start_idx; | ||||
| 173 | // If the right one can be referenced pass previous pivot_idx, it is not right. | ||||
| 174 | if (left_right_link[j + left_start_idx] >= pivot_idx + right_start_idx) | ||||
| 175 | return 0; | ||||
| 176 | right_loop_idx[left_right_link[j + left_start_idx]] = k + right_start_idx; | ||||
| 177 | ++k; | ||||
| 178 | if (k > pivot_idx) | ||||
| 179 | return 0; | ||||
| 180 | } | ||||
| 181 | assert(k == pivot_idx)((void) sizeof ((k == pivot_idx) ? 1 : 0), __extension__ ({ if (k == pivot_idx) ; else __assert_fail ("k == pivot_idx", "ccv_nnc_micro_simplify.c" , 181, __extension__ __PRETTY_FUNCTION__); })); | ||||
| 182 | pivot_idx = i + 1; | ||||
| 183 | } | ||||
| 184 | if (right_block->loops[i + right_start_idx].statement_count > 0) | ||||
| 185 | { | ||||
| 186 | for (j = i + 1, k = i + 1; j < end_idx; j++) | ||||
| 187 | if (right_loop_idx[j + left_start_idx] == UNASSIGNED) | ||||
| 188 | { | ||||
| 189 | right_loop_idx[j + right_start_idx] = k + right_start_idx; | ||||
| 190 | // If the left one can be referenced pass previous pivot_idx, it is not right. | ||||
| 191 | if (right_left_link[j + right_start_idx] >= pivot_idx + left_start_idx) | ||||
| 192 | return 0; | ||||
| 193 | left_loop_idx[right_left_link[j + right_start_idx]] = k + left_start_idx; | ||||
| 194 | ++k; | ||||
| 195 | if (k > pivot_idx) | ||||
| 196 | return 0; | ||||
| 197 | } | ||||
| 198 | assert(k == pivot_idx)((void) sizeof ((k == pivot_idx) ? 1 : 0), __extension__ ({ if (k == pivot_idx) ; else __assert_fail ("k == pivot_idx", "ccv_nnc_micro_simplify.c" , 198, __extension__ __PRETTY_FUNCTION__); })); | ||||
| 199 | pivot_idx = i + 1; | ||||
| 200 | } | ||||
| 201 | } | ||||
| 202 | if (end_idx == 0) | ||||
| 203 | return 1; | ||||
| 204 | // Finally, to distribute the rest. | ||||
| 205 | for (j = 0, k = 0; j < end_idx; j++) | ||||
| 206 | { | ||||
| 207 | if (left_loop_idx[j + left_start_idx] == UNASSIGNED) | ||||
| 208 | { | ||||
| 209 | left_loop_idx[j + left_start_idx] = k + left_start_idx; | ||||
| 210 | // If the right one can be referenced pass previous pivot_idx, it is not right. | ||||
| 211 | if (left_right_link[j + left_start_idx] >= pivot_idx + right_start_idx) | ||||
| 212 | return 0; | ||||
| 213 | right_loop_idx[left_right_link[j + left_start_idx]] = k + right_start_idx; | ||||
| 214 | ++k; | ||||
| 215 | if (k > pivot_idx) | ||||
| 216 | return 0; | ||||
| 217 | } | ||||
| 218 | } | ||||
| 219 | assert(k == pivot_idx)((void) sizeof ((k == pivot_idx) ? 1 : 0), __extension__ ({ if (k == pivot_idx) ; else __assert_fail ("k == pivot_idx", "ccv_nnc_micro_simplify.c" , 219, __extension__ __PRETTY_FUNCTION__); })); | ||||
| 220 | return 1; | ||||
| 221 | } | ||||
| 222 | |||||
| 223 | static void _ccv_nnc_loop_order_by(ccv_nnc_micro_loop_block_t* const block, int* const loop_idx, ccv_nnc_micro_loop_t* const loops) | ||||
| 224 | { | ||||
| 225 | int i; | ||||
| 226 | for (i = 0; i < block->loop_count; i++) | ||||
| 227 | if (loop_idx[i] >= 0) | ||||
| 228 | loops[loop_idx[i]] = block->loops[i]; | ||||
| 229 | else | ||||
| 230 | loops[i] = block->loops[i]; | ||||
| 231 | for (i = 0; i < block->loop_count; i++) | ||||
| 232 | { | ||||
| 233 | // Essentially, we don't need to move statements, loop-carried variables, just the loop id and the start / end index. | ||||
| 234 | block->loops[i].id = loops[i].id; | ||||
| 235 | block->loops[i].start_index = loops[i].start_index; | ||||
| 236 | block->loops[i].end_index = loops[i].end_index; | ||||
| 237 | } | ||||
| 238 | } | ||||
| 239 | |||||
| 240 | static void _ccv_nnc_expression_rename_carrieds(ccv_nnc_micro_loop_expression_t* const expression, const int start_idx) | ||||
| 241 | { | ||||
| 242 | switch (expression->type) | ||||
| 243 | { | ||||
| 244 | case CCV_NNC_MICRO_LOOP_EXPR_TYPE_ID: | ||||
| 245 | assert(expression->id.type == CCV_NNC_MICRO_LOOP_CARRIED_ID)((void) sizeof ((expression->id.type == CCV_NNC_MICRO_LOOP_CARRIED_ID ) ? 1 : 0), __extension__ ({ if (expression->id.type == CCV_NNC_MICRO_LOOP_CARRIED_ID ) ; else __assert_fail ("expression->id.type == CCV_NNC_MICRO_LOOP_CARRIED_ID" , "ccv_nnc_micro_simplify.c", 245, __extension__ __PRETTY_FUNCTION__ ); })); | ||||
| 246 | expression->id.id += start_idx; | ||||
| 247 | break; | ||||
| 248 | case CCV_NNC_MICRO_LOOP_EXPR_TYPE_TERNAY: | ||||
| 249 | _ccv_nnc_expression_rename_carrieds(expression->ternary.pivot, start_idx); | ||||
| 250 | _ccv_nnc_expression_rename_carrieds(expression->ternary.left, start_idx); | ||||
| 251 | _ccv_nnc_expression_rename_carrieds(expression->ternary.right, start_idx); | ||||
| 252 | break; | ||||
| 253 | case CCV_NNC_MICRO_LOOP_EXPR_TYPE_BINARY: | ||||
| 254 | _ccv_nnc_expression_rename_carrieds(expression->binary.left, start_idx); | ||||
| 255 | _ccv_nnc_expression_rename_carrieds(expression->binary.right, start_idx); | ||||
| 256 | break; | ||||
| 257 | case CCV_NNC_MICRO_LOOP_EXPR_TYPE_UNARY: | ||||
| 258 | _ccv_nnc_expression_rename_carrieds(expression->unary.x, start_idx); | ||||
| 259 | break; | ||||
| 260 | // We don't need to care about other expressions because loop-carried variable cannot participate these operations. | ||||
| 261 | case CCV_NNC_MICRO_LOOP_EXPR_TYPE_VAR: | ||||
| 262 | break; | ||||
| 263 | } | ||||
| 264 | } | ||||
| 265 | |||||
| 266 | static void _ccv_nnc_loop_rename_carrieds(ccv_nnc_micro_loop_block_t* const block, const int start_idx) | ||||
| 267 | { | ||||
| 268 | int i, j; | ||||
| 269 | const int loop_count = block->loop_count; | ||||
| 270 | ccv_nnc_micro_loop_t* const loops = block->loops; | ||||
| 271 | for (i = 0; i < loop_count; i++) | ||||
| 272 | { | ||||
| 273 | for (j = 0; j < loops[i].carried_count; j++) | ||||
| 274 | loops[i].carrieds[j].id.id += start_idx; | ||||
| 275 | for (j = 0; j < loops[i].statement_count; j++) | ||||
| 276 | switch (loops[i].statements[j].type) | ||||
| 277 | { | ||||
| 278 | case CCV_NNC_MICRO_LOOP_STATEMENT_TYPE_ASSIGNMENT: | ||||
| 279 | _ccv_nnc_expression_rename_carrieds(&loops[i].statements[j].compound_assignment.rvalue, start_idx); | ||||
| 280 | break; | ||||
| 281 | case CCV_NNC_MICRO_LOOP_STATEMENT_TYPE_COMPOUND_ASSIGNMENT: | ||||
| 282 | if (loops[i].statements[j].compound_assignment.lvalue.type == CCV_NNC_MICRO_LOOP_EXPR_TYPE_ID) | ||||
| 283 | { | ||||
| 284 | assert(loops[i].statements[j].compound_assignment.lvalue.id.type == CCV_NNC_MICRO_LOOP_CARRIED_ID)((void) sizeof ((loops[i].statements[j].compound_assignment.lvalue .id.type == CCV_NNC_MICRO_LOOP_CARRIED_ID) ? 1 : 0), __extension__ ({ if (loops[i].statements[j].compound_assignment.lvalue.id. type == CCV_NNC_MICRO_LOOP_CARRIED_ID) ; else __assert_fail ( "loops[i].statements[j].compound_assignment.lvalue.id.type == CCV_NNC_MICRO_LOOP_CARRIED_ID" , "ccv_nnc_micro_simplify.c", 284, __extension__ __PRETTY_FUNCTION__ ); })); | ||||
| 285 | loops[i].statements[j].compound_assignment.lvalue.id.id += start_idx; | ||||
| 286 | } | ||||
| 287 | _ccv_nnc_expression_rename_carrieds(&loops[i].statements[j].compound_assignment.rvalue, start_idx); | ||||
| 288 | break; | ||||
| 289 | } | ||||
| 290 | } | ||||
| 291 | } | ||||
| 292 | |||||
| 293 | static int _ccv_nnc_only_var_in_expression(const int id, const ccv_nnc_micro_loop_variable_t var, const ccv_nnc_micro_loop_expression_t* const expression, const int* const groups, khash_t(ccv_nnc_axis_id_group)kh_ccv_nnc_axis_id_group_t* const axis_id_groups) | ||||
| 294 | { | ||||
| 295 | switch (expression->type) | ||||
| 296 | { | ||||
| 297 | case CCV_NNC_MICRO_LOOP_EXPR_TYPE_VAR: | ||||
| 298 | if (expression->variable.id.type == CCV_NNC_MICRO_TENSOR_ID && expression->variable.id.id == id) | ||||
| 299 | { | ||||
| 300 | if (var.index_count != expression->variable.index_count) | ||||
| 301 | return 2; | ||||
| 302 | int i; | ||||
| 303 | for (i = 0; i < var.index_count; i++) | ||||
| 304 | if (!_ccv_nnc_same_index_term(var.index[i], expression->variable.index[i], groups, axis_id_groups)) | ||||
| 305 | return 2; | ||||
| 306 | return 1; | ||||
| 307 | } else | ||||
| 308 | return 0; | ||||
| 309 | case CCV_NNC_MICRO_LOOP_EXPR_TYPE_TERNAY: { | ||||
| 310 | const int pivot = _ccv_nnc_only_var_in_expression(id, var, expression->ternary.pivot, groups, axis_id_groups); | ||||
| 311 | const int left = _ccv_nnc_only_var_in_expression(id, var, expression->ternary.left, groups, axis_id_groups); | ||||
| 312 | const int right = _ccv_nnc_only_var_in_expression(id, var, expression->ternary.right, groups, axis_id_groups); | ||||
| 313 | if (pivot == 2 || left == 2 || right == 2) | ||||
| 314 | return 2; | ||||
| 315 | return (pivot || left || right); | ||||
| 316 | } | ||||
| 317 | case CCV_NNC_MICRO_LOOP_EXPR_TYPE_BINARY: { | ||||
| 318 | const int left = _ccv_nnc_only_var_in_expression(id, var, expression->binary.left, groups, axis_id_groups); | ||||
| 319 | const int right = _ccv_nnc_only_var_in_expression(id, var, expression->binary.right, groups, axis_id_groups); | ||||
| 320 | if (left == 2 || right == 2) | ||||
| 321 | return 2; | ||||
| 322 | return (left || right); | ||||
| 323 | } | ||||
| 324 | case CCV_NNC_MICRO_LOOP_EXPR_TYPE_UNARY: | ||||
| 325 | return _ccv_nnc_only_var_in_expression(id, var, expression->unary.x, groups, axis_id_groups); | ||||
| 326 | case CCV_NNC_MICRO_LOOP_EXPR_TYPE_ID: | ||||
| 327 | assert(expression->id.type == CCV_NNC_MICRO_LOOP_CARRIED_ID)((void) sizeof ((expression->id.type == CCV_NNC_MICRO_LOOP_CARRIED_ID ) ? 1 : 0), __extension__ ({ if (expression->id.type == CCV_NNC_MICRO_LOOP_CARRIED_ID ) ; else __assert_fail ("expression->id.type == CCV_NNC_MICRO_LOOP_CARRIED_ID" , "ccv_nnc_micro_simplify.c", 327, __extension__ __PRETTY_FUNCTION__ ); })); | ||||
| 328 | return 0; | ||||
| 329 | } | ||||
| 330 | return 0; | ||||
| 331 | } | ||||
| 332 | |||||
| 333 | static int _ccv_nnc_only_var_in_rvalue(const int id, const ccv_nnc_micro_loop_variable_t var, const ccv_nnc_micro_loop_statement_t statement, const int* const groups, khash_t(ccv_nnc_axis_id_group)kh_ccv_nnc_axis_id_group_t* const axis_id_groups) | ||||
| 334 | { | ||||
| 335 | switch (statement.type) | ||||
| 336 | { | ||||
| 337 | case CCV_NNC_MICRO_LOOP_STATEMENT_TYPE_ASSIGNMENT: | ||||
| 338 | return _ccv_nnc_only_var_in_expression(id, var, &statement.assignment.rvalue, groups, axis_id_groups); | ||||
| 339 | case CCV_NNC_MICRO_LOOP_STATEMENT_TYPE_COMPOUND_ASSIGNMENT: | ||||
| 340 | return _ccv_nnc_only_var_in_expression(id, var, &statement.compound_assignment.rvalue, groups, axis_id_groups); | ||||
| 341 | } | ||||
| 342 | return 0; | ||||
| 343 | } | ||||
| 344 | |||||
| 345 | static ccv_nnc_micro_loop_expression_t _ccv_nnc_expression_deep_copy(const ccv_nnc_micro_loop_expression_t* const expression) | ||||
| 346 | { | ||||
| 347 | switch (expression->type) | ||||
| 348 | { | ||||
| 349 | case CCV_NNC_MICRO_LOOP_EXPR_TYPE_TERNAY: { | ||||
| 350 | ccv_nnc_micro_loop_expression_t copy = *expression; | ||||
| 351 | copy.ternary.pivot = (ccv_nnc_micro_loop_expression_t*)ccmallocmalloc(sizeof(ccv_nnc_micro_loop_expression_t)); | ||||
| 352 | *copy.ternary.pivot = _ccv_nnc_expression_deep_copy(expression->ternary.pivot); | ||||
| 353 | copy.ternary.left = (ccv_nnc_micro_loop_expression_t*)ccmallocmalloc(sizeof(ccv_nnc_micro_loop_expression_t)); | ||||
| 354 | *copy.ternary.left = _ccv_nnc_expression_deep_copy(expression->ternary.left); | ||||
| 355 | copy.ternary.right = (ccv_nnc_micro_loop_expression_t*)ccmallocmalloc(sizeof(ccv_nnc_micro_loop_expression_t)); | ||||
| 356 | *copy.ternary.right = _ccv_nnc_expression_deep_copy(expression->ternary.right); | ||||
| 357 | return copy; | ||||
| 358 | } | ||||
| 359 | case CCV_NNC_MICRO_LOOP_EXPR_TYPE_BINARY: { | ||||
| 360 | ccv_nnc_micro_loop_expression_t copy = *expression; | ||||
| 361 | copy.binary.left = (ccv_nnc_micro_loop_expression_t*)ccmallocmalloc(sizeof(ccv_nnc_micro_loop_expression_t)); | ||||
| 362 | *copy.binary.left = _ccv_nnc_expression_deep_copy(expression->binary.left); | ||||
| 363 | copy.binary.right = (ccv_nnc_micro_loop_expression_t*)ccmallocmalloc(sizeof(ccv_nnc_micro_loop_expression_t)); | ||||
| 364 | *copy.binary.right = _ccv_nnc_expression_deep_copy(expression->binary.right); | ||||
| 365 | return copy; | ||||
| 366 | } | ||||
| 367 | case CCV_NNC_MICRO_LOOP_EXPR_TYPE_UNARY: { | ||||
| 368 | ccv_nnc_micro_loop_expression_t copy = *expression; | ||||
| 369 | copy.unary.x = (ccv_nnc_micro_loop_expression_t*)ccmallocmalloc(sizeof(ccv_nnc_micro_loop_expression_t)); | ||||
| 370 | *copy.unary.x = _ccv_nnc_expression_deep_copy(expression->unary.x); | ||||
| 371 | return copy; | ||||
| 372 | } | ||||
| 373 | case CCV_NNC_MICRO_LOOP_EXPR_TYPE_VAR: { | ||||
| 374 | ccv_nnc_micro_loop_expression_t copy = *expression; | ||||
| 375 | int i; | ||||
| 376 | for (i = 0; i < copy.variable.index_count; i++) | ||||
| 377 | copy.variable.index[i] = ccv_nnc_micro_loop_index_deep_copy(©.variable.index[i]); | ||||
| 378 | return copy; | ||||
| 379 | } | ||||
| 380 | case CCV_NNC_MICRO_LOOP_EXPR_TYPE_ID: | ||||
| 381 | return *expression; | ||||
| 382 | } | ||||
| 383 | return *expression; | ||||
| 384 | } | ||||
| 385 | |||||
| 386 | static void _ccv_nnc_replacing_id_in_expression(ccv_nnc_micro_loop_expression_t* const expression, const int id, ccv_nnc_micro_loop_expression_t rvalue, int* const count) | ||||
| 387 | { | ||||
| 388 | switch (expression->type) | ||||
| 389 | { | ||||
| 390 | case CCV_NNC_MICRO_LOOP_EXPR_TYPE_VAR: | ||||
| 391 | if (expression->variable.id.type == CCV_NNC_MICRO_TENSOR_ID && expression->variable.id.id == id) | ||||
| 392 | { | ||||
| 393 | ccv_nnc_micro_loop_variable_free(&expression->variable); | ||||
| 394 | if (*count == 0) // First time, just assign to expression. | ||||
| 395 | *expression = rvalue; | ||||
| 396 | else // Otherwise, need to make deep copy of it. | ||||
| 397 | *expression = _ccv_nnc_expression_deep_copy(&rvalue); | ||||
| 398 | ++(*count); | ||||
| 399 | } | ||||
| 400 | break; | ||||
| 401 | case CCV_NNC_MICRO_LOOP_EXPR_TYPE_TERNAY: | ||||
| 402 | _ccv_nnc_replacing_id_in_expression(expression->ternary.pivot, id, rvalue, count); | ||||
| 403 | _ccv_nnc_replacing_id_in_expression(expression->ternary.left, id, rvalue, count); | ||||
| 404 | _ccv_nnc_replacing_id_in_expression(expression->ternary.right, id, rvalue, count); | ||||
| 405 | break; | ||||
| 406 | case CCV_NNC_MICRO_LOOP_EXPR_TYPE_BINARY: | ||||
| 407 | _ccv_nnc_replacing_id_in_expression(expression->binary.left, id, rvalue, count); | ||||
| 408 | _ccv_nnc_replacing_id_in_expression(expression->binary.right, id, rvalue, count); | ||||
| 409 | break; | ||||
| 410 | case CCV_NNC_MICRO_LOOP_EXPR_TYPE_UNARY: | ||||
| 411 | _ccv_nnc_replacing_id_in_expression(expression->unary.x, id, rvalue, count); | ||||
| 412 | break; | ||||
| 413 | case CCV_NNC_MICRO_LOOP_EXPR_TYPE_ID: | ||||
| 414 | assert(expression->id.type == CCV_NNC_MICRO_LOOP_CARRIED_ID)((void) sizeof ((expression->id.type == CCV_NNC_MICRO_LOOP_CARRIED_ID ) ? 1 : 0), __extension__ ({ if (expression->id.type == CCV_NNC_MICRO_LOOP_CARRIED_ID ) ; else __assert_fail ("expression->id.type == CCV_NNC_MICRO_LOOP_CARRIED_ID" , "ccv_nnc_micro_simplify.c", 414, __extension__ __PRETTY_FUNCTION__ ); })); | ||||
| 415 | break; | ||||
| 416 | } | ||||
| 417 | } | ||||
| 418 | |||||
| 419 | static void _ccv_nnc_replacing_id_in_rvalue(ccv_nnc_micro_loop_statement_t* const statement, const int id, ccv_nnc_micro_loop_expression_t rvalue, int* const count) | ||||
| 420 | { | ||||
| 421 | switch (statement->type) | ||||
| 422 | { | ||||
| 423 | case CCV_NNC_MICRO_LOOP_STATEMENT_TYPE_ASSIGNMENT: | ||||
| 424 | _ccv_nnc_replacing_id_in_expression(&statement->assignment.rvalue, id, rvalue, count); | ||||
| 425 | break; | ||||
| 426 | case CCV_NNC_MICRO_LOOP_STATEMENT_TYPE_COMPOUND_ASSIGNMENT: | ||||
| 427 | // Not going to be in lvalue (which is the carried variable only). | ||||
| 428 | _ccv_nnc_replacing_id_in_expression(&statement->compound_assignment.rvalue, id, rvalue, count); | ||||
| 429 | break; | ||||
| 430 | } | ||||
| 431 | } | ||||
| 432 | |||||
| 433 | typedef struct { | ||||
| 434 | int flag; | ||||
| 435 | int merge_to; | ||||
| 436 | ccv_array_t* writes; | ||||
| 437 | ccv_array_t* reads; | ||||
| 438 | } ccv_nnc_micro_loop_block_dependency_t; | ||||
| 439 | |||||
| 440 | typedef struct { | ||||
| 441 | int flag; | ||||
| 442 | ccv_array_t* writes; | ||||
| 443 | ccv_array_t* reads; | ||||
| 444 | } ccv_nnc_micro_tensor_dependency_t; | ||||
| 445 | |||||
| 446 | static void _ccv_nnc_micro_block_dependencies_from_rvalue(const ccv_nnc_micro_loop_expression_t* const rvalue, const int i, ccv_nnc_micro_loop_block_dependency_t* const block_dependencies, ccv_nnc_micro_tensor_dependency_t* const tensor_dependencies) | ||||
| 447 | { | ||||
| 448 | switch (rvalue->type) | ||||
| 449 | { | ||||
| 450 | case CCV_NNC_MICRO_LOOP_EXPR_TYPE_VAR: | ||||
| 451 | if (rvalue->variable.id.type == CCV_NNC_MICRO_TENSOR_ID) | ||||
| 452 | { | ||||
| 453 | if (!block_dependencies[i].reads) | ||||
| 454 | block_dependencies[i].reads = ccv_array_new(sizeof(int), 1, 0); | ||||
| 455 | const int id = rvalue->variable.id.id; | ||||
| 456 | ccv_array_add_unique_int(block_dependencies[i].reads, id); | ||||
| 457 | if (!tensor_dependencies[id].reads) | ||||
| 458 | tensor_dependencies[id].reads = ccv_array_new(sizeof(int), 1, 0); | ||||
| 459 | ccv_array_add_unique_int(tensor_dependencies[id].reads, i); | ||||
| 460 | } | ||||
| 461 | break; | ||||
| 462 | case CCV_NNC_MICRO_LOOP_EXPR_TYPE_TERNAY: | ||||
| 463 | _ccv_nnc_micro_block_dependencies_from_rvalue(rvalue->ternary.pivot, i, block_dependencies, tensor_dependencies); | ||||
| 464 | _ccv_nnc_micro_block_dependencies_from_rvalue(rvalue->ternary.left, i, block_dependencies, tensor_dependencies); | ||||
| 465 | _ccv_nnc_micro_block_dependencies_from_rvalue(rvalue->ternary.right, i, block_dependencies, tensor_dependencies); | ||||
| 466 | break; | ||||
| 467 | case CCV_NNC_MICRO_LOOP_EXPR_TYPE_BINARY: | ||||
| 468 | _ccv_nnc_micro_block_dependencies_from_rvalue(rvalue->binary.left, i, block_dependencies, tensor_dependencies); | ||||
| 469 | _ccv_nnc_micro_block_dependencies_from_rvalue(rvalue->binary.right, i, block_dependencies, tensor_dependencies); | ||||
| 470 | break; | ||||
| 471 | case CCV_NNC_MICRO_LOOP_EXPR_TYPE_UNARY: | ||||
| 472 | _ccv_nnc_micro_block_dependencies_from_rvalue(rvalue->unary.x, i, block_dependencies, tensor_dependencies); | ||||
| 473 | break; | ||||
| 474 | case CCV_NNC_MICRO_LOOP_EXPR_TYPE_ID: | ||||
| 475 | assert(rvalue->id.type == CCV_NNC_MICRO_LOOP_CARRIED_ID)((void) sizeof ((rvalue->id.type == CCV_NNC_MICRO_LOOP_CARRIED_ID ) ? 1 : 0), __extension__ ({ if (rvalue->id.type == CCV_NNC_MICRO_LOOP_CARRIED_ID ) ; else __assert_fail ("rvalue->id.type == CCV_NNC_MICRO_LOOP_CARRIED_ID" , "ccv_nnc_micro_simplify.c", 475, __extension__ __PRETTY_FUNCTION__ ); })); | ||||
| 476 | break; | ||||
| 477 | } | ||||
| 478 | } | ||||
| 479 | |||||
| 480 | static void _ccv_nnc_micro_block_dependencies(const ccv_nnc_micro_loop_block_t* const blocks, const int block_size, const int var_count, ccv_nnc_micro_loop_block_dependency_t** const block_dependencies_ref, ccv_nnc_micro_tensor_dependency_t** const tensor_dependencies_ref) | ||||
| 481 | { | ||||
| 482 | ccv_nnc_micro_loop_block_dependency_t* const block_dependencies = (ccv_nnc_micro_loop_block_dependency_t*)cccalloccalloc(block_size, sizeof(ccv_nnc_micro_loop_block_dependency_t)); | ||||
| 483 | ccv_nnc_micro_tensor_dependency_t* const tensor_dependencies = (ccv_nnc_micro_tensor_dependency_t*)cccalloccalloc(var_count, sizeof(ccv_nnc_micro_tensor_dependency_t)); | ||||
| 484 | int i, j, k; | ||||
| 485 | for (i = 0; i < block_size; i++) | ||||
| 486 | { | ||||
| 487 | block_dependencies[i].merge_to = i; | ||||
| 488 | const ccv_nnc_micro_loop_t* const loops = blocks[i].loops; | ||||
| 489 | const int loop_count = blocks[i].loop_count; | ||||
| 490 | for (j = 0; j < loop_count; j++) | ||||
| 491 | { | ||||
| 492 | const ccv_nnc_micro_loop_statement_t* const statements = loops[j].statements; | ||||
| 493 | const int statement_count = loops[j].statement_count; | ||||
| 494 | for (k = 0; k < statement_count; k++) | ||||
| 495 | switch (statements[k].type) | ||||
| 496 | { | ||||
| 497 | case CCV_NNC_MICRO_LOOP_STATEMENT_TYPE_ASSIGNMENT: { | ||||
| 498 | assert(statements[k].assignment.lvalue.id.type == CCV_NNC_MICRO_TENSOR_ID)((void) sizeof ((statements[k].assignment.lvalue.id.type == CCV_NNC_MICRO_TENSOR_ID ) ? 1 : 0), __extension__ ({ if (statements[k].assignment.lvalue .id.type == CCV_NNC_MICRO_TENSOR_ID) ; else __assert_fail ("statements[k].assignment.lvalue.id.type == CCV_NNC_MICRO_TENSOR_ID" , "ccv_nnc_micro_simplify.c", 498, __extension__ __PRETTY_FUNCTION__ ); })); | ||||
| 499 | const int id = statements[k].assignment.lvalue.id.id; | ||||
| 500 | if (!block_dependencies[i].writes) | ||||
| 501 | block_dependencies[i].writes = ccv_array_new(sizeof(int), 1, 0); | ||||
| 502 | ccv_array_add_unique_int(block_dependencies[i].writes, id); | ||||
| 503 | if (!tensor_dependencies[id].writes) | ||||
| 504 | tensor_dependencies[id].writes = ccv_array_new(sizeof(int), 1, 0); | ||||
| 505 | ccv_array_add_unique_int(tensor_dependencies[id].writes, i); | ||||
| 506 | _ccv_nnc_micro_block_dependencies_from_rvalue(&statements[k].assignment.rvalue, i, block_dependencies, tensor_dependencies); | ||||
| 507 | break; | ||||
| 508 | } | ||||
| 509 | case CCV_NNC_MICRO_LOOP_STATEMENT_TYPE_COMPOUND_ASSIGNMENT: { | ||||
| 510 | if (statements[k].compound_assignment.lvalue.type == CCV_NNC_MICRO_LOOP_EXPR_TYPE_VAR) | ||||
| 511 | { | ||||
| 512 | assert(statements[k].compound_assignment.lvalue.id.type == CCV_NNC_MICRO_TENSOR_ID)((void) sizeof ((statements[k].compound_assignment.lvalue.id. type == CCV_NNC_MICRO_TENSOR_ID) ? 1 : 0), __extension__ ({ if (statements[k].compound_assignment.lvalue.id.type == CCV_NNC_MICRO_TENSOR_ID ) ; else __assert_fail ("statements[k].compound_assignment.lvalue.id.type == CCV_NNC_MICRO_TENSOR_ID" , "ccv_nnc_micro_simplify.c", 512, __extension__ __PRETTY_FUNCTION__ ); })); | ||||
| 513 | const int id = statements[k].compound_assignment.lvalue.id.id; | ||||
| 514 | if (!block_dependencies[i].writes) | ||||
| 515 | block_dependencies[i].writes = ccv_array_new(sizeof(int), 1, 0); | ||||
| 516 | ccv_array_add_unique_int(block_dependencies[i].writes, id); | ||||
| 517 | if (!tensor_dependencies[id].writes) | ||||
| 518 | tensor_dependencies[id].writes = ccv_array_new(sizeof(int), 1, 0); | ||||
| 519 | ccv_array_add_unique_int(tensor_dependencies[id].writes, i); | ||||
| 520 | if (!block_dependencies[i].reads) | ||||
| 521 | block_dependencies[i].reads = ccv_array_new(sizeof(int), 1, 0); | ||||
| 522 | ccv_array_add_unique_int(block_dependencies[i].reads, id); | ||||
| 523 | if (!tensor_dependencies[id].reads) | ||||
| 524 | tensor_dependencies[id].reads = ccv_array_new(sizeof(int), 1, 0); | ||||
| 525 | ccv_array_add_unique_int(tensor_dependencies[id].reads, i); | ||||
| 526 | } | ||||
| 527 | _ccv_nnc_micro_block_dependencies_from_rvalue(&statements[k].compound_assignment.rvalue, i, block_dependencies, tensor_dependencies); | ||||
| 528 | break; | ||||
| 529 | } | ||||
| 530 | } | ||||
| 531 | } | ||||
| 532 | } | ||||
| 533 | *block_dependencies_ref = block_dependencies; | ||||
| 534 | *tensor_dependencies_ref = tensor_dependencies; | ||||
| 535 | } | ||||
| 536 | |||||
| 537 | static void _ccv_nnc_micro_dependencies_free(ccv_nnc_micro_loop_block_dependency_t* const block_dependencies, const int block_size, ccv_nnc_micro_tensor_dependency_t* const tensor_dependencies, const int var_count) | ||||
| 538 | { | ||||
| 539 | int i; | ||||
| 540 | for (i = 0; i < block_size; i++) | ||||
| 541 | { | ||||
| 542 | if (block_dependencies[i].writes) | ||||
| 543 | ccv_array_free(block_dependencies[i].writes); | ||||
| 544 | if (block_dependencies[i].reads) | ||||
| 545 | ccv_array_free(block_dependencies[i].reads); | ||||
| 546 | } | ||||
| 547 | ccfreefree(block_dependencies); | ||||
| 548 | for (i = 0; i < var_count; i++) | ||||
| 549 | { | ||||
| 550 | if (tensor_dependencies[i].writes) | ||||
| 551 | ccv_array_free(tensor_dependencies[i].writes); | ||||
| 552 | if (tensor_dependencies[i].reads) | ||||
| 553 | ccv_array_free(tensor_dependencies[i].reads); | ||||
| 554 | } | ||||
| 555 | ccfreefree(tensor_dependencies); | ||||
| 556 | } | ||||
| 557 | |||||
| 558 | static int _ccv_nnc_tensor_reads_in_y_from_writes_after_x(const ccv_nnc_micro_loop_block_dependency_t* const block_dependencies, const ccv_nnc_micro_tensor_dependency_t* const tensor_dependencies, const int x, const int y) | ||||
| 559 | { | ||||
| 560 | int i, j; | ||||
| 561 | int flag = 0; | ||||
| 562 | for (i = 0; !flag && i < block_dependencies[y].reads->rnum; i++) | ||||
| 563 | { | ||||
| 564 | const int read_idx = *(int*)ccv_array_get(block_dependencies[y].reads, i)((void*)(((char*)((block_dependencies[y].reads)->data)) + ( size_t)(block_dependencies[y].reads)->rsize * (size_t)(i)) ); | ||||
| 565 | if (tensor_dependencies[read_idx].writes) | ||||
| 566 | for (j = 0; !flag && j < tensor_dependencies[read_idx].writes->rnum; j++) | ||||
| 567 | { | ||||
| 568 | int block_idx = *(int*)ccv_array_get(tensor_dependencies[read_idx].writes, j)((void*)(((char*)((tensor_dependencies[read_idx].writes)-> data)) + (size_t)(tensor_dependencies[read_idx].writes)->rsize * (size_t)(j))); | ||||
| 569 | while (block_idx != block_dependencies[block_idx].merge_to) | ||||
| 570 | block_idx = block_dependencies[block_idx].merge_to; | ||||
| 571 | if (!block_dependencies[block_idx].flag) // Not in use, continue. | ||||
| 572 | continue; | ||||
| 573 | assert(block_idx <= y)((void) sizeof ((block_idx <= y) ? 1 : 0), __extension__ ( { if (block_idx <= y) ; else __assert_fail ("block_idx <= y" , "ccv_nnc_micro_simplify.c", 573, __extension__ __PRETTY_FUNCTION__ ); })); | ||||
| 574 | // If the block_idx is between i and j (and not neither). We cannot merge. | ||||
| 575 | if (block_idx > x && block_idx != y) | ||||
| 576 | flag = block_idx; | ||||
| 577 | } | ||||
| 578 | } | ||||
| 579 | return flag; | ||||
| 580 | } | ||||
| 581 | |||||
| 582 | static int _ccv_nnc_tensor_writes_in_x_reads_before_y(const ccv_nnc_micro_loop_block_dependency_t* const block_dependencies, const ccv_nnc_micro_tensor_dependency_t* const tensor_dependencies, const int x, const int y) | ||||
| 583 | { | ||||
| 584 | int i, j; | ||||
| 585 | int flag = 0; | ||||
| 586 | for (i = 0; !flag && i < block_dependencies[x].writes->rnum; i++) | ||||
| 587 | { | ||||
| 588 | const int write_idx = *(int*)ccv_array_get(block_dependencies[x].writes, i)((void*)(((char*)((block_dependencies[x].writes)->data)) + (size_t)(block_dependencies[x].writes)->rsize * (size_t)( i))); | ||||
| 589 | if (tensor_dependencies[write_idx].reads) | ||||
| 590 | for (j = 0; !flag && j < tensor_dependencies[write_idx].reads->rnum; j++) | ||||
| 591 | { | ||||
| 592 | int block_idx = *(int*)ccv_array_get(tensor_dependencies[write_idx].reads, j)((void*)(((char*)((tensor_dependencies[write_idx].reads)-> data)) + (size_t)(tensor_dependencies[write_idx].reads)->rsize * (size_t)(j))); | ||||
| 593 | while (block_idx != block_dependencies[block_idx].merge_to) | ||||
| 594 | block_idx = block_dependencies[block_idx].merge_to; | ||||
| 595 | if (!block_dependencies[block_idx].flag) // Not in use, continue. | ||||
| 596 | continue; | ||||
| 597 | assert(block_idx >= x)((void) sizeof ((block_idx >= x) ? 1 : 0), __extension__ ( { if (block_idx >= x) ; else __assert_fail ("block_idx >= x" , "ccv_nnc_micro_simplify.c", 597, __extension__ __PRETTY_FUNCTION__ ); })); | ||||
| 598 | // If the block_idx is between i and j (and not neither). We cannot merge. | ||||
| 599 | if (block_idx < y && block_idx != x) | ||||
| 600 | flag = block_idx; | ||||
| 601 | } | ||||
| 602 | } | ||||
| 603 | return flag; | ||||
| 604 | } | ||||
| 605 | |||||
| 606 | static void _ccv_nnc_tensor_remove_dead_store(const ccv_nnc_micro_tensor_dependency_t* const tensor_dependency, const int tensor_idx, ccv_array_t* const blocks) | ||||
| 607 | { | ||||
| 608 | int i, j, k, l;; | ||||
| 609 | if (tensor_dependency->writes) | ||||
| 610 | for (i = 0; i < tensor_dependency->writes->rnum; i++) | ||||
| 611 | { | ||||
| 612 | const int write_idx = *(int*)ccv_array_get(tensor_dependency->writes, i)((void*)(((char*)((tensor_dependency->writes)->data)) + (size_t)(tensor_dependency->writes)->rsize * (size_t)( i))); | ||||
| 613 | ccv_nnc_micro_loop_block_t* const block = (ccv_nnc_micro_loop_block_t*)ccv_array_get(blocks, write_idx)((void*)(((char*)((blocks)->data)) + (size_t)(blocks)-> rsize * (size_t)(write_idx))); | ||||
| 614 | int flag = 0; | ||||
| 615 | for (j = 0; j < block->loop_count; j++) | ||||
| 616 | { | ||||
| 617 | ccv_nnc_micro_loop_statement_t* const statements = block->loops[j].statements; | ||||
| 618 | for (k = 0, l = 0; k < block->loops[j].statement_count; k++) | ||||
| 619 | { | ||||
| 620 | // It cannot be compound assignment, in this case, this tensor will be in read, and | ||||
| 621 | // it will be in active use (anything "read" in an active block will be marked as in use). | ||||
| 622 | assert(!(statements[k].type == CCV_NNC_MICRO_LOOP_STATEMENT_TYPE_COMPOUND_ASSIGNMENT &&((void) sizeof ((!(statements[k].type == CCV_NNC_MICRO_LOOP_STATEMENT_TYPE_COMPOUND_ASSIGNMENT && statements[k].compound_assignment.lvalue.id.type == CCV_NNC_MICRO_TENSOR_ID && statements[k].compound_assignment .lvalue.id.id == tensor_idx)) ? 1 : 0), __extension__ ({ if ( !(statements[k].type == CCV_NNC_MICRO_LOOP_STATEMENT_TYPE_COMPOUND_ASSIGNMENT && statements[k].compound_assignment.lvalue.id.type == CCV_NNC_MICRO_TENSOR_ID && statements[k].compound_assignment .lvalue.id.id == tensor_idx)) ; else __assert_fail ("!(statements[k].type == CCV_NNC_MICRO_LOOP_STATEMENT_TYPE_COMPOUND_ASSIGNMENT && statements[k].compound_assignment.lvalue.id.type == CCV_NNC_MICRO_TENSOR_ID && statements[k].compound_assignment.lvalue.id.id == tensor_idx)" , "ccv_nnc_micro_simplify.c", 624, __extension__ __PRETTY_FUNCTION__ ); })) | ||||
| 623 | statements[k].compound_assignment.lvalue.id.type == CCV_NNC_MICRO_TENSOR_ID &&((void) sizeof ((!(statements[k].type == CCV_NNC_MICRO_LOOP_STATEMENT_TYPE_COMPOUND_ASSIGNMENT && statements[k].compound_assignment.lvalue.id.type == CCV_NNC_MICRO_TENSOR_ID && statements[k].compound_assignment .lvalue.id.id == tensor_idx)) ? 1 : 0), __extension__ ({ if ( !(statements[k].type == CCV_NNC_MICRO_LOOP_STATEMENT_TYPE_COMPOUND_ASSIGNMENT && statements[k].compound_assignment.lvalue.id.type == CCV_NNC_MICRO_TENSOR_ID && statements[k].compound_assignment .lvalue.id.id == tensor_idx)) ; else __assert_fail ("!(statements[k].type == CCV_NNC_MICRO_LOOP_STATEMENT_TYPE_COMPOUND_ASSIGNMENT && statements[k].compound_assignment.lvalue.id.type == CCV_NNC_MICRO_TENSOR_ID && statements[k].compound_assignment.lvalue.id.id == tensor_idx)" , "ccv_nnc_micro_simplify.c", 624, __extension__ __PRETTY_FUNCTION__ ); })) | ||||
| 624 | statements[k].compound_assignment.lvalue.id.id == tensor_idx))((void) sizeof ((!(statements[k].type == CCV_NNC_MICRO_LOOP_STATEMENT_TYPE_COMPOUND_ASSIGNMENT && statements[k].compound_assignment.lvalue.id.type == CCV_NNC_MICRO_TENSOR_ID && statements[k].compound_assignment .lvalue.id.id == tensor_idx)) ? 1 : 0), __extension__ ({ if ( !(statements[k].type == CCV_NNC_MICRO_LOOP_STATEMENT_TYPE_COMPOUND_ASSIGNMENT && statements[k].compound_assignment.lvalue.id.type == CCV_NNC_MICRO_TENSOR_ID && statements[k].compound_assignment .lvalue.id.id == tensor_idx)) ; else __assert_fail ("!(statements[k].type == CCV_NNC_MICRO_LOOP_STATEMENT_TYPE_COMPOUND_ASSIGNMENT && statements[k].compound_assignment.lvalue.id.type == CCV_NNC_MICRO_TENSOR_ID && statements[k].compound_assignment.lvalue.id.id == tensor_idx)" , "ccv_nnc_micro_simplify.c", 624, __extension__ __PRETTY_FUNCTION__ ); })); | ||||
| 625 | if (statements[k].type == CCV_NNC_MICRO_LOOP_STATEMENT_TYPE_ASSIGNMENT && | ||||
| 626 | statements[k].assignment.lvalue.id.type == CCV_NNC_MICRO_TENSOR_ID && | ||||
| 627 | statements[k].assignment.lvalue.id.id == tensor_idx) | ||||
| 628 | { | ||||
| 629 | // This is a dead store, prepare to remove. | ||||
| 630 | ccv_nnc_micro_loop_statement_free(&statements[k]); | ||||
| 631 | } else { | ||||
| 632 | statements[l] = statements[k]; | ||||
| 633 | ++l; | ||||
| 634 | } | ||||
| 635 | } | ||||
| 636 | if (l < block->loops[j].statement_count) | ||||
| 637 | { | ||||
| 638 | if (l > 0) | ||||
| 639 | block->loops[j].statements = (ccv_nnc_micro_loop_statement_t*)ccreallocrealloc(block->loops[j].statements, sizeof(ccv_nnc_micro_loop_statement_t) * l); | ||||
| 640 | else { | ||||
| 641 | ccfreefree(block->loops[j].statements); | ||||
| 642 | block->loops[j].statements = 0; | ||||
| 643 | } | ||||
| 644 | block->loops[j].statement_count = 0; | ||||
| 645 | } | ||||
| 646 | if (block->loops[j].statement_count > 0) | ||||
| 647 | flag = 1; | ||||
| 648 | } | ||||
| 649 | if (!flag) // No statement for this block, remove this whole block. | ||||
| 650 | { | ||||
| 651 | ccv_nnc_micro_loops_free(block->loops, block->loop_count); | ||||
| 652 | ccfreefree(block->loops); | ||||
| 653 | block->loops = 0; | ||||
| 654 | block->loop_count = 0; | ||||
| 655 | } | ||||
| 656 | } | ||||
| 657 | } | ||||
| 658 | |||||
| 659 | static void _ccv_nnc_loop_merging(ccv_nnc_micro_loop_block_dependency_t* const block_dependencies, const ccv_nnc_micro_tensor_dependency_t* const tensor_dependencies, ccv_array_t* const blocks, const int max_loop_count, const int* const groups, khash_t(ccv_nnc_axis_id_group)kh_ccv_nnc_axis_id_group_t* const axis_id_groups) | ||||
| 660 | { | ||||
| 661 | int i, j; | ||||
| 662 | int left_loop_idx[max_loop_count]; | ||||
| 663 | int right_loop_idx[max_loop_count]; | ||||
| 664 | ccv_nnc_micro_loop_t loops[max_loop_count]; | ||||
| 665 | // Merge loops from blocks. | ||||
| 666 | for (i = 0; i < blocks->rnum - 1; i++) | ||||
| 667 | { | ||||
| 668 | ccv_nnc_micro_loop_block_t* const left_block = (ccv_nnc_micro_loop_block_t*)ccv_array_get(blocks, i)((void*)(((char*)((blocks)->data)) + (size_t)(blocks)-> rsize * (size_t)(i))); | ||||
| 669 | if (left_block->loop_count == 0) | ||||
| 670 | continue; | ||||
| 671 | for (j = i + 1; j < blocks->rnum; j++) | ||||
| 672 | { | ||||
| 673 | // We always merge from right block to left block. Thus, the right block will always be | ||||
| 674 | // in the original form. | ||||
| 675 | ccv_nnc_micro_loop_block_t* const right_block = (ccv_nnc_micro_loop_block_t*)ccv_array_get(blocks, j)((void*)(((char*)((blocks)->data)) + (size_t)(blocks)-> rsize * (size_t)(j))); | ||||
| 676 | if (right_block->loop_count == 0) | ||||
| 677 | continue; | ||||
| 678 | int merge_to_right = 0; | ||||
| 679 | // First check whether between left and right, there are any blocks that the right block | ||||
| 680 | // depends on. If that is the case, we cannot merge the right block into the left block. | ||||
| 681 | if (j > i + 1 && block_dependencies[j].reads) | ||||
| 682 | { | ||||
| 683 | const int block_idx = _ccv_nnc_tensor_reads_in_y_from_writes_after_x(block_dependencies, tensor_dependencies, i, j); | ||||
| 684 | // Cannot merge because we have dependencies in between. Merging will violate that | ||||
| 685 | // dependency relationship. | ||||
| 686 | if (block_idx) | ||||
| 687 | { | ||||
| 688 | // Now check to see if left can be merged into right? If so, we are lucky. | ||||
| 689 | if (_ccv_nnc_tensor_writes_in_x_reads_before_y(block_dependencies, tensor_dependencies, i, j)) | ||||
| 690 | continue; | ||||
| 691 | merge_to_right = 1; | ||||
| 692 | } | ||||
| 693 | } | ||||
| 694 | // This method not only compares whether they have the same loop or not, but also gives indexes that | ||||
| 695 | // to match the loop start / end index, where they should move to. For example, if: | ||||
| 696 | // left_loop_idx[2] = 3 | ||||
| 697 | // right_loop_idx[0] = 3 | ||||
| 698 | // That means right now, loop at index 2 on the left is the same as loop at index 0 on the right. | ||||
| 699 | // And to match exactly, they both need to move to index 3. | ||||
| 700 | if (_ccv_nnc_same_loop(left_block, right_block, groups, axis_id_groups, left_loop_idx, right_loop_idx)) | ||||
| 701 | { | ||||
| 702 | // Make sure if we have extra loop, they are on the left. | ||||
| 703 | if (right_block->loop_count > left_block->loop_count) | ||||
| 704 | { | ||||
| 705 | ccv_nnc_micro_loop_block_t t; | ||||
| 706 | CCV_SWAP(*left_block, *right_block, t)((t) = (*left_block), (*left_block) = (*right_block), (*right_block ) = (t)); | ||||
| 707 | } | ||||
| 708 | assert(left_block->loop_count == right_block->loop_count || left_block->loop_count == right_block->loop_count + 1)((void) sizeof ((left_block->loop_count == right_block-> loop_count || left_block->loop_count == right_block->loop_count + 1) ? 1 : 0), __extension__ ({ if (left_block->loop_count == right_block->loop_count || left_block->loop_count == right_block->loop_count + 1) ; else __assert_fail ("left_block->loop_count == right_block->loop_count || left_block->loop_count == right_block->loop_count + 1" , "ccv_nnc_micro_simplify.c", 708, __extension__ __PRETTY_FUNCTION__ ); })); | ||||
| 709 | _ccv_nnc_loop_order_by(left_block, left_loop_idx, loops); | ||||
| 710 | _ccv_nnc_loop_order_by(right_block, right_loop_idx, loops); | ||||
| 711 | const int left_start_idx = left_block->loop_count - right_block->loop_count; | ||||
| 712 | if (left_block->carried_count > 0) | ||||
| 713 | _ccv_nnc_loop_rename_carrieds(right_block, left_block->carried_count); | ||||
| 714 | left_block->carried_count += right_block->carried_count; | ||||
| 715 | int k; | ||||
| 716 | for (k = 0; k < right_block->loop_count; k++) // Merge loops. | ||||
| 717 | { | ||||
| 718 | const int left_idx = left_start_idx + k; | ||||
| 719 | if (right_block->loops[k].carried_count > 0) | ||||
| 720 | { | ||||
| 721 | if (left_block->loops[left_idx].carried_count > 0) | ||||
| 722 | { | ||||
| 723 | left_block->loops[left_idx].carrieds = (ccv_nnc_micro_loop_carried_t*)ccreallocrealloc(left_block->loops[left_idx].carrieds, sizeof(ccv_nnc_micro_loop_carried_t) * (left_block->loops[left_idx].carried_count + right_block->loops[k].carried_count)); | ||||
| 724 | memcpy(left_block->loops[left_idx].carrieds + left_block->loops[left_idx].carried_count, right_block->loops[k].carrieds, sizeof(ccv_nnc_micro_loop_carried_t) * right_block->loops[k].carried_count); | ||||
| 725 | ccfreefree(right_block->loops[k].carrieds); | ||||
| 726 | } else | ||||
| 727 | left_block->loops[left_idx].carrieds = right_block->loops[k].carrieds; | ||||
| 728 | left_block->loops[left_idx].carried_count += right_block->loops[k].carried_count; | ||||
| 729 | right_block->loops[k].carrieds = 0; | ||||
| 730 | right_block->loops[k].carried_count = 0; | ||||
| 731 | } | ||||
| 732 | if (right_block->loops[k].statement_count > 0) | ||||
| 733 | { | ||||
| 734 | if (left_block->loops[left_idx].statement_count > 0) | ||||
| 735 | { | ||||
| 736 | left_block->loops[left_idx].statements = (ccv_nnc_micro_loop_statement_t*)ccreallocrealloc(left_block->loops[left_idx].statements, sizeof(ccv_nnc_micro_loop_statement_t) * (left_block->loops[left_idx].statement_count + right_block->loops[k].statement_count)); | ||||
| 737 | memcpy(left_block->loops[left_idx].statements + left_block->loops[left_idx].statement_count, right_block->loops[k].statements, sizeof(ccv_nnc_micro_loop_statement_t) * right_block->loops[k].statement_count); | ||||
| 738 | ccfreefree(right_block->loops[k].statements); | ||||
| 739 | } else | ||||
| 740 | left_block->loops[left_idx].statements = right_block->loops[k].statements; | ||||
| 741 | left_block->loops[left_idx].statement_count += right_block->loops[k].statement_count; | ||||
| 742 | right_block->loops[k].statements = 0; | ||||
| 743 | right_block->loops[k].statement_count = 0; | ||||
| 744 | } | ||||
| 745 | } | ||||
| 746 | // Once merged, free the loop. | ||||
| 747 | ccfreefree(right_block->loops); | ||||
| 748 | right_block->loops = 0; | ||||
| 749 | right_block->loop_count = 0; | ||||
| 750 | int x = i, y = j; | ||||
| 751 | if (merge_to_right) // If this is merge to right. | ||||
| 752 | { | ||||
| 753 | ccv_nnc_micro_loop_block_t t; | ||||
| 754 | CCV_SWAP(*left_block, *right_block, t)((t) = (*left_block), (*left_block) = (*right_block), (*right_block ) = (t)); | ||||
| 755 | x = j, y = i; | ||||
| 756 | } | ||||
| 757 | // Merge all reads and writes tensors into block dependency. | ||||
| 758 | if (block_dependencies[y].writes && block_dependencies[y].writes->rnum) | ||||
| 759 | { | ||||
| 760 | if (!block_dependencies[x].writes) | ||||
| 761 | block_dependencies[x].writes = ccv_array_new(sizeof(int), 1, 0); | ||||
| 762 | for (k = 0; k < block_dependencies[y].writes->rnum; k++) | ||||
| 763 | ccv_array_push(block_dependencies[x].writes, ccv_array_get(block_dependencies[y].writes, k)((void*)(((char*)((block_dependencies[y].writes)->data)) + (size_t)(block_dependencies[y].writes)->rsize * (size_t)( k)))); | ||||
| 764 | } | ||||
| 765 | if (block_dependencies[y].reads && block_dependencies[y].reads->rnum) | ||||
| 766 | { | ||||
| 767 | if (!block_dependencies[x].reads) | ||||
| 768 | block_dependencies[x].reads = ccv_array_new(sizeof(int), 1, 0); | ||||
| 769 | for (k = 0; k < block_dependencies[y].reads->rnum; k++) | ||||
| 770 | ccv_array_push(block_dependencies[x].reads, ccv_array_get(block_dependencies[y].reads, k)((void*)(((char*)((block_dependencies[y].reads)->data)) + ( size_t)(block_dependencies[y].reads)->rsize * (size_t)(k)) )); | ||||
| 771 | } | ||||
| 772 | // Merged, mark the proper merging dependency. | ||||
| 773 | block_dependencies[y].merge_to = x; | ||||
| 774 | if (merge_to_right) // If this is merge to right, now left is empty, break. | ||||
| 775 | break; | ||||
| 776 | } | ||||
| 777 | } | ||||
| 778 | } | ||||
| 779 | } | ||||
| 780 | |||||
| 781 | static void _ccv_nnc_var_subst(ccv_nnc_micro_tensor_t* const vars, const int var_count, const ccv_nnc_micro_io_t* const inputs, const int input_size, const ccv_nnc_micro_io_t* const outputs, const int output_size, ccv_array_t* const blocks, const int* const groups, khash_t(ccv_nnc_axis_id_group)kh_ccv_nnc_axis_id_group_t* const axis_id_groups) | ||||
| 782 | { | ||||
| 783 | int i, j; | ||||
| 784 | // These are simple programs, so we are going to loop over all blocks to see whether a non-output-input | ||||
| 785 | // var only write / read in one loop. If that is the case, we are going to remove that var. | ||||
| 786 | // We have to do this replacement from bottom to top though. | ||||
| 787 | for (i = 0; i < var_count; i++) | ||||
| 788 | { | ||||
| 789 | int flag = 0; | ||||
| 790 | for (j = 0; !flag && j < input_size; j++) | ||||
| 791 | flag = (inputs[j]->id == i); | ||||
| 792 | for (j = 0; !flag && j < output_size; j++) | ||||
| 793 | flag = (outputs[j]->id == i); | ||||
| 794 | if (flag) // This is in outputs or inputs. | ||||
| 795 | continue; | ||||
| 796 | int count_var = 0; | ||||
| 797 | ccv_nnc_micro_loop_variable_t lvalue; | ||||
| 798 | ccv_nnc_micro_loop_expression_t rvalue; | ||||
| 799 | int block_idx, loop_idx, statement_idx; | ||||
| 800 | for (j = 0; j < blocks->rnum; j++) | ||||
| 801 | { | ||||
| 802 | const ccv_nnc_micro_loop_block_t* const block = (ccv_nnc_micro_loop_block_t*)ccv_array_get(blocks, j)((void*)(((char*)((blocks)->data)) + (size_t)(blocks)-> rsize * (size_t)(j))); | ||||
| 803 | int k, l; | ||||
| 804 | const int loop_count = block->loop_count; | ||||
| 805 | const ccv_nnc_micro_loop_t* const loops = block->loops; | ||||
| 806 | int var_per_block = 0; | ||||
| 807 | for (k = 0; k < loop_count; k++) | ||||
| 808 | { | ||||
| 809 | int flag = 0; | ||||
| 810 | const int statement_count = loops[k].statement_count; | ||||
| 811 | ccv_nnc_micro_loop_statement_t* const statements = loops[k].statements; | ||||
| 812 | for (l = 0; l < statement_count; l++) | ||||
| 813 | if (statements[l].type == CCV_NNC_MICRO_LOOP_STATEMENT_TYPE_ASSIGNMENT && | ||||
| 814 | statements[l].assignment.lvalue.id.type == CCV_NNC_MICRO_TENSOR_ID && | ||||
| 815 | statements[l].assignment.lvalue.id.id == i) | ||||
| 816 | { | ||||
| 817 | lvalue = statements[l].assignment.lvalue; | ||||
| 818 | if (_ccv_nnc_only_var_in_rvalue(i, lvalue, statements[l], groups, axis_id_groups)) | ||||
| 819 | flag = 2; | ||||
| 820 | else { | ||||
| 821 | // If the variable not showing up on the right-side, we can continue. | ||||
| 822 | rvalue = statements[l].assignment.rvalue; | ||||
| 823 | block_idx = j; | ||||
| 824 | loop_idx = k; | ||||
| 825 | statement_idx = l; | ||||
| 826 | ++flag; | ||||
| 827 | } | ||||
| 828 | } else if (statements[l].type == CCV_NNC_MICRO_LOOP_STATEMENT_TYPE_COMPOUND_ASSIGNMENT && | ||||
| 829 | statements[l].compound_assignment.lvalue.id.type == CCV_NNC_MICRO_TENSOR_ID && | ||||
| 830 | statements[l].compound_assignment.lvalue.id.id == i) { | ||||
| 831 | // This is compound assignment, automatically increase by 2. | ||||
| 832 | flag += 2; | ||||
| 833 | } | ||||
| 834 | if (flag > 1) // We have more than 1 assignment for this id, it is not good. We cannot remove it. | ||||
| 835 | { | ||||
| 836 | var_per_block += flag; | ||||
| 837 | continue; | ||||
| 838 | } | ||||
| 839 | for (l = 0; l < statement_count; l++) | ||||
| 840 | flag = ccv_max(flag, _ccv_nnc_only_var_in_rvalue(i, lvalue, statements[l], groups, axis_id_groups))({ typeof (flag) _a = (flag); typeof (_ccv_nnc_only_var_in_rvalue (i, lvalue, statements[l], groups, axis_id_groups)) _b = (_ccv_nnc_only_var_in_rvalue (i, lvalue, statements[l], groups, axis_id_groups)); (_a > _b) ? _a : _b; }); | ||||
| 841 | // If flag == 2, meaning it found a var with a different index. This is a bad news. | ||||
| 842 | var_per_block += flag; | ||||
| 843 | } | ||||
| 844 | count_var += var_per_block; | ||||
| 845 | } | ||||
| 846 | // If this is used more than one place (write multiple times, have different index, or used in different blocks), | ||||
| 847 | // I cannot get rid of it. | ||||
| 848 | if (count_var != 1) | ||||
| 849 | continue; | ||||
| 850 | // Otherwise, now loop again and prepare to get rid of it. | ||||
| 851 | ccv_nnc_micro_loop_block_t* const block = (ccv_nnc_micro_loop_block_t*)ccv_array_get(blocks, block_idx)((void*)(((char*)((blocks)->data)) + (size_t)(blocks)-> rsize * (size_t)(block_idx))); | ||||
| 852 | ccv_nnc_micro_loop_statement_t* statements = block->loops[loop_idx].statements; | ||||
| 853 | ccv_nnc_micro_loop_statement_t statement = statements[statement_idx]; | ||||
| 854 | // First, remove the assignment. | ||||
| 855 | if (statement_idx < block->loops[loop_idx].statement_count - 1) | ||||
| 856 | memmove(statements + statement_idx, statements + statement_idx + 1, sizeof(ccv_nnc_micro_loop_statement_t) * (block->loops[loop_idx].statement_count - statement_idx - 1)); | ||||
| 857 | --block->loops[loop_idx].statement_count; | ||||
| 858 | const int statement_count = block->loops[loop_idx].statement_count; | ||||
| 859 | statements = block->loops[loop_idx].statements = (ccv_nnc_micro_loop_statement_t*)ccreallocrealloc(statements, sizeof(ccv_nnc_micro_loop_statement_t) * statement_count); | ||||
| 860 | int k = 0; | ||||
| 861 | for (j = 0; j < statement_count; j++) | ||||
| 862 | _ccv_nnc_replacing_id_in_rvalue(&statements[j], i, rvalue, &k); | ||||
| 863 | if (k == 0) // If nothing to replace, free up everything. | ||||
| 864 | ccv_nnc_micro_loop_statement_free(&statement); | ||||
| 865 | else | ||||
| 866 | ccv_nnc_micro_loop_statement_lvalue_free(&statement); | ||||
| 867 | // No need to allocate for this var. It is not used, only useful for shape computation. | ||||
| 868 | vars[i].no_alloc = 1; | ||||
| 869 | } | ||||
| 870 | } | ||||
| 871 | |||||
| 872 | static int _ccv_nnc_index_binary_size(const ccv_nnc_micro_loop_index_term_t index) | ||||
| 873 | { | ||||
| 874 | switch (index.type) | ||||
| 875 | { | ||||
| 876 | case CCV_NNC_MICRO_LOOP_INDEX_TYPE_NONE: | ||||
| 877 | return 0; | ||||
| 878 | case CCV_NNC_MICRO_LOOP_INDEX_TYPE_VAL: | ||||
| 879 | case CCV_NNC_MICRO_LOOP_INDEX_TYPE_ID: | ||||
| 880 | return 1; | ||||
| 881 | case CCV_NNC_MICRO_LOOP_INDEX_TYPE_BINARY: | ||||
| 882 | if (index.binary->op == CCV_NNC_MICRO_BINARY_OP_PLUS || index.binary->op == CCV_NNC_MICRO_BINARY_OP_MINUS) | ||||
| 883 | return _ccv_nnc_index_binary_size(index.binary->left) + _ccv_nnc_index_binary_size(index.binary->right); | ||||
| 884 | else | ||||
| 885 | return 1; | ||||
| 886 | } | ||||
| 887 | return 0; | ||||
| 888 | } | ||||
| 889 | |||||
| 890 | typedef struct { | ||||
| 891 | int sign:7; | ||||
| 892 | int ignore:1; | ||||
| 893 | ccv_nnc_micro_loop_index_term_t term; | ||||
| 894 | } ccv_nnc_micro_loop_binary_term_t; | ||||
| 895 | |||||
| 896 | static void _ccv_nnc_index_term_flatten(ccv_nnc_micro_loop_binary_term_t* const binary_terms, const ccv_nnc_micro_loop_index_term_t index, const int sign, int* const i) | ||||
| 897 | { | ||||
| 898 | switch (index.type) | ||||
| 899 | { | ||||
| 900 | case CCV_NNC_MICRO_LOOP_INDEX_TYPE_NONE: // No need to occupy. | ||||
| 901 | break; | ||||
| 902 | case CCV_NNC_MICRO_LOOP_INDEX_TYPE_VAL: | ||||
| 903 | case CCV_NNC_MICRO_LOOP_INDEX_TYPE_ID: | ||||
| 904 | binary_terms[*i].term = index; | ||||
| 905 | binary_terms[*i].sign = sign; | ||||
| 906 | binary_terms[*i].ignore = 0; | ||||
| 907 | ++(*i); | ||||
| 908 | break; | ||||
| 909 | case CCV_NNC_MICRO_LOOP_INDEX_TYPE_BINARY: | ||||
| 910 | if (index.binary->op == CCV_NNC_MICRO_BINARY_OP_PLUS || index.binary->op == CCV_NNC_MICRO_BINARY_OP_MINUS) | ||||
| 911 | { | ||||
| 912 | _ccv_nnc_index_term_flatten(binary_terms, index.binary->left, sign, i); | ||||
| 913 | if (index.binary->op == CCV_NNC_MICRO_BINARY_OP_MINUS) // Switch sign. | ||||
| 914 | _ccv_nnc_index_term_flatten(binary_terms, index.binary->right, sign == CCV_NNC_MICRO_BINARY_OP_PLUS ? CCV_NNC_MICRO_BINARY_OP_MINUS : CCV_NNC_MICRO_BINARY_OP_PLUS, i); | ||||
| 915 | else | ||||
| 916 | _ccv_nnc_index_term_flatten(binary_terms, index.binary->right, sign, i); | ||||
| 917 | } else { | ||||
| 918 | binary_terms[*i].term = index; | ||||
| 919 | binary_terms[*i].sign = sign; | ||||
| 920 | binary_terms[*i].ignore = 0; | ||||
| 921 | ++(*i); | ||||
| 922 | } | ||||
| 923 | break; | ||||
| 924 | } | ||||
| 925 | } | ||||
| 926 | |||||
| 927 | // 0 is we don't understand, -1 is false, 1 is true. | ||||
| 928 | static int _ccv_nnc_index_less_than_or_equal_to(const ccv_nnc_micro_loop_index_term_t left, const ccv_nnc_micro_loop_index_term_t right, const ccv_nnc_micro_tensor_t* const vars, const int var_count, const int* const groups, khash_t(ccv_nnc_axis_id_group)kh_ccv_nnc_axis_id_group_t* const axis_id_groups) | ||||
| 929 | { | ||||
| 930 | // Special case 1. | ||||
| 931 | if (left.type == CCV_NNC_MICRO_LOOP_INDEX_TYPE_VAL && right.type == CCV_NNC_MICRO_LOOP_INDEX_TYPE_VAL) | ||||
| 932 | return left.immediate_value <= right.immediate_value ? 1 : -1; | ||||
| 933 | // Special case 2. | ||||
| 934 | if (left.type == CCV_NNC_MICRO_LOOP_INDEX_TYPE_VAL && left.immediate_value == 0 && right.type == CCV_NNC_MICRO_LOOP_INDEX_TYPE_ID && right.id.type == CCV_NNC_MICRO_AXIS_SIZE_ID) | ||||
| 935 | return 1; | ||||
| 936 | // Special case 3. | ||||
| 937 | if (left.type == CCV_NNC_MICRO_LOOP_INDEX_TYPE_ID && left.id.type == CCV_NNC_MICRO_AXIS_SIZE_ID && right.type == CCV_NNC_MICRO_LOOP_INDEX_TYPE_VAL && right.immediate_value == 0) | ||||
| 938 | return -1; | ||||
| 939 | // Now, we only have one variable in both left and right, need to flat the binary tree (if possible) and reduce it to constant if possible. | ||||
| 940 | // We can only flatten if it is + / - at the moment. | ||||
| 941 | const int left_binary_size = _ccv_nnc_index_binary_size(left); | ||||
| 942 | assert(left_binary_size >= 1)((void) sizeof ((left_binary_size >= 1) ? 1 : 0), __extension__ ({ if (left_binary_size >= 1) ; else __assert_fail ("left_binary_size >= 1" , "ccv_nnc_micro_simplify.c", 942, __extension__ __PRETTY_FUNCTION__ ); })); | ||||
| 943 | const int right_binary_size = _ccv_nnc_index_binary_size(right); | ||||
| 944 | assert(right_binary_size >= 1)((void) sizeof ((right_binary_size >= 1) ? 1 : 0), __extension__ ({ if (right_binary_size >= 1) ; else __assert_fail ("right_binary_size >= 1" , "ccv_nnc_micro_simplify.c", 944, __extension__ __PRETTY_FUNCTION__ ); })); | ||||
| 945 | ccv_nnc_micro_loop_binary_term_t* const left_binary_terms = (ccv_nnc_micro_loop_binary_term_t*)ccmallocmalloc(sizeof(ccv_nnc_micro_loop_binary_term_t) * (left_binary_size + right_binary_size)); | ||||
| 946 | ccv_nnc_micro_loop_binary_term_t* const right_binary_terms = left_binary_terms + left_binary_size; | ||||
| 947 | int i, j; | ||||
| 948 | i = 0; | ||||
| 949 | _ccv_nnc_index_term_flatten(left_binary_terms, left, CCV_NNC_MICRO_BINARY_OP_PLUS, &i); | ||||
| 950 | assert(i == left_binary_size)((void) sizeof ((i == left_binary_size) ? 1 : 0), __extension__ ({ if (i == left_binary_size) ; else __assert_fail ("i == left_binary_size" , "ccv_nnc_micro_simplify.c", 950, __extension__ __PRETTY_FUNCTION__ ); })); | ||||
| 951 | i = 0; | ||||
| 952 | _ccv_nnc_index_term_flatten(right_binary_terms, right, CCV_NNC_MICRO_BINARY_OP_PLUS, &i); | ||||
| 953 | assert(i == right_binary_size)((void) sizeof ((i == right_binary_size) ? 1 : 0), __extension__ ({ if (i == right_binary_size) ; else __assert_fail ("i == right_binary_size" , "ccv_nnc_micro_simplify.c", 953, __extension__ __PRETTY_FUNCTION__ ); })); | ||||
| 954 | // Matching signs in left terms. | ||||
| 955 | for (i = 0; i < left_binary_size - 1; i++) | ||||
| 956 | for (j = i + 1; j < left_binary_size; j++) | ||||
| 957 | if (!left_binary_terms[i].ignore && !left_binary_terms[j].ignore && | ||||
| 958 | _ccv_nnc_same_index_term(left_binary_terms[i].term, left_binary_terms[j].term, groups, axis_id_groups) && | ||||
| 959 | left_binary_terms[i].sign != left_binary_terms[j].sign) | ||||
| 960 | { | ||||
| 961 | left_binary_terms[i].ignore = -1; | ||||
| 962 | left_binary_terms[j].ignore = -1; | ||||
| 963 | } | ||||
| 964 | // Matching signs in right terms. | ||||
| 965 | for (i = 0; i < right_binary_size - 1; i++) | ||||
| 966 | for (j = i + 1; j < right_binary_size; j++) | ||||
| 967 | if (!right_binary_terms[i].ignore && !right_binary_terms[j].ignore && | ||||
| 968 | _ccv_nnc_same_index_term(right_binary_terms[i].term, right_binary_terms[j].term, groups, axis_id_groups) && | ||||
| 969 | right_binary_terms[i].sign != right_binary_terms[j].sign) | ||||
| 970 | { | ||||
| 971 | right_binary_terms[i].ignore = -1; | ||||
| 972 | right_binary_terms[j].ignore = -1; | ||||
| 973 | } | ||||
| 974 | // Matching left to right. | ||||
| 975 | for (i = 0; i < left_binary_size; i++) | ||||
| 976 | for (j = 0; j < right_binary_size; j++) | ||||
| 977 | // If they are the same, we can ignore now. | ||||
| 978 | if (!left_binary_terms[i].ignore && !right_binary_terms[j].ignore && | ||||
| 979 | _ccv_nnc_same_index_term(left_binary_terms[i].term, right_binary_terms[j].term, groups, axis_id_groups) && | ||||
| 980 | left_binary_terms[i].sign == right_binary_terms[j].sign) | ||||
| 981 | { | ||||
| 982 | left_binary_terms[i].ignore = -1; | ||||
| 983 | right_binary_terms[j].ignore = -1; | ||||
| 984 | } | ||||
| 985 | // After reduced, we should only have immediate values left, otherwise we cannot progress. | ||||
| 986 | int left_val = 0; | ||||
| 987 | for (i = 0; i < left_binary_size; i++) | ||||
| 988 | if (!left_binary_terms[i].ignore) | ||||
| 989 | { | ||||
| 990 | if (left_binary_terms[i].term.type != CCV_NNC_MICRO_LOOP_INDEX_TYPE_VAL) | ||||
| 991 | { | ||||
| 992 | ccfreefree(left_binary_terms); | ||||
| 993 | return 0; | ||||
| 994 | } else | ||||
| 995 | left_val += left_binary_terms[i].sign == CCV_NNC_MICRO_BINARY_OP_PLUS ? left_binary_terms[i].term.immediate_value : -left_binary_terms[i].term.immediate_value; | ||||
| 996 | } | ||||
| 997 | int right_val = 0; | ||||
| 998 | for (i = 0; i < right_binary_size; i++) | ||||
| 999 | if (!right_binary_terms[i].ignore) | ||||
| 1000 | { | ||||
| 1001 | if (right_binary_terms[i].term.type != CCV_NNC_MICRO_LOOP_INDEX_TYPE_VAL) | ||||
| 1002 | { | ||||
| 1003 | ccfreefree(left_binary_terms); | ||||
| 1004 | return 0; | ||||
| 1005 | } else | ||||
| 1006 | right_val += right_binary_terms[i].sign == CCV_NNC_MICRO_BINARY_OP_PLUS ? right_binary_terms[i].term.immediate_value : -right_binary_terms[i].term.immediate_value; | ||||
| 1007 | } | ||||
| 1008 | ccfreefree(left_binary_terms); | ||||
| 1009 | return left_val <= right_val ? 1 : -1; | ||||
| 1010 | } | ||||
| 1011 | |||||
| 1012 | // If this index term refers to an axis size that actually has a expression, refer to that instead (like for reindex operation). | ||||
| 1013 | static ccv_nnc_micro_loop_index_term_t _ccv_nnc_micro_index_shape_merging(const ccv_nnc_micro_loop_index_term_t index, const ccv_nnc_micro_tensor_t* const vars, const int var_count, const int* const groups, khash_t(ccv_nnc_axis_id_group)kh_ccv_nnc_axis_id_group_t* const axis_id_groups) | ||||
| 1014 | { | ||||
| 1015 | ccv_nnc_micro_loop_index_term_t result = index; | ||||
| 1016 | for (;;) | ||||
| 1017 | { | ||||
| 1018 | if (!(result.type == CCV_NNC_MICRO_LOOP_INDEX_TYPE_ID && result.id.type == CCV_NNC_MICRO_AXIS_SIZE_ID)) | ||||
| 1019 | return result; | ||||
| 1020 | int root = groups[result.id.id]; | ||||
| 1021 | while (groups[root] != root) | ||||
| 1022 | root = groups[root]; | ||||
| 1023 | if (vars[root].shape == 0) | ||||
| 1024 | return result; | ||||
| 1025 | assert(result.id.d >= 0 && result.id.d < vars[root].dimensions)((void) sizeof ((result.id.d >= 0 && result.id.d < vars[root].dimensions) ? 1 : 0), __extension__ ({ if (result .id.d >= 0 && result.id.d < vars[root].dimensions ) ; else __assert_fail ("result.id.d >= 0 && result.id.d < vars[root].dimensions" , "ccv_nnc_micro_simplify.c", 1025, __extension__ __PRETTY_FUNCTION__ ); })); | ||||
| 1026 | result = vars[root].shape[result.id.d]; | ||||
| 1027 | } | ||||
| 1028 | } | ||||
| 1029 | |||||
| 1030 | static int _ccv_nnc_micro_low_high_bound_from_index(const ccv_nnc_micro_loop_index_term_t index, ccv_nnc_micro_loop_index_term_t* const low_ref, ccv_nnc_micro_loop_index_term_t* const high_ref, const ccv_nnc_micro_loop_t* const loops, const int loop_count, const ccv_nnc_micro_tensor_t* const vars, const int var_count, const int* const groups, khash_t(ccv_nnc_axis_id_group)kh_ccv_nnc_axis_id_group_t* const axis_id_groups) | ||||
| 1031 | { | ||||
| 1032 | switch (index.type) | ||||
| 1033 | { | ||||
| 1034 | case CCV_NNC_MICRO_LOOP_INDEX_TYPE_NONE: | ||||
| 1035 | *low_ref = (ccv_nnc_micro_loop_index_term_t){ | ||||
| 1036 | .type = CCV_NNC_MICRO_LOOP_INDEX_TYPE_VAL, | ||||
| 1037 | .immediate_value = 0 | ||||
| 1038 | }; | ||||
| 1039 | *high_ref = (ccv_nnc_micro_loop_index_term_t){ | ||||
| 1040 | .type = CCV_NNC_MICRO_LOOP_INDEX_TYPE_VAL, | ||||
| 1041 | .immediate_value = 0 | ||||
| 1042 | }; | ||||
| 1043 | return 1; | ||||
| 1044 | case CCV_NNC_MICRO_LOOP_INDEX_TYPE_ID: | ||||
| 1045 | if (index.id.type == CCV_NNC_MICRO_LOOP_ID) | ||||
| 1046 | { | ||||
| 1047 | int loop_idx = -1; | ||||
| 1048 | int i; | ||||
| 1049 | for (i = 0; loop_idx < 0 && i < loop_count; i++) | ||||
| 1050 | if (loops[i].id.id == index.id.id) | ||||
| 1051 | loop_idx = i; | ||||
| 1052 | assert(loop_idx >= 0)((void) sizeof ((loop_idx >= 0) ? 1 : 0), __extension__ ({ if (loop_idx >= 0) ; else __assert_fail ("loop_idx >= 0" , "ccv_nnc_micro_simplify.c", 1052, __extension__ __PRETTY_FUNCTION__ ); })); | ||||
| 1053 | const ccv_nnc_micro_loop_index_term_t start_index = _ccv_nnc_micro_index_shape_merging(loops[loop_idx].start_index, vars, var_count, groups, axis_id_groups); | ||||
| 1054 | const ccv_nnc_micro_loop_index_term_t end_index = _ccv_nnc_micro_index_shape_merging(loops[loop_idx].end_index, vars, var_count, groups, axis_id_groups); | ||||
| 1055 | *low_ref = ccv_nnc_micro_loop_index_deep_copy(&start_index); | ||||
| 1056 | *high_ref = ccv_nnc_micro_loop_index_deep_copy(&end_index); | ||||
| 1057 | } else { | ||||
| 1058 | *low_ref = index; | ||||
| 1059 | *high_ref = index; | ||||
| 1060 | } | ||||
| 1061 | return 1; | ||||
| 1062 | case CCV_NNC_MICRO_LOOP_INDEX_TYPE_VAL: | ||||
| 1063 | *low_ref = index; | ||||
| 1064 | *high_ref = index; | ||||
| 1065 | return 1; | ||||
| 1066 | case CCV_NNC_MICRO_LOOP_INDEX_TYPE_BINARY: { | ||||
| 1067 | // Get low, high from both left and right, and then construct new low / high. | ||||
| 1068 | ccv_nnc_micro_loop_index_term_t left_low, left_high; | ||||
| 1069 | if (!_ccv_nnc_micro_low_high_bound_from_index(index.binary->left, &left_low, &left_high, loops, loop_count, vars, var_count, groups, axis_id_groups)) | ||||
| 1070 | return 0; | ||||
| 1071 | ccv_nnc_micro_loop_index_term_t right_low, right_high; | ||||
| 1072 | if (!_ccv_nnc_micro_low_high_bound_from_index(index.binary->right, &right_low, &right_high, loops, loop_count, vars, var_count, groups, axis_id_groups)) | ||||
| 1073 | { | ||||
| 1074 | ccv_nnc_micro_loop_index_free(&left_low); | ||||
| 1075 | ccv_nnc_micro_loop_index_free(&left_high); | ||||
| 1076 | return 0; | ||||
| 1077 | } | ||||
| 1078 | // If left is not a range, or right is not a range, it is simple, just copy over. | ||||
| 1079 | if (_ccv_nnc_same_index_term(left_low, left_high, groups, axis_id_groups) || _ccv_nnc_same_index_term(right_low, right_high, groups, axis_id_groups)) | ||||
| 1080 | { | ||||
| 1081 | *low_ref = (ccv_nnc_micro_loop_index_term_t){ | ||||
| 1082 | .type = CCV_NNC_MICRO_LOOP_INDEX_TYPE_BINARY, | ||||
| 1083 | .binary = (ccv_nnc_micro_loop_index_binary_t*)ccmallocmalloc(sizeof(ccv_nnc_micro_loop_index_binary_t)) | ||||
| 1084 | }; | ||||
| 1085 | low_ref->binary->op = index.binary->op; | ||||
| 1086 | low_ref->binary->left = left_low; | ||||
| 1087 | low_ref->binary->right = right_low; | ||||
| 1088 | *high_ref = (ccv_nnc_micro_loop_index_term_t){ | ||||
| 1089 | .type = CCV_NNC_MICRO_LOOP_INDEX_TYPE_BINARY, | ||||
| 1090 | .binary = (ccv_nnc_micro_loop_index_binary_t*)ccmallocmalloc(sizeof(ccv_nnc_micro_loop_index_binary_t)) | ||||
| 1091 | }; | ||||
| 1092 | high_ref->binary->op = index.binary->op; | ||||
| 1093 | high_ref->binary->left = left_high; | ||||
| 1094 | high_ref->binary->right = right_high; | ||||
| 1095 | return 1; | ||||
| 1096 | } | ||||
| 1097 | // Cannot handle -, because lower bound will go to negative, similar for /. Only can handle + and *. | ||||
| 1098 | if (!(index.binary->op == CCV_NNC_MICRO_BINARY_OP_PLUS || index.binary->op == CCV_NNC_MICRO_BINARY_OP_MUL) || | ||||
| 1099 | // If lower bound is not a non-negative integer, we cannot compute interesting low / high bound, abort. | ||||
| 1100 | (left_low.type != CCV_NNC_MICRO_LOOP_INDEX_TYPE_VAL || left_low.immediate_value < 0) || | ||||
| 1101 | (right_low.type != CCV_NNC_MICRO_LOOP_INDEX_TYPE_VAL || right_low.immediate_value < 0)) | ||||
| 1102 | { | ||||
| 1103 | ccv_nnc_micro_loop_index_free(&left_low); | ||||
| 1104 | ccv_nnc_micro_loop_index_free(&left_high); | ||||
| 1105 | ccv_nnc_micro_loop_index_free(&right_low); | ||||
| 1106 | ccv_nnc_micro_loop_index_free(&right_high); | ||||
| 1107 | return 0; | ||||
| 1108 | } | ||||
| 1109 | *low_ref = (ccv_nnc_micro_loop_index_term_t){ | ||||
| 1110 | .type = CCV_NNC_MICRO_LOOP_INDEX_TYPE_VAL, | ||||
| 1111 | .immediate_value = index.binary->op == CCV_NNC_MICRO_BINARY_OP_PLUS ? left_low.immediate_value + right_low.immediate_value : left_low.immediate_value * right_low.immediate_value, | ||||
| 1112 | }; | ||||
| 1113 | // higher bound is not inclusive, hence, we need to minus extra 1 for this. | ||||
| 1114 | if (index.binary->op == CCV_NNC_MICRO_BINARY_OP_PLUS) | ||||
| 1115 | { | ||||
| 1116 | // (left - 1) + (right - 1) + 1 | ||||
| 1117 | ccv_nnc_micro_loop_index_term_t sum = { | ||||
| 1118 | .type = CCV_NNC_MICRO_LOOP_INDEX_TYPE_BINARY, | ||||
| 1119 | .binary = (ccv_nnc_micro_loop_index_binary_t*)ccmallocmalloc(sizeof(ccv_nnc_micro_loop_index_binary_t)) | ||||
| 1120 | }; | ||||
| 1121 | sum.binary->op = CCV_NNC_MICRO_BINARY_OP_PLUS; | ||||
| 1122 | sum.binary->left = left_high; | ||||
| 1123 | sum.binary->right = right_high; | ||||
| 1124 | *high_ref = (ccv_nnc_micro_loop_index_term_t){ | ||||
| 1125 | .type = CCV_NNC_MICRO_LOOP_INDEX_TYPE_BINARY, | ||||
| 1126 | .binary = (ccv_nnc_micro_loop_index_binary_t*)ccmallocmalloc(sizeof(ccv_nnc_micro_loop_index_binary_t)) | ||||
| 1127 | }; | ||||
| 1128 | high_ref->binary->op = CCV_NNC_MICRO_BINARY_OP_MINUS; | ||||
| 1129 | high_ref->binary->left = sum; | ||||
| 1130 | high_ref->binary->right = (ccv_nnc_micro_loop_index_term_t){ | ||||
| 1131 | .type = CCV_NNC_MICRO_LOOP_INDEX_TYPE_VAL, | ||||
| 1132 | .immediate_value = 1 | ||||
| 1133 | }; | ||||
| 1134 | } else { | ||||
| 1135 | // (left - 1) * (right - 1) + 1 | ||||
| 1136 | ccv_nnc_micro_loop_index_term_t prod = { | ||||
| 1137 | .type = CCV_NNC_MICRO_LOOP_INDEX_TYPE_BINARY, | ||||
| 1138 | .binary = (ccv_nnc_micro_loop_index_binary_t*)ccmallocmalloc(sizeof(ccv_nnc_micro_loop_index_binary_t)) | ||||
| 1139 | }; | ||||
| 1140 | prod.binary->op = CCV_NNC_MICRO_BINARY_OP_MUL; | ||||
| 1141 | prod.binary->left = left_high; | ||||
| 1142 | prod.binary->right = right_high; | ||||
| 1143 | ccv_nnc_micro_loop_index_term_t minus_left = { | ||||
| 1144 | .type = CCV_NNC_MICRO_LOOP_INDEX_TYPE_BINARY, | ||||
| 1145 | .binary = (ccv_nnc_micro_loop_index_binary_t*)ccmallocmalloc(sizeof(ccv_nnc_micro_loop_index_binary_t)) | ||||
| 1146 | }; | ||||
| 1147 | minus_left.binary->op = CCV_NNC_MICRO_BINARY_OP_MINUS; | ||||
| 1148 | minus_left.binary->left = prod; | ||||
| 1149 | minus_left.binary->right = left_high; | ||||
| 1150 | ccv_nnc_micro_loop_index_term_t minus_right = { | ||||
| 1151 | .type = CCV_NNC_MICRO_LOOP_INDEX_TYPE_BINARY, | ||||
| 1152 | .binary = (ccv_nnc_micro_loop_index_binary_t*)ccmallocmalloc(sizeof(ccv_nnc_micro_loop_index_binary_t)) | ||||
| 1153 | }; | ||||
| 1154 | minus_right.binary->op = CCV_NNC_MICRO_BINARY_OP_MINUS; | ||||
| 1155 | minus_right.binary->left = minus_left; | ||||
| 1156 | minus_right.binary->right = right_high; | ||||
| 1157 | *high_ref = (ccv_nnc_micro_loop_index_term_t){ | ||||
| 1158 | .type = CCV_NNC_MICRO_LOOP_INDEX_TYPE_BINARY, | ||||
| 1159 | .binary = (ccv_nnc_micro_loop_index_binary_t*)ccmallocmalloc(sizeof(ccv_nnc_micro_loop_index_binary_t)) | ||||
| 1160 | }; | ||||
| 1161 | high_ref->binary->op = CCV_NNC_MICRO_BINARY_OP_PLUS; | ||||
| 1162 | high_ref->binary->left = minus_right; | ||||
| 1163 | high_ref->binary->right = (ccv_nnc_micro_loop_index_term_t){ | ||||
| 1164 | .type = CCV_NNC_MICRO_LOOP_INDEX_TYPE_VAL, | ||||
| 1165 | .immediate_value = 2 | ||||
| 1166 | }; | ||||
| 1167 | } | ||||
| 1168 | return 1; | ||||
| 1169 | } | ||||
| 1170 | } | ||||
| 1171 | return 0; | ||||
| 1172 | } | ||||
| 1173 | |||||
| 1174 | static void _ccv_nnc_micro_check_bound_for_variable(ccv_nnc_micro_loop_variable_t* const variable, const ccv_nnc_micro_loop_t* const loops, const int loop_count, const ccv_nnc_micro_tensor_t* const vars, const int var_count, const int* const groups, khash_t(ccv_nnc_axis_id_group)kh_ccv_nnc_axis_id_group_t* const axis_id_groups) | ||||
| 1175 | { | ||||
| 1176 | if (variable->id.type != CCV_NNC_MICRO_TENSOR_ID) | ||||
| 1177 | return; | ||||
| 1178 | int i, j; | ||||
| 1179 | assert(variable->id.id >= 0 && variable->id.id < var_count)((void) sizeof ((variable->id.id >= 0 && variable ->id.id < var_count) ? 1 : 0), __extension__ ({ if (variable ->id.id >= 0 && variable->id.id < var_count ) ; else __assert_fail ("variable->id.id >= 0 && variable->id.id < var_count" , "ccv_nnc_micro_simplify.c", 1179, __extension__ __PRETTY_FUNCTION__ ); })); | ||||
| 1180 | ccv_nnc_micro_loop_index_term_t index_zero = { | ||||
| 1181 | .type = CCV_NNC_MICRO_LOOP_INDEX_TYPE_VAL, | ||||
| 1182 | .immediate_value = 0 | ||||
| 1183 | }; | ||||
| 1184 | for (i = 0; i < variable->index_count; i++) | ||||
| 1185 | { | ||||
| 1186 | const ccv_nnc_micro_loop_index_term_t shape = _ccv_nnc_micro_index_shape_merging((ccv_nnc_micro_loop_index_term_t){ | ||||
| 1187 | .type = CCV_NNC_MICRO_LOOP_INDEX_TYPE_ID, | ||||
| 1188 | .id = { | ||||
| 1189 | .type = CCV_NNC_MICRO_AXIS_SIZE_ID, | ||||
| 1190 | .id = variable->id.id, | ||||
| 1191 | .d = i | ||||
| 1192 | } | ||||
| 1193 | }, vars, var_count, groups, axis_id_groups); | ||||
| 1194 | switch (variable->index[i].type) | ||||
| 1195 | { | ||||
| 1196 | case CCV_NNC_MICRO_LOOP_INDEX_TYPE_ID: | ||||
| 1197 | // For loop id, we can check the range to see if it is within the shape. | ||||
| 1198 | if (variable->index[i].id.type == CCV_NNC_MICRO_LOOP_ID) | ||||
| 1199 | { | ||||
| 1200 | int loop_idx = -1; | ||||
| 1201 | for (j = 0; loop_idx < 0 && j < loop_count; j++) | ||||
| 1202 | if (loops[j].id.id == variable->index[i].id.id) | ||||
| 1203 | loop_idx = j; | ||||
| 1204 | assert(loop_idx >= 0)((void) sizeof ((loop_idx >= 0) ? 1 : 0), __extension__ ({ if (loop_idx >= 0) ; else __assert_fail ("loop_idx >= 0" , "ccv_nnc_micro_simplify.c", 1204, __extension__ __PRETTY_FUNCTION__ ); })); | ||||
| 1205 | const ccv_nnc_micro_loop_index_term_t start_index = _ccv_nnc_micro_index_shape_merging(loops[loop_idx].start_index, vars, var_count, groups, axis_id_groups); | ||||
| 1206 | const ccv_nnc_micro_loop_index_term_t end_index = _ccv_nnc_micro_index_shape_merging(loops[loop_idx].end_index, vars, var_count, groups, axis_id_groups); | ||||
| 1207 | if (_ccv_nnc_index_less_than_or_equal_to(index_zero, start_index, vars, var_count, groups, axis_id_groups) == 1 && | ||||
| 1208 | _ccv_nnc_index_less_than_or_equal_to(end_index, shape, vars, var_count, groups, axis_id_groups) == 1) | ||||
| 1209 | variable->no_check_bound[i] = 1; | ||||
| 1210 | else | ||||
| 1211 | variable->no_check_bound[i] = 0; | ||||
| 1212 | } else // If it is anything other than loop id, we have to check the bound. | ||||
| 1213 | variable->no_check_bound[i] = 0; | ||||
| 1214 | break; | ||||
| 1215 | case CCV_NNC_MICRO_LOOP_INDEX_TYPE_BINARY: { | ||||
| 1216 | // Compute higher / lower bounds along the expression. | ||||
| 1217 | ccv_nnc_micro_loop_index_term_t low, high; | ||||
| 1218 | // Cannot find high low, mark no_check_bound[i] = 0 | ||||
| 1219 | if (!_ccv_nnc_micro_low_high_bound_from_index(variable->index[i], &low, &high, loops, loop_count, vars, var_count, groups, axis_id_groups)) | ||||
| 1220 | { | ||||
| 1221 | variable->no_check_bound[i] = 0; | ||||
| 1222 | break; | ||||
| 1223 | } | ||||
| 1224 | if (_ccv_nnc_index_less_than_or_equal_to(index_zero, low, vars, var_count, groups, axis_id_groups) == 1 && | ||||
| 1225 | _ccv_nnc_index_less_than_or_equal_to(high, shape, vars, var_count, groups, axis_id_groups) == 1) | ||||
| 1226 | variable->no_check_bound[i] = 1; | ||||
| 1227 | else | ||||
| 1228 | variable->no_check_bound[i] = 0; | ||||
| 1229 | ccv_nnc_micro_loop_index_free(&low); | ||||
| 1230 | ccv_nnc_micro_loop_index_free(&high); | ||||
| 1231 | break; | ||||
| 1232 | } | ||||
| 1233 | case CCV_NNC_MICRO_LOOP_INDEX_TYPE_VAL: | ||||
| 1234 | // If the index is an integer, and it is bigger than 0, we need to check bound (there is no assertion the end index is larger than anything other than 0). | ||||
| 1235 | if (variable->index[i].immediate_value == 0) | ||||
| 1236 | variable->no_check_bound[i] = 1; | ||||
| 1237 | else | ||||
| 1238 | variable->no_check_bound[i] = 0; | ||||
| 1239 | break; | ||||
| 1240 | } | ||||
| 1241 | } | ||||
| 1242 | } | ||||
| 1243 | |||||
| 1244 | static void _ccv_nnc_micro_check_bound_for_expression(ccv_nnc_micro_loop_expression_t* const expression, const ccv_nnc_micro_loop_t* const loops, const int loop_count, const ccv_nnc_micro_tensor_t* const vars, const int var_count, const int* const groups, khash_t(ccv_nnc_axis_id_group)kh_ccv_nnc_axis_id_group_t* const axis_id_groups) | ||||
| 1245 | { | ||||
| 1246 | switch (expression->type) | ||||
| 1247 | { | ||||
| 1248 | case CCV_NNC_MICRO_LOOP_EXPR_TYPE_VAR: | ||||
| 1249 | _ccv_nnc_micro_check_bound_for_variable(&expression->variable, loops, loop_count, vars, var_count, groups, axis_id_groups); | ||||
| 1250 | break; | ||||
| 1251 | case CCV_NNC_MICRO_LOOP_EXPR_TYPE_TERNAY: | ||||
| 1252 | _ccv_nnc_micro_check_bound_for_expression(expression->ternary.pivot, loops, loop_count, vars, var_count, groups, axis_id_groups); | ||||
| 1253 | _ccv_nnc_micro_check_bound_for_expression(expression->ternary.left, loops, loop_count, vars, var_count, groups, axis_id_groups); | ||||
| 1254 | _ccv_nnc_micro_check_bound_for_expression(expression->ternary.right, loops, loop_count, vars, var_count, groups, axis_id_groups); | ||||
| 1255 | break; | ||||
| 1256 | case CCV_NNC_MICRO_LOOP_EXPR_TYPE_BINARY: | ||||
| 1257 | _ccv_nnc_micro_check_bound_for_expression(expression->binary.left, loops, loop_count, vars, var_count, groups, axis_id_groups); | ||||
| 1258 | _ccv_nnc_micro_check_bound_for_expression(expression->binary.right, loops, loop_count, vars, var_count, groups, axis_id_groups); | ||||
| 1259 | break; | ||||
| 1260 | case CCV_NNC_MICRO_LOOP_EXPR_TYPE_UNARY: | ||||
| 1261 | _ccv_nnc_micro_check_bound_for_expression(expression->unary.x, loops, loop_count, vars, var_count, groups, axis_id_groups); | ||||
| 1262 | break; | ||||
| 1263 | } | ||||
| 1264 | } | ||||
| 1265 | |||||
| 1266 | static void _ccv_nnc_micro_check_bound_for_block(ccv_nnc_micro_loop_block_t* const block, const ccv_nnc_micro_tensor_t* const vars, const int var_count, const int* const groups, khash_t(ccv_nnc_axis_id_group)kh_ccv_nnc_axis_id_group_t* const axis_id_groups) | ||||
| 1267 | { | ||||
| 1268 | int i, j; | ||||
| 1269 | for (i = 0; i < block->loop_count; i++) | ||||
| 1270 | { | ||||
| 1271 | const int statement_count = block->loops[i].statement_count; | ||||
| 1272 | ccv_nnc_micro_loop_statement_t* const statements = block->loops[i].statements; | ||||
| 1273 | for (j = 0; j < statement_count; j++) | ||||
| 1274 | { | ||||
| 1275 | switch (statements[j].type) | ||||
| 1276 | { | ||||
| 1277 | case CCV_NNC_MICRO_LOOP_STATEMENT_TYPE_ASSIGNMENT: | ||||
| 1278 | _ccv_nnc_micro_check_bound_for_variable(&statements[j].assignment.lvalue, block->loops, block->loop_count, vars, var_count, groups, axis_id_groups); | ||||
| 1279 | _ccv_nnc_micro_check_bound_for_expression(&statements[j].assignment.rvalue, block->loops, block->loop_count, vars, var_count, groups, axis_id_groups); | ||||
| 1280 | break; | ||||
| 1281 | case CCV_NNC_MICRO_LOOP_STATEMENT_TYPE_COMPOUND_ASSIGNMENT: | ||||
| 1282 | if (statements[j].compound_assignment.lvalue.type == CCV_NNC_MICRO_LOOP_EXPR_TYPE_VAR) | ||||
| 1283 | _ccv_nnc_micro_check_bound_for_variable(&statements[j].compound_assignment.lvalue.variable, block->loops, block->loop_count, vars, var_count, groups, axis_id_groups); | ||||
| 1284 | _ccv_nnc_micro_check_bound_for_expression(&statements[j].compound_assignment.rvalue, block->loops, block->loop_count, vars, var_count, groups, axis_id_groups); | ||||
| 1285 | break; | ||||
| 1286 | } | ||||
| 1287 | } | ||||
| 1288 | } | ||||
| 1289 | } | ||||
| 1290 | |||||
| 1291 | void ccv_nnc_micro_program_simplify(ccv_nnc_micro_program_t* const program, const ccv_nnc_micro_io_t* const inputs, const int input_size, const ccv_nnc_micro_io_t* const outputs, const int output_size, const ccv_array_t* const equal_assertions) | ||||
| 1292 | { | ||||
| 1293 | // Nothing to simplify for. | ||||
| 1294 | if (program->function_count < 1) | ||||
| |||||
| 1295 | return; | ||||
| 1296 | // Only one block, nothing to simplify for. | ||||
| 1297 | if (program->function_count == 1 && program->functions[0].block_count == 1) | ||||
| 1298 | return; | ||||
| 1299 | if (input_size == 0 || output_size == 0) | ||||
| 1300 | return; | ||||
| 1301 | // Union-find to group all variables with the same shape. | ||||
| 1302 | ccv_nnc_micro_tensor_t* const vars = program->vars; | ||||
| 1303 | const int var_count = program->var_count; | ||||
| 1304 | int* const groups = (int*)ccmallocmalloc(sizeof(int) * var_count); | ||||
| 1305 | int i, j; | ||||
| 1306 | for (i = 0; i < var_count; i++) | ||||
| 1307 | groups[i] = i; | ||||
| 1308 | // If no shape, they should match these input. | ||||
| 1309 | for (i = 0; i
| ||||
| 1310 | if (vars[i].input >= 0 && !vars[i].shape) | ||||
| 1311 | { | ||||
| 1312 | int root = vars[i].input; | ||||
| 1313 | while (groups[root] != root) | ||||
| 1314 | root = groups[root]; | ||||
| 1315 | groups[i] = root; | ||||
| 1316 | } | ||||
| 1317 | for (i = 0; i
| ||||
| 1318 | { | ||||
| 1319 | // If this is input (no other tensor as the input), we skip. | ||||
| 1320 | if (vars[i].input < 0) | ||||
| 1321 | continue; | ||||
| 1322 | int root = i; | ||||
| 1323 | while (groups[root] != root) | ||||
| 1324 | root = groups[root]; | ||||
| 1325 | // If the sibling exists and we haven't visited yet, mark them has the same group as us. | ||||
| 1326 | if (vars[i].sibling >= 0 && vars[i].sibling < i && groups[vars[i].sibling] < 0) | ||||
| 1327 | groups[vars[i].sibling] = root; | ||||
| 1328 | } | ||||
| 1329 | for (i = var_count - 1; i > 0; i--) | ||||
| 1330 | { | ||||
| 1331 | // Now matching the shape. | ||||
| 1332 | if (vars[i].input < 0 || !vars[i].shape) | ||||
| 1333 | continue; | ||||
| 1334 | int root = i; | ||||
| 1335 | while (groups[root] != root) | ||||
| 1336 | root = groups[root]; | ||||
| 1337 | for (j = i - 1; j >= 0; j--) | ||||
| 1338 | if (vars[j].shape && vars[j].dimensions == vars[i].dimensions && | ||||
| 1339 | _ccv_nnc_same_shape(vars[j].shape, vars[i].shape, vars[i].dimensions)) | ||||
| 1340 | groups[j] = root; | ||||
| 1341 | } | ||||
| 1342 | // Group equal assertions on axis together. | ||||
| 1343 | khash_t(ccv_nnc_axis_id_group)kh_ccv_nnc_axis_id_group_t* const axis_id_groups = kh_init(ccv_nnc_axis_id_group)kh_init_ccv_nnc_axis_id_group(); | ||||
| 1344 | for (i = 0; i < equal_assertions->rnum; i++) | ||||
| 1345 | { | ||||
| 1346 | const ccv_nnc_micro_id_equal_assertion_t* const equal_assertion = (ccv_nnc_micro_id_equal_assertion_t*)ccv_array_get(equal_assertions, i)((void*)(((char*)((equal_assertions)->data)) + (size_t)(equal_assertions )->rsize * (size_t)(i))); | ||||
| 1347 | ccv_nnc_micro_id_t left = equal_assertion->left; | ||||
| 1348 | while (groups[left.id] != left.id) | ||||
| 1349 | left.id = groups[left.id]; | ||||
| 1350 | int left_root = MICRO_ID_TO_INT(left)(((left).id << 8) | ((left).d)); | ||||
| 1351 | khiter_t k; | ||||
| 1352 | for (;;) { | ||||
| 1353 | k = kh_get(ccv_nnc_axis_id_group, axis_id_groups, left_root)kh_get_ccv_nnc_axis_id_group(axis_id_groups, left_root); | ||||
| 1354 | if (k
| ||||
| 1355 | break; | ||||
| 1356 | left_root = kh_val(axis_id_groups, k)((axis_id_groups)->vals[k]); | ||||
| 1357 | } | ||||
| 1358 | ccv_nnc_micro_id_t right = equal_assertion->right; | ||||
| 1359 | while (groups[right.id] != right.id) | ||||
| 1360 | left.id = groups[right.id]; | ||||
| 1361 | int right_root = MICRO_ID_TO_INT(equal_assertion->right)(((equal_assertion->right).id << 8) | ((equal_assertion ->right).d)); | ||||
| 1362 | for (;;) { | ||||
| 1363 | k = kh_get(ccv_nnc_axis_id_group, axis_id_groups, right_root)kh_get_ccv_nnc_axis_id_group(axis_id_groups, right_root); | ||||
| 1364 | if (k
| ||||
| 1365 | break; | ||||
| 1366 | right_root = kh_val(axis_id_groups, k)((axis_id_groups)->vals[k]); | ||||
| |||||
| 1367 | } | ||||
| 1368 | if (left_root != right_root) // k is the right root at the moment. | ||||
| 1369 | { | ||||
| 1370 | int ret; | ||||
| 1371 | k = kh_put(ccv_nnc_axis_id_group, axis_id_groups, right_root, &ret)kh_put_ccv_nnc_axis_id_group(axis_id_groups, right_root, & ret); | ||||
| 1372 | assert(ret != 0)((void) sizeof ((ret != 0) ? 1 : 0), __extension__ ({ if (ret != 0) ; else __assert_fail ("ret != 0", "ccv_nnc_micro_simplify.c" , 1372, __extension__ __PRETTY_FUNCTION__); })); | ||||
| 1373 | kh_val(axis_id_groups, k)((axis_id_groups)->vals[k]) = left_root; | ||||
| 1374 | } | ||||
| 1375 | } | ||||
| 1376 | // First, flat out all functions into blocks. | ||||
| 1377 | ccv_array_t* const blocks = ccv_array_new(sizeof(ccv_nnc_micro_loop_block_t), 0, 0); | ||||
| 1378 | ccv_nnc_micro_function_t* const functions = program->functions; | ||||
| 1379 | const int function_count = program->function_count; | ||||
| 1380 | int max_loop_count = 0; | ||||
| 1381 | for (i = 0; i < function_count; i++) | ||||
| 1382 | { | ||||
| 1383 | const int block_count = functions[i].block_count; | ||||
| 1384 | ccv_nnc_micro_loop_block_t* const function_blocks = block_count == 1 ? &functions[i].one_block : functions[i].blocks; | ||||
| 1385 | for (j = 0; j < block_count; j++) | ||||
| 1386 | { | ||||
| 1387 | max_loop_count = ccv_max(function_blocks[j].loop_count, max_loop_count)({ typeof (function_blocks[j].loop_count) _a = (function_blocks [j].loop_count); typeof (max_loop_count) _b = (max_loop_count ); (_a > _b) ? _a : _b; }); | ||||
| 1388 | ccv_array_push(blocks, &function_blocks[j]); | ||||
| 1389 | } | ||||
| 1390 | } | ||||
| 1391 | // Next, find dependencies between these function blocks and marking these that are dependencies for the final outputs. | ||||
| 1392 | // We need to build our connections between blocks <-> r/w vars. | ||||
| 1393 | ccv_nnc_micro_loop_block_dependency_t* block_dependencies; | ||||
| 1394 | ccv_nnc_micro_tensor_dependency_t* tensor_dependencies; | ||||
| 1395 | const int block_size = blocks->rnum; | ||||
| 1396 | _ccv_nnc_micro_block_dependencies((ccv_nnc_micro_loop_block_t*)ccv_array_get(blocks, 0)((void*)(((char*)((blocks)->data)) + (size_t)(blocks)-> rsize * (size_t)(0))), block_size, var_count, &block_dependencies, &tensor_dependencies); | ||||
| 1397 | ccv_array_t* const in_use = ccv_array_new(sizeof(int), output_size, 0); | ||||
| 1398 | // Use the dependencies to mark blocks / vars that are in use. | ||||
| 1399 | for (i = 0; i < output_size; i++) | ||||
| 1400 | { | ||||
| 1401 | tensor_dependencies[outputs[i]->id].flag = 1; // Mark them as in use. | ||||
| 1402 | ccv_array_push(in_use, &outputs[i]->id); | ||||
| 1403 | } | ||||
| 1404 | for (i = 0; i < input_size; i++) | ||||
| 1405 | tensor_dependencies[inputs[i]->id].flag = 1; // Mark inputs as in use so we don't go pass them. | ||||
| 1406 | for (i = 0; i < in_use->rnum; i++) | ||||
| 1407 | { | ||||
| 1408 | const int tensor_idx = *(int*)ccv_array_get(in_use, i)((void*)(((char*)((in_use)->data)) + (size_t)(in_use)-> rsize * (size_t)(i))); | ||||
| 1409 | if (tensor_dependencies[tensor_idx].writes) | ||||
| 1410 | for (j = 0; j < tensor_dependencies[tensor_idx].writes->rnum; j++) | ||||
| 1411 | { | ||||
| 1412 | const int block_idx = *(int*)ccv_array_get(tensor_dependencies[tensor_idx].writes, j)((void*)(((char*)((tensor_dependencies[tensor_idx].writes)-> data)) + (size_t)(tensor_dependencies[tensor_idx].writes)-> rsize * (size_t)(j))); | ||||
| 1413 | block_dependencies[block_idx].flag = 1; | ||||
| 1414 | int k; | ||||
| 1415 | if (block_dependencies[block_idx].reads) | ||||
| 1416 | for (k = 0; k < block_dependencies[block_idx].reads->rnum; k++) | ||||
| 1417 | { | ||||
| 1418 | const int read_idx = *(int*)ccv_array_get(block_dependencies[block_idx].reads, k)((void*)(((char*)((block_dependencies[block_idx].reads)->data )) + (size_t)(block_dependencies[block_idx].reads)->rsize * (size_t)(k))); | ||||
| 1419 | if (!tensor_dependencies[read_idx].flag) | ||||
| 1420 | { | ||||
| 1421 | tensor_dependencies[read_idx].flag = 1; | ||||
| 1422 | ccv_array_push(in_use, &read_idx); | ||||
| 1423 | } | ||||
| 1424 | } | ||||
| 1425 | } | ||||
| 1426 | } | ||||
| 1427 | ccv_array_free(in_use); | ||||
| 1428 | for (i = 0; i < block_size; i++) | ||||
| 1429 | if (!block_dependencies[i].flag) | ||||
| 1430 | { | ||||
| 1431 | ccv_nnc_micro_loop_block_t* const block = (ccv_nnc_micro_loop_block_t*)ccv_array_get(blocks, i)((void*)(((char*)((blocks)->data)) + (size_t)(blocks)-> rsize * (size_t)(i))); | ||||
| 1432 | ccv_nnc_micro_loops_free(block->loops, block->loop_count); | ||||
| 1433 | ccfreefree(block->loops); | ||||
| 1434 | block->loops = 0; | ||||
| 1435 | block->loop_count = 0; | ||||
| 1436 | } | ||||
| 1437 | for (i = 0; i < var_count; i++) | ||||
| 1438 | if (!tensor_dependencies[i].flag) // If this tensor is not visited, there is no need to alloc. | ||||
| 1439 | { | ||||
| 1440 | _ccv_nnc_tensor_remove_dead_store(&tensor_dependencies[i], i, blocks); | ||||
| 1441 | vars[i].no_alloc = 1; | ||||
| 1442 | } | ||||
| 1443 | _ccv_nnc_loop_merging(block_dependencies, tensor_dependencies, blocks, max_loop_count, groups, axis_id_groups); | ||||
| 1444 | _ccv_nnc_micro_dependencies_free(block_dependencies, block_size, tensor_dependencies, var_count); | ||||
| 1445 | // Culling out empty blocks. | ||||
| 1446 | for (i = 0, j = 0; i < blocks->rnum; i++) | ||||
| 1447 | { | ||||
| 1448 | const ccv_nnc_micro_loop_block_t* const block = (ccv_nnc_micro_loop_block_t*)ccv_array_get(blocks, i)((void*)(((char*)((blocks)->data)) + (size_t)(blocks)-> rsize * (size_t)(i))); | ||||
| 1449 | if (block->loop_count > 0) | ||||
| 1450 | { | ||||
| 1451 | *(ccv_nnc_micro_loop_block_t*)ccv_array_get(blocks, j)((void*)(((char*)((blocks)->data)) + (size_t)(blocks)-> rsize * (size_t)(j))) = *block; | ||||
| 1452 | ++j; | ||||
| 1453 | } | ||||
| 1454 | } | ||||
| 1455 | // Now we moved everything, set the proper block size. | ||||
| 1456 | ccv_array_resize(blocks, j); | ||||
| 1457 | // Substitute variables. | ||||
| 1458 | _ccv_nnc_var_subst(vars, var_count, inputs, input_size, outputs, output_size, blocks, groups, axis_id_groups); | ||||
| 1459 | // Mark whether we need to check bound for a particular variable or not. | ||||
| 1460 | for (i = 0; i < blocks->rnum; i++) | ||||
| 1461 | { | ||||
| 1462 | ccv_nnc_micro_loop_block_t* const block = (ccv_nnc_micro_loop_block_t*)ccv_array_get(blocks, i)((void*)(((char*)((blocks)->data)) + (size_t)(blocks)-> rsize * (size_t)(i))); | ||||
| 1463 | _ccv_nnc_micro_check_bound_for_block(block, vars, var_count, groups, axis_id_groups); | ||||
| 1464 | } | ||||
| 1465 | free(groups); | ||||
| 1466 | kh_destroy(ccv_nnc_axis_id_group, axis_id_groups)kh_destroy_ccv_nnc_axis_id_group(axis_id_groups); | ||||
| 1467 | // Reallocate function to be 1. | ||||
| 1468 | for (i = 0; i < function_count; i++) | ||||
| 1469 | if (functions[i].block_count > 1) | ||||
| 1470 | ccfreefree(functions[i].blocks); | ||||
| 1471 | program->functions = (ccv_nnc_micro_function_t*)ccreallocrealloc(program->functions, sizeof(ccv_nnc_micro_function_t)); | ||||
| 1472 | program->functions[0].block_count = blocks->rnum; | ||||
| 1473 | if (blocks->rnum > 1) | ||||
| 1474 | { | ||||
| 1475 | program->functions[0].blocks = (ccv_nnc_micro_loop_block_t*)ccmallocmalloc(sizeof(ccv_nnc_micro_loop_block_t) * blocks->rnum); | ||||
| 1476 | memcpy(program->functions[0].blocks, ccv_array_get(blocks, 0)((void*)(((char*)((blocks)->data)) + (size_t)(blocks)-> rsize * (size_t)(0))), sizeof(ccv_nnc_micro_loop_block_t) * blocks->rnum); | ||||
| 1477 | } else | ||||
| 1478 | program->functions[0].one_block = *(ccv_nnc_micro_loop_block_t*)ccv_array_get(blocks, 0)((void*)(((char*)((blocks)->data)) + (size_t)(blocks)-> rsize * (size_t)(0))); | ||||
| 1479 | program->function_count = 1; | ||||
| 1480 | ccv_array_free(blocks); | ||||
| 1481 | } |