6using namespace std::literals;
12 partial_pullback[lit] = pb;
17 assert(augmented.count(var));
18 auto aug_var = augmented[var];
19 assert(partial_pullback.count(aug_var));
27 if (augmented.count(lam)) {
32 DLOG(
"already augmented {} : {} to {} : {}", lam, lam->
type(), augmented[lam], augmented[lam]->
type());
33 return augmented[lam];
46 DLOG(
"found an open continuation {} : {}", lam, lam->
type());
47 auto cont_dom = lam->
type()->
dom();
50 DLOG(
"augmented domain {}", aug_dom);
51 DLOG(
"pb type is {}", pb_ty);
52 auto aug_lam =
world().
mut_con({aug_dom, pb_ty})->set(
"aug_"s + lam->
sym().str());
53 auto aug_var = aug_lam->var((
nat_t)0);
54 augmented[lam->
var()] = aug_var;
55 augmented[lam] = aug_lam;
56 derived[lam] = aug_lam;
57 auto pb = aug_lam->var(1);
58 partial_pullback[aug_var] = pb;
62 aug_lam->set(lam->
filter(), new_body);
65 partial_pullback[aug_lam] = lam_pb;
66 DLOG(
"augmented {} : {}", lam, lam->
type());
67 DLOG(
"to {} : {}", aug_lam, aug_lam->type());
68 DLOG(
"ppb for lam cont: {}", lam_pb);
72 DLOG(
"found a closed function call {} : {}", lam, lam->
type());
76 DLOG(
"augmented function is {} : {}", aug_lam, aug_lam->type());
89 DLOG(
"aug tuple: {} : {}", aug_tuple, aug_tuple->type());
90 if (shadow_pullback.count(aug_tuple)) {
91 auto shadow_tuple_pb = shadow_pullback[aug_tuple];
92 DLOG(
"Shadow pullback: {} : {}", shadow_tuple_pb, shadow_tuple_pb->type());
100 assert(partial_pullback.count(aug_tuple));
101 auto tuple_pb = partial_pullback[aug_tuple];
104 DLOG(
"Pullback: {} : {}", pb_fun, pb_fun->type());
105 auto pb_tangent = pb_fun->var(0uz)->set(
"s");
106 auto tuple_tan =
world().
insert(
world().call<zero>(aug_tuple->type()), aug_index, pb_tangent)->
set(
"tup_s");
107 pb_fun->app(
true, tuple_pb, {tuple_tan, pb_fun->var(1) });
112 partial_pullback[aug_ext] = pb;
119 auto aug_ops = tup->
projs([&](
const Def* op) ->
const Def* {
return augment(op, f, f_diff); });
122 auto pbs =
DefVec(
Defs(aug_ops), [&](
const Def* op) {
return partial_pullback[op]; });
123 DLOG(
"tuple pbs {}", fe::Join(pbs));
126 shadow_pullback[aug_tup] = shadow_pb;
135 DLOG(
"Augmented tuple: {} : {}", aug_tup, aug_tup->type());
136 DLOG(
"Tuple Pullback: {} : {}", pb, pb->type());
137 DLOG(
"shadow pb: {} : {}", shadow_pb, shadow_pb->type());
139 auto pb_tangent = pb->var(0uz)->set(
"tup_s");
142 return world().app(direct::op_cps2ds_dep(pbs[i]), world().extract(pb_tangent, i));
144 pb->app(
true, pb->var(1),
147 partial_pullback[aug_tup] = pb;
153 auto arity = pack->
arity();
154 auto body = pack->
body();
156 auto aug_arity =
augment_(arity, f, f_diff);
157 auto aug_body =
augment(body, f, f_diff);
159 auto aug_pack =
world().
pack(aug_arity, aug_body);
161 assert(partial_pullback[aug_body] &&
"pack pullback should exists");
163 auto body_pb = partial_pullback[aug_body];
164 auto pb_pack =
world().
pack(aug_arity, body_pb);
165 shadow_pullback[aug_pack] = pb_pack;
167 DLOG(
"shadow pb of pack: {} : {}", pb_pack, pb_pack->type());
172 DLOG(
"pb of pack: {} : {}", pb, pb_type);
182 DLOG(
"app pb of pack: {} : {}", app_pb, app_pb->type());
185 DLOG(
"sumup: {} : {}", sumup, sumup->type());
187 pb->app(
true, pb->var(1),
world().app(sumup, app_pb));
189 partial_pullback[aug_pack] = pb;
195 auto callee = app->
callee();
196 auto arg = app->
arg();
198 auto aug_arg =
augment(arg, f, f_diff);
199 auto aug_callee =
augment(callee, f, f_diff);
201 DLOG(
"augmented argument <{}> {} : {}", aug_arg->unique_name(), aug_arg, aug_arg->type());
202 DLOG(
"augmented callee <{}> {} : {}", aug_callee->unique_name(), aug_callee, aug_callee->type());
206 DLOG(
"wrapped augmented callee: <{}> {} : {}", aug_callee->unique_name(), aug_callee, aug_callee->type());
210 if (app->
type()->isa<
Pi>()) {
211 DLOG(
"Nested application callee: {} : {}", aug_callee, aug_callee->type());
212 DLOG(
"Nested application arg: {} : {}", aug_arg, aug_arg->type());
213 auto aug_app =
world().
app(aug_callee, aug_arg);
214 DLOG(
"Nested application result: <{}> {} : {}", aug_app->unique_name(), aug_app, aug_app->type());
228 DLOG(
"continuation {} : {} => {} : {}", callee, callee->type(), aug_callee, aug_callee->type());
230 auto arg_pb = partial_pullback[aug_arg];
231 auto aug_app =
world().
app(aug_callee, {aug_arg, arg_pb});
232 DLOG(
"Augmented application: {} : {}", aug_app, aug_app->type());
238 auto aug_app =
world().
app(aug_callee, aug_arg);
239 DLOG(
"Augmented application: <{}> {} : {}", aug_app->unique_name(), aug_app, aug_app->type());
241 DLOG(
"ds function: {} : {}", aug_app, aug_app->type());
243 auto [aug_res, fun_pb] = aug_app->projs<2>();
246 auto arg_pb = partial_pullback[aug_arg];
250 DLOG(
"function pullback: {} : {}", fun_pb, fun_pb->type());
251 DLOG(
"argument pullback: {} : {}", arg_pb, arg_pb->type());
253 DLOG(
"result pullback: {} : {}", res_pb, res_pb->type());
254 partial_pullback[aug_res] = res_pb;
270 auto g_deriv = aug_callee;
271 DLOG(
"g: {} : {}", g, g->type());
272 DLOG(
"g': {} : {}", g_deriv, g_deriv->type());
274 auto [real_aug_args, aug_cont] = aug_arg->projs<2>();
275 DLOG(
"real_aug_args: {} : {}", real_aug_args, real_aug_args->type());
276 DLOG(
"aug_cont: {} : {}", aug_cont, aug_cont->type());
277 auto e_pb = partial_pullback[real_aug_args];
278 DLOG(
"e_pb (arg_pb): {} : {}", e_pb, e_pb->type());
281 auto ret_g_deriv_ty = g_deriv->type()->as<
Pi>()->dom(1);
282 DLOG(
"ret_g_deriv_ty: {} ", ret_g_deriv_ty);
283 auto c1_ty = ret_g_deriv_ty->as<
Pi>();
284 DLOG(
"c1_ty: (cn[X, cn[X+, cn E+]]) {}", c1_ty);
287 auto r_pb = c1->
var(1);
288 c1->app(
true, aug_cont, {res,
compose_cn(e_pb, r_pb)});
290 auto aug_app =
world().
app(aug_callee, {real_aug_args, c1});
291 DLOG(
"aug_app: {} : {}", aug_app, aug_app->type());
296 assert(
false &&
"should not be reached");
305 DLOG(
"Augment def {} : {}", def, def->
type());
308 if (
auto app = def->isa<
App>()) {
309 auto callee = app->callee();
310 auto arg = app->arg();
311 DLOG(
"Augment application: app {} with {}", callee, arg);
313 }
else if (
auto ext = def->isa<
Extract>()) {
314 auto tuple = ext->tuple();
315 auto index = ext->index();
318 }
else if (
auto var = def->isa<
Var>()) {
319 DLOG(
"Augment variable: {}", var);
322 DLOG(
"Augment mut lambda: {}", lam);
324 }
else if (
auto lam = def->isa<
Lam>()) {
325 ELOG(
"Augment lambda: {}", lam);
326 assert(
false &&
"can not handle non-mutable lambdas");
327 }
else if (
auto lit = def->isa<
Lit>()) {
328 DLOG(
"Augment literal: {}", def);
330 }
else if (
auto tup = def->isa<
Tuple>()) {
331 DLOG(
"Augment tuple: {}", def);
333 }
else if (
auto pack = def->isa<
Pack>()) {
335 auto arity = pack->arity();
336 auto body = pack->body();
337 DLOG(
"Augment pack: {} : {} with {}", arity, arity->type(), body);
339 }
else if (
auto ax = def->isa<
Axm>()) {
341 DLOG(
"Augment axm: {} : {}", ax, ax->type());
342 DLOG(
"axm curry: {}", ax->curry());
343 DLOG(
"axm flags: {}", ax->flags());
344 auto diff_name = ax->sym().str();
347 diff_name =
"internal_diff_" + diff_name;
348 DLOG(
"axm name: {}", ax->sym());
349 DLOG(
"axm function name: {}", diff_name);
353 ELOG(
"derivation not found: {}", diff_name);
355 ELOG(
"expected: {} : {}", diff_name, expected_type);
356 assert(
false &&
"unhandled axm");
365 ELOG(
"did not expect to augment: {} : {}", def, def->
type());
367 assert(
false &&
"augment not implemented on this def");
const Def * callee() const
Def * set(size_t i, const Def *)
Successively set from left to right.
std::string_view node_name() const
T * isa_mut() const
If this is mutable, it will cast constness away and perform a dynamic_cast to T.
const Def * var(nat_t a, nat_t i) noexcept
auto projs(F f) const
Splits this Def via Def::projections into an Array (if A == std::dynamic_extent) or std::array (other...
const Def * type() const noexcept
Yields the "raw" type of this Def (maybe nullptr).
const Def * filter() const
Lam * set(Filter filter, const Def *body)
static const Lam * isa_basicblock(const Def *d)
A (possibly paramterized) Tuple.
const Def * arity() const final
Pack * set(const Def *body)
A dependent function type.
static const Pi * isa_cn(const Def *d)
Is this a continuation - i.e. is the Pi::codom mim::Bottom?
static const Pi * isa_basicblock(const Def *d)
Is this a continuation (Pi::isa_cn) that is not Pi::isa_returning?
Data constructor for a Sigma.
A variable introduced by a binder (mutable).
const Def * insert(const Def *d, const Def *i, const Def *val)
const Def * pack(const Def *arity, const Def *body)
const Def * app(const Def *callee, const Def *arg)
const Def * tuple(Defs ops)
void debug_dump()
Dump in Debug build if World::log::level is Log::Level::Debug.
const Def * extract(const Def *d, const Def *i)
Pack * mut_pack(const Def *type)
Sym sym(std::string_view)
const Def * call(const Def *callee, T &&arg, Args &&... args)
const Externals & externals() const
Lam * mut_con(const Def *dom)
Lam * mut_lam(const Pi *pi)
const Def * augment_lit(const Lit *, Lam *, Lam *)
const Def * augment_tuple(const Tuple *, Lam *, Lam *)
const Def * augment_pack(const Pack *pack, Lam *f, Lam *f_diff)
const Def * augment_app(const App *, Lam *, Lam *)
const Def * augment_lam(Lam *, Lam *, Lam *)
const Def * augment_(const Def *, Lam *, Lam *)
Rewrites the given definition in a lambda environment.
const Def * augment(const Def *, Lam *, Lam *)
Applies to (open) expressions in a functional context.
const Def * augment_extract(const Extract *, Lam *, Lam *)
const Def * augment_var(const Var *, Lam *, Lam *)
helper functions for augment
#define DLOG(...)
Vaporizes to nothingness in Debug build.
The automatic differentiation Plugin
const Def * op_sum(const Def *T, Defs)
const Def * autodiff_type_fun(const Def *)
const Def * tangent_type_fun(const Def *)
const Def * zero_pullback(const Def *E, const Def *A)
const Pi * pullback_type(const Def *E, const Def *A)
computes pb type E* -> A* E - type of the expression (return type for a function) A - type of the arg...
const Def * op_cps2ds_dep(const Def *k)
Vector< const Def * > DefVec
void find_and_replace(std::string &str, std::string_view what, std::string_view repl)
Replaces all occurrences of what with repl.
const Def * compose_cn(const Def *f, const Def *g)
The high level view is: