MimIR
MimIR is my Intermediate Representation
Loading...
Searching...
No Matches
normalizers.cpp
Go to the documentation of this file.
5#include <mim/plug/vec/vec.h>
6
7#include "mim/def.h"
8#include "mim/plugin.h"
9#include "mim/world.h"
10
11#include "mim/util/sets.h"
12
14
15namespace mim::plug::tensor {
16
17// There's no good reason keeping 1s around for get/set indices.
18// So this just skips relevant dimensions in the index and shape, and reduces the rank accordingly.
19std::tuple<u64, const Def*, const Def*> fold_shape_and_index(const Def* shape, const Def* index) {
20 auto& w = shape->world();
21
22 DefVec dims;
23 DefVec index_dims;
24 auto r = shape->num_projs();
25 for (size_t i = 0, e = r; i != e; ++i) {
26 auto dim = shape->proj(r, i);
27 if (auto dim_lit = Lit::isa<u64>(dim))
28 if (dim_lit == 1) continue;
29
30 dims.push_back(dim);
31 index_dims.push_back(index->proj(r, i));
32 }
33
34 assert(dims.size() == index_dims.size());
35 return std::make_tuple(dims.size(), w.tuple(dims), w.tuple(index_dims));
36}
37
38const Def* normalize_get(const Def*, const Def* c, const Def* arg) {
39 auto& w = c->world();
40
41 auto [arr, index] = arg->projs<2>();
42 auto callee = c->as<App>();
43 auto [T, r, s] = callee->args<3>();
44
45 w.DLOG("normalize_get");
46 w.DLOG(" arr = {} : {}", arr, arr->type());
47 w.DLOG(" index = {} : {}", index, index->type());
48 w.DLOG(" T = {} : {}", T, T->type());
49 w.DLOG(" r = {} : {}", r, r->type());
50 w.DLOG(" s = {} : {}", s, s->type());
51
52 if (r->isa<Lit>()) {
53 auto [new_r, new_s, new_index] = fold_shape_and_index(s, index);
54 w.DLOG(" new_index = {} : {}", new_index, new_index->type());
55 w.DLOG(" new_s = {} : {}", new_s, new_s->type());
56 w.DLOG(" new_r = {} : {}", w.lit_nat(new_r), w.lit_nat(new_r)->type());
57 if (new_r == 0) return arr;
58 if (new_s != s || new_index != index) return op_get(T, w.lit_nat(new_r), new_s, arr, new_index);
59 }
60
61 if (Axm::isa<tensor::set>(arr)) {
62 w.DLOG("get after set, try to bypass");
63 auto set = arr->as<App>();
64 auto [_, target_index, x] = set->args<3>();
65 if (target_index == index) {
66 w.DLOG("bypass successful");
67 return x;
68 }
69 }
70 if (Axm::isa<tensor::get>(arr)) {
71 w.DLOG("get after get, try to bypass");
72 auto get = arr->as<App>();
73 auto [outer_arr, outer_index] = get->args<2>();
74 auto [o_T, o_r, o_s] = get->callee()->as<App>()->args<3>();
75 w.DLOG(" outer_arr = {} : {}", outer_arr, outer_arr->type());
76 w.DLOG(" outer_index = {} : {}", outer_index, outer_index->type());
77 w.DLOG(" o_T = {} : {}", o_T, o_T->type());
78 w.DLOG(" o_r = {} : {}", o_r, o_r->type());
79 w.DLOG(" o_s = {} : {}", o_s, o_s->type());
80
81 auto new_r = w.call(core::nat::add, DefVec{r, o_r});
82 auto new_s = w.call<tuple::cat>(DefVec{o_s, s});
83 auto new_index = w.call<tuple::cat>(DefVec{outer_index, index});
84
85 return op_get(T, new_r, new_s, outer_arr, new_index);
86 }
87
88 return nullptr;
89}
90
91const Def* normalize_set(const Def*, const Def* c, const Def* arg) {
92 auto& w = c->world();
93
94 auto [arr, index, x] = arg->projs<3>();
95 w.DLOG("normalize_set");
96 w.DLOG(" arr = {} : {}", arr, arr->type());
97 w.DLOG(" index = {} : {}", index, index->type());
98 w.DLOG(" x = {} : {}", x, x->type());
99
100 auto callee = c->as<App>();
101 auto [T, r, s] = callee->args<3>();
102
103 if (r->isa<Lit>()) {
104 auto [new_r, new_s, new_index] = fold_shape_and_index(s, index);
105 w.DLOG(" new_index = {} : {}", new_index, new_index->type());
106 w.DLOG(" new_s = {} : {}", new_s, new_s->type());
107 w.DLOG(" new_r = {} : {}", w.lit_nat(new_r), w.lit_nat(new_r)->type());
108 if (new_r == 0) return x;
109 if (new_s != s || new_index != index) return op_set(T, w.lit_nat(new_r), new_s, arr, new_index, x);
110 }
111
112 if (Axm::isa<tensor::get>(x)) {
113 w.DLOG("set after get, try to bypass");
114 auto get = x->as<App>();
115 auto [inner_arr, inner_index] = get->args<2>();
116 if (inner_arr == arr && inner_index == index) {
117 w.DLOG("bypass successful");
118 return inner_arr;
119 }
120 }
121
122 if (Axm::isa<tensor::set>(x)) {
123 w.DLOG("set after set, try to bypass");
124 auto inner_set = x->as<App>();
125 auto [inner_arr, inner_index, inner_x] = inner_set->args<3>();
126 auto [i_T, i_r, i_s] = inner_set->callee()->as<App>()->args<3>();
127
128 w.DLOG(" inner_arr = {} : {}", inner_arr, inner_arr->type());
129 w.DLOG(" inner_index = {} : {}", inner_index, inner_index->type());
130 w.DLOG(" inner_x = {} : {}", inner_x, inner_x->type());
131 w.DLOG(" i_T = {} : {}", i_T, i_T->type());
132 w.DLOG(" i_r = {} : {}", i_r, i_r->type());
133 w.DLOG(" i_s = {} : {}", i_s, i_s->type());
134
135 if (auto inner_get = Axm::isa<tensor::get>(inner_arr)) {
136 auto [g_arr, g_index] = inner_get->args<2>();
137 if (g_arr == arr && g_index == index) {
138 auto new_r = w.call(core::nat::add, DefVec{r, i_r});
139 auto new_s = w.call<tuple::cat>(DefVec{s, i_s});
140 auto new_index = w.call<tuple::cat>(DefVec{index, inner_index});
141
142 return op_set(i_T, new_r, new_s, arr, new_index, inner_x);
143 }
144 }
145 w.DLOG("set after set bypass not applicable: inner_arr is not get(arr, index)");
146 }
147 w.DLOG("no normalization applicable");
148 return nullptr;
149}
150
151const Def* normalize_broadcast(const Def*, const Def* c, const Def* arg) {
152 auto& w = c->world();
153
154 auto [s_in, s_out, input] = arg->projs<3>();
155 auto callee = c->as<App>();
156 auto [T, r] = callee->args<2>();
157 w.DLOG("normalize_broadcast");
158 w.DLOG(" s_out = {} : {}", s_out, s_out->type());
159 w.DLOG(" input = {} : {}", input, input->type());
160 w.DLOG(" T = {} : {}", T, T->type());
161 w.DLOG(" r = {} : {}", r, r->type());
162 w.DLOG(" s_in = {} : {}", s_in, s_in->type());
163
164 if (s_in == s_out) return input;
165
166 auto r_nat = Lit::isa<u64>(r);
167 if (!r_nat) return nullptr;
168 if (r_nat == 0) return input;
169
170 return nullptr;
171}
172
173const Def* normalize_broadcast_in_dim(const Def*, const Def*, const Def*) { return nullptr; }
174
175const Def* normalize_map_reduce(const Def*, const Def*, const Def*) {
176 // TODO: is there anything we can normalize here?
177 return nullptr;
178}
179
180const Def* normalize_map_reduce_aff(const Def*, const Def*, const Def*) {
181 // TODO: fold size-1 loop dimensions / identity access maps.
182 return nullptr;
183}
184
186
187} // namespace mim::plug::tensor
const Def * callee() const
Definition lam.h:276
static auto isa(const Def *def)
Definition axm.h:107
Base class for all Defs.
Definition def.h:246
auto projs(F f) const
Splits this Def via Def::projections into an Array (if A == std::dynamic_extent) or std::array (other...
Definition def.h:386
const Def * type() const noexcept
Yields the "raw" type of this Def (maybe nullptr).
Definition def.cpp:452
static std::optional< T > isa(const Def *def)
Definition def.h:838
The tensor Plugin
Definition fuse.h:5
const Def * normalize_broadcast(const Def *, const Def *c, const Def *arg)
const Def * op_set(const Def *T, const Def *r, const Def *s, const Def *arr, const Def *index, const Def *x)
Definition tensor.h:17
const Def * normalize_get(const Def *, const Def *c, const Def *arg)
const Def * normalize_broadcast_in_dim(const Def *, const Def *, const Def *)
std::tuple< u64, const Def *, const Def * > fold_shape_and_index(const Def *shape, const Def *index)
const Def * normalize_set(const Def *, const Def *c, const Def *arg)
const Def * normalize_map_reduce(const Def *, const Def *, const Def *)
const Def * op_get(const Def *T, const Def *r, const Def *s, const Def *arr, const Def *index)
Definition tensor.h:9
const Def * normalize_map_reduce_aff(const Def *, const Def *, const Def *)
Vector< const Def * > DefVec
Definition def.h:79
#define MIM_tensor_NORMALIZER_IMPL
Definition autogen.h:216