MimIR
MimIR is my Intermediate Representation
Loading...
Searching...
No Matches
lower_matrix_highlevel.cpp
Go to the documentation of this file.
2
3#include <mim/lam.h>
4
6
8
9namespace mim::plug::matrix {
10
11namespace {
12
13// clang-format off
14absl::flat_hash_map<flags_t, flags_t> axm_to_impl_map = {
18};
19// clang-format on
20
21std::optional<const Def*> internal_function_of_axm(const Axm* axm, const Def* meta_args, const Def* args) {
22 auto& world = axm->world();
23 if (auto it = axm_to_impl_map.find(axm->flags()); it != axm_to_impl_map.end()) {
24 const Def* spec_fun = world.implicit_app(world.annexes().flags2entry().at(it->second).def, meta_args);
25 auto ds_fun = direct::op_cps2ds_dep(spec_fun);
26 return world.app(ds_fun, args);
27 }
28 return std::nullopt;
29}
30
31} // namespace
32
34 if (auto i = rewritten.find(def); i != rewritten.end()) return i->second;
35 auto new_def = rewrite_(def);
36 rewritten[def] = new_def;
37 return rewritten[def];
38}
39
41 if (auto mat_ax = Axm::isa<matrix::prod>(def)) {
42 auto [mem, M, N] = mat_ax->args<3>();
43 auto [m, k, l, w] = mat_ax->decurry()->args<4>();
44 auto w_lit = Lit::isa(w);
45
46 auto ext_fun = world().externals()[world().sym("extern_matrix_prod")];
47 if (ext_fun && (w_lit && *w_lit == 64)) {
48 auto ds_fun = direct::op_cps2ds_dep(ext_fun);
49 auto fun_app = world().app(ds_fun, {mem, m, k, l, M, N});
50 return fun_app;
51 }
52 }
53
54 if (auto outer_app = def->isa<App>()) {
55 if (auto inner_app = outer_app->callee()->isa<App>()) {
56 if (auto axm = inner_app->callee()->isa<Axm>()) {
57 if (auto internal_function = internal_function_of_axm(axm, inner_app->arg(), outer_app->arg())) {
58 DLOG("lower matrix axm {} in {} : {}", *axm->sym(), def, def->type());
59 DLOG("lower matrix axm using: {} : {}", *internal_function, (*internal_function)->type());
60 return *internal_function;
61 }
62 }
63 }
64 }
65
66 return def;
67}
68
69} // namespace mim::plug::matrix
Definition axm.h:9
static auto isa(const Def *def)
Definition axm.h:107
Base class for all Defs.
Definition def.h:246
const Def * type() const noexcept
Yields the "raw" type of this Def (maybe nullptr).
Definition def.cpp:452
static std::optional< T > isa(const Def *def)
Definition def.h:838
World & world()
Definition pass.h:77
const Def * app(const Def *callee, const Def *arg)
Definition world.cpp:205
Sym sym(std::string_view)
Definition world.cpp:105
const Externals & externals() const
Definition world.h:282
const Def * rewrite(const Def *) override
custom rewrite function memoized version of rewrite_
#define DLOG(...)
Vaporizes to nothingness in Debug build.
Definition log.h:94
const Def * op_cps2ds_dep(const Def *k)
Definition direct.h:16
The matrix Plugin
Definition matrix.h:7
The mem Plugin
Definition mem.h:11
u64 flags_t
Definition types.h:39
@ Axm
Definition def.h:109
static constexpr flags_t Base
Definition plugin.h:148