MimIR
MimIR is my Intermediate Representation
Loading...
Searching...
No Matches
cps2ds.cpp
Go to the documentation of this file.
2
3#include <ranges>
4
5#include "mim/def.h"
6#include "mim/rewrite.h"
7#include "mim/schedule.h"
8#include "mim/world.h"
9
11
12#define DEBUG_CPS2DS 0
13
14namespace mim::plug::direct {
15
17#if DEBUG_CPS2DS
18 world().debug_dump();
19#endif
20
21 scheduler_.clear();
22 nests_.clear();
23 lam2lam_.clear();
24 rewritten_.clear();
25
26 world().for_each(true, [this](Def* mut) {
27 if (auto lam = mut->isa_mut<Lam>(); lam && !lam->codom()->isa<Type>())
28 nests_.try_emplace(lam, std::make_unique<Nest>(lam));
29 });
30
31 for (auto external : world().externals().muts())
32 if (auto lam = external->isa_mut<Lam>()) {
33 current_external_ = lam;
34 rewrite_lam(lam);
35 }
36
37#if DEBUG_CPS2DS
38 world().debug_dump();
39#endif
40}
41
42const Def* CPS2DSPhase::rewrite_lam(Lam* lam) {
43 if (auto i = rewritten_.find(lam); i != rewritten_.end()) return i->second;
44 if (lam2lam_.contains(lam)) return lam;
45 if (lam->isa_imm() || !lam->is_set() || lam->codom()->isa<Type>()) {
46 world().DLOG("skipped {}", lam);
47 return lam;
48 }
49
50 lam2lam_[lam] = lam;
51
52 world().DLOG("Rewriting lam: {}", lam->unique_name());
53
54 auto filter = rewrite(lam->filter());
55
56 if (auto body = lam->body()->isa<App>(); !body) {
57 world().DLOG(" non-app body {}, skipped", lam->body());
58 auto new_body = rewrite(lam->body());
59 lam->unset()->set(filter, new_body);
60 return rewritten_[lam] = lam;
61 }
62
63 auto body = lam->body()->as<App>();
64 auto new_arg = rewrite(body->arg());
65
66 auto new_callee = rewrite(body->callee());
67 auto new_lam = result_lam(lam);
68#if DEBUG_CPS2DS
69 world().DLOG("Result of rewrite {} set for {}", lam->unique_name(), new_lam->unique_name());
70#endif
71
72 if (world().log().level() >= mim::Log::Level::Debug) body->dump(1);
73
74 new_lam->unset()->app(filter, new_callee, new_arg);
75
76 return rewritten_[lam] = lam;
77}
78
79const Def* CPS2DSPhase::rewrite(const Def* def) {
80 if (auto i = rewritten_.find(def); i != rewritten_.end()) return i->second;
81
82 if (auto lam = def->isa_mut<Lam>()) return rewrite_lam(lam);
83 if (auto mut = def->isa_mut()) rewritten_[def] = mut->stub(world(), mut->type())->set(mut->dbg());
84
85 if (auto app = def->isa<App>()) {
86 if (auto cps2ds = Axm::isa<direct::cps2ds_dep>(app->callee())) {
87 auto cps_lam = rewrite(cps2ds->arg())->as<Lam>();
88
89 auto call_arg = rewrite(app->arg());
90
91 if (world().log().level() >= mim::Log::Level::Debug) {
92 cps2ds->dump(2);
93 cps2ds->arg()->dump(2);
94 }
95
96 auto early = scheduler(app).early(app);
97 auto late = scheduler(app).late(current_external_, app);
98 auto node = Nest::lca(early, late);
99#if DEBUG_CPS2DS
100 world().DLOG("scheduling {} between {} (level {}) and {} (level {}) at {}", app,
101 early->mut() ? early->mut()->unique_name() : "root", early->level(),
102 late->mut() ? late->mut()->unique_name() : "root", late->level(),
103 node->mut() ? node->mut()->unique_name() : "root");
104#endif
105 auto lam = result_lam(node->mut()->as_mut<Lam>());
106
107#if DEBUG_CPS2DS
108 world().DLOG("current lam: {} : {}", lam->unique_name(), lam->type());
109#endif
110
111 auto cn_dom = cps_lam->ret_dom();
112 auto cont = make_continuation(cn_dom, app, cps_lam->sym());
113#if DEBUG_CPS2DS
114 world().DLOG("continuation created: {} : {}", cont, cont->type());
115 if (world().log().level() >= mim::Log::Level::Debug) cont->dump(2);
116#endif
117 {
118 auto filter = rewritten_[lam->filter()] = rewrite(lam->filter());
119 auto body = world().app(cps_lam, world().tuple({call_arg, cont}));
120 rewritten_[lam] = lam->unset()->set(filter, body);
121 lam2lam_[lam] = cont;
122 }
123
124#if DEBUG_CPS2DS
125 world().DLOG("point the lam to the cont: {} = {}", lam->unique_name(), lam->body());
126 if (world().log().level() >= mim::Log::Level::Debug) {
127 lam->dump(2);
128 cont->dump(2);
129 }
130#endif
131
132 return cont->var(); // rewritten_[def] = cont->var(); done in make_continuation
133 }
134 }
135
136 DefVec new_ops{def->ops(), [this](const Def* d) { return rewrite(d); }};
137 if (def->isa_mut()) return rewritten_[def]->as_mut()->set(new_ops);
138
139 auto new_def = def->rebuild(def->type(), new_ops);
140 rewritten_[def] = new_def;
141 return new_def;
142}
143
144Lam* CPS2DSPhase::make_continuation(const Def* cn_type, const Def* arg, Sym prefix) {
145#if DEBUG_CPS2DS
146 world().DLOG("make_continuation {} : {} ({})", prefix, cn_type, arg);
147 if (world().log().level() >= mim::Log::Level::Debug) arg->dump(2);
148#endif
149 auto name = world().append_suffix(prefix, "_cps2ds_cont");
150 auto cont = world().mut_con(cn_type)->set(name)->set_filter(false);
151
152 rewritten_[arg] = cont->var();
153
154 return cont;
155}
156
157Lam* CPS2DSPhase::result_lam(Lam* lam) {
158 if (auto i = lam2lam_.find(lam); i != lam2lam_.end())
159 if (i->second != lam) return result_lam(i->second);
160 return lam;
161}
162
163Scheduler& CPS2DSPhase::scheduler(const Def* def) {
164 auto get_or_make = [&](const Def* lam, const Nest& nest) -> Scheduler& {
165 if (auto sched = scheduler_.find(lam); sched != scheduler_.end()) {
166#if DEBUG_CPS2DS
167 world().DLOG("found existing scheduler for {}", lam);
168#endif
169 return sched->second;
170 } else {
171#if DEBUG_CPS2DS
172 world().DLOG("creating new scheduler for {}", lam);
173#endif
174 auto [it, inserted] = scheduler_.insert({lam, Scheduler(nest)});
175 return it->second;
176 }
177 };
178 for (const auto& [lam, nest] : nests_) {
179#if DEBUG_CPS2DS
180 world().DLOG("looking for scheduler in {} for {}", lam, def);
181#endif
182 if (nest->contains(def)) return get_or_make(lam, *nest);
183 }
184#if DEBUG_CPS2DS
185 world().DLOG("no scheduler found for {}, using current external {}", def, current_external_);
186#endif
187 return get_or_make(current_external_, curr_external_nest());
188}
189
190const Nest& CPS2DSPhase::curr_external_nest() const {
191 auto i = nests_.find(current_external_);
192 assert(i != nests_.end());
193 return *i->second;
194}
195
196} // namespace mim::plug::direct
static auto isa(const Def *def)
Definition axm.h:107
Base class for all Defs.
Definition def.h:246
bool is_set() const
Yields true if empty or the last op is set.
Definition def.cpp:298
void dump() const
Definition dump.cpp:549
T * isa_mut() const
If this is mutable, it will cast constness away and perform a dynamic_cast to T.
Definition def.h:493
const Def * var(nat_t a, nat_t i) noexcept
Definition def.h:425
std::string unique_name() const
name + "_" + Def::gid
Definition def.cpp:584
const T * isa_imm() const
Definition def.h:487
A function.
Definition lam.h:110
Lam * unset()
Definition lam.h:179
const Def * filter() const
Definition lam.h:122
Lam * set(Filter filter, const Def *body)
Definition lam.cpp:29
Lam * set_filter(Filter)
Set filter first.
Definition lam.cpp:28
const Pi * type() const
Definition lam.h:130
const Def * body() const
Definition lam.h:123
const Def * codom() const
Definition lam.h:132
static const Node * lca(const Node *n, const Node *m)
Least common ancestor of n and m.
Definition nest.cpp:105
World & world()
Definition pass.h:77
std::string_view name() const
Definition pass.h:80
Log & log() const
Definition pass.h:79
const Def * insert(const Def *d, const Def *i, const Def *val)
Definition world.cpp:452
const Type * type(const Def *level)
Definition world.cpp:112
const Def * app(const Def *callee, const Def *arg)
Definition world.cpp:205
void for_each(bool elide_empty, std::function< void(Def *)>, bool schedule=false)
Definition world.cpp:699
void debug_dump()
Dump in Debug build if World::log::level is Log::Level::Debug.
Definition dump.cpp:590
Sym append_suffix(Sym name, std::string suffix)
Appends a suffix or an increasing number if the suffix already exists.
Definition world.cpp:663
const Def * var(Def *mut)
Definition world.cpp:184
Lam * mut_con(const Def *dom)
Definition world.h:418
void dump(std::ostream &os)
Dump to os.
Definition dump.cpp:566
void start() final
Actual entry.
Definition cps2ds.cpp:16
The direct style Plugin
Definition direct.h:8
Vector< const Def * > DefVec
Definition def.h:79
@ Lam
Definition def.h:109
@ App
Definition def.h:109