MimIR
MimIR is my Intermediate Representation
Loading...
Searching...
No Matches
sym_expr_opt.cpp
Go to the documentation of this file.
2
3#include <absl/container/fixed_array.h>
4
5namespace mim {
6
7const Def* SymExprOpt::Analysis::propagate(const Def* var, const Def* def) {
8 auto [i, ins] = lattice_.emplace(var, def);
9 if (ins) {
10 invalidate();
11 DLOG("propagate: {} → {}", var, def);
12 return def;
13 }
14
15 auto cur = i->second;
16 if (!cur || def->isa<Bot>() || cur == def || cur == var || cur->isa<Proxy>()) return cur;
17
18 invalidate();
19 DLOG("cannot propagate {}, trying GVN", var);
20 if (cur->isa<Bot>()) return i->second = def;
21 return i->second = nullptr; // we reached top for propagate; nullptr marks this to bundle for GVN
22}
23
24static nat_t get_index(const Def* def) { return Lit::as(def->as<Extract>()->index()); }
25
26const Def* SymExprOpt::Analysis::rewrite_imm_App(const App* app) {
27 if (auto lam = app->callee()->isa_mut<Lam>(); isa_optimizable(lam)) {
28 auto n = app->num_targs();
29 auto abstr_args = absl::FixedArray<const Def*>(n);
30 auto abstr_vars = absl::FixedArray<const Def*>(n);
31
32 // propagate
33 for (size_t i = 0; i != n; ++i) {
34 auto abstr = rewrite(app->targ(i));
35 abstr_vars[i] = propagate(lam->tvar(i), abstr);
36 abstr_args[i] = abstr;
37 }
38
39 // GVN bundle: All things marked as top (nullptr) by propagate are now treated as one entity by bundling them
40 // into one proxy
41 for (size_t i = 0; i != n; ++i) {
42 if (abstr_vars[i]) continue;
43
44 auto vars = DefVec();
45 auto vi = lam->tvar(i);
46 auto ai = abstr_args[i];
47 vars.emplace_back(vi);
48
49 for (size_t j = i + 1; j != n; ++j) {
50 auto vj = lam->tvar(j);
51 if (!abstr_vars[j] && abstr_args[j] == ai) vars.emplace_back(vj);
52 }
53
54 if (vars.size() == 1) {
55 lattice_[vi] = abstr_vars[i] = vi; // top
56 } else {
57 auto proxy = world().proxy(vi->type(), vars, 0, 0);
58
59 for (auto p : proxy->ops()) {
60 auto j = get_index(p);
61 auto vj = lam->tvar(j);
62 lattice_[vj] = abstr_vars[j] = proxy;
63 }
64
65 DLOG("bundle: {}", proxy);
66 }
67 }
68
69 // GVN split: We have to prove that all incoming args for all vars in a bundle are the same value.
70 // Otherwise we have to refine the bundle by splitting off contradictions.
71 // E.g.: Say we started with `{a, b, c, d, e}` as a single bundle for all tvars of `lam`.
72 // Now, we see `lam (x, y, x, y, z)`. Then we have to build:
73 // a -> {a, c}
74 // b -> {b, d}
75 // c -> {a, c}
76 // d -> {b, d}
77 // e -> e (top)
78 for (size_t i = 0; i != n; ++i) {
79 if (auto proxy = abstr_vars[i]->isa<Proxy>()) {
80 auto num = proxy->num_ops();
81 auto vars = DefVec();
82 auto ai = abstr_args[i];
83
84 for (auto p : proxy->ops()) {
85 auto j = get_index(p);
86 auto vj = lam->tvar(j);
87 if (p == vj) {
88 if (ai == abstr_args[j]) vars.emplace_back(vj);
89 }
90 }
91
92 auto new_num = vars.size();
93 if (new_num == 1) {
94 invalidate();
95 auto vi = lam->tvar(i);
96 lattice_[vi] = abstr_vars[i] = vi;
97 DLOG("single: {}", vi);
98 } else if (new_num != num) {
99 invalidate();
100 auto new_proxy = world().proxy(ai->type(), vars, 0, 0);
101 DLOG("split: {}", new_proxy);
102
103 for (auto p : new_proxy->ops()) {
104 auto j = get_index(p);
105 auto vj = lam->tvar(j);
106 if (p == vj) lattice_[vj] = abstr_vars[j] = new_proxy;
107 }
108 }
109 // if new_num == num: do nothing
110 }
111 }
112
113 set(lam->var(), world().tuple(abstr_vars)); // set new abstract var
114 return world().app(rewrite_deps(lam), abstr_args);
115 }
116
117 return mim::Analysis::rewrite_imm_App(app);
118}
119
120static bool keep(const Def* old_var, const Def* abstr) {
121 if (old_var == abstr) return true; // top
122 auto proxy = abstr->isa<Proxy>();
123 return proxy && proxy->op(0) == old_var; // first in GVN bundle?
124}
125
126const Def* SymExprOpt::rewrite_imm_App(const App* old_app) {
127 if (auto old_lam = old_app->callee()->isa_mut<Lam>()) {
128 if (auto l = lattice(old_lam->var()); l && l != old_lam->var()) {
129 invalidate();
130
131 size_t num_old = old_lam->num_tvars();
132 Lam* new_lam;
133 if (auto i = lam2lam_.find(old_lam); i != lam2lam_.end())
134 new_lam = i->second;
135 else {
136 // build new dom
137 auto new_doms = DefVec();
138 for (size_t i = 0; i != num_old; ++i) {
139 auto old_var = old_lam->var(num_old, i);
140 auto abstr = lattice(old_var);
141 if (keep(old_var, abstr)) new_doms.emplace_back(rewrite(old_lam->dom(num_old, i)));
142 }
143
144 // build new lam
145 size_t num_new = new_doms.size();
146 auto new_vars = absl::FixedArray<const Def*>(num_old);
147 new_lam = new_world().mut_lam(new_doms, rewrite(old_lam->codom()))->set(old_lam->dbg());
148 lam2lam_[old_lam] = new_lam;
149
150 // build new var
151 for (size_t i = 0, j = 0; i != num_old; ++i) {
152 auto old_var = old_lam->var(num_old, i);
153 auto abstr = lattice(old_var);
154
155 if (keep(old_var, abstr)) {
156 auto v = new_lam->var(num_new, j++);
157 new_vars[i] = v;
158 if (abstr != old_var) map(abstr, v); // GVN bundle
159 } else {
160 new_vars[i] = rewrite(abstr); // SCCP propagate
161 }
162 }
163
164 map(old_lam->var(), new_vars);
165 auto new_filter = rewrite(old_lam->filter());
166 auto new_body = rewrite(old_lam->body());
167 new_lam->set(new_filter, new_body);
168 }
169
170 // build new app
171 size_t num_new = new_lam->num_vars();
172 auto new_args = absl::FixedArray<const Def*>(num_new);
173 for (size_t i = 0, j = 0; i != num_old; ++i) {
174 auto old_var = old_lam->var(num_old, i);
175 auto abstr = lattice(old_var);
176 if (keep(old_var, abstr)) new_args[j++] = rewrite(old_app->targ(i));
177 }
178
179 return map(old_app, new_world().app(new_lam, new_args));
180 }
181 }
182
183 return RWPhase::rewrite_imm_App(old_app);
184}
185
186} // namespace mim
Def2Def lattice_
Definition phase.h:165
const Def * callee() const
Definition lam.h:276
Base class for all Defs.
Definition def.h:246
T * isa_mut() const
If this is mutable, it will cast constness away and perform a dynamic_cast to T.
Definition def.h:493
const Def * op(size_t i) const noexcept
Definition def.h:304
const Def * var(nat_t a, nat_t i) noexcept
Definition def.h:425
nat_t num_vars() noexcept
Definition def.h:425
Extracts from a Sigma or Array-typed Extract::tuple the element at position Extract::index.
Definition tuple.h:206
const Def * index() const
Definition tuple.h:217
A function.
Definition lam.h:110
Lam * set(Filter filter, const Def *body)
Definition lam.cpp:29
static T as(const Def *def)
Definition def.h:844
void invalidate(bool todo=true)
Signals that another round of fixed-point iteration is required, either as part of.
Definition phase.h:48
World & new_world()
Create new Defs into this.
Definition phase.h:243
const Def * lattice(const Def *old_def)
Returns the abstract value computed by the associated Analysis for the given old-world Def,...
Definition phase.h:211
World & world()=delete
Hides both and forbids direct access.
virtual const Def * map(const Def *old_def, const Def *new_def)
Definition rewrite.h:45
virtual const Def * rewrite(const Def *)
Definition rewrite.cpp:56
World & world()
Definition pass.h:77
const Def * rewrite_imm_App(const App *) final
const Def * app(const Def *callee, const Def *arg)
Definition world.cpp:205
const Proxy * proxy(const Def *type, Defs ops, u32 index, u32 tag)
Definition world.h:342
Lam * mut_lam(const Pi *pi)
Definition world.h:406
#define DLOG(...)
Vaporizes to nothingness in Debug build.
Definition log.h:94
Definition ast.h:14
u64 nat_t
Definition types.h:37
static bool keep(const Def *old_var, const Def *abstr)
Vector< const Def * > DefVec
Definition def.h:79
static nat_t get_index(const Def *def)
Lam * isa_optimizable(Lam *lam)
These are Lams that are.
Definition lam.h:352
TExt< false > Bot
Definition lattice.h:171
@ Lam
Definition def.h:109
@ App
Definition def.h:109
@ Proxy
Definition def.h:109