MimIR
MimIR is my Intermediate Representation
Loading...
Searching...
No Matches
lower.cpp
Go to the documentation of this file.
2
3#include "mim/def.h"
4#include "mim/lam.h"
5
6#include "mim/util/types.h"
7
9
10#include "absl/container/btree_set.h"
11#include "absl/container/flat_hash_set.h"
12
14
15const Def* Lower::lower_via_impl(const App* app, const Def* impl_annex) {
16 auto& w = new_world();
17
18 // Walk the curry chain (innermost App outermost in syntax) to collect the args
19 // in the order they were applied.
20 DefVec args;
21 const Def* head = app;
22 while (auto h = head->isa<App>()) {
23 args.push_back(rewrite(h->arg()));
24 head = h->callee();
25 }
26 std::reverse(args.begin(), args.end());
27
28 auto impl = impl_annex;
29 for (auto a : args)
30 impl = w.app(impl, a);
31
32 // The `_impl` is a `lam`, so applying it triggers beta-reduction. Each `_impl`
33 // body references the `_impl` variants of its dependencies directly, so the
34 // chain bottoms out at the low-level axioms (`map_reduce`, …) in one go.
35 return impl;
36}
37
38const Def* Lower::lower_broadcast_in_dim(const App* app) {
39 auto& w = new_world();
40 auto c = rewrite(app->callee());
41 auto arg = rewrite(app->arg());
42
43 auto [s_in, s_out, input, index] = arg->projs<4>();
44 auto callee = c->as<App>();
45 auto [T, r_in, r_out] = callee->args<3>();
46 DLOG("lower_broadcast_in_dim");
47 DLOG(" s_out = {} : {}", s_out, s_out->type());
48 DLOG(" input = {} : {}", input, input->type());
49 DLOG(" index = {} : {}", index, index->type());
50 DLOG(" T = {} : {}", T, T->type());
51 DLOG(" r_in = {} : {}", r_in, r_in->type());
52 DLOG(" r_out = {} : {}", r_out, r_out->type());
53 DLOG(" s_in = {} : {}", s_in, s_in->type());
54
55 auto r_in_lit = Lit::isa<u64>(r_in);
56 if (!r_in_lit) return nullptr;
57 auto r_out_lit = Lit::isa<u64>(r_out);
58 if (!r_out_lit) return nullptr;
59 auto r_out_nat = *r_out_lit;
60 auto r_in_nat = *r_in_lit;
61
62 auto s_tr_vec = DefVec(r_out_nat, [&](size_t i) {
63 if (i < r_in_nat) return s_in->proj(r_in_nat, i);
64 return w.lit_nat_1()->as<Def>();
65 });
66 auto s_tr = w.tuple(s_tr_vec);
67
68 absl::btree_set<u64> set_perm;
69 absl::flat_hash_map<u64, u64> map_perm;
70 for (u64 i = 0; i < r_out_nat; ++i)
71 set_perm.insert(i);
72 for (u64 i = 0; i < r_in_nat; ++i) {
73 auto idx = index->proj(r_in_nat, i);
74 auto idx_lit = Lit::isa(idx);
75 if (!idx_lit) return nullptr;
76 u64 idx_nat = *idx_lit;
77
78 map_perm[idx_nat] = i;
79
80 set_perm.erase(idx_nat);
81 }
82 for (u64 j = r_in_nat; auto i : set_perm) {
83 map_perm[i] = j;
84 j++;
85 }
86 auto permutation_vec = DefVec(r_out_nat, [&](size_t i) { return w.lit_idx(r_out_nat, map_perm[i]); });
87 auto permutation = w.tuple(permutation_vec);
88
89 // Apply `transpose_impl` directly so the lam expands to `%tensor.map_reduce`
90 // immediately — no further high-level lowering is needed for the transpose.
91 auto tr = w.annex<tensor::transpose_impl>();
92 tr = w.app(tr, {T, r_out, s_tr});
93 tr = w.app(tr, {input, permutation});
94
95 auto s_bc_vec = DefVec(r_out_nat, [&](size_t i) { return s_tr->proj(r_out_nat, map_perm.at(i)); });
96 auto s_bc = w.tuple(s_bc_vec);
97
98 auto bc = w.annex<tensor::broadcast>();
99 bc = w.app(bc, {T, r_out});
100 bc = w.app(bc, {s_bc, s_out, tr});
101
102 // The resulting `%tensor.broadcast` is low-level and is handled by the
103 // `LowerMapReduce` phase.
104 return bc;
105}
106
107const Def* Lower::rewrite_imm_App(const App* app) {
108 auto& w = new_world();
109
110 if (auto bid = Axm::isa<tensor::broadcast_in_dim>(app)) {
111 if (auto res = lower_broadcast_in_dim(bid)) return res;
112 } else if (Axm::isa<tensor::product_2d>(app)) {
113 return lower_via_impl(app, w.annex<tensor::product_2d_impl>());
114 } else if (Axm::isa<tensor::dot_product>(app)) {
115 return lower_via_impl(app, w.annex<tensor::dot_product_impl>());
116 } else if (Axm::isa<tensor::transpose>(app)) {
117 return lower_via_impl(app, w.annex<tensor::transpose_impl>());
118 } else if (Axm::isa<tensor::transpose_2d>(app)) {
119 return lower_via_impl(app, w.annex<tensor::transpose_2d_impl>());
120 } else if (Axm::isa<tensor::map>(app)) {
121 return lower_via_impl(app, w.annex<tensor::map_impl>());
122 } else if (Axm::isa<tensor::map_reduce_ds>(app)) {
123 return lower_via_impl(app, w.annex<tensor::map_reduce_ds_impl>());
124 } else if (Axm::isa<tensor::unary>(app)) {
125 return lower_via_impl(app, w.annex<tensor::unary_impl>());
126 } else if (Axm::isa<tensor::binary>(app)) {
127 return lower_via_impl(app, w.annex<tensor::binary_impl>());
128 } else if (Axm::isa<tensor::select>(app)) {
129 return lower_via_impl(app, w.annex<tensor::select_impl>());
130 }
131 return RWPhase::rewrite_imm_App(app);
132}
133
134} // namespace mim::plug::tensor::phase
static auto isa(const Def *def)
Definition axm.h:107
Base class for all Defs.
Definition def.h:246
static std::optional< T > isa(const Def *def)
Definition def.h:838
World & new_world()
Create new Defs into this.
Definition phase.h:243
virtual const Def * rewrite(const Def *)
Definition rewrite.cpp:56
const Def * rewrite_imm_App(const App *) final
Definition lower.cpp:107
#define DLOG(...)
Vaporizes to nothingness in Debug build.
Definition log.h:94
Vector< const Def * > DefVec
Definition def.h:79
uint64_t u64
Definition types.h:27
@ App
Definition def.h:109