10#include "absl/container/btree_set.h"
11#include "absl/container/flat_hash_set.h"
15const Def* Lower::lower_via_impl(
const App* app,
const Def* impl_annex) {
21 const Def*
head = app;
22 while (
auto h =
head->isa<
App>()) {
26 std::reverse(args.begin(), args.end());
28 auto impl = impl_annex;
30 impl =
w.app(impl, a);
38const Def* Lower::lower_broadcast_in_dim(
const App* app) {
43 auto [s_in, s_out, input,
index] = arg->projs<4>();
44 auto callee =
c->as<
App>();
45 auto [T, r_in, r_out] = callee->args<3>();
46 DLOG(
"lower_broadcast_in_dim");
47 DLOG(
" s_out = {} : {}", s_out, s_out->type());
48 DLOG(
" input = {} : {}", input, input->type());
49 DLOG(
" index = {} : {}", index,
index->type());
50 DLOG(
" T = {} : {}", T, T->type());
51 DLOG(
" r_in = {} : {}", r_in, r_in->type());
52 DLOG(
" r_out = {} : {}", r_out, r_out->type());
53 DLOG(
" s_in = {} : {}", s_in, s_in->type());
56 if (!r_in_lit)
return nullptr;
58 if (!r_out_lit)
return nullptr;
59 auto r_out_nat = *r_out_lit;
60 auto r_in_nat = *r_in_lit;
62 auto s_tr_vec =
DefVec(r_out_nat, [&](
size_t i) {
63 if (i < r_in_nat)
return s_in->proj(r_in_nat, i);
64 return w.lit_nat_1()->as<Def>();
66 auto s_tr =
w.tuple(s_tr_vec);
68 absl::btree_set<u64> set_perm;
69 absl::flat_hash_map<u64, u64> map_perm;
70 for (
u64 i = 0; i < r_out_nat; ++i)
72 for (
u64 i = 0; i < r_in_nat; ++i) {
75 if (!idx_lit)
return nullptr;
76 u64 idx_nat = *idx_lit;
78 map_perm[idx_nat] = i;
80 set_perm.erase(idx_nat);
82 for (
u64 j = r_in_nat;
auto i : set_perm) {
86 auto permutation_vec =
DefVec(r_out_nat, [&](
size_t i) {
return w.lit_idx(r_out_nat, map_perm[i]); });
87 auto permutation =
w.tuple(permutation_vec);
92 tr =
w.app(tr, {T, r_out, s_tr});
93 tr =
w.app(tr, {input, permutation});
95 auto s_bc_vec =
DefVec(r_out_nat, [&](
size_t i) {
return s_tr->proj(r_out_nat, map_perm.at(i)); });
96 auto s_bc =
w.tuple(s_bc_vec);
99 bc =
w.app(bc, {T, r_out});
100 bc =
w.app(bc, {s_bc, s_out, tr});
111 if (
auto res = lower_broadcast_in_dim(bid))
return res;
131 return RWPhase::rewrite_imm_App(app);
static auto isa(const Def *def)
static std::optional< T > isa(const Def *def)
World & new_world()
Create new Defs into this.
virtual const Def * rewrite(const Def *)
const Def * rewrite_imm_App(const App *) final
#define DLOG(...)
Vaporizes to nothingness in Debug build.
Vector< const Def * > DefVec