MimIR
MimIR is my Intermediate Representation
Loading...
Searching...
No Matches
lower_map_reduce.cpp
Go to the documentation of this file.
2
3#include "mim/def.h"
4#include "mim/lam.h"
5
6#include "mim/util/types.h"
7
12
13#include "absl/container/flat_hash_map.h"
14
16
17const Def* LowerMapReduce::lower_get(const App* app) {
18 auto& w = new_world();
19 auto c = rewrite(app->callee());
20 auto arg = rewrite(app->arg());
21
22 auto [arr, index] = arg->projs<2>();
23 auto callee = c->as<App>();
24 auto [T, r, s] = callee->args<3>();
25
26 DLOG("lower_get");
27 DLOG(" arr = {} : {}", arr, arr->type());
28 if (auto arr_seq = arr->type()->isa<Seq>()) DLOG(" arr shape = {}", arr_seq->arity());
29 DLOG(" index = {} : {}", index, index->type());
30 DLOG(" T = {} : {}", T, T->type());
31 DLOG(" r = {} : {}", r, r->type());
32 DLOG(" s = {} : {}", s, s->type());
33
34 auto r_nat = Lit::isa<u64>(r);
35 if (!r_nat) {
36 WLOG("{} doesn't have a lowering-time known rank: {}", app, r);
37 return nullptr;
38 }
39 if (r_nat == 1) {
40 DLOG("index of size 1, extract");
41 return w.extract(arr, index);
42 }
43 auto curr_arr = arr;
44 for (auto ri = 0_u64; ri < *r_nat; ++ri) {
45 auto idx = index->proj(*r_nat, ri);
46 DLOG(" idx = {} : {}", idx, idx->type());
47 curr_arr = w.extract(curr_arr, idx);
48 }
49 return curr_arr;
50}
51
52const Def* LowerMapReduce::lower_set(const App* app) {
53 auto& w = new_world();
54 auto c = rewrite(app->callee());
55 auto arg = rewrite(app->arg());
56
57 auto [arr, index, x] = arg->projs<3>();
58
59 DLOG("lower_set");
60 DLOG(" arr = {} : {}", arr, arr->type());
61 DLOG(" index = {} : {}", index, index->type());
62 DLOG(" x = {} : {}", x, x->type());
63
64 auto callee = c->as<App>();
65 auto [T, r, s] = callee->args<3>();
66 DLOG(" T = {} : {}", T, T->type());
67 DLOG(" r = {} : {}", r, r->type());
68 DLOG(" s = {} : {}", s, s->type());
69
70 auto r_nat = Lit::isa<u64>(r);
71 if (!r_nat) {
72 WLOG("{} doesn't have a lowering-time known rank: {}", app, r);
73 return nullptr;
74 }
75 if (r_nat == 1) {
76 DLOG("index of size 1, insert");
77 return w.insert(arr, index, x);
78 }
79
80 // r_nat will never be 0, as we would have normalized this case away already
81 DefVec arrs_to_insert_into(*r_nat);
82 arrs_to_insert_into[0] = arr;
83 for (auto ri = 0_u64; ri < *r_nat - 1; ++ri) {
84 auto idx = index->proj(*r_nat, ri);
85 DLOG(" extract idx = {} : {}", idx, idx->type());
86 arrs_to_insert_into[ri + 1] = w.extract(arrs_to_insert_into[ri], idx);
87 }
88
89 auto new_arr = x;
90 for (auto ri = static_cast<s64>(*r_nat - 1); ri >= 0; --ri) {
91 auto idx = index->proj(*r_nat, ri);
92 DLOG(" idx = {} : {}", idx, idx->type());
93 DLOG(" arr_to_insert_into = {} : {}", arrs_to_insert_into[ri], arrs_to_insert_into[ri]->type());
94
95 new_arr = w.insert(arrs_to_insert_into[ri], idx, new_arr);
96 }
97 return new_arr;
98}
99
100const Def* LowerMapReduce::rec_broadcast(const Def* s_in, const Def* s_out, const Def* input, u64 r, u64 i) {
101 auto& w = new_world();
102 // Base case: all dimensions have been processed; `input` is the final scalar.
103 if (i == r) return input;
104
105 auto s_in_ri = s_in->proj(r, i), s_out_ri = s_out->proj(r, i);
106 DLOG("rec_broadcast");
107 DLOG(" r = {}", r);
108 DLOG(" i = {}", i);
109 DLOG(" s_in_ri = {} : {}", s_in_ri, s_in_ri->type());
110 DLOG(" s_out_ri = {} : {}", s_out_ri, s_out_ri->type());
111 DLOG(" input = {} : {}", input, input->type());
112
113 if (s_in_ri == s_out_ri) {
114 if (auto s_in_lit = Lit::isa<u64>(s_in_ri)) {
115 DefVec inputs(*s_in_lit, [&](size_t j) { return rec_broadcast(s_in, s_out, input->proj(j), r, i + 1); });
116 return w.tuple(inputs);
117 } else {
118 // TODO: we could probably support non-literal sizes as well, but we would need to generate loops to copy
119 // the data instead of just packing it.
120 WLOG("dimension {} of the input and output are equal but not literal: {} : {}", i, s_in_ri,
121 s_in_ri->type());
122 return nullptr;
123 }
124 }
125
126 if (auto s_in_lit = Lit::isa<u64>(s_in_ri); s_in_lit && *s_in_lit == 1) {
127 DLOG("dimension {} of the input is 1, can be broadcasted to dimension {} of the output", i, s_out_ri);
128 return w.pack(s_out_ri, rec_broadcast(s_in, s_out, input, r, i + 1));
129 }
130
131 WLOG("cannot broadcast dimension {} of size {} to size {}", i, s_in_ri, s_out_ri);
132 return nullptr;
133}
134
135const Def* LowerMapReduce::lower_broadcast(const App* app) {
136 auto& w = new_world();
137 auto c = rewrite(app->callee());
138 auto arg = rewrite(app->arg());
139
140 auto [s_in, s_out, input] = arg->projs<3>();
141 auto callee = c->as<App>();
142 auto [T, r] = callee->args<2>();
143 DLOG("lower_broadcast");
144 DLOG(" s_out = {} : {}", s_out, s_out->type());
145 DLOG(" input = {} : {}", input, input->type());
146 DLOG(" T = {} : {}", T, T->type());
147 DLOG(" r = {} : {}", r, r->type());
148 DLOG(" s_in = {} : {}", s_in, s_in->type());
149
150 auto r_nat = Lit::isa<u64>(r);
151 if (!r_nat) {
152 WLOG("{} doesn't have a lowering-time known rank: {}", app, r);
153 return nullptr;
154 }
155 // r_nat will never be 0, as we would have normalized this case away already
156 if (s_in == s_out) return input;
157
158 if (*r_nat == 1) {
159 if (auto s_in_lit = Lit::isa<u64>(s_in)) {
160 assert(*s_in_lit == 1 && "input dimensions must be 1 or equal to the output dimension");
161 return w.pack(s_out, input);
162 }
163 }
164
165 auto result = rec_broadcast(s_in, s_out, input, *r_nat, 0);
166 DLOG("result of rec_broadcast = {} : {}", result, result->type());
167 return result;
168}
169
170static std::pair<Lam*, const Def*> counting_for(const Def* bound, const Def* acc, const Def* exit, Sym name) {
171 auto& w = bound->world();
172 auto acc_ty = acc->type();
173 auto body = w.mut_con({/* iter */ w.type_i64(), /* acc */ acc_ty, /* return */ w.cn(acc_ty)})->set(name);
174 auto for_loop = w.call<affine::For>(body, exit, Defs{w.lit_i64(0), bound, w.lit_i64(1), acc});
175 return {body, for_loop};
176}
177
178static std::tuple<Vector<u64>, Vector<u64>, absl::flat_hash_map<u64, const Def*>, Vector<u64>>
179extract_indices(const u64 n_nat, const u64 nis_nat, const Def* S, const Def* Ris, const Def* Sis, const Def* subs) {
180 auto& w = S->world();
181
182 absl::flat_hash_map<u64, const Def*> dims; // idx ↦ nat (size bound = dimension)
183 Vector<u64> out_indices; // output indices 0..n-1
184 Vector<u64> in_indices; // input indices ≥ n
185
186 Vector<const Def*> output_dims; // i<n ↦ nat (dimension S#i)
187 Vector<DefVec> input_dims; // i<nis ↦ j<Ris#i ↦ nat (dimension Sis#i#j)
188 Vector<u64> n_input; // i<nis ↦ nat (number of dimensions of Sis#i)
189
190 // collect output dimensions
191 w.DLOG("out dims (n) = {}", n_nat);
192 for (u64 i = 0; i < n_nat; ++i) {
193 auto dim = S->proj(n_nat, i);
194 w.DLOG("dim {} = {}", i, dim);
195 dims[i] = dim;
196 output_dims.push_back(dim);
197 }
198
199 // collect other (input) dimensions
200 w.DLOG("matrix count (nis) = {}", nis_nat);
201
202 for (u64 i = 0; i < nis_nat; ++i) {
203 auto ni = Ris->proj(nis_nat, i);
204 auto ni_lit = Lit::isa(ni);
205 if (!ni_lit) error("matrix {} has non-constant dimension count", i);
206 u64 ni_nat = *ni_lit;
207 w.DLOG(" dims({}) = {}", i, ni_nat);
208 auto Sis_i = Sis->proj(nis_nat, i);
209 DefVec input_dims_i;
210 for (u64 j = 0; j < ni_nat; ++j) {
211 auto dim = Sis_i->proj(ni_nat, j);
212 w.DLOG(" dim {} {} = {}", i, j, dim);
213 input_dims_i.push_back(dim);
214 }
215 input_dims.push_back(input_dims_i);
216 n_input.push_back(ni_nat);
217 }
218
219 // extracts bounds for each index (in, out)
220 for (u64 i = 0; i < nis_nat; ++i) {
221 w.DLOG("investigate {} / {}", i, nis_nat);
222 auto indices = subs->proj(nis_nat, i);
223 w.DLOG(" indices {} = {}", i, indices);
224
225 for (u64 j = 0; j < n_input[i]; ++j) {
226 auto idx = indices->proj(n_input[i], j);
227 auto idx_lit = Lit::isa(idx);
228 if (!idx_lit) error("index {} {} is not a literal", i, j);
229 u64 idx_nat = *idx_lit;
230 auto dim = input_dims[i][j];
231 w.DLOG(" index {} = {}", j, idx);
232 w.DLOG(" dim {} = {}", idx, dim);
233 if (!dims.contains(idx_nat)) {
234 dims[idx_nat] = dim;
235 w.DLOG(" {} ↦ {}", idx_nat, dim);
236 } else {
237 auto prev_dim = dims[idx_nat];
238 w.DLOG(" prev dim {} = {}", idx_nat, prev_dim);
239 // override with more precise information
240 if (auto dim_lit = Lit::isa<u64>(dim)) {
241 if (auto prev_dim_lit = Lit::isa<u64>(prev_dim)) {
242 if (dim != prev_dim) {
243 if (!dim_lit) error("dimension {} is not a literal", dim);
244 if (!prev_dim_lit) error("previous dimension {} is not a literal", prev_dim);
245 assert(*dim_lit == *prev_dim_lit && "dimensions must be equal");
246 }
247 } else
248 dims[idx_nat] = dim;
249 } else if (dim != prev_dim) {
250 error("dimensions {} and {} must be equal", dim, prev_dim);
251 }
252 }
253 }
254 }
255
256 for (auto [idx, dim] : dims) {
257 w.ILOG("dim {} = {}", idx, dim);
258 if (idx < n_nat)
259 out_indices.push_back(idx);
260 else
261 in_indices.push_back(idx);
262 }
263 // sort indices to make checks easier later.
264 std::sort(out_indices.begin(), out_indices.end());
265 std::sort(in_indices.begin(), in_indices.end());
266
267 return {in_indices, out_indices, dims, n_input};
268}
269
270static std::tuple<const Def*, const Def*, absl::flat_hash_map<u64, const Def*>, Lam*>
271create_outer_loop(Lam* fun, const Vector<u64>& out_indices, const absl::flat_hash_map<u64, const Def*>& dims) {
272 auto& w = fun->world();
273
274 // The function on where to continue -- return after all output loops.
275 auto cont = fun->var(1);
276 auto current_mut = fun;
277
278 // First create the output matrix.
279 auto init_mat = w.bot(cont->type()->as<Pi>()->dom());
280 w.DLOG("init_mat {} : {}", init_mat, init_mat->type());
281
282 // Each of the outer loops contains the memory and matrix as accumulator (in an inner monad).
283 auto acc = init_mat;
284
285 absl::flat_hash_map<u64, const Def*> iterator; // idx ↦ %Idx (S/NI#i)
286
287 for (auto idx : out_indices) {
288 auto for_name = w.sym("forIn_" + std::to_string(idx));
289 auto dim_nat_def = dims.at(idx);
290 auto dim = w.call<core::bitcast>(w.type_i64(), dim_nat_def);
291 w.DLOG("out_cont {} : {}", cont, cont->type());
292
293 auto [body, for_call] = counting_for(dim, acc, cont, for_name);
294 auto [iter, new_acc, yield] = body->template vars<3>();
295 cont = yield;
296 iterator[idx] = w.call(core::conv::u, dim_nat_def, iter);
297 acc = new_acc;
298 current_mut->set(true, for_call);
299 current_mut = body;
300 }
301 return {acc, cont, iterator, current_mut};
302}
303
304const Def* LowerMapReduce::lower_map_reduce(const App* app) {
305 // meta arguments:
306 // * n = out-count, (nat)
307 // * S = out-dim, (n*nat)
308 // * T = out-type (*)
309 // * nis = in-count (nat)
310 // * Ris = in-dim-count (nis*nat)
311 // * Tis = types (nis**)
312 // * Sis = dimensions (nis*Ris#i)
313 // arguments:
314 // * mem
315 // * zero = accumulator init (T)
316 // * combination function (mem, acc, inputs) -> (mem, acc)
317 // * input matrixes
318
319 auto& w = new_world();
320 auto c = rewrite(app->callee());
321 auto inputs = rewrite(app->arg());
322 auto type = rewrite(app->type());
323 auto callee = c->as<App>();
324
325 auto [nis, ToRo, So, TisRisSis, comb_init, subs] = callee->uncurry_args<6>();
326
327 auto [comb, zero] = comb_init->projs<2>();
328 auto [Tis, Ris, Sis] = TisRisSis->projs<3>();
329 auto [T, n] = ToRo->projs<2>();
330
331 DLOG("lower map_reduce");
332 DLOG("type : {}", type);
333 DLOG("meta variables:");
334 DLOG(" n = {}", n);
335 DLOG(" S = {}", So);
336 DLOG(" T = {}", T);
337 DLOG(" nis = {}", nis);
338 DLOG(" Ris = {} : {}", Ris, Ris->type());
339 DLOG(" Tis = {} : {}", Tis, Tis->type());
340 DLOG(" Sis = {} : {}", Sis, Sis->type());
341 DLOG("arguments:");
342 DLOG(" zero = {}", zero);
343 DLOG(" comb = {} : {}", comb, comb->type());
344 DLOG(" subs = {} : {}", subs, subs->type());
345 DLOG(" inputs = {} : {}", inputs, inputs->type());
346
347 // Our goal is to generate a call to a function that performs:
348 // ```
349 // matrix = new matrix (n, S, T)
350 // for out_idx { // n for loops
351 // acc = zero
352 // for in_idx { // remaining loops
353 // inps = read from matrices // nis-tuple
354 // acc = comb(mem, acc, inps)
355 // }
356 // write acc to output matrix
357 // }
358 // return matrix
359 // ```
360
361 auto n_lit = Lit::isa<u64>(n);
362 auto nis_lit = Lit::isa<u64>(nis);
363 if (!n_lit || !nis_lit) {
364 DLOG("n or nis is not a literal");
365 return nullptr;
366 }
367
368 auto n_nat = *n_lit; // number of output dimensions (in S)
369 auto nis_nat = *nis_lit; // number of input matrices
370
371 try {
372 // out-indices are loops (potentially parallel) over the output tensor, in-indices are reductions
373 auto [in_indices, out_indices, dims, n_input] = extract_indices(n_nat, nis_nat, So, Ris, Sis, subs);
374
375 for (auto idx : out_indices)
376 ILOG("output index {} with dim {}", idx, dims[idx]);
377 for (auto idx : in_indices)
378 ILOG("input index {} with dim {}", idx, dims[idx]);
379
380 auto fun = w.mut_fun(inputs->type(), type)->set("mapRed");
381 DLOG("fun {} : {}", fun, fun->type());
382
383 auto ds_fun = direct::op_cps2ds_dep(fun)->set("dsFun");
384 DLOG("ds_fun {} : {}", ds_fun, ds_fun->type());
385 auto call = w.app(ds_fun, inputs)->set("call");
386 DLOG("call {} : {}", call, call->type());
387
388 auto new_inputs = fun->var(0)->set("is");
389
390 DLOG("inputs = {} : {}", inputs, inputs->type());
391 DLOG("new_inputs = {} : {}", new_inputs, new_inputs->type());
392
393 // flowchart:
394 // ```
395 // -> init
396 // -> forOut1 with yieldOut1
397 // => exitOut1 = return_cont
398 // -> forOut2 with yieldOut2
399 // => exitOut2 = yieldOut1
400 // -> ...
401 // -> accumulator init
402 // -> forIn1 with yieldIn1
403 // => exitIn1 = writeCont
404 // -> forIn2 with yieldIn2
405 // => exitIn2 = yieldIn1
406 // -> ...
407 // -> read matrices
408 // -> fun
409 // => exitFun = yieldInM
410 //
411 // (return path)
412 // -> ...
413 // -> write
414 // -> yieldOutN
415 // -> ...
416 // ```
417
418 auto [wb_matrix, cont, iterator, current_mut] = create_outer_loop(fun, out_indices, dims);
419
420 // Now the inner loops for the inputs:
421 // Each of the inner loops contains the element accumulator and memory as accumulator (in an inner monad).
422
423 // First create the accumulator.
424 auto element_acc = zero;
425 element_acc->set("acc");
426 assert(wb_matrix);
427 DLOG("wb_matrix {} : {}", wb_matrix, wb_matrix->type());
428
429 // Write back element to matrix. Set this as return after all inner loops.
430 auto write_back = w.mut_con(T)->set("matrixWriteBack");
431 DLOG("write_back {} : {}", write_back, write_back->type());
432 auto element_final = write_back->var(0);
433
434 DefVec output_iterators;
435 for (u64 i = 0; i < n_nat; ++i) {
436 auto idx = out_indices[i];
437 if (idx != i) error("output indices must be consecutive 0..n-1 but {} != {}", idx, i);
438 if (auto dim_lit = Lit::isa<u64>(dims[idx])) {
439 if (*dim_lit == 1) {
440 DLOG("dimension {} is 1, no iterator needed", idx);
441 continue;
442 }
443 }
444 output_iterators.push_back(iterator[idx]);
445 }
446
447 u64 n_oi = output_iterators.size();
448 DefVec output_submatrices;
449 output_submatrices.reserve(n_oi);
450 output_submatrices.push_back(wb_matrix);
451 for (u64 i = 0; i + 1 < n_oi; ++i)
452 output_submatrices.push_back(w.extract(output_submatrices[i], output_iterators[i]));
453
454 auto written_matrix = element_final;
455 for (u64 i = 1; i <= n_oi; ++i)
456 written_matrix = w.insert(output_submatrices[n_oi - i], output_iterators[n_oi - i], written_matrix);
457
458 DLOG("written_matrix {} : {}", written_matrix, written_matrix->type());
459 write_back->app(true, cont, written_matrix);
460
461 // From here on the continuations take the element and memory.
462 auto acc = element_acc;
463 cont = write_back;
464
465 for (auto idx : in_indices) {
466 auto for_name = w.sym("forIn_" + std::to_string(idx));
467 auto dim_nat_def = dims[idx];
468 auto dim = w.call<core::bitcast>(w.type_i64(), dim_nat_def);
469 DLOG("in_cont {} : {}", cont, cont->type());
470
471 auto [body, for_call] = counting_for(dim, acc, cont, for_name);
472 auto [iter, new_acc, yield] = body->vars<3>();
473 cont = yield;
474 iterator[idx] = w.call(core::conv::u, dim_nat_def, iter);
475 acc = new_acc;
476 current_mut->set(true, for_call);
477 current_mut = body;
478 }
479 element_acc = acc;
480
481 // Read element from input matrix.
482 DefVec input_elements((size_t)nis_nat);
483 for (u64 i = 0; i < nis_nat; i++) {
484 auto input_idx_tup = subs->proj(nis_nat, i);
485 auto input_matrix = new_inputs->proj(nis_nat, i);
486
487 DLOG("input matrix {} is {} : {}", i, input_matrix, input_matrix->type());
488
489 auto indices = input_idx_tup->projs(n_input[i]);
490 auto input_iterators = DefVec(n_input[i], [&](u64 j) {
491 auto idx = indices[j];
492 auto idx_lit = Lit::isa<u64>(idx);
493 DLOG(" idx {} {} = {}", i, j, *idx_lit);
494 return iterator[*idx_lit];
495 });
496
497 auto curr_mat = input_matrix;
498 for (auto idx : input_iterators)
499 curr_mat = w.extract(curr_mat, idx);
500
501 DLOG("read_entry {} : {}", curr_mat, curr_mat->type());
502 auto element_i = curr_mat;
503 input_elements[i] = element_i;
504 }
505
506 DLOG(" read elements {}", fe::Join(input_elements));
507 DLOG(" fun {} : {}", fun, fun->type());
508 DLOG(" current_mut {} : {}", current_mut, current_mut->type());
509
510 comb->set("comb");
511
512 // TODO: make non-scalar or completely scalar?
513 current_mut->app(true, comb, {w.tuple({element_acc, w.tuple(input_elements)}), cont});
514 DLOG("final call {} : {}", call, call->type());
515 return call;
516 } catch (const std::exception& e) {
517 ELOG("error during lowering map_reduce: {}", e.what());
518 return nullptr;
519 }
520}
521
522const Def* LowerMapReduce::lower_map_reduce_aff(const App* app) {
523 // meta arguments:
524 // * nis = in-count (nat)
525 // * To = out-type (*), Ro = #output loops = result rank, Rr = #reduction loops
526 // * So = result shape (Ro*nat)
527 // * Sr = the full loop bounds (Ro+Rr)*nat: the leading Ro are the output-loop bounds, the trailing Rr the
528 // reductions
529 // * Tis/Ris/Sis = input types/ranks/shapes
530 // arguments:
531 // * f = combination function (CPS), init = accumulator init
532 // * acc_out = affine map from the (Ro+Rr) loop vector to the Ro write coordinates in the result «So» (the reduction
533 // part is not in scope at write-back, so acc_out must depend only on the leading Ro output indices)
534 // * accs = per-input affine map from the (Ro+Rr) loop vector to the input's read coordinates
535 // * is = input tensors
536 auto& w = new_world();
537 auto c = rewrite(app->callee())->as<App>();
538 auto inputs = rewrite(app->arg());
539 auto type = rewrite(app->type());
540
541 auto [nis, meta, shapes, TisRisSis, comb_init, acc_out, accs] = c->uncurry_args<7>();
542 auto [To, Ro, Rr] = meta->projs<3>();
543 auto [So, Sr] = shapes->projs<2>();
544 auto [Tis, Ris, Sis] = TisRisSis->projs<3>();
545 auto [comb, init] = comb_init->projs<2>();
546
547 auto nis_l = Lit::isa<u64>(nis);
548 auto ro_l = Lit::isa<u64>(Ro), rr_l = Lit::isa<u64>(Rr);
549 if (!nis_l || !ro_l || !rr_l) {
550 WLOG("{} doesn't have lowering-time known rank counts (nis/Ro/Rr)", app);
551 return nullptr;
552 }
553 auto nis_nat = *nis_l;
554 auto ro = *ro_l, rr = *rr_l;
555 auto nloops = ro + rr; // length of the full loop vector (= length of Sr)
556 auto n = w.lit_nat(nloops); // passed as the affine maps' domain length
557
558 // ranks of each input must be literal so that we know how many `extract`s to emit
559 Vector<u64> ris_nat(nis_nat);
560 for (u64 i = 0; i < nis_nat; ++i) {
561 auto l = Lit::isa<u64>(Ris->proj(nis_nat, i));
562 if (!l) {
563 WLOG("input {} of {} has a non-literal rank", i, app);
564 return nullptr;
565 }
566 ris_nat[i] = *l;
567 }
568
569 // Builds `%affine.map @(m, n) @(sin, sout) f idxs`. The emitted `%affine.map` is lowered to %core arithmetic by the
570 // subsequent %affine.lower_index_phase.
571 auto affine_map = [&](const Def* f, const Def* m, const Def* n, const Def* sin, const Def* sout, const Def* idxs) {
572 auto a = w.app(w.annex<affine::map>(), w.tuple({m, n}));
573 a = w.app(a, w.tuple({sin, sout}));
574 a = w.app(a, f);
575 return w.app(a, idxs);
576 };
577 auto nested_extract = [&](const Def* matrix, const Def* coords, u64 r) {
578 auto cur = matrix;
579 for (u64 k = 0; k < r; ++k)
580 cur = w.extract(cur, coords->proj(r, k));
581 return cur;
582 };
583 auto nested_insert = [&](const Def* matrix, const Def* coords, u64 r, const Def* elem) -> const Def* {
584 if (r == 0) return elem;
585 DefVec subs(r);
586 subs[0] = matrix;
587 for (u64 k = 0; k + 1 < r; ++k)
588 subs[k + 1] = w.extract(subs[k], coords->proj(r, k));
589 auto cur = elem;
590 for (auto k = static_cast<s64>(r) - 1; k >= 0; --k)
591 cur = w.insert(subs[k], coords->proj(r, k), cur);
592 return cur;
593 };
594
595 try {
596 auto fun = w.mut_fun(inputs->type(), type)->set("mapRedAff");
597 auto ds_fun = direct::op_cps2ds_dep(fun)->set("dsFun");
598 auto call = w.app(ds_fun, inputs)->set("call");
599
600 auto new_inputs = fun->var(0)->set("is");
601
602 // Outer (parallel) loops over the leading Ro bounds of `Sr`, collecting the output iteration indices.
603 auto cont = fun->var(1);
604 auto init_mat = w.bot(cont->type()->as<Pi>()->dom());
605 auto acc = init_mat;
606 auto current_mut = fun;
607 DefVec out_iters;
608 out_iters.reserve(ro);
609 for (u64 i = 0; i < ro; ++i) {
610 auto dim = Sr->proj(nloops, i);
611 auto bound = w.call<core::bitcast>(w.type_i64(), dim);
612 auto [body, for_call] = counting_for(bound, acc, cont, w.sym("forOut_" + std::to_string(i)));
613 auto [iter, new_acc, yield] = body->vars<3>();
614 cont = yield;
615 out_iters.push_back(w.call(core::conv::u, dim, iter));
616 acc = new_acc;
617 current_mut->set(true, for_call);
618 current_mut = body;
619 }
620 auto wb_matrix = acc;
621
622 // Write-back: narrow the accumulated element into the result at the affine write coordinates `acc_out`.
623 // acc_out takes the full (Ro+Rr) loop vector, but the reduction loops have already been folded away here, so we
624 // pass 0 for those slots; acc_out must depend only on the leading Ro output indices.
625 auto write_back = w.mut_con(To)->set("writeBack");
626 auto element_final = write_back->var(0);
627 DefVec wb_iters = out_iters;
628 for (u64 j = 0; j < rr; ++j)
629 wb_iters.push_back(w.call(core::conv::u, Sr->proj(nloops, ro + j), w.lit(w.type_i64(), 0)));
630 auto write_coords = affine_map(acc_out, Ro, n, Sr, So, w.tuple(wb_iters)); // «Ro; Idx (So#k)»
631 write_back->app(true, cont, nested_insert(wb_matrix, write_coords, ro, element_final));
632
633 // Inner (reduction) loops over the trailing Rr bounds of `Sr`, collecting the reduction iteration indices.
634 acc = init;
635 cont = write_back;
636 DefVec red_iters;
637 red_iters.reserve(rr);
638 for (u64 j = 0; j < rr; ++j) {
639 auto dim = Sr->proj(nloops, ro + j);
640 auto bound = w.call<core::bitcast>(w.type_i64(), dim);
641 auto [body, for_call] = counting_for(bound, acc, cont, w.sym("forIn_" + std::to_string(j)));
642 auto [iter, new_acc, yield] = body->vars<3>();
643 cont = yield;
644 red_iters.push_back(w.call(core::conv::u, dim, iter));
645 acc = new_acc;
646 current_mut->set(true, for_call);
647 current_mut = body;
648 }
649 auto element_acc = acc;
650
651 // The full loop iteration vector `(o…, r…)`; its moduli are exactly `Sr`.
652 DefVec iters_v = out_iters;
653 iters_v.insert(iters_v.end(), red_iters.begin(), red_iters.end());
654 auto iters = w.tuple(iters_v);
655
656 // Read one element from each input at its affine read coordinates.
657 DefVec input_elements(nis_nat);
658 for (u64 i = 0; i < nis_nat; ++i) {
659 auto input_matrix = new_inputs->proj(nis_nat, i);
660 auto coords
661 = affine_map(accs->proj(nis_nat, i), Ris->proj(nis_nat, i), n, Sr, Sis->proj(nis_nat, i), iters);
662 input_elements[i] = nested_extract(input_matrix, coords, ris_nat[i]);
663 }
664
665 comb->set("comb");
666 current_mut->app(true, comb, {w.tuple({element_acc, w.tuple(input_elements)}), cont});
667 return call;
668 } catch (const std::exception& e) {
669 error("error during lowering map_reduce_aff: {}", e.what());
670 }
671}
672
674 if (auto get = Axm::isa<tensor::get>(app)) {
675 if (auto res = lower_get(get)) return res;
676 } else if (auto set = Axm::isa<tensor::set>(app)) {
677 if (auto res = lower_set(set)) return res;
678 } else if (auto bc = Axm::isa<tensor::broadcast>(app)) {
679 if (auto res = lower_broadcast(bc)) return res;
680 } else if (auto mr = Axm::isa<tensor::map_reduce>(app)) {
681 if (auto res = lower_map_reduce(mr)) return res;
682 } else if (auto mra = Axm::isa<tensor::map_reduce_aff>(app)) {
683 if (auto res = lower_map_reduce_aff(mra)) return res;
684 }
685 return RWPhase::rewrite_imm_App(app);
686}
687
688} // namespace mim::plug::tensor::phase
const Def * callee() const
Definition lam.h:276
static auto uncurry_args(const Def *def)
Definition lam.h:329
const Def * arg() const
Definition lam.h:285
static auto isa(const Def *def)
Definition axm.h:107
Base class for all Defs.
Definition def.h:246
const Def * proj(nat_t a, nat_t i) const
Similar to World::extract while assuming an arity of a, but also works on Sigmas and Arrays.
Definition def.cpp:593
Def * set(size_t i, const Def *)
Successively set from left to right.
Definition def.cpp:266
World & world() const noexcept
Definition def.cpp:444
const Def * var(nat_t a, nat_t i) noexcept
Definition def.h:425
auto projs(F f) const
Splits this Def via Def::projections into an Array (if A == std::dynamic_extent) or std::array (other...
Definition def.h:386
const Def * type() const noexcept
Yields the "raw" type of this Def (maybe nullptr).
Definition def.cpp:452
A function.
Definition lam.h:110
static std::optional< T > isa(const Def *def)
Definition def.h:838
A dependent function type.
Definition lam.h:14
const Def * dom() const
Definition lam.h:35
World & new_world()
Create new Defs into this.
Definition phase.h:243
virtual const Def * rewrite(const Def *)
Definition rewrite.cpp:56
This is a thin wrapper for absl::InlinedVector<T, N, A> which is a drop-in replacement for std::vecto...
Definition vector.h:18
const Def * rewrite_imm_App(const App *) final
#define ILOG(...)
Definition log.h:90
#define ELOG(...)
Definition log.h:88
#define WLOG(...)
Definition log.h:89
#define DLOG(...)
Vaporizes to nothingness in Debug build.
Definition log.h:94
const Def * op_cps2ds_dep(const Def *k)
Definition direct.h:16
static std::tuple< const Def *, const Def *, absl::flat_hash_map< u64, const Def * >, Lam * > create_outer_loop(Lam *fun, const Vector< u64 > &out_indices, const absl::flat_hash_map< u64, const Def * > &dims)
static std::pair< Lam *, const Def * > counting_for(const Def *bound, const Def *acc, const Def *exit, Sym name)
static std::tuple< Vector< u64 >, Vector< u64 >, absl::flat_hash_map< u64, const Def * >, Vector< u64 > > extract_indices(const u64 n_nat, const u64 nis_nat, const Def *S, const Def *Ris, const Def *Sis, const Def *subs)
View< const Def * > Defs
Definition def.h:78
Vector< const Def * > DefVec
Definition def.h:79
int64_t s64
Definition types.h:27
void error(std::format_string< Args... > fmt, Args &&... args)
Wraps std::format to throw T with a formatted message.
Definition dbg.h:17
uint64_t u64
Definition types.h:27
@ Pi
Definition def.h:109
@ App
Definition def.h:109
Vector(I, I, A=A()) -> Vector< typename std::iterator_traits< I >::value_type, Default_Inlined_Size< typename std::iterator_traits< I >::value_type >, A >