MimIR
MimIR is my Intermediate Representation
Loading...
Searching...
No Matches
rewrite.cpp
Go to the documentation of this file.
1#include "mim/rewrite.h"
2
3#include <absl/container/fixed_array.h>
4
5#include "mim/world.h"
6
7#include "fe/assert.h"
8
9// Don't use fancy C++-lambdas; it's way too annoying stepping through them in a debugger.
10
11namespace mim {
12
13/*
14 * Rewriter
15 */
16
17Rewriter::Rewriter(std::unique_ptr<World>&& ptr)
18 : ptr_(std::move(ptr))
19 , world_(ptr_.get()) {
20 push(); // create root map
21}
22
24 : world_(&world) {
25 push(); // create root map
26}
27
28Rewriter::~Rewriter() = default;
29
30void Rewriter::reset(std::unique_ptr<World>&& ptr) {
31 ptr_ = std::move(ptr);
32 world_ = ptr_.get();
33 reset();
34}
35
37 pop();
38 assert(old2news_.empty());
39 push();
40}
41
42const Def* Rewriter::map(const Def* old_def, Defs new_defs) {
43 auto new_tuple = world().tuple(new_defs);
44 return old2news_.back()[old_def] = new_tuple;
45}
46const Def* Rewriter::map(Defs old_defs, const Def* new_def) {
47 auto old_tuple = world().tuple(old_defs);
48 return old2news_.back()[old_tuple] = new_def;
49}
50const Def* Rewriter::map(Defs old_defs, Defs new_defs) {
51 auto old_tuple = world().tuple(old_defs);
52 auto new_tuple = world().tuple(new_defs);
53 return old2news_.back()[old_tuple] = new_tuple;
54}
55
56const Def* Rewriter::rewrite(const Def* old_def) {
57 if (auto new_def = lookup(old_def)) return new_def;
58
59 auto new_def = old_def->isa_mut() ? rewrite_mut((Def*)old_def) : rewrite_imm(old_def);
60 return new_def->set(old_def->dbg());
61}
62
63// clang-format off
64#define CODE_MUT(N) case Node::N: new_def = rewrite_mut_##N(old_mut->as<N>()); break;
65#define CODE_IMM(N) case Node::N: new_def = rewrite_imm_##N(old_def->as<N>()); break;
66// clang-format on
67
68const Def* Rewriter::rewrite_imm(const Def* old_def) {
69 const Def* new_def;
70 switch (old_def->node()) {
72 default: fe::unreachable();
73 }
74 return map(old_def, new_def);
75}
76
77const Def* Rewriter::rewrite_mut(Def* old_mut) {
78 const Def* new_def;
79 switch (old_mut->node()) {
81 default: fe::unreachable();
82 }
83 return new_def;
84}
85
86#undef CODE_MUT
87#undef CODE_IMM
88
90 auto new_ops = DefVec(ops.size());
91 for (size_t i = 0, e = ops.size(); i != e; ++i)
92 new_ops[i] = rewrite(ops[i]);
93 return new_ops;
94}
95
96#ifndef DOXYGEN
97// clang-format off
98const Def* Rewriter::rewrite_imm_Idx (const Idx* ) { return world().type_idx(); }
99const Def* Rewriter::rewrite_imm_Nat (const Nat* ) { return world().type_nat(); }
100const Def* Rewriter::rewrite_imm_Univ (const Univ* ) { return world().univ(); }
101const Def* Rewriter::rewrite_imm_Lit (const Lit* d) { return world().lit (rewrite(d->type()), d->get()); }
102const Def* Rewriter::rewrite_imm_Match (const Match* d) { return world().match (rewrite(d->ops())); }
103const Def* Rewriter::rewrite_imm_Reform(const Reform* d) { return world().reform(rewrite(d->dom())); }
104const Def* Rewriter::rewrite_imm_Sigma (const Sigma* d) { return world().sigma (rewrite(d->ops())); }
105const Def* Rewriter::rewrite_imm_Type (const Type* d) { return world().type (rewrite(d->level())); }
106const Def* Rewriter::rewrite_imm_UInc (const UInc* d) { return world().uinc (rewrite(d->op()), d->offset()); }
107const Def* Rewriter::rewrite_imm_UMax (const UMax* d) { return world().umax (rewrite(d->ops())); }
108const Def* Rewriter::rewrite_imm_Uniq (const Uniq* d) { return world().uniq (rewrite(d->op())); }
109const Def* Rewriter::rewrite_imm_Var (const Var* d) { return world().var (rewrite(d->mut())->as_mut()); }
110const Def* Rewriter::rewrite_imm_Top (const Top* d) { return world().top (rewrite(d->type())); }
111const Def* Rewriter::rewrite_imm_Bot (const Bot* d) { return world().bot (rewrite(d->type())); }
112const Def* Rewriter::rewrite_imm_Meet (const Meet* d) { return world().meet (rewrite(d->ops())); }
113const Def* Rewriter::rewrite_imm_Join (const Join* d) { return world().join (rewrite(d->ops())); }
114
115const Def* Rewriter::rewrite_imm_Arr (const Arr* d) { return rewrite_imm_Seq(d); }
116const Def* Rewriter::rewrite_imm_Pack(const Pack* d) { return rewrite_imm_Seq(d); }
117const Def* Rewriter::rewrite_mut_Arr ( Arr* d) { return rewrite_mut_Seq(d); }
118const Def* Rewriter::rewrite_mut_Pack( Pack* d) { return rewrite_mut_Seq(d); }
119// clang-format on
120
121const Def* Rewriter::rewrite_imm_App(const App* d) {
122 // Rewrite the arg before the callee:
123 // the callee may be a recursive mutable that, when rewritten first, eagerly expands its body before the concrete
124 // argument is available to specialize it, breaking partial-evaluation termination
125 auto new_arg = rewrite(d->arg());
126 auto new_callee = rewrite(d->callee());
127 return world().app(new_callee, new_arg);
128}
129
130const Def* Rewriter::rewrite_imm_Inj(const Inj* d) {
131 auto new_type = rewrite(d->type());
132 auto new_value = rewrite(d->value());
133 return world().inj(new_type, new_value);
134}
135
136const Def* Rewriter::rewrite_imm_Insert(const Insert* d) {
137 auto new_tuple = rewrite(d->tuple());
138 auto new_index = rewrite(d->index());
139 auto new_value = rewrite(d->value());
140 return world().insert(new_tuple, new_index, new_value);
141}
142
143const Def* Rewriter::rewrite_imm_Lam(const Lam* d) {
144 auto new_type = rewrite(d->type())->as<Pi>();
145 auto new_filter = rewrite(d->filter());
146 auto new_body = rewrite(d->body());
147 return world().lam(new_type, new_filter, new_body);
148}
149
150const Def* Rewriter::rewrite_imm_Merge(const Merge* d) {
151 auto new_type = rewrite(d->type());
152 auto new_ops = rewrite(d->ops());
153 return world().merge(new_type, new_ops);
154}
155
156const Def* Rewriter::rewrite_imm_Pi(const Pi* d) {
157 auto new_dom = rewrite(d->dom());
158 auto new_codom = rewrite(d->codom());
159 return world().pi(new_dom, new_codom, d->is_implicit());
160}
161
162const Def* Rewriter::rewrite_imm_Proxy(const Proxy* d) {
163 auto new_type = rewrite(d->type());
164 auto new_ops = rewrite(d->ops());
165 return world().proxy(new_type, new_ops, d->pass(), d->tag());
166}
167
168const Def* Rewriter::rewrite_imm_Rule(const Rule* d) {
169 auto new_type = rewrite(d->type())->as<Reform>();
170 auto new_lhs = rewrite(d->lhs());
171 auto new_rhs = rewrite(d->rhs());
172 auto new_guard = rewrite(d->guard());
173 return world().rule(new_type, new_lhs, new_rhs, new_guard);
174}
175
176const Def* Rewriter::rewrite_imm_Split(const Split* d) {
177 auto new_type = rewrite(d->type());
178 auto new_value = rewrite(d->value());
179 return world().split(new_type, new_value);
180}
181
182const Def* Rewriter::rewrite_imm_Tuple(const Tuple* d) {
183 auto new_type = rewrite(d->type());
184 auto new_ops = rewrite(d->ops());
185 return world().tuple(new_type, new_ops);
186}
187
188// clang-format on
189const Def* Rewriter::rewrite_mut_Pi(Pi* d) {
190 return rewrite_stub(d, world().mut_pi(rewrite(d->type()), d->is_implicit()));
191}
192const Def* Rewriter::rewrite_mut_Lam(Lam* d) { return rewrite_stub(d, world().mut_lam(rewrite(d->type())->as<Pi>())); }
193const Def* Rewriter::rewrite_mut_Rule(Rule* d) {
194 return rewrite_stub(d, world().mut_rule(rewrite(d->type())->as<Reform>()));
195}
196const Def* Rewriter::rewrite_mut_Sigma(Sigma* d) {
197 return rewrite_stub(d, world().mut_sigma(rewrite(d->type()), d->num_ops()));
198}
199const Def* Rewriter::rewrite_mut_Global(Global* d) {
200 return rewrite_stub(d, world().global(rewrite(d->type()), d->is_mutable()));
201}
202
203const Def* Rewriter::rewrite_imm_Axm(const Axm* a) {
204 if (&a->world() != &world()) {
205 auto type = rewrite(a->type());
206 return world().axm(a->normalizer(), a->curry(), a->trip(), type, a->plugin(), a->tag(), a->sub());
207 }
208 return a;
209}
210
211const Def* Rewriter::rewrite_imm_Extract(const Extract* ex) {
212 auto new_index = rewrite(ex->index());
213 if (auto index = Lit::isa(new_index)) {
214 if (auto tuple = ex->tuple()->isa<Tuple>()) return map(ex, rewrite(tuple->op(*index)));
215 if (auto pack = ex->tuple()->isa_imm<Pack>(); pack && pack->arity()->is_closed())
216 return map(ex, rewrite(pack->body()));
217 }
218
219 auto new_tuple = rewrite(ex->tuple());
220 return world().extract(new_tuple, new_index);
221}
222
223const Def* Rewriter::rewrite_mut_Hole(Hole* hole) {
224 auto [last, op] = hole->find();
225 return op ? rewrite(op) : rewrite_stub(last, world().mut_hole(rewrite(last->type())));
226}
227
228#endif
229
231 auto new_arity = rewrite(seq->arity());
232 if (auto l = Lit::isa(new_arity); l && *l == 0) return world().prod(seq->is_intro());
233 return world().seq(seq->is_intro(), new_arity, rewrite(seq->body()));
234}
235
237 if (!seq->is_set()) {
238 auto new_seq = seq->as_mut<Seq>()->stub(world(), rewrite(seq->type()));
239 return map(seq, new_seq);
240 }
241
242 auto new_arity = rewrite(seq->arity())->zonk();
243 auto l = Lit::isa(new_arity);
244 if (l && *l == 0) return world().prod(seq->is_intro());
245
246 if (auto var = seq->has_var(); var && l && *l <= world().flags().scalarize_threshold) {
247 auto new_ops = absl::FixedArray<const Def*>(*l);
248 for (size_t i = 0, e = *l; i != e; ++i) {
249 push();
250 map(var, world().lit_idx(e, i));
251 new_ops[i] = rewrite(seq->body());
252 pop();
253 }
254 return map(seq, world().prod(seq->is_intro(), new_ops));
255 }
256
257 if (!seq->has_var()) return map(seq, world().seq(seq->is_intro(), new_arity, rewrite(seq->body())));
258 return rewrite_stub(seq->as_mut(), world().mut_seq(seq->is_intro(), rewrite(seq->type())));
259}
260
261const Def* Rewriter::rewrite_stub(Def* old_mut, Def* new_mut) {
262 map(old_mut, new_mut);
263
264 if (old_mut->is_set()) {
265 for (size_t i = 0, e = old_mut->num_ops(); i != e; ++i)
266 new_mut->set(i, rewrite(old_mut->op(i)));
267 if (auto new_imm = new_mut->immutabilize()) return map(old_mut, new_imm);
268 }
269
270 return new_mut;
271}
272
273/*
274 * VarRewriter
275 */
276
277const Def* VarRewriter::rewrite(const Def* old_def) {
278 if (auto new_def = lookup(old_def)) return new_def;
279
280 if (auto old_mut = old_def->isa_mut())
281 return has_intersection(old_mut) ? rewrite_mut(old_mut)->set(old_mut->dbg()) : old_mut;
282
283 if (old_def->local_vars().empty() && old_def->local_muts().empty()) return old_def; // safe to skip
284
285 return has_intersection(old_def) ? rewrite_imm(old_def)->set(old_def->dbg()) : old_def;
286}
287
289 if (auto var = mut->has_var()) {
290 auto& vars = vars_.back();
291 vars = world().vars().insert(vars, var);
292 }
293
294 return Rewriter::rewrite_mut(mut);
295}
296
297/*
298 * Zonker
299 */
300
301const Def* Zonker::map(const Def* old_def, const Def* new_def) {
302 auto repr = lookup(new_def); // always normalize new_def to its representative
303 if (!repr) repr = new_def;
304 return old2news_.back()[old_def] = repr;
305}
306
307const Def* Zonker::lookup(const Def* old_def) {
308 for (auto& old2new : old2news_ | std::views::reverse) {
309 const Def* repr;
310 auto path = DefVec();
311 while (true) {
312 repr = get(old_def);
313
314 if (repr == nullptr) break;
315
316 path.emplace_back(repr);
317 if (repr == old_def) break; // explicit self-map
318
319 old_def = repr;
320 }
321
322 if (path.empty()) continue;
323
324 // path compression: flatten all visited nodes
325 for (auto def : path)
326 old2new[def] = repr;
327
328 return repr;
329 }
330
331 return nullptr;
332}
333
334const Def* Zonker::rewrite(const Def* def) {
335 if (auto hole = def->isa_mut<Hole>()) {
336 auto [last, op] = hole->find();
337 def = op ? op : last;
338 }
339
340 return def->needs_zonk() ? Rewriter::rewrite(def) : def;
341}
342
344 map(mut, mut);
345
346 auto old_type = mut->type();
347 auto old_ops = absl::FixedArray<const Def*>(mut->ops().begin(), mut->ops().end());
348
349 mut->unset()->set_type(rewrite(old_type));
350
351 for (size_t i = 0, e = mut->num_ops(); i != e; ++i)
352 mut->set(i, rewrite(old_ops[i]));
353
354 if (auto new_imm = mut->immutabilize()) return map(mut, new_imm);
355
356 return mut;
357}
358
359} // namespace mim
A (possibly paramterized) Array.
Definition tuple.h:117
Definition axm.h:9
Base class for all Defs.
Definition def.h:246
bool is_set() const
Yields true if empty or the last op is set.
Definition def.cpp:298
constexpr Node node() const noexcept
Definition def.h:270
Def * set(size_t i, const Def *)
Successively set from left to right.
Definition def.cpp:266
T * as_mut() const
Asserts that this is a mutable, casts constness away and performs a static_cast to T.
Definition def.h:502
const Def * zonk() const
If Holes have been filled, reconstruct the program without them.
Definition check.cpp:21
Def * set_type(const Def *)
Update type.
Definition def.cpp:283
bool is_intro() const noexcept
Definition def.h:280
constexpr auto ops() const noexcept
Definition def.h:301
Vars local_vars() const
Vars reachable by following immutable deps().
Definition def.cpp:348
T * isa_mut() const
If this is mutable, it will cast constness away and perform a dynamic_cast to T.
Definition def.h:493
const Def * op(size_t i) const noexcept
Definition def.h:304
virtual const Def * immutabilize()
Tries to make an immutable from a mutable.
Definition def.h:564
Muts local_muts() const
Mutables reachable by following immutable deps(); mut->local_muts() is by definition the set { mut }...
Definition def.cpp:332
const Def * type() const noexcept
Yields the "raw" type of this Def (maybe nullptr).
Definition def.cpp:452
virtual const Def * arity() const
Definition def.cpp:558
Def * unset()
Unsets all Def::ops; works even, if not set at all or only partially set.
Definition def.cpp:289
bool needs_zonk() const
Yields true, if Def::local_muts() contain a Hole that is set.
Definition check.cpp:12
Dbg dbg() const
Definition def.h:513
const Var * has_var()
Only returns not nullptr, if Var of this mutable has ever been created.
Definition def.h:429
constexpr size_t num_ops() const noexcept
Definition def.h:305
Extracts from a Sigma or Array-typed Extract::tuple the element at position Extract::index.
Definition tuple.h:206
This node is a hole in the IR that is inferred by its context later on.
Definition check.h:14
A built-in constant of type Nat -> *.
Definition def.h:873
Constructs a Join value.
Definition lattice.h:70
Creates a new Tuple / Pack by inserting Insert::value at position Insert::index into Insert::tuple.
Definition tuple.h:233
A function.
Definition lam.h:110
static std::optional< T > isa(const Def *def)
Definition def.h:838
Scrutinize Match::scrutinee() and dispatch to Match::arms.
Definition lattice.h:116
Constructs a Meet value.
Definition lattice.h:53
A (possibly paramterized) Tuple.
Definition tuple.h:166
A dependent function type.
Definition lam.h:14
Type formation of a rewrite Rule.
Definition rule.h:9
virtual const Def * rewrite_imm_Seq(const Seq *seq)
Definition rewrite.cpp:230
virtual const Def * rewrite_mut_Seq(Seq *seq)
Definition rewrite.cpp:236
World & world()
Definition rewrite.h:33
virtual const Def * rewrite_mut(Def *)
Definition rewrite.cpp:77
virtual void push()
Definition rewrite.h:38
virtual const Def * rewrite_stub(Def *, Def *)
Definition rewrite.cpp:261
virtual const Def * map(const Def *old_def, const Def *new_def)
Definition rewrite.h:45
virtual void pop()
Definition rewrite.h:39
virtual ~Rewriter()
void reset()
Definition rewrite.cpp:36
std::deque< Def2Def > old2news_
Definition rewrite.h:93
virtual const Def * rewrite_imm(const Def *)
Definition rewrite.cpp:68
Rewriter(std::unique_ptr< World > &&ptr)
Definition rewrite.cpp:17
virtual const Def * rewrite(const Def *)
Definition rewrite.cpp:56
virtual const Def * lookup(const Def *old_def)
Lookup old_def by searching in reverse through the stack of maps.
Definition rewrite.h:55
A rewrite rule.
Definition rule.h:43
Base class for Arr and Pack.
Definition tuple.h:84
const Def * body() const
Definition tuple.h:91
constexpr bool empty() const noexcept
Is empty?
Definition sets.h:245
A dependent tuple type.
Definition tuple.h:20
Picks the aspect of a Meet [value](Pick::value) by its [type](Def::type).
Definition lattice.h:93
Data constructor for a Sigma.
Definition tuple.h:68
A singleton wraps a type into a higher order type.
Definition lattice.h:180
const Def * rewrite_mut(Def *) final
Definition rewrite.cpp:288
const Def * rewrite(const Def *) final
Definition rewrite.cpp:277
A variable introduced by a binder (mutable).
Definition def.h:714
The World represents the whole program and manages creation of MimIR nodes (Defs).
Definition world.h:36
const Def * insert(const Def *d, const Def *i, const Def *val)
Definition world.cpp:452
const Def * meet(Defs ops)
Definition world.h:595
const Def * uinc(const Def *op, level_t offset=1)
Definition world.cpp:122
const Lit * lit(const Def *type, u64 val)
Definition world.cpp:534
const Def * seq(bool is_pack, const Def *arity, const Def *body)
Definition world.cpp:501
const Type * type(const Def *level)
Definition world.cpp:112
const Def * sigma(Defs ops)
Definition world.cpp:287
const Def * app(const Def *callee, const Def *arg)
Definition world.cpp:205
const Def * match(Defs)
Definition world.cpp:622
const Pi * pi(const Def *dom, const Def *codom, bool implicit=false)
Definition world.h:377
const Univ * univ()
Definition world.h:326
const Def * bot(const Def *type)
Definition world.h:587
const Idx * type_idx()
Definition world.h:613
const Reform * reform(const Def *dom)
Definition world.h:433
const Nat * type_nat()
Definition world.h:612
const Lam * lam(const Pi *pi, Lam::Filter f, const Def *body)
Definition world.h:405
const Def * tuple(Defs ops)
Definition world.cpp:297
const Def * inj(const Def *type, const Def *value)
Definition world.cpp:607
const Axm * axm(NormalizeFn n, u8 curry, u8 trip, const Def *type, plugin_t p, tag_t t, sub_t s)
Definition world.h:359
const Def * extract(const Def *d, const Def *i)
Definition world.cpp:362
const Def * join(Defs ops)
Definition world.h:594
const Proxy * proxy(const Def *type, Defs ops, u32 index, u32 tag)
Definition world.h:342
const Def * var(Def *mut)
Definition world.cpp:184
const Def * uniq(const Def *inhabitant)
Definition world.cpp:658
const Def * prod(bool term, Defs ops)
Definition world.h:485
const Def * umax(Defs)
Definition world.cpp:142
const Def * merge(const Def *type, Defs ops)
Definition world.cpp:589
const Def * top(const Def *type)
Definition world.h:588
auto & vars()
Definition world.h:692
const Def * split(const Def *type, const Def *value)
Definition world.cpp:615
const Rule * rule(const Reform *type, const Def *lhs, const Def *rhs, const Def *guard)
Definition world.h:435
const Def * rewire_mut(Def *)
Definition rewrite.cpp:343
const Def * lookup(const Def *old_def) final
Lookup old_def by searching in reverse through the stack of maps.
Definition rewrite.cpp:307
const Def * rewrite(const Def *) final
Definition rewrite.cpp:334
const Def * map(const Def *old_def, const Def *new_def) final
Definition rewrite.cpp:301
#define MIM_MUT_NODE(X)
Definition def.h:54
#define MIM_IMM_NODE(X)
Definition def.h:38
Definition ast.h:14
View< const Def * > Defs
Definition def.h:78
Vector< const Def * > DefVec
Definition def.h:79
TBound< true > Join
AKA union.
Definition lattice.h:174
constexpr decltype(auto) get(Span< T, N > span) noexcept
Definition span.h:115
TExt< true > Top
Definition lattice.h:172
TExt< false > Bot
Definition lattice.h:171
TBound< false > Meet
AKA intersection.
Definition lattice.h:173
@ Nat
Definition def.h:109
@ Pi
Definition def.h:109
@ Pack
Definition def.h:109
@ Reform
Definition def.h:109
@ Tuple
Definition def.h:109
Definition span.h:122
#define CODE_MUT(N)
Definition rewrite.h:72
#define CODE_IMM(N)
Definition rewrite.h:71