MimIR
MimIR is my Intermediate Representation
Loading...
Searching...
No Matches
autodiff_rewrite_inner.cpp
Go to the documentation of this file.
2
5
6using namespace std::literals;
7
8namespace mim::plug::autodiff {
9
10const Def* Eval::augment_lit(const Lit* lit, Lam* f, Lam*) {
11 auto pb = zero_pullback(lit->type(), f->dom(2, 0));
12 partial_pullback[lit] = pb;
13 return lit;
14}
15
16const Def* Eval::augment_var(const Var* var, Lam*, Lam*) {
17 assert(augmented.count(var));
18 auto aug_var = augmented[var];
19 assert(partial_pullback.count(aug_var));
20 return var;
21}
22
23const Def* Eval::augment_lam(Lam* lam, Lam* f, Lam* f_diff) {
24 // TODO: we need partial pullbacks for tuples (higher-order / ret-cont application)
25 // also for higher-order args, ret_cont (at another point)
26 // the pullback is not important but formally required by tuple rule
27 if (augmented.count(lam)) {
28 // We already know the function:
29 // * recursion
30 // * higher order arguments
31 // * new encounter of previous function
32 DLOG("already augmented {} : {} to {} : {}", lam, lam->type(), augmented[lam], augmented[lam]->type());
33 return augmented[lam];
34 }
35 // TODO: better fix (another pass as analysis?)
36 // TODO: handle open functions
37 if (Lam::isa_basicblock(lam) || lam->sym().view().contains("ret") || lam->sym().view().contains("_cont")) {
38 // A open continuation behaves the same as return:
39 // ```
40 // cont: Cn[X]
41 // cont': Cn[X,Cn[X,A]]
42 // ```
43 // There is dependency on the closed function context.
44 // (All derivatives are with respect to the arguments of a closed function.)
45
46 DLOG("found an open continuation {} : {}", lam, lam->type());
47 auto cont_dom = lam->type()->dom(); // not only 0 but all
48 auto pb_ty = pullback_type(cont_dom, f->dom(2, 0));
49 auto aug_dom = autodiff_type_fun(cont_dom);
50 DLOG("augmented domain {}", aug_dom);
51 DLOG("pb type is {}", pb_ty);
52 auto aug_lam = world().mut_con({aug_dom, pb_ty})->set("aug_"s + lam->sym().str());
53 auto aug_var = aug_lam->var((nat_t)0);
54 augmented[lam->var()] = aug_var;
55 augmented[lam] = aug_lam; // TODO: only one of these two
56 derived[lam] = aug_lam;
57 auto pb = aug_lam->var(1);
58 partial_pullback[aug_var] = pb;
59 // We are still in same closed function.
60 auto new_body = augment(lam->body(), f, f_diff);
61 // TODO we also need to rewrite the filter
62 aug_lam->set(lam->filter(), new_body);
63
64 auto lam_pb = zero_pullback(lam->type(), f->dom(2, 0));
65 partial_pullback[aug_lam] = lam_pb;
66 DLOG("augmented {} : {}", lam, lam->type());
67 DLOG("to {} : {}", aug_lam, aug_lam->type());
68 DLOG("ppb for lam cont: {}", lam_pb);
69
70 return aug_lam;
71 }
72 DLOG("found a closed function call {} : {}", lam, lam->type());
73 // Some general function in the program needs to be differentiated.
74 auto aug_lam = world().call<ad>(lam);
75 // TODO: directly more association here? => partly inline op_autodiff
76 DLOG("augmented function is {} : {}", aug_lam, aug_lam->type());
77 return aug_lam;
78}
79
80const Def* Eval::augment_extract(const Extract* ext, Lam* f, Lam* f_diff) {
81 auto tuple = ext->tuple();
82 auto index = ext->index();
83
84 auto aug_tuple = augment(tuple, f, f_diff);
85 auto aug_index = augment(index, f, f_diff);
86
87 const Def* pb;
88 DLOG("tuple was: {} : {}", tuple, tuple->type());
89 DLOG("aug tuple: {} : {}", aug_tuple, aug_tuple->type());
90 if (shadow_pullback.count(aug_tuple)) {
91 auto shadow_tuple_pb = shadow_pullback[aug_tuple];
92 DLOG("Shadow pullback: {} : {}", shadow_tuple_pb, shadow_tuple_pb->type());
93 pb = world().extract(shadow_tuple_pb, aug_index);
94 } else {
95 // ```
96 // e:T, b:B
97 // b = e#i
98 // b* = \lambda (s:B). e* (insert s at i in (zero T))
99 // ```
100 assert(partial_pullback.count(aug_tuple));
101 auto tuple_pb = partial_pullback[aug_tuple];
102 auto pb_ty = pullback_type(ext->type(), f->dom(2, 0));
103 auto pb_fun = world().mut_lam(pb_ty)->set("extract_pb");
104 DLOG("Pullback: {} : {}", pb_fun, pb_fun->type());
105 auto pb_tangent = pb_fun->var(0uz)->set("s");
106 auto tuple_tan = world().insert(world().call<zero>(aug_tuple->type()), aug_index, pb_tangent)->set("tup_s");
107 pb_fun->app(true, tuple_pb, {tuple_tan, pb_fun->var(1) /* ret_var but make sure to select correct one */});
108 pb = pb_fun;
109 }
110
111 auto aug_ext = world().extract(aug_tuple, aug_index);
112 partial_pullback[aug_ext] = pb;
113
114 return aug_ext;
115}
116
117const Def* Eval::augment_tuple(const Tuple* tup, Lam* f, Lam* f_diff) {
118 // TODO: should use ops instead?
119 auto aug_ops = tup->projs([&](const Def* op) -> const Def* { return augment(op, f, f_diff); });
120 auto aug_tup = world().tuple(aug_ops);
121
122 auto pbs = DefVec(Defs(aug_ops), [&](const Def* op) { return partial_pullback[op]; });
123 DLOG("tuple pbs {}", fe::Join(pbs));
124 // shadow pb = tuple of pbs
125 auto shadow_pb = world().tuple(pbs);
126 shadow_pullback[aug_tup] = shadow_pb;
127
128 // ```
129 // \lambda (s:[E0,...,Em]).
130 // sum (m,A)
131 // ((cps2ds e0*) (s#0), ..., (cps2ds em*) (s#m))
132 // ```
133 auto pb_ty = pullback_type(tup->type(), f->dom(2, 0));
134 auto pb = world().mut_lam(pb_ty)->set("tup_pb");
135 DLOG("Augmented tuple: {} : {}", aug_tup, aug_tup->type());
136 DLOG("Tuple Pullback: {} : {}", pb, pb->type());
137 DLOG("shadow pb: {} : {}", shadow_pb, shadow_pb->type());
138
139 auto pb_tangent = pb->var(0uz)->set("tup_s");
140
141 auto tangents = DefVec(pbs.size(), [&](nat_t i) {
142 return world().app(direct::op_cps2ds_dep(pbs[i]), world().extract(pb_tangent, i));
143 });
144 pb->app(true, pb->var(1),
145 // summed up tangents
146 op_sum(tangent_type_fun(f->dom(2, 0)), tangents));
147 partial_pullback[aug_tup] = pb;
148
149 return aug_tup;
150}
151
152const Def* Eval::augment_pack(const Pack* pack, Lam* f, Lam* f_diff) {
153 auto arity = pack->arity(); // TODO: arity vs shape
154 auto body = pack->body();
155
156 auto aug_arity = augment_(arity, f, f_diff);
157 auto aug_body = augment(body, f, f_diff);
158
159 auto aug_pack = world().pack(aug_arity, aug_body);
160
161 assert(partial_pullback[aug_body] && "pack pullback should exists");
162 // TODO: or use scale axm
163 auto body_pb = partial_pullback[aug_body];
164 auto pb_pack = world().pack(aug_arity, body_pb);
165 shadow_pullback[aug_pack] = pb_pack;
166
167 DLOG("shadow pb of pack: {} : {}", pb_pack, pb_pack->type());
168
169 auto pb_type = pullback_type(pack->type(), f->dom(2, 0));
170 auto pb = world().mut_lam(pb_type)->set("pack_pb");
171
172 DLOG("pb of pack: {} : {}", pb, pb_type);
173
174 auto f_arg_ty_diff = tangent_type_fun(f->dom(2, 0));
175 auto app_pb = world().mut_pack(world().arr(aug_arity, f_arg_ty_diff));
176
177 // TODO: special case for const width (special tuple)
178
179 // <i:n, cps2ds body_pb (s#i)>
180 app_pb->set(world().app(direct::op_cps2ds_dep(body_pb), world().extract(pb->var((nat_t)0), app_pb->var())));
181
182 DLOG("app pb of pack: {} : {}", app_pb, app_pb->type());
183
184 auto sumup = world().app(world().annex<sum>(), {aug_arity, f_arg_ty_diff});
185 DLOG("sumup: {} : {}", sumup, sumup->type());
186
187 pb->app(true, pb->var(1), world().app(sumup, app_pb));
188
189 partial_pullback[aug_pack] = pb;
190
191 return aug_pack;
192}
193
194const Def* Eval::augment_app(const App* app, Lam* f, Lam* f_diff) {
195 auto callee = app->callee();
196 auto arg = app->arg();
197
198 auto aug_arg = augment(arg, f, f_diff);
199 auto aug_callee = augment(callee, f, f_diff);
200
201 DLOG("augmented argument <{}> {} : {}", aug_arg->unique_name(), aug_arg, aug_arg->type());
202 DLOG("augmented callee <{}> {} : {}", aug_callee->unique_name(), aug_callee, aug_callee->type());
203 // TODO: move down to if(!is_cont(callee))
204 if (!Pi::isa_cn(callee->type()) && Pi::isa_cn(aug_callee->type())) {
205 aug_callee = direct::op_cps2ds_dep(aug_callee);
206 DLOG("wrapped augmented callee: <{}> {} : {}", aug_callee->unique_name(), aug_callee, aug_callee->type());
207 }
208
209 // nested (inner application)
210 if (app->type()->isa<Pi>()) {
211 DLOG("Nested application callee: {} : {}", aug_callee, aug_callee->type());
212 DLOG("Nested application arg: {} : {}", aug_arg, aug_arg->type());
213 auto aug_app = world().app(aug_callee, aug_arg);
214 DLOG("Nested application result: <{}> {} : {}", aug_app->unique_name(), aug_app, aug_app->type());
215 // We do not add a pullback as the pullback is bundled in the cps call or returned by the ds call
216 return aug_app;
217 }
218
219 // continuation (ret, if, ...)
220 if (Pi::isa_basicblock(callee->type())) {
221 // TODO: check if function (not operator)
222 // The original function is an open function (return cont / continuation) of type `Cn[E]`
223 // The augmented function `aug_callee` looks like a function but is not really a function has the type `Cn[E,
224 // Cn[E, Cn[A]]]`
225
226 // ret(e) => ret'(e, e*)
227
228 DLOG("continuation {} : {} => {} : {}", callee, callee->type(), aug_callee, aug_callee->type());
229
230 auto arg_pb = partial_pullback[aug_arg];
231 auto aug_app = world().app(aug_callee, {aug_arg, arg_pb});
232 DLOG("Augmented application: {} : {}", aug_app, aug_app->type());
233 return aug_app;
234 }
235
236 // ds function
237 if (!Pi::isa_cn(callee->type())) {
238 auto aug_app = world().app(aug_callee, aug_arg);
239 DLOG("Augmented application: <{}> {} : {}", aug_app->unique_name(), aug_app, aug_app->type());
240
241 DLOG("ds function: {} : {}", aug_app, aug_app->type());
242 // The calle is ds function (e.g. operator (or its partial application))
243 auto [aug_res, fun_pb] = aug_app->projs<2>();
244 // We compose `fun_pb` with `argument_pb` to get the result pb
245 // TODO: combine case with cps function case
246 auto arg_pb = partial_pullback[aug_arg];
247 assert(arg_pb);
248 // `fun_pb: out_tan -> arg_tan`
249 // `arg_pb: arg_tan -> fun_tan`
250 DLOG("function pullback: {} : {}", fun_pb, fun_pb->type());
251 DLOG("argument pullback: {} : {}", arg_pb, arg_pb->type());
252 auto res_pb = compose_cn(arg_pb, fun_pb);
253 DLOG("result pullback: {} : {}", res_pb, res_pb->type());
254 partial_pullback[aug_res] = res_pb;
255 world().debug_dump();
256 return aug_res;
257 }
258
259 // TODO: dest with a function such that f args != g args
260 {
261 // normal function app
262 // ```
263 // g: cn[E, cn X]
264 // g(args,cont)
265 // g': cn[E, cn[X, cn[X, cn E]]]
266 // g'(aug_args, ____)
267 // ```
268 auto g = callee;
269 // At this point g_deriv might still be "autodiff ... g".
270 auto g_deriv = aug_callee;
271 DLOG("g: {} : {}", g, g->type());
272 DLOG("g': {} : {}", g_deriv, g_deriv->type());
273
274 auto [real_aug_args, aug_cont] = aug_arg->projs<2>();
275 DLOG("real_aug_args: {} : {}", real_aug_args, real_aug_args->type());
276 DLOG("aug_cont: {} : {}", aug_cont, aug_cont->type());
277 auto e_pb = partial_pullback[real_aug_args];
278 DLOG("e_pb (arg_pb): {} : {}", e_pb, e_pb->type());
279
280 // TODO: better debug names
281 auto ret_g_deriv_ty = g_deriv->type()->as<Pi>()->dom(1);
282 DLOG("ret_g_deriv_ty: {} ", ret_g_deriv_ty);
283 auto c1_ty = ret_g_deriv_ty->as<Pi>();
284 DLOG("c1_ty: (cn[X, cn[X+, cn E+]]) {}", c1_ty);
285 auto c1 = world().mut_lam(c1_ty)->set("c1");
286 auto res = c1->var((nat_t)0);
287 auto r_pb = c1->var(1);
288 c1->app(true, aug_cont, {res, compose_cn(e_pb, r_pb)});
289
290 auto aug_app = world().app(aug_callee, {real_aug_args, c1});
291 DLOG("aug_app: {} : {}", aug_app, aug_app->type());
292
293 // The result is * => no pb needed, no composition needed.
294 return aug_app;
295 }
296 assert(false && "should not be reached");
297}
298
299/// Rewrites the given definition in a lambda environment.
300const Def* Eval::augment_(const Def* def, Lam* f, Lam* f_diff) {
301 // We use macros above to avoid recomputation.
302 // TODO: Alternative:
303 // Use class instances to rewrite inside a function and save such values (f, f_diff, f->dom(2, 0)).
304
305 DLOG("Augment def {} : {}", def, def->type());
306
307 // Applications are continuations, operators, or full functions
308 if (auto app = def->isa<App>()) {
309 auto callee = app->callee();
310 auto arg = app->arg();
311 DLOG("Augment application: app {} with {}", callee, arg);
312 return augment_app(app, f, f_diff);
313 } else if (auto ext = def->isa<Extract>()) {
314 auto tuple = ext->tuple();
315 auto index = ext->index();
316 DLOG("Augment extract: {} #[{}]", tuple, index);
317 return augment_extract(ext, f, f_diff);
318 } else if (auto var = def->isa<Var>()) {
319 DLOG("Augment variable: {}", var);
320 return augment_var(var, f, f_diff);
321 } else if (auto lam = def->isa_mut<Lam>()) {
322 DLOG("Augment mut lambda: {}", lam);
323 return augment_lam(lam, f, f_diff);
324 } else if (auto lam = def->isa<Lam>()) {
325 ELOG("Augment lambda: {}", lam);
326 assert(false && "can not handle non-mutable lambdas");
327 } else if (auto lit = def->isa<Lit>()) {
328 DLOG("Augment literal: {}", def);
329 return augment_lit(lit, f, f_diff);
330 } else if (auto tup = def->isa<Tuple>()) {
331 DLOG("Augment tuple: {}", def);
332 return augment_tuple(tup, f, f_diff);
333 } else if (auto pack = def->isa<Pack>()) {
334 // TODO: handle mut packs (dependencies in the pack) (=> see paper about vectors)
335 auto arity = pack->arity(); // TODO: arity vs shape
336 auto body = pack->body();
337 DLOG("Augment pack: {} : {} with {}", arity, arity->type(), body);
338 return augment_pack(pack, f, f_diff);
339 } else if (auto ax = def->isa<Axm>()) {
340 // TODO: move concrete handling to own function / file / directory (file per plugin)
341 DLOG("Augment axm: {} : {}", ax, ax->type());
342 DLOG("axm curry: {}", ax->curry());
343 DLOG("axm flags: {}", ax->flags());
344 auto diff_name = ax->sym().str();
345 find_and_replace(diff_name, ".", "_");
346 find_and_replace(diff_name, "%", "");
347 diff_name = "internal_diff_" + diff_name;
348 DLOG("axm name: {}", ax->sym());
349 DLOG("axm function name: {}", diff_name);
350
351 auto diff_fun = world().externals()[world().sym(diff_name)];
352 if (!diff_fun) {
353 ELOG("derivation not found: {}", diff_name);
354 auto expected_type = autodiff_type_fun(ax->type());
355 ELOG("expected: {} : {}", diff_name, expected_type);
356 assert(false && "unhandled axm");
357 }
358 // TODO: why does this cause a depth error?
359 return diff_fun;
360 }
361
362 // TODO: handle Pi for axm app
363 // TODO: remaining (lambda, axm)
364
365 ELOG("did not expect to augment: {} : {}", def, def->type());
366 ELOG("node: {}", def->node_name());
367 assert(false && "augment not implemented on this def");
368 fe::unreachable();
369}
370
371} // namespace mim::plug::autodiff
const Def * callee() const
Definition lam.h:276
const Def * arg() const
Definition lam.h:285
Definition axm.h:9
Base class for all Defs.
Definition def.h:246
Def * set(size_t i, const Def *)
Successively set from left to right.
Definition def.cpp:266
std::string_view node_name() const
Definition def.cpp:465
T * isa_mut() const
If this is mutable, it will cast constness away and perform a dynamic_cast to T.
Definition def.h:493
const Def * var(nat_t a, nat_t i) noexcept
Definition def.h:425
auto projs(F f) const
Splits this Def via Def::projections into an Array (if A == std::dynamic_extent) or std::array (other...
Definition def.h:386
const Def * type() const noexcept
Yields the "raw" type of this Def (maybe nullptr).
Definition def.cpp:452
Sym sym() const
Definition def.h:515
Extracts from a Sigma or Array-typed Extract::tuple the element at position Extract::index.
Definition tuple.h:206
const Def * tuple() const
Definition tuple.h:216
const Def * index() const
Definition tuple.h:217
A function.
Definition lam.h:110
const Def * filter() const
Definition lam.h:122
Lam * set(Filter filter, const Def *body)
Definition lam.cpp:29
const Pi * type() const
Definition lam.h:130
static const Lam * isa_basicblock(const Def *d)
Definition lam.h:142
const Def * body() const
Definition lam.h:123
A (possibly paramterized) Tuple.
Definition tuple.h:166
const Def * arity() const final
Definition tuple.cpp:45
Pack * set(const Def *body)
Definition tuple.h:183
size_t index() const
Definition pass.h:112
A dependent function type.
Definition lam.h:14
static const Pi * isa_cn(const Def *d)
Is this a continuation - i.e. is the Pi::codom mim::Bottom?
Definition lam.h:47
static const Pi * isa_basicblock(const Def *d)
Is this a continuation (Pi::isa_cn) that is not Pi::isa_returning?
Definition lam.h:51
const Def * dom() const
Definition lam.h:35
const Def * body() const
Definition tuple.h:91
World & world()
Definition pass.h:77
flags_t annex() const
Definition pass.h:81
Data constructor for a Sigma.
Definition tuple.h:68
A variable introduced by a binder (mutable).
Definition def.h:714
const Def * insert(const Def *d, const Def *i, const Def *val)
Definition world.cpp:452
const Def * pack(const Def *arity, const Def *body)
Definition world.h:475
const Def * app(const Def *callee, const Def *arg)
Definition world.cpp:205
const Def * tuple(Defs ops)
Definition world.cpp:297
void debug_dump()
Dump in Debug build if World::log::level is Log::Level::Debug.
Definition dump.cpp:590
const Def * extract(const Def *d, const Def *i)
Definition world.cpp:362
Pack * mut_pack(const Def *type)
Definition world.h:473
Sym sym(std::string_view)
Definition world.cpp:105
const Def * call(const Def *callee, T &&arg, Args &&... args)
Definition world.h:659
const Externals & externals() const
Definition world.h:282
Lam * mut_con(const Def *dom)
Definition world.h:418
Lam * mut_lam(const Pi *pi)
Definition world.h:406
const Def * augment_lit(const Lit *, Lam *, Lam *)
const Def * augment_tuple(const Tuple *, Lam *, Lam *)
const Def * augment_pack(const Pack *pack, Lam *f, Lam *f_diff)
const Def * augment_app(const App *, Lam *, Lam *)
const Def * augment_lam(Lam *, Lam *, Lam *)
const Def * augment_(const Def *, Lam *, Lam *)
Rewrites the given definition in a lambda environment.
const Def * augment(const Def *, Lam *, Lam *)
Applies to (open) expressions in a functional context.
Definition eval.cpp:10
const Def * augment_extract(const Extract *, Lam *, Lam *)
const Def * augment_var(const Var *, Lam *, Lam *)
helper functions for augment
#define ELOG(...)
Definition log.h:88
#define DLOG(...)
Vaporizes to nothingness in Debug build.
Definition log.h:94
The automatic differentiation Plugin
Definition autodiff.h:6
const Def * op_sum(const Def *T, Defs)
Definition autodiff.cpp:167
const Def * autodiff_type_fun(const Def *)
Definition autodiff.cpp:111
const Def * tangent_type_fun(const Def *)
Definition autodiff.cpp:55
const Def * zero_pullback(const Def *E, const Def *A)
Definition autodiff.cpp:42
const Pi * pullback_type(const Def *E, const Def *A)
computes pb type E* -> A* E - type of the expression (return type for a function) A - type of the arg...
Definition autodiff.cpp:60
const Def * op_cps2ds_dep(const Def *k)
Definition direct.h:16
The tuple Plugin
View< const Def * > Defs
Definition def.h:78
u64 nat_t
Definition types.h:37
Vector< const Def * > DefVec
Definition def.h:79
void find_and_replace(std::string &str, std::string_view what, std::string_view repl)
Replaces all occurrences of what with repl.
Definition util.h:73
const Def * compose_cn(const Def *f, const Def *g)
The high level view is:
Definition lam.cpp:62