13#include "absl/container/flat_hash_map.h"
17const Def* LowerMapReduce::lower_get(
const App* app) {
22 auto [arr,
index] = arg->projs<2>();
23 auto callee =
c->as<
App>();
24 auto [T,
r,
s] = callee->args<3>();
27 DLOG(
" arr = {} : {}", arr, arr->type());
28 if (
auto arr_seq = arr->type()->isa<Seq>())
DLOG(
" arr shape = {}", arr_seq->arity());
29 DLOG(
" index = {} : {}", index,
index->type());
30 DLOG(
" T = {} : {}", T, T->type());
31 DLOG(
" r = {} : {}", r,
r->type());
32 DLOG(
" s = {} : {}", s,
s->type());
36 WLOG(
"{} doesn't have a lowering-time known rank: {}", app, r);
40 DLOG(
"index of size 1, extract");
41 return w.extract(arr, index);
44 for (
auto ri = 0_u64; ri < *r_nat; ++ri) {
46 DLOG(
" idx = {} : {}", idx,
idx->type());
47 curr_arr =
w.extract(curr_arr, idx);
52const Def* LowerMapReduce::lower_set(
const App* app) {
57 auto [arr,
index, x] = arg->projs<3>();
60 DLOG(
" arr = {} : {}", arr, arr->type());
61 DLOG(
" index = {} : {}", index,
index->type());
62 DLOG(
" x = {} : {}", x, x->type());
64 auto callee =
c->as<
App>();
65 auto [T,
r,
s] = callee->args<3>();
66 DLOG(
" T = {} : {}", T, T->type());
67 DLOG(
" r = {} : {}", r,
r->type());
68 DLOG(
" s = {} : {}", s,
s->type());
72 WLOG(
"{} doesn't have a lowering-time known rank: {}", app, r);
76 DLOG(
"index of size 1, insert");
77 return w.insert(arr, index, x);
81 DefVec arrs_to_insert_into(*r_nat);
82 arrs_to_insert_into[0] = arr;
83 for (
auto ri = 0_u64; ri < *r_nat - 1; ++ri) {
85 DLOG(
" extract idx = {} : {}", idx,
idx->type());
86 arrs_to_insert_into[ri + 1] =
w.extract(arrs_to_insert_into[ri], idx);
90 for (
auto ri =
static_cast<s64>(*r_nat - 1); ri >= 0; --ri) {
92 DLOG(
" idx = {} : {}", idx,
idx->type());
93 DLOG(
" arr_to_insert_into = {} : {}", arrs_to_insert_into[ri], arrs_to_insert_into[ri]->
type());
95 new_arr =
w.insert(arrs_to_insert_into[ri], idx, new_arr);
100const Def* LowerMapReduce::rec_broadcast(
const Def* s_in,
const Def* s_out,
const Def* input,
u64 r,
u64 i) {
103 if (i == r)
return input;
105 auto s_in_ri = s_in->proj(r, i), s_out_ri = s_out->proj(r, i);
106 DLOG(
"rec_broadcast");
109 DLOG(
" s_in_ri = {} : {}", s_in_ri, s_in_ri->type());
110 DLOG(
" s_out_ri = {} : {}", s_out_ri, s_out_ri->type());
111 DLOG(
" input = {} : {}", input, input->type());
113 if (s_in_ri == s_out_ri) {
115 DefVec inputs(*s_in_lit, [&](
size_t j) {
return rec_broadcast(s_in, s_out, input->proj(j), r, i + 1); });
116 return w.tuple(inputs);
120 WLOG(
"dimension {} of the input and output are equal but not literal: {} : {}", i, s_in_ri,
126 if (
auto s_in_lit =
Lit::isa<u64>(s_in_ri); s_in_lit && *s_in_lit == 1) {
127 DLOG(
"dimension {} of the input is 1, can be broadcasted to dimension {} of the output", i, s_out_ri);
128 return w.pack(s_out_ri, rec_broadcast(s_in, s_out, input, r, i + 1));
131 WLOG(
"cannot broadcast dimension {} of size {} to size {}", i, s_in_ri, s_out_ri);
135const Def* LowerMapReduce::lower_broadcast(
const App* app) {
138 auto arg =
rewrite(app->arg());
140 auto [s_in, s_out, input] = arg->projs<3>();
141 auto callee =
c->as<
App>();
142 auto [T,
r] = callee->args<2>();
143 DLOG(
"lower_broadcast");
144 DLOG(
" s_out = {} : {}", s_out, s_out->type());
145 DLOG(
" input = {} : {}", input, input->type());
146 DLOG(
" T = {} : {}", T, T->type());
147 DLOG(
" r = {} : {}", r,
r->type());
148 DLOG(
" s_in = {} : {}", s_in, s_in->type());
152 WLOG(
"{} doesn't have a lowering-time known rank: {}", app, r);
156 if (s_in == s_out)
return input;
160 assert(*s_in_lit == 1 &&
"input dimensions must be 1 or equal to the output dimension");
161 return w.pack(s_out, input);
165 auto result = rec_broadcast(s_in, s_out, input, *r_nat, 0);
166 DLOG(
"result of rec_broadcast = {} : {}", result, result->type());
171 auto& w = bound->
world();
172 auto acc_ty = acc->
type();
173 auto body = w.mut_con({ w.type_i64(), acc_ty, w.cn(acc_ty)})->
set(name);
174 auto for_loop = w.call<
affine::For>(body, exit,
Defs{w.lit_i64(0), bound, w.lit_i64(1), acc});
175 return {body, for_loop};
180 auto& w = S->world();
182 absl::flat_hash_map<u64, const Def*> dims;
191 w.DLOG(
"out dims (n) = {}", n_nat);
192 for (
u64 i = 0; i < n_nat; ++i) {
193 auto dim = S->proj(n_nat, i);
194 w.DLOG(
"dim {} = {}", i, dim);
196 output_dims.push_back(dim);
200 w.DLOG(
"matrix count (nis) = {}", nis_nat);
202 for (
u64 i = 0; i < nis_nat; ++i) {
203 auto ni = Ris->
proj(nis_nat, i);
205 if (!ni_lit)
error(
"matrix {} has non-constant dimension count", i);
206 u64 ni_nat = *ni_lit;
207 w.DLOG(
" dims({}) = {}", i, ni_nat);
208 auto Sis_i = Sis->
proj(nis_nat, i);
210 for (
u64 j = 0; j < ni_nat; ++j) {
211 auto dim = Sis_i->proj(ni_nat, j);
212 w.DLOG(
" dim {} {} = {}", i, j, dim);
213 input_dims_i.push_back(dim);
215 input_dims.push_back(input_dims_i);
216 n_input.push_back(ni_nat);
220 for (
u64 i = 0; i < nis_nat; ++i) {
221 w.DLOG(
"investigate {} / {}", i, nis_nat);
222 auto indices = subs->
proj(nis_nat, i);
223 w.DLOG(
" indices {} = {}", i, indices);
225 for (
u64 j = 0; j < n_input[i]; ++j) {
226 auto idx = indices->proj(n_input[i], j);
228 if (!idx_lit)
error(
"index {} {} is not a literal", i, j);
229 u64 idx_nat = *idx_lit;
230 auto dim = input_dims[i][j];
231 w.DLOG(
" index {} = {}", j, idx);
232 w.DLOG(
" dim {} = {}", idx, dim);
233 if (!dims.contains(idx_nat)) {
235 w.DLOG(
" {} ↦ {}", idx_nat, dim);
237 auto prev_dim = dims[idx_nat];
238 w.DLOG(
" prev dim {} = {}", idx_nat, prev_dim);
242 if (dim != prev_dim) {
243 if (!dim_lit)
error(
"dimension {} is not a literal", dim);
244 if (!prev_dim_lit)
error(
"previous dimension {} is not a literal", prev_dim);
245 assert(*dim_lit == *prev_dim_lit &&
"dimensions must be equal");
249 }
else if (dim != prev_dim) {
250 error(
"dimensions {} and {} must be equal", dim, prev_dim);
256 for (
auto [idx, dim] : dims) {
257 w.ILOG(
"dim {} = {}", idx, dim);
259 out_indices.push_back(idx);
261 in_indices.push_back(idx);
264 std::sort(out_indices.begin(), out_indices.end());
265 std::sort(in_indices.begin(), in_indices.end());
267 return {in_indices, out_indices, dims, n_input};
270static std::tuple<const Def*, const Def*, absl::flat_hash_map<u64, const Def*>,
Lam*>
272 auto& w = fun->
world();
275 auto cont = fun->
var(1);
276 auto current_mut = fun;
279 auto init_mat = w.bot(cont->type()->as<
Pi>()->
dom());
280 w.DLOG(
"init_mat {} : {}", init_mat, init_mat->type());
285 absl::flat_hash_map<u64, const Def*> iterator;
287 for (
auto idx : out_indices) {
288 auto for_name = w.sym(
"forIn_" + std::to_string(idx));
289 auto dim_nat_def = dims.at(idx);
291 w.DLOG(
"out_cont {} : {}", cont, cont->type());
293 auto [body, for_call] =
counting_for(dim, acc, cont, for_name);
294 auto [iter, new_acc, yield] = body->template vars<3>();
298 current_mut->set(
true, for_call);
301 return {acc, cont, iterator, current_mut};
304const Def* LowerMapReduce::lower_map_reduce(
const App* app) {
323 auto callee = c->as<
App>();
325 auto [nis, ToRo, So, TisRisSis, comb_init, subs] = callee->
uncurry_args<6>();
327 auto [comb, zero] = comb_init->
projs<2>();
328 auto [Tis, Ris, Sis] = TisRisSis->
projs<3>();
329 auto [T, n] = ToRo->
projs<2>();
331 DLOG(
"lower map_reduce");
332 DLOG(
"type : {}", type);
333 DLOG(
"meta variables:");
337 DLOG(
" nis = {}", nis);
338 DLOG(
" Ris = {} : {}", Ris, Ris->type());
339 DLOG(
" Tis = {} : {}", Tis, Tis->type());
340 DLOG(
" Sis = {} : {}", Sis, Sis->type());
342 DLOG(
" zero = {}", zero);
343 DLOG(
" comb = {} : {}", comb, comb->type());
344 DLOG(
" subs = {} : {}", subs, subs->type());
345 DLOG(
" inputs = {} : {}", inputs, inputs->type());
363 if (!n_lit || !nis_lit) {
364 DLOG(
"n or nis is not a literal");
369 auto nis_nat = *nis_lit;
373 auto [in_indices, out_indices, dims, n_input] =
extract_indices(n_nat, nis_nat, So, Ris, Sis, subs);
375 for (
auto idx : out_indices)
376 ILOG(
"output index {} with dim {}", idx, dims[idx]);
377 for (
auto idx : in_indices)
378 ILOG(
"input index {} with dim {}", idx, dims[idx]);
380 auto fun = w.mut_fun(inputs->type(), type)->set(
"mapRed");
381 DLOG(
"fun {} : {}", fun, fun->type());
384 DLOG(
"ds_fun {} : {}", ds_fun, ds_fun->type());
385 auto call = w.app(ds_fun, inputs)->set(
"call");
386 DLOG(
"call {} : {}", call, call->type());
388 auto new_inputs = fun->var(0)->set(
"is");
390 DLOG(
"inputs = {} : {}", inputs, inputs->type());
391 DLOG(
"new_inputs = {} : {}", new_inputs, new_inputs->type());
418 auto [wb_matrix, cont, iterator, current_mut] =
create_outer_loop(fun, out_indices, dims);
424 auto element_acc = zero;
425 element_acc->set(
"acc");
427 DLOG(
"wb_matrix {} : {}", wb_matrix, wb_matrix->type());
430 auto write_back = w.mut_con(T)->set(
"matrixWriteBack");
431 DLOG(
"write_back {} : {}", write_back, write_back->type());
432 auto element_final = write_back->var(0);
435 for (
u64 i = 0; i < n_nat; ++i) {
436 auto idx = out_indices[i];
437 if (idx != i)
error(
"output indices must be consecutive 0..n-1 but {} != {}", idx, i);
440 DLOG(
"dimension {} is 1, no iterator needed", idx);
444 output_iterators.push_back(iterator[idx]);
447 u64 n_oi = output_iterators.size();
448 DefVec output_submatrices;
449 output_submatrices.reserve(n_oi);
450 output_submatrices.push_back(wb_matrix);
451 for (
u64 i = 0; i + 1 < n_oi; ++i)
452 output_submatrices.push_back(
w.extract(output_submatrices[i], output_iterators[i]));
454 auto written_matrix = element_final;
455 for (
u64 i = 1; i <= n_oi; ++i)
456 written_matrix =
w.insert(output_submatrices[n_oi - i], output_iterators[n_oi - i], written_matrix);
458 DLOG(
"written_matrix {} : {}", written_matrix, written_matrix->type());
459 write_back->app(
true, cont, written_matrix);
462 auto acc = element_acc;
465 for (
auto idx : in_indices) {
466 auto for_name =
w.sym(
"forIn_" + std::to_string(idx));
467 auto dim_nat_def = dims[
idx];
469 DLOG(
"in_cont {} : {}", cont, cont->type());
471 auto [body, for_call] =
counting_for(dim, acc, cont, for_name);
472 auto [iter, new_acc, yield] = body->vars<3>();
476 current_mut->set(
true, for_call);
482 DefVec input_elements((
size_t)nis_nat);
483 for (
u64 i = 0; i < nis_nat; i++) {
484 auto input_idx_tup = subs->proj(nis_nat, i);
485 auto input_matrix = new_inputs->proj(nis_nat, i);
487 DLOG(
"input matrix {} is {} : {}", i, input_matrix, input_matrix->type());
489 auto indices = input_idx_tup->projs(n_input[i]);
490 auto input_iterators =
DefVec(n_input[i], [&](
u64 j) {
491 auto idx = indices[j];
493 DLOG(
" idx {} {} = {}", i, j, *idx_lit);
494 return iterator[*idx_lit];
497 auto curr_mat = input_matrix;
498 for (
auto idx : input_iterators)
499 curr_mat =
w.extract(curr_mat, idx);
501 DLOG(
"read_entry {} : {}", curr_mat, curr_mat->type());
502 auto element_i = curr_mat;
503 input_elements[i] = element_i;
506 DLOG(
" read elements {}", fe::Join(input_elements));
507 DLOG(
" fun {} : {}", fun, fun->type());
508 DLOG(
" current_mut {} : {}", current_mut, current_mut->type());
513 current_mut->app(
true, comb, {
w.tuple({element_acc,
w.tuple(input_elements)}), cont});
514 DLOG(
"final call {} : {}", call, call->type());
516 }
catch (
const std::exception& e) {
517 ELOG(
"error during lowering map_reduce: {}",
e.what());
522const Def* LowerMapReduce::lower_map_reduce_aff(
const App* app) {
538 auto inputs =
rewrite(app->arg());
541 auto [nis, meta, shapes, TisRisSis, comb_init, acc_out, accs] =
c->uncurry_args<7>();
542 auto [To, Ro, Rr] = meta->projs<3>();
543 auto [So, Sr] = shapes->projs<2>();
544 auto [Tis, Ris, Sis] = TisRisSis->projs<3>();
545 auto [comb,
init] = comb_init->projs<2>();
549 if (!nis_l || !ro_l || !rr_l) {
550 WLOG(
"{} doesn't have lowering-time known rank counts (nis/Ro/Rr)", app);
553 auto nis_nat = *nis_l;
554 auto ro = *ro_l, rr = *rr_l;
555 auto nloops = ro + rr;
556 auto n =
w.lit_nat(nloops);
560 for (
u64 i = 0; i < nis_nat; ++i) {
563 WLOG(
"input {} of {} has a non-literal rank", i, app);
571 auto affine_map = [&](
const Def*
f,
const Def* m,
const Def* n,
const Def*
sin,
const Def* sout,
const Def* idxs) {
573 a =
w.app(a,
w.tuple({sin, sout}));
575 return w.app(a, idxs);
577 auto nested_extract = [&](
const Def* matrix,
const Def* coords,
u64 r) {
579 for (
u64 k = 0; k <
r; ++k)
580 cur =
w.extract(cur, coords->proj(r, k));
583 auto nested_insert = [&](
const Def* matrix,
const Def* coords,
u64 r,
const Def* elem) ->
const Def* {
584 if (r == 0)
return elem;
587 for (
u64 k = 0; k + 1 <
r; ++k)
588 subs[k + 1] =
w.extract(subs[k], coords->proj(r, k));
590 for (
auto k =
static_cast<s64>(r) - 1; k >= 0; --k)
591 cur =
w.insert(subs[k], coords->proj(r, k), cur);
596 auto fun =
w.mut_fun(inputs->type(), type)->set(
"mapRedAff");
598 auto call =
w.app(ds_fun, inputs)->set(
"call");
600 auto new_inputs = fun->var(0)->set(
"is");
603 auto cont = fun->var(1);
604 auto init_mat =
w.bot(cont->type()->as<
Pi>()->dom());
606 auto current_mut = fun;
608 out_iters.reserve(ro);
609 for (
u64 i = 0; i < ro; ++i) {
610 auto dim = Sr->proj(nloops, i);
612 auto [body, for_call] =
counting_for(bound, acc, cont,
w.sym(
"forOut_" + std::to_string(i)));
613 auto [iter, new_acc, yield] = body->vars<3>();
617 current_mut->set(
true, for_call);
620 auto wb_matrix = acc;
625 auto write_back =
w.mut_con(To)->set(
"writeBack");
626 auto element_final = write_back->var(0);
627 DefVec wb_iters = out_iters;
628 for (
u64 j = 0; j < rr; ++j)
629 wb_iters.push_back(
w.call(
core::conv::u, Sr->proj(nloops, ro + j),
w.lit(
w.type_i64(), 0)));
630 auto write_coords = affine_map(acc_out, Ro, n, Sr, So,
w.tuple(wb_iters));
631 write_back->app(
true, cont, nested_insert(wb_matrix, write_coords, ro, element_final));
637 red_iters.reserve(rr);
638 for (
u64 j = 0; j < rr; ++j) {
639 auto dim = Sr->proj(nloops, ro + j);
641 auto [body, for_call] =
counting_for(bound, acc, cont,
w.sym(
"forIn_" + std::to_string(j)));
642 auto [iter, new_acc, yield] = body->vars<3>();
646 current_mut->set(
true, for_call);
649 auto element_acc = acc;
652 DefVec iters_v = out_iters;
653 iters_v.insert(iters_v.end(), red_iters.begin(), red_iters.end());
654 auto iters =
w.tuple(iters_v);
657 DefVec input_elements(nis_nat);
658 for (
u64 i = 0; i < nis_nat; ++i) {
659 auto input_matrix = new_inputs->proj(nis_nat, i);
661 = affine_map(accs->proj(nis_nat, i), Ris->proj(nis_nat, i), n, Sr, Sis->proj(nis_nat, i), iters);
662 input_elements[i] = nested_extract(input_matrix, coords, ris_nat[i]);
666 current_mut->app(
true, comb, {
w.tuple({element_acc,
w.tuple(input_elements)}), cont});
668 }
catch (
const std::exception& e) {
669 error(
"error during lowering map_reduce_aff: {}",
e.what());
675 if (
auto res = lower_get(
get))
return res;
677 if (
auto res = lower_set(
set))
return res;
679 if (
auto res = lower_broadcast(bc))
return res;
683 if (
auto res = lower_map_reduce_aff(mra))
return res;
685 return RWPhase::rewrite_imm_App(app);
const Def * callee() const
static auto uncurry_args(const Def *def)
static auto isa(const Def *def)
const Def * proj(nat_t a, nat_t i) const
Similar to World::extract while assuming an arity of a, but also works on Sigmas and Arrays.
Def * set(size_t i, const Def *)
Successively set from left to right.
World & world() const noexcept
const Def * var(nat_t a, nat_t i) noexcept
auto projs(F f) const
Splits this Def via Def::projections into an Array (if A == std::dynamic_extent) or std::array (other...
const Def * type() const noexcept
Yields the "raw" type of this Def (maybe nullptr).
static std::optional< T > isa(const Def *def)
A dependent function type.
World & new_world()
Create new Defs into this.
virtual const Def * rewrite(const Def *)
This is a thin wrapper for absl::InlinedVector<T, N, A> which is a drop-in replacement for std::vecto...
const Def * rewrite_imm_App(const App *) final
#define DLOG(...)
Vaporizes to nothingness in Debug build.
const Def * op_cps2ds_dep(const Def *k)
static std::tuple< const Def *, const Def *, absl::flat_hash_map< u64, const Def * >, Lam * > create_outer_loop(Lam *fun, const Vector< u64 > &out_indices, const absl::flat_hash_map< u64, const Def * > &dims)
static std::pair< Lam *, const Def * > counting_for(const Def *bound, const Def *acc, const Def *exit, Sym name)
static std::tuple< Vector< u64 >, Vector< u64 >, absl::flat_hash_map< u64, const Def * >, Vector< u64 > > extract_indices(const u64 n_nat, const u64 nis_nat, const Def *S, const Def *Ris, const Def *Sis, const Def *subs)
Vector< const Def * > DefVec
void error(std::format_string< Args... > fmt, Args &&... args)
Wraps std::format to throw T with a formatted message.
Vector(I, I, A=A()) -> Vector< typename std::iterator_traits< I >::value_type, Default_Inlined_Size< typename std::iterator_traits< I >::value_type >, A >