MimIR
MimIR is my Intermediate Representation
Loading...
Searching...
No Matches
fuse.cpp
Go to the documentation of this file.
2
3#include "mim/def.h"
4#include "mim/lam.h"
5
6#include "mim/util/types.h"
7
11
13
14// Fuses an outer `tensor.map_reduce` with any number of its inputs — and, recursively, any
15// fusible inputs of those inputs — whenever each such input is itself a `tensor.map_reduce`
16// whose subs only reference output indices (i.e. the inner reduction is empty, so reading the
17// inner tensor at a position is just a single call to the inner combination function).
18//
19// Outer: map_reduce nis_o (To, Ro) So (Tis_o, Ris_o, Sis_o) (f_o, init_o) subs_o is_o
20// Inner: map_reduce nis_k (To_k, Ro_k) So_k (Tis_k, Ris_k, Sis_k) (f_k, init_k) subs_k is_k
21// for every fusible input — possibly nested inside another fusible input
22//
23// Result: map_reduce nis_new (To, Ro) So (Tis_new, Ris_new, Sis_new) (f_new, init_o) subs_new
24// is_new
25//
26// The collection phase walks the tree of fusible inner map_reduces below `app` once, producing
27// a flat list of *leaves* (the surviving tensor inputs of the fused mr) and *inner nodes* (the
28// inner combiners that must run before `f_o`). Each fusible input is replaced by its inner's
29// inputs, with subs remapped through the outer subs at that position; the remapping composes
30// across nested levels. The new combination function `f_new` invokes every inner combiner in
31// post-order — each starting from its own init — and finally invokes `f_o`, threading inner
32// results into the corresponding outer input slots.
33const Def* Fuse::fuse_map_reduce(const App* app) {
34 auto outer_callee = rewrite(app->callee())->as<App>();
35
36 auto [nis, ToRo, So, TisRisSis, comb_init, subs] = outer_callee->uncurry_args<6>();
37
38 auto [comb, init] = comb_init->projs<2>();
39 auto [Tis, Ris, Sis] = TisRisSis->projs<3>();
40 auto [To, Ro] = ToRo->projs<2>();
41 auto is = rewrite(app->arg());
42
43 DLOG("considering map_reduce for fusion:");
44 DLOG(" subs = {} : {}", subs, subs->type());
45 DLOG(" comb = {} : {}", comb, comb->type());
46 DLOG(" init = {} : {}", init, init->type());
47 DLOG(" Tis = {} : {}", Tis, Tis->type());
48 DLOG(" Ris = {} : {}", Ris, Ris->type());
49 DLOG(" Sis = {} : {}", Sis, Sis->type());
50 DLOG(" To = {} : {}", To, To->type());
51 DLOG(" Ro = {} : {}", Ro, Ro->type());
52 DLOG(" nis = {} : {}", nis, nis->type());
53 DLOG(" is = {} : {}", is, is->type());
54
55 auto nis_lit = Lit::isa<u64>(nis);
56 if (!nis_lit) return nullptr;
57 auto nis_nat = *nis_lit;
58
59 struct InnerInfo {
60 bool fusible = false;
61 const Def* subs = nullptr;
62 const Def* comb = nullptr;
63 const Def* init = nullptr;
64 const Def* Tis = nullptr;
65 const Def* Ris = nullptr;
66 const Def* Sis = nullptr;
67 const Def* To = nullptr;
68 u64 nis = 0;
69 const Def* is = nullptr;
70 Vector<u64> Ris_nats;
71 u64 outer_Ris_nat = 0;
72 const Def* outer_subs = nullptr;
73 };
74
75 Vector<InnerInfo> infos(nis_nat);
76 bool any_fusible = false;
77
78 for (u64 k = 0; k < nis_nat; ++k) {
79 auto input_k = is->proj(nis_nat, k);
80 auto inner = Axm::isa<tensor::map_reduce>(input_k);
81 if (!inner) continue;
82
83 auto [inner_nis, inner_ToRo, inner_So, inner_TisRisSis, inner_comb_init, inner_subs, inner_is] = inner->uncurry_args<7>();
84 auto [inner_comb, inner_init] = inner_comb_init->projs<2>();
85 auto [inner_Tis, inner_Ris, inner_Sis] = inner_TisRisSis->projs<3>();
86 auto [inner_To, inner_Ro] = inner_ToRo->projs<2>();
87
88 auto inner_nis_nat = Lit::isa<u64>(inner_nis);
89 auto inner_Ro_nat = Lit::isa<u64>(inner_Ro);
90 if (!inner_nis_nat || !inner_Ro_nat) continue;
91
92 // We can only fuse when the inner has no reduction dimensions, i.e. all subs are < Ro_i.
93 // In that case the inner tensor at any position is just a single call of `inner_comb`.
94 Vector<u64> inner_Ris_nats(*inner_nis_nat);
95 bool fusible = true;
96 for (u64 l = 0; l < *inner_nis_nat && fusible; ++l) {
97 auto Ris_l_lit = Lit::isa<u64>(inner_Ris->proj(*inner_nis_nat, l));
98 if (!Ris_l_lit) {
99 fusible = false;
100 break;
101 }
102 inner_Ris_nats[l] = *Ris_l_lit;
103 auto inner_subs_l = inner_subs->proj(*inner_nis_nat, l);
104 for (u64 j = 0; j < inner_Ris_nats[l]; ++j) {
105 auto idx_lit = Lit::isa<u64>(inner_subs_l->proj(inner_Ris_nats[l], j));
106 if (!idx_lit || *idx_lit >= *inner_Ro_nat) {
107 fusible = false;
108 break;
109 }
110 }
111 }
112 if (!fusible) continue;
113
114 // We need the outer subs for input `k` to be indexable by literal positions.
115 auto Ris_k_lit = Lit::isa<u64>(Ris->proj(nis_nat, k));
116 if (!Ris_k_lit) continue;
117
118 auto& info = infos[k];
119 info.fusible = true;
120 info.subs = inner_subs;
121 info.comb = inner_comb;
122 info.init = inner_init;
123 info.Tis = inner_Tis;
124 info.Ris = inner_Ris;
125 info.Sis = inner_Sis;
126 info.To = inner_To;
127 info.nis = *inner_nis_nat;
128 info.is = inner_is;
129 info.Ris_nats = std::move(inner_Ris_nats);
130 info.outer_Ris_nat = *Ris_k_lit;
131 info.outer_subs = subs->proj(nis_nat, k);
132 any_fusible = true;
133 }
134
135 if (!any_fusible) return nullptr;
136
137 auto& w = new_world();
138
139 // Each fusible outer input k is replaced by `infos[k].nis` slots in the fused input list;
140 // every non-fusible input retains exactly one slot. `new_pos[i]` is the start of input i's
141 // slot range in the fused list.
142 Vector<u64> new_pos(nis_nat);
143 u64 new_nis_nat = 0;
144 for (u64 i = 0; i < nis_nat; ++i) {
145 new_pos[i] = new_nis_nat;
146 new_nis_nat += infos[i].fusible ? infos[i].nis : 1;
147 }
148
149 DefVec new_Tis_vec(new_nis_nat);
150 DefVec new_Ris_vec(new_nis_nat);
151 DefVec new_Sis_vec(new_nis_nat);
152 DefVec new_subs_vec(new_nis_nat);
153 DefVec new_is_vec(new_nis_nat);
154
155 for (u64 i = 0; i < nis_nat; ++i) {
156 if (infos[i].fusible) {
157 const auto& info = infos[i];
158 for (u64 l = 0; l < info.nis; ++l) {
159 auto pos = new_pos[i] + l;
160 new_Tis_vec[pos] = info.Tis->proj(info.nis, l);
161 new_Ris_vec[pos] = info.Ris->proj(info.nis, l);
162 new_Sis_vec[pos] = info.Sis->proj(info.nis, l);
163 new_is_vec[pos] = info.is->proj(info.nis, l);
164
165 auto inner_subs_l = info.subs->proj(info.nis, l);
166 DefVec subs_l_vec(info.Ris_nats[l]);
167 for (u64 j = 0; j < info.Ris_nats[l]; ++j) {
168 // The inner refers to one of its own output indices; remap that into the
169 // outer index space via the outer subs at position `i`.
170 auto inner_idx = *Lit::isa(inner_subs_l->proj(info.Ris_nats[l], j));
171 subs_l_vec[j] = rewrite(info.outer_subs->proj(info.outer_Ris_nat, inner_idx));
172 }
173 new_subs_vec[pos] = w.tuple(subs_l_vec);
174 }
175 } else {
176 auto pos = new_pos[i];
177 new_Tis_vec[pos] = Tis->proj(nis_nat, i);
178 new_Ris_vec[pos] = Ris->proj(nis_nat, i);
179 new_Sis_vec[pos] = Sis->proj(nis_nat, i);
180 new_subs_vec[pos] = subs->proj(nis_nat, i);
181 new_is_vec[pos] = is->proj(nis_nat, i);
182 }
183 }
184
185 auto new_Tis = w.tuple(new_Tis_vec);
186 auto new_Ris = w.tuple(new_Ris_vec);
187 auto new_Sis = w.tuple(new_Sis_vec);
188 auto new_subs = w.tuple(new_subs_vec);
189 auto new_is = w.tuple(new_is_vec);
190
191 auto new_nis_def = w.lit_nat(new_nis_nat);
192
193 // Build the fused combination function:
194 //
195 // cn f_new(data: [To, [new_Tis ...]], ret: cn To) =
196 // cn inner_ret_<r>(value_<r>: inner_To_<r>) = ...
197 // f_<fused[0]>((init_<fused[0]>, inner_inputs_<fused[0]>), inner_ret_0)
198 //
199 // inner_ret_<r>(value_<r>):
200 // if r is not the last fused input:
201 // f_<fused[r+1]>((init_<fused[r+1]>, inner_inputs_<fused[r+1]>), inner_ret_<r+1>)
202 // else:
203 // f_o((acc, outer_inputs), ret)
204 //
205 // `outer_inputs[i]` is `value_<r>` when input i is the r-th fused input, and the
206 // corresponding `new_in` slot otherwise. Each `inner_ret_<r>` closes over the prior
207 // `value_<j>`s as free variables — those are bound by the dynamic call chain.
208 auto inputs_sigma = w.sigma(new_Tis_vec);
209 auto data_sigma = w.sigma({To, inputs_sigma});
210 auto ret_cn_type = w.cn(To);
211 auto new_comb = w.mut_con({data_sigma, ret_cn_type})->set("fused_comb");
212 auto new_data = new_comb->var(0);
213 auto new_ret = new_comb->var(1);
214 auto new_acc = new_data->proj(2, 0);
215 auto new_in = new_data->proj(2, 1);
216
217 Vector<u64> fused_indices;
218 for (u64 i = 0; i < nis_nat; ++i)
219 if (infos[i].fusible) fused_indices.emplace_back(i);
220
221 Vector<Lam*> inner_rets(fused_indices.size());
222 Vector<const Def*> inner_values(fused_indices.size());
223 for (size_t r = 0; r < fused_indices.size(); ++r) {
224 auto new_inner_To = infos[fused_indices[r]].To;
225 inner_rets[r] = w.mut_con(new_inner_To)->set("inner_ret");
226 inner_values[r] = inner_rets[r]->var(0);
227 }
228
229 // Map each outer input position to its value at the f_o call site.
230 DefVec outer_inputs_vec(nis_nat);
231 {
232 size_t r = 0;
233 for (u64 i = 0; i < nis_nat; ++i)
234 if (infos[i].fusible)
235 outer_inputs_vec[i] = inner_values[r++];
236 else
237 outer_inputs_vec[i] = new_in->proj(new_nis_nat, new_pos[i]);
238 }
239
240 // Chain: caller for fused step r is new_comb (r==0) or inner_rets[r-1] (otherwise).
241 for (size_t r = 0; r < fused_indices.size(); ++r) {
242 auto k = fused_indices[r];
243 auto new_inner_comb = infos[k].comb;
244 auto new_inner_init = infos[k].init;
245
246 DefVec inner_inputs_vec(infos[k].nis);
247 for (u64 l = 0; l < infos[k].nis; ++l)
248 inner_inputs_vec[l] = new_in->proj(new_nis_nat, new_pos[k] + l);
249
250 Lam* caller = (r == 0) ? new_comb : inner_rets[r - 1];
251 caller->app(true, new_inner_comb, {w.tuple({new_inner_init, w.tuple(inner_inputs_vec)}), inner_rets[r]});
252 }
253
254 // After every inner combiner has produced its value, call the outer combiner.
255 inner_rets.back()->app(true, comb, {w.tuple({new_acc, w.tuple(outer_inputs_vec)}), new_ret});
256
257 // Construct the fused map_reduce.
258 auto mr = w.annex<tensor::map_reduce>();
259 mr = w.app(mr, new_nis_def);
260 mr = w.app(mr, {To, Ro});
261 mr = w.app(mr, So);
262 mr = w.app(mr, {new_Tis, new_Ris, new_Sis});
263 mr = w.app(mr, {new_comb, init});
264 mr = w.app(mr, new_subs);
265 mr = w.app(mr, new_is);
266
267 return mr;
268}
269
270const Def* Fuse::rewrite_imm_App(const App* app) {
271 if (auto mr = Axm::isa<tensor::map_reduce>(app)) {
272 if (auto res = fuse_map_reduce(mr)) {
273 DLOG("Fused map_reduce at {} into a new map_reduce {}", app, res);
274 return res;
275 }
276 }
277 return RWPhase::rewrite_imm_App(app);
278}
279
280} // namespace mim::plug::tensor::phase
static auto isa(const Def *def)
Definition axm.h:107
Base class for all Defs.
Definition def.h:246
auto projs(F f) const
Splits this Def via Def::projections into an Array (if A == std::dynamic_extent) or std::array (other...
Definition def.h:386
static std::optional< T > isa(const Def *def)
Definition def.h:838
World & new_world()
Create new Defs into this.
Definition phase.h:243
virtual const Def * rewrite(const Def *)
Definition rewrite.cpp:56
const Def * rewrite_imm_App(const App *) final
Definition fuse.cpp:270
#define DLOG(...)
Vaporizes to nothingness in Debug build.
Definition log.h:94
Vector< const Def * > DefVec
Definition def.h:79
uint64_t u64
Definition types.h:27
@ Lam
Definition def.h:109
@ App
Definition def.h:109
Vector(I, I, A=A()) -> Vector< typename std::iterator_traits< I >::value_type, Default_Inlined_Size< typename std::iterator_traits< I >::value_type >, A >