MimIR 0.1
MimIR is my Intermediate Representation
Loading...
Searching...
No Matches
rewrite.h
Go to the documentation of this file.
1#pragma once
2
3#include <memory>
4
5#include "mim/check.h"
6#include "mim/def.h"
7#include "mim/lam.h"
8#include "mim/lattice.h"
9#include "mim/rule.h"
10#include "mim/tuple.h"
11
12namespace mim {
13
14class World;
15
16/// Recurseivly rebuilds part of a program **into** the provided World w.r.t.\ Rewriter::map.
17/// This World may be different than the World we started with.
18/// @see @ref rewriter
19class Rewriter {
20public:
21 /// @name Construction & Destruction
22 ///@{
23 Rewriter(std::unique_ptr<World>&& ptr)
24 : ptr_(std::move(ptr))
25 , world_(ptr_.get()) {
26 push(); // create root map
27 }
29 : world_(&world) {
30 push(); // create root map
31 }
32 virtual ~Rewriter() = default;
33
34 void reset(std::unique_ptr<World>&& ptr) {
35 ptr_ = std::move(ptr);
36 world_ = ptr_.get();
37 reset();
38 }
39 void reset() {
40 pop();
41 assert(old2news_.empty());
42 push();
43 }
44 ///@}
45
46 /// @name Getters
47 ///@{
48 World& world() { return *world_; }
49 ///@}
50
51 /// @name Push / Pop
52 ///@{
53 virtual void push() { old2news_.emplace_back(Def2Def{}); }
54 virtual void pop() { old2news_.pop_back(); }
55 ///@}
56
57 /// @name Map / Lookup
58 /// Map @p old_def to @p new_def and returns @p new_def.
59 ///@{
60 virtual const Def* map(const Def* old_def, const Def* new_def) { return old2news_.back()[old_def] = new_def; }
61
62 // clang-format off
63 const Def* map(const Def* old_def , Defs new_defs);
64 const Def* map(Defs old_defs, const Def* new_def );
65 const Def* map(Defs old_defs, Defs new_defs);
66 // clang-format on
67
68 /// Lookup `old_def` by searching in reverse through the stack of maps.
69 /// @returns `nullptr` if nothing was found.
70 virtual const Def* lookup(const Def* old_def) {
71 for (const auto& old2new : old2news_ | std::views::reverse)
72 if (auto i = old2new.find(old_def); i != old2new.end()) return i->second;
73 return nullptr;
74 }
75 ///@}
76
77 /// @name rewrite
78 /// Recursively rewrite old Def%s.
79 ///@{
80 virtual const Def* rewrite(const Def*);
81 virtual const Def* rewrite_imm(const Def*);
82 virtual const Def* rewrite_mut(Def*);
83 virtual const Def* rewrite_stub(Def*, Def*);
84 virtual DefVec rewrite(Defs);
85
86#define CODE_IMM(N) virtual const Def* rewrite_imm_##N(const N*);
87#define CODE_MUT(N) virtual const Def* rewrite_mut_##N(N*);
90#undef CODE_IMM
91#undef CODE_MUT
92
93 virtual const Def* rewrite_imm_Seq(const Seq* seq);
94 virtual const Def* rewrite_mut_Seq(Seq* seq);
95 ///@}
96
97 friend void swap(Rewriter& rw1, Rewriter& rw2) noexcept {
98 using std::swap;
99 swap(rw1.old2news_, rw2.old2news_);
100 // Do NOT swap ptr_ and world_: they are back pointers!
101 }
102
103private:
104 std::unique_ptr<World> ptr_;
105 World* world_;
106
107protected:
108 std::deque<Def2Def> old2news_;
109};
110
111/// Extends Rewriter for variable substitution.
112/// @see @ref rewriter
113class VarRewriter : public Rewriter {
114public:
115 /// @name Construction
116 ///@{
119 VarRewriter(const Var* var, const Def* arg)
120 : Rewriter(arg->world()) {
121 add(var, arg);
122 }
123
124 // Add initial mapping from @pvar -> @p arg.
125 VarRewriter& add(const Var* var, const Def* arg) {
126 map(var, arg);
127 vars_.emplace_back(var);
128 return *this;
129 }
130 ///@}
131
132 /// @name push / pop
133 ///@{
134 void push() final { Rewriter::push(), vars_.emplace_back(Vars()); }
135 void pop() final { vars_.pop_back(), Rewriter::pop(); }
136 ///@}
137
138 /// @name rewrite
139 ///@{
140 const Def* rewrite(const Def*) final;
141 const Def* rewrite_mut(Def*) final;
142 ///@}
143
144 friend void swap(VarRewriter& vrw1, VarRewriter& vrw2) noexcept {
145 using std::swap;
146 swap(static_cast<Rewriter&>(vrw1), static_cast<Rewriter&>(vrw2));
147 swap(vrw1.vars_, vrw2.vars_);
148 }
149
150private:
151 bool has_intersection(const Def* old_def) {
152 for (const auto& vars : vars_ | std::views::reverse)
153 if (vars.has_intersection(old_def->free_vars())) return true;
154 return false;
155 }
156
157 Vector<Vars> vars_;
158};
159
160class Zonker : public Rewriter {
161public:
162 /// @name C'tor
163 ///@{
166 ///@}
167
168 /// @name Stack of Maps
169 ///@{
170 const Def* map(const Def* old_def, const Def* new_def) final;
171 const Def* lookup(const Def* old_def) final;
172 ///@}
173
174 /// @name rewrite
175 ///@{
176 const Def* rewrite(const Def*) final;
177 const Def* rewrite_mut(Def* mut) final { return map(mut, mut); }
178 const Def* rewire_mut(Def*);
179 ///@}
180
181 friend void swap(Zonker& z1, Zonker& z2) noexcept {
182 using std::swap;
183 swap(static_cast<Rewriter&>(z1), static_cast<Rewriter&>(z2));
184 }
185
186private:
187 const Def* get(const Def* old_def) {
188 auto& old2new = old2news_.back();
189 if (auto i = old2new.find(old_def); i != old2new.end()) return i->second;
190 return nullptr;
191 }
192};
193
194} // namespace mim
Base class for all Defs.
Definition def.h:252
Vars free_vars() const
Compute a global solution by transitively following mutables as well.
Definition def.cpp:337
virtual ~Rewriter()=default
friend void swap(Rewriter &rw1, Rewriter &rw2) noexcept
Definition rewrite.h:97
virtual const Def * rewrite_imm_Seq(const Seq *seq)
Definition rewrite.cpp:135
virtual const Def * rewrite_mut_Seq(Seq *seq)
Definition rewrite.cpp:141
World & world()
Definition rewrite.h:48
virtual const Def * rewrite_mut(Def *)
Definition rewrite.cpp:48
virtual void push()
Definition rewrite.h:53
virtual const Def * rewrite_stub(Def *, Def *)
Definition rewrite.cpp:166
virtual const Def * map(const Def *old_def, const Def *new_def)
Definition rewrite.h:60
virtual void pop()
Definition rewrite.h:54
void reset(std::unique_ptr< World > &&ptr)
Definition rewrite.h:34
Rewriter(World &world)
Definition rewrite.h:28
void reset()
Definition rewrite.h:39
std::deque< Def2Def > old2news_
Definition rewrite.h:108
virtual const Def * rewrite_imm(const Def *)
Definition rewrite.cpp:39
Rewriter(std::unique_ptr< World > &&ptr)
Definition rewrite.h:23
virtual const Def * rewrite(const Def *)
Definition rewrite.cpp:27
virtual const Def * lookup(const Def *old_def)
Lookup old_def by searching in reverse through the stack of maps.
Definition rewrite.h:70
VarRewriter & add(const Var *var, const Def *arg)
Definition rewrite.h:125
void pop() final
Definition rewrite.h:135
const Def * rewrite_mut(Def *) final
Definition rewrite.cpp:193
void push() final
Definition rewrite.h:134
VarRewriter(World &world)
Definition rewrite.h:117
friend void swap(VarRewriter &vrw1, VarRewriter &vrw2) noexcept
Definition rewrite.h:144
const Def * rewrite(const Def *) final
Definition rewrite.cpp:182
VarRewriter(const Var *var, const Def *arg)
Definition rewrite.h:119
A variable introduced by a binder (mutable).
Definition def.h:719
The World represents the whole program and manages creation of MimIR nodes (Defs).
Definition world.h:34
const Def * rewire_mut(Def *)
Definition rewrite.cpp:248
friend void swap(Zonker &z1, Zonker &z2) noexcept
Definition rewrite.h:181
const Def * lookup(const Def *old_def) final
Lookup old_def by searching in reverse through the stack of maps.
Definition rewrite.cpp:212
const Def * rewrite(const Def *) final
Definition rewrite.cpp:239
const Def * map(const Def *old_def, const Def *new_def) final
Definition rewrite.cpp:206
Zonker(World &world)
Definition rewrite.h:164
const Def * rewrite_mut(Def *mut) final
Definition rewrite.h:177
#define MIM_MUT_NODE(X)
Definition def.h:53
#define MIM_IMM_NODE(X)
Definition def.h:37
Definition ast.h:14
View< const Def * > Defs
Definition def.h:77
DefMap< const Def * > Def2Def
Definition def.h:76
Vector< const Def * > DefVec
Definition def.h:78
constexpr decltype(auto) get(Span< T, N > span) noexcept
Definition span.h:115
Sets< const Var >::Set Vars
Definition def.h:98
Vector(I, I, A=A()) -> Vector< typename std::iterator_traits< I >::value_type, Default_Inlined_Size< typename std::iterator_traits< I >::value_type >, A >
Definition span.h:122
#define CODE_MUT(N)
Definition rewrite.h:87
#define CODE_IMM(N)
Definition rewrite.h:86