MimIR
MimIR is my Intermediate Representation
Loading...
Searching...
No Matches
rewrite.h
Go to the documentation of this file.
1#pragma once
2
3#include <memory>
4
5#include "mim/check.h"
6#include "mim/def.h"
7#include "mim/lam.h"
8#include "mim/lattice.h"
9#include "mim/rule.h"
10#include "mim/tuple.h"
11
12namespace mim {
13
14class World;
15
16/// Recurseivly rebuilds part of a program **into** the provided World w.r.t.\ Rewriter::map.
17/// This World may be different than the World we started with.
18/// @see @ref rewriter
19class Rewriter {
20public:
21 /// @name Construction & Destruction
22 ///@{
23 Rewriter(std::unique_ptr<World>&& ptr);
25 virtual ~Rewriter();
26
27 void reset(std::unique_ptr<World>&& ptr);
28 void reset();
29 ///@}
30
31 /// @name Getters
32 ///@{
33 World& world() { return *world_; }
34 ///@}
35
36 /// @name Push / Pop
37 ///@{
38 virtual void push() { old2news_.emplace_back(Def2Def{}); }
39 virtual void pop() { old2news_.pop_back(); }
40 ///@}
41
42 /// @name Map / Lookup
43 /// Map @p old_def to @p new_def and returns @p new_def.
44 ///@{
45 virtual const Def* map(const Def* old_def, const Def* new_def) { return old2news_.back()[old_def] = new_def; }
46
47 // clang-format off
48 const Def* map(const Def* old_def , Defs new_defs);
49 const Def* map(Defs old_defs, const Def* new_def );
50 const Def* map(Defs old_defs, Defs new_defs);
51 // clang-format on
52
53 /// Lookup `old_def` by searching in reverse through the stack of maps.
54 /// @returns `nullptr` if nothing was found.
55 virtual const Def* lookup(const Def* old_def) {
56 for (const auto& old2new : old2news_ | std::views::reverse)
57 if (auto i = old2new.find(old_def); i != old2new.end()) return i->second;
58 return nullptr;
59 }
60 ///@}
61
62 /// @name rewrite
63 /// Recursively rewrite old Def%s.
64 ///@{
65 virtual const Def* rewrite(const Def*);
66 virtual const Def* rewrite_imm(const Def*);
67 virtual const Def* rewrite_mut(Def*);
68 virtual const Def* rewrite_stub(Def*, Def*);
69 virtual DefVec rewrite(Defs);
70
71#define CODE_IMM(N) virtual const Def* rewrite_imm_##N(const N*);
72#define CODE_MUT(N) virtual const Def* rewrite_mut_##N(N*);
75#undef CODE_IMM
76#undef CODE_MUT
77
78 virtual const Def* rewrite_imm_Seq(const Seq* seq);
79 virtual const Def* rewrite_mut_Seq(Seq* seq);
80 ///@}
81
82 friend void swap(Rewriter& rw1, Rewriter& rw2) noexcept {
83 using std::swap;
84 swap(rw1.old2news_, rw2.old2news_);
85 // Do NOT swap ptr_ and world_: they are back pointers!
86 }
87
88private:
89 std::unique_ptr<World> ptr_;
90 World* world_;
91
92protected:
93 std::deque<Def2Def> old2news_;
94};
95
96/// Extends Rewriter for variable substitution.
97/// @see @ref rewriter
98class VarRewriter : public Rewriter {
99public:
100 /// @name Construction
101 ///@{
104 VarRewriter(const Var* var, const Def* arg)
105 : Rewriter(arg->world()) {
106 add(var, arg);
107 }
108
109 // Add initial mapping from @pvar -> @p arg.
110 VarRewriter& add(const Var* var, const Def* arg) {
111 map(var, arg);
112 vars_.emplace_back(var);
113 return *this;
114 }
115 ///@}
116
117 /// @name push / pop
118 ///@{
119 void push() final { Rewriter::push(), vars_.emplace_back(Vars()); }
120 void pop() final { vars_.pop_back(), Rewriter::pop(); }
121 ///@}
122
123 /// @name rewrite
124 ///@{
125 const Def* rewrite(const Def*) final;
126 const Def* rewrite_mut(Def*) final;
127 ///@}
128
129 friend void swap(VarRewriter& vrw1, VarRewriter& vrw2) noexcept {
130 using std::swap;
131 swap(static_cast<Rewriter&>(vrw1), static_cast<Rewriter&>(vrw2));
132 swap(vrw1.vars_, vrw2.vars_);
133 }
134
135private:
136 bool has_intersection(const Def* old_def) {
137 for (const auto& vars : vars_ | std::views::reverse)
138 if (vars.has_intersection(old_def->free_vars())) return true;
139 return false;
140 }
141
142 Vector<Vars> vars_;
143};
144
145class Zonker : public Rewriter {
146public:
147 /// @name C'tor
148 ///@{
151 ///@}
152
153 /// @name Stack of Maps
154 ///@{
155 const Def* map(const Def* old_def, const Def* new_def) final;
156 const Def* lookup(const Def* old_def) final;
157 ///@}
158
159 /// @name rewrite
160 ///@{
161 const Def* rewrite(const Def*) final;
162 const Def* rewrite_mut(Def* mut) final { return map(mut, mut); }
163 const Def* rewire_mut(Def*);
164 ///@}
165
166 friend void swap(Zonker& z1, Zonker& z2) noexcept {
167 using std::swap;
168 swap(static_cast<Rewriter&>(z1), static_cast<Rewriter&>(z2));
169 }
170
171private:
172 const Def* get(const Def* old_def) {
173 auto& old2new = old2news_.back();
174 if (auto i = old2new.find(old_def); i != old2new.end()) return i->second;
175 return nullptr;
176 }
177};
178
179} // namespace mim
Base class for all Defs.
Definition def.h:246
Vars free_vars() const
Compute a global solution by transitively following mutables as well.
Definition def.cpp:337
friend void swap(Rewriter &rw1, Rewriter &rw2) noexcept
Definition rewrite.h:82
virtual const Def * rewrite_imm_Seq(const Seq *seq)
Definition rewrite.cpp:230
virtual const Def * rewrite_mut_Seq(Seq *seq)
Definition rewrite.cpp:236
World & world()
Definition rewrite.h:33
virtual const Def * rewrite_mut(Def *)
Definition rewrite.cpp:77
virtual void push()
Definition rewrite.h:38
virtual const Def * rewrite_stub(Def *, Def *)
Definition rewrite.cpp:261
virtual const Def * map(const Def *old_def, const Def *new_def)
Definition rewrite.h:45
virtual void pop()
Definition rewrite.h:39
virtual ~Rewriter()
void reset()
Definition rewrite.cpp:36
std::deque< Def2Def > old2news_
Definition rewrite.h:93
virtual const Def * rewrite_imm(const Def *)
Definition rewrite.cpp:68
Rewriter(std::unique_ptr< World > &&ptr)
Definition rewrite.cpp:17
virtual const Def * rewrite(const Def *)
Definition rewrite.cpp:56
virtual const Def * lookup(const Def *old_def)
Lookup old_def by searching in reverse through the stack of maps.
Definition rewrite.h:55
VarRewriter & add(const Var *var, const Def *arg)
Definition rewrite.h:110
void pop() final
Definition rewrite.h:120
const Def * rewrite_mut(Def *) final
Definition rewrite.cpp:288
void push() final
Definition rewrite.h:119
VarRewriter(World &world)
Definition rewrite.h:102
friend void swap(VarRewriter &vrw1, VarRewriter &vrw2) noexcept
Definition rewrite.h:129
const Def * rewrite(const Def *) final
Definition rewrite.cpp:277
VarRewriter(const Var *var, const Def *arg)
Definition rewrite.h:104
A variable introduced by a binder (mutable).
Definition def.h:714
The World represents the whole program and manages creation of MimIR nodes (Defs).
Definition world.h:36
const Def * rewire_mut(Def *)
Definition rewrite.cpp:343
friend void swap(Zonker &z1, Zonker &z2) noexcept
Definition rewrite.h:166
const Def * lookup(const Def *old_def) final
Lookup old_def by searching in reverse through the stack of maps.
Definition rewrite.cpp:307
const Def * rewrite(const Def *) final
Definition rewrite.cpp:334
const Def * map(const Def *old_def, const Def *new_def) final
Definition rewrite.cpp:301
Zonker(World &world)
Definition rewrite.h:149
const Def * rewrite_mut(Def *mut) final
Definition rewrite.h:162
#define MIM_MUT_NODE(X)
Definition def.h:54
#define MIM_IMM_NODE(X)
Definition def.h:38
Definition ast.h:14
View< const Def * > Defs
Definition def.h:78
DefMap< const Def * > Def2Def
Definition def.h:77
Vector< const Def * > DefVec
Definition def.h:79
constexpr decltype(auto) get(Span< T, N > span) noexcept
Definition span.h:115
Sets< const Var >::Set Vars
Definition def.h:99
Vector(I, I, A=A()) -> Vector< typename std::iterator_traits< I >::value_type, Default_Inlined_Size< typename std::iterator_traits< I >::value_type >, A >
#define CODE_MUT(N)
Definition rewrite.h:72
#define CODE_IMM(N)
Definition rewrite.h:71