33const Def* Fuse::fuse_map_reduce(
const App* app) {
34 auto outer_callee =
rewrite(app->callee())->as<
App>();
36 auto [nis, ToRo, So, TisRisSis, comb_init, subs] = outer_callee->uncurry_args<6>();
38 auto [comb,
init] = comb_init->
projs<2>();
39 auto [Tis, Ris, Sis] = TisRisSis->projs<3>();
40 auto [To, Ro] = ToRo->projs<2>();
43 DLOG(
"considering map_reduce for fusion:");
44 DLOG(
" subs = {} : {}", subs, subs->type());
45 DLOG(
" comb = {} : {}", comb, comb->type());
46 DLOG(
" init = {} : {}", init,
init->type());
47 DLOG(
" Tis = {} : {}", Tis, Tis->type());
48 DLOG(
" Ris = {} : {}", Ris, Ris->type());
49 DLOG(
" Sis = {} : {}", Sis, Sis->type());
50 DLOG(
" To = {} : {}", To, To->type());
51 DLOG(
" Ro = {} : {}", Ro, Ro->type());
52 DLOG(
" nis = {} : {}", nis, nis->type());
53 DLOG(
" is = {} : {}", is, is->type());
56 if (!nis_lit)
return nullptr;
57 auto nis_nat = *nis_lit;
61 const Def* subs =
nullptr;
62 const Def* comb =
nullptr;
63 const Def*
init =
nullptr;
64 const Def* Tis =
nullptr;
65 const Def* Ris =
nullptr;
66 const Def* Sis =
nullptr;
67 const Def* To =
nullptr;
69 const Def* is =
nullptr;
71 u64 outer_Ris_nat = 0;
72 const Def* outer_subs =
nullptr;
76 bool any_fusible =
false;
78 for (
u64 k = 0; k < nis_nat; ++k) {
79 auto input_k = is->proj(nis_nat, k);
83 auto [inner_nis, inner_ToRo, inner_So, inner_TisRisSis, inner_comb_init, inner_subs, inner_is] = inner->uncurry_args<7>();
84 auto [inner_comb, inner_init] = inner_comb_init->projs<2>();
85 auto [inner_Tis, inner_Ris, inner_Sis] = inner_TisRisSis->projs<3>();
86 auto [inner_To, inner_Ro] = inner_ToRo->projs<2>();
90 if (!inner_nis_nat || !inner_Ro_nat)
continue;
96 for (
u64 l = 0;
l < *inner_nis_nat && fusible; ++
l) {
97 auto Ris_l_lit =
Lit::isa<u64>(inner_Ris->proj(*inner_nis_nat, l));
102 inner_Ris_nats[
l] = *Ris_l_lit;
103 auto inner_subs_l = inner_subs->proj(*inner_nis_nat, l);
104 for (
u64 j = 0; j < inner_Ris_nats[
l]; ++j) {
105 auto idx_lit =
Lit::isa<u64>(inner_subs_l->proj(inner_Ris_nats[l], j));
106 if (!idx_lit || *idx_lit >= *inner_Ro_nat) {
112 if (!fusible)
continue;
116 if (!Ris_k_lit)
continue;
118 auto&
info = infos[k];
120 info.subs = inner_subs;
121 info.comb = inner_comb;
122 info.init = inner_init;
123 info.Tis = inner_Tis;
124 info.Ris = inner_Ris;
125 info.Sis = inner_Sis;
127 info.nis = *inner_nis_nat;
129 info.Ris_nats = std::move(inner_Ris_nats);
130 info.outer_Ris_nat = *Ris_k_lit;
131 info.outer_subs = subs->proj(nis_nat, k);
135 if (!any_fusible)
return nullptr;
144 for (
u64 i = 0; i < nis_nat; ++i) {
145 new_pos[i] = new_nis_nat;
146 new_nis_nat += infos[i].fusible ? infos[i].nis : 1;
149 DefVec new_Tis_vec(new_nis_nat);
150 DefVec new_Ris_vec(new_nis_nat);
151 DefVec new_Sis_vec(new_nis_nat);
152 DefVec new_subs_vec(new_nis_nat);
153 DefVec new_is_vec(new_nis_nat);
155 for (
u64 i = 0; i < nis_nat; ++i) {
156 if (infos[i].fusible) {
157 const auto&
info = infos[i];
159 auto pos = new_pos[i] +
l;
160 new_Tis_vec[pos] =
info.Tis->proj(
info.nis, l);
161 new_Ris_vec[pos] =
info.Ris->proj(
info.nis, l);
162 new_Sis_vec[pos] =
info.Sis->proj(
info.nis, l);
163 new_is_vec[pos] =
info.is->proj(
info.nis, l);
165 auto inner_subs_l =
info.subs->proj(
info.nis, l);
167 for (
u64 j = 0; j <
info.Ris_nats[
l]; ++j) {
170 auto inner_idx = *
Lit::isa(inner_subs_l->proj(
info.Ris_nats[l], j));
171 subs_l_vec[j] =
rewrite(
info.outer_subs->proj(
info.outer_Ris_nat, inner_idx));
173 new_subs_vec[pos] =
w.tuple(subs_l_vec);
176 auto pos = new_pos[i];
177 new_Tis_vec[pos] = Tis->proj(nis_nat, i);
178 new_Ris_vec[pos] = Ris->proj(nis_nat, i);
179 new_Sis_vec[pos] = Sis->proj(nis_nat, i);
180 new_subs_vec[pos] = subs->proj(nis_nat, i);
181 new_is_vec[pos] = is->proj(nis_nat, i);
185 auto new_Tis =
w.tuple(new_Tis_vec);
186 auto new_Ris =
w.tuple(new_Ris_vec);
187 auto new_Sis =
w.tuple(new_Sis_vec);
188 auto new_subs =
w.tuple(new_subs_vec);
189 auto new_is =
w.tuple(new_is_vec);
191 auto new_nis_def =
w.lit_nat(new_nis_nat);
208 auto inputs_sigma =
w.sigma(new_Tis_vec);
209 auto data_sigma =
w.sigma({To, inputs_sigma});
210 auto ret_cn_type =
w.cn(To);
211 auto new_comb =
w.mut_con({data_sigma, ret_cn_type})->
set(
"fused_comb");
212 auto new_data = new_comb->var(0);
213 auto new_ret = new_comb->var(1);
214 auto new_acc = new_data->proj(2, 0);
215 auto new_in = new_data->proj(2, 1);
218 for (
u64 i = 0; i < nis_nat; ++i)
219 if (infos[i].fusible) fused_indices.emplace_back(i);
222 Vector<const Def*> inner_values(fused_indices.size());
223 for (
size_t r = 0;
r < fused_indices.size(); ++
r) {
224 auto new_inner_To = infos[fused_indices[
r]].To;
225 inner_rets[
r] =
w.mut_con(new_inner_To)->set(
"inner_ret");
226 inner_values[
r] = inner_rets[
r]->var(0);
230 DefVec outer_inputs_vec(nis_nat);
233 for (
u64 i = 0; i < nis_nat; ++i)
234 if (infos[i].fusible)
235 outer_inputs_vec[i] = inner_values[
r++];
237 outer_inputs_vec[i] = new_in->proj(new_nis_nat, new_pos[i]);
241 for (
size_t r = 0;
r < fused_indices.size(); ++
r) {
242 auto k = fused_indices[
r];
243 auto new_inner_comb = infos[k].comb;
244 auto new_inner_init = infos[k].init;
246 DefVec inner_inputs_vec(infos[k].nis);
247 for (
u64 l = 0;
l < infos[k].nis; ++
l)
248 inner_inputs_vec[l] = new_in->proj(new_nis_nat, new_pos[k] + l);
250 Lam* caller = (
r == 0) ? new_comb : inner_rets[
r - 1];
251 caller->app(
true, new_inner_comb, {
w.tuple({new_inner_init,
w.tuple(inner_inputs_vec)}), inner_rets[r]});
255 inner_rets.back()->app(
true, comb, {
w.tuple({new_acc,
w.tuple(outer_inputs_vec)}), new_ret});
259 mr =
w.app(mr, new_nis_def);
260 mr =
w.app(mr, {To, Ro});
262 mr =
w.app(mr, {new_Tis, new_Ris, new_Sis});
263 mr =
w.app(mr, {new_comb,
init});
264 mr =
w.app(mr, new_subs);
265 mr =
w.app(mr, new_is);
272 if (
auto res = fuse_map_reduce(mr)) {
273 DLOG(
"Fused map_reduce at {} into a new map_reduce {}", app, res);
277 return RWPhase::rewrite_imm_App(app);
static auto isa(const Def *def)
auto projs(F f) const
Splits this Def via Def::projections into an Array (if A == std::dynamic_extent) or std::array (other...
static std::optional< T > isa(const Def *def)
World & new_world()
Create new Defs into this.
virtual const Def * rewrite(const Def *)
const Def * rewrite_imm_App(const App *) final
#define DLOG(...)
Vaporizes to nothingness in Debug build.
Vector< const Def * > DefVec
Vector(I, I, A=A()) -> Vector< typename std::iterator_traits< I >::value_type, Default_Inlined_Size< typename std::iterator_traits< I >::value_type >, A >