MimIR
MimIR is my Intermediate Representation
Loading...
Searching...
No Matches
lower_for.cpp
Go to the documentation of this file.
2
3#include <mim/lam.h>
4#include <mim/tuple.h>
5
6#include <mim/plug/mem/mem.h>
7
9
11
12namespace {
13
14const Def* merge_s(const Def* elem, const Def* sigma, const Def* mem) {
15 auto& w = elem->world();
16 if (mem) {
17 auto elems = sigma->projs();
18 return cat_sigma(elem, elems);
19 }
20 return w.sigma({elem, sigma});
21}
22
23const Def* merge_t(const Def* elem, const Def* tuple, const Def* mem) {
24 auto& w = elem->world();
25 if (mem) {
26 auto elems = tuple->projs();
27 return cat_tuple(elem, elems);
28 }
29 return w.tuple({elem, tuple});
30}
31
32} // namespace
33
34const Def* LowerFor::rewrite_imm_App(const App* app) {
35 if (is_bootstrapping()) return RWPhase::rewrite_imm_App(app);
36
37 if (auto for_ax = Axm::isa<affine::For>(app)) {
38 DLOG("rewriting for axm: `{}`", for_ax);
39 auto [old_body, old_exit, args] = for_ax->uncurry_args<3>();
40 auto [new_begin, new_end, new_step, new_init] = args->projs<4>([this](const Def* def) { return rewrite(def); });
41
42 auto old_body_lam = old_body->isa_mut<Lam>();
43 auto old_exit_lam = old_exit->isa_mut<Lam>();
44 if (!old_body_lam) old_body_lam = Lam::eta_expand(old_body);
45 if (!old_exit_lam) old_exit_lam = Lam::eta_expand(old_exit);
46
47 auto new_mem = mem::mem_def(new_init);
48 auto new_head_lam = new_world().mut_con(merge_s(new_begin->type(), new_init->type(), new_mem))->set("head");
49 auto new_phis = new_head_lam->vars();
50 auto new_iter = new_phis.front();
51 auto new_acc = new_world().tuple(new_phis.view().subspan(1));
52 new_mem = mem::mem_var(new_head_lam);
53 auto new_bb_dom = new_mem ? new_mem->type() : new_world().sigma();
54
55 auto new_body = new_world().mut_con(new_bb_dom)->set("new_body");
56 auto new_exit = new_world().mut_con(new_bb_dom)->set("new_exit");
57 auto new_yield = new_world().mut_con(new_init->type())->set("new_yield");
58 auto new_cmp = new_world().call(core::icmp::ul, Defs{new_iter, new_end});
59 auto new_inc = new_world().call(core::wrap::add, core::Mode::nsuw, Defs{new_iter, new_step});
60
61 new_head_lam->branch(false, new_cmp, new_body, new_exit, new_mem);
62 new_yield->app(false, new_head_lam, merge_t(new_inc, new_yield->var(), new_mem));
63
64 push();
65 map(old_body_lam->var(), {new_iter, new_acc, new_yield});
66 auto new_body_filter = rewrite(old_body_lam->filter());
67 auto new_body_value = rewrite(old_body_lam->body());
68 new_body->set({new_body_filter, new_body_value});
69 pop();
70
71 push();
72 map(old_exit_lam->var(), new_acc);
73 auto new_exit_filter = rewrite(old_exit_lam->filter());
74 auto new_exit_value = rewrite(old_exit_lam->body());
75 new_exit->set({new_exit_filter, new_exit_value});
76 pop();
77
78 return new_world().app(new_head_lam, merge_t(new_begin, new_init, new_mem));
79 }
80
81 return RWPhase::rewrite_imm_App(app);
82}
83
84} // namespace mim::plug::affine::phase
static auto isa(const Def *def)
Definition axm.h:107
Base class for all Defs.
Definition def.h:246
T * isa_mut() const
If this is mutable, it will cast constness away and perform a dynamic_cast to T.
Definition def.h:493
const Def * var(nat_t a, nat_t i) noexcept
Definition def.h:425
auto vars(F f) noexcept
Definition def.h:425
A function.
Definition lam.h:110
const Def * filter() const
Definition lam.h:122
Lam * set(Filter filter, const Def *body)
Definition lam.cpp:29
static Lam * eta_expand(Filter, const Def *f)
Definition lam.cpp:51
const Def * body() const
Definition lam.h:123
World & new_world()
Create new Defs into this.
Definition phase.h:243
bool is_bootstrapping() const
Returns whether we are currently bootstrapping (rewriting annexes).
Definition phase.h:231
virtual void push()
Definition rewrite.h:38
virtual void pop()
Definition rewrite.h:39
virtual const Def * rewrite(const Def *)
Definition rewrite.cpp:56
const Def * sigma(Defs ops)
Definition world.cpp:287
const Def * app(const Def *callee, const Def *arg)
Definition world.cpp:205
const Def * tuple(Defs ops)
Definition world.cpp:297
const Def * call(const Def *callee, T &&arg, Args &&... args)
Definition world.h:659
Lam * mut_con(const Def *dom)
Definition world.h:418
const Def * rewrite_imm_App(const App *) final
Definition lower_for.cpp:34
#define DLOG(...)
Vaporizes to nothingness in Debug build.
Definition log.h:94
const Def * mem_var(Lam *lam)
Returns the memory argument of a function if it has one.
Definition mem.h:38
const Def * mem_def(const Def *def)
Returns the (first) element of type mem.M a from the given tuple.
Definition mem.h:25
View< const Def * > Defs
Definition def.h:78
const Def * cat_tuple(nat_t n, nat_t m, const Def *a, const Def *b)
Definition tuple.cpp:156
const Def * cat_sigma(nat_t n, nat_t m, const Def *a, const Def *b)
Definition tuple.cpp:157