This pass is the heart of AD. More...
#include <mim/plug/autodiff/pass/eval.h>
Public Member Functions | |
| Eval (World &world, flags_t annex) | |
| const Def * | rewrite (const Def *) override |
| Detect autodiff calls. | |
| const Def * | derive (const Def *) |
| Acts on toplevel autodiff on closed terms: | |
| const Def * | derive_ (const Def *) |
| Additionally to the derivation, the pullback is registered and the maps are initialized. | |
| const Def * | augment (const Def *, Lam *, Lam *) |
| Applies to (open) expressions in a functional context. | |
| const Def * | augment_ (const Def *, Lam *, Lam *) |
| Rewrites the given definition in a lambda environment. | |
| const Def * | augment_var (const Var *, Lam *, Lam *) |
| helper functions for augment | |
| const Def * | augment_lam (Lam *, Lam *, Lam *) |
| const Def * | augment_extract (const Extract *, Lam *, Lam *) |
| const Def * | augment_app (const App *, Lam *, Lam *) |
| const Def * | augment_lit (const Lit *, Lam *, Lam *) |
| const Def * | augment_tuple (const Tuple *, Lam *, Lam *) |
| const Def * | augment_pack (const Pack *pack, Lam *f, Lam *f_diff) |
| Public Member Functions inherited from mim::RWPass< Eval, Lam > | |
| RWPass (World &world, std::string name) | |
| bool | inspect () const override |
| Should the PassMan even consider this pass? | |
| Lam * | curr_mut () const |
| Public Member Functions inherited from mim::Pass | |
| Pass (World &world, std::string name) | |
| Pass (World &world, flags_t annex) | |
| virtual void | init (PassMan *) |
| PassMan & | man () |
| const PassMan & | man () const |
| size_t | index () const |
| virtual const Def * | rewrite (const Var *var) |
| virtual const Def * | rewrite (const Proxy *proxy) |
| virtual undo_t | analyze (const Def *) |
| virtual undo_t | analyze (const Var *) |
| virtual undo_t | analyze (const Proxy *) |
| virtual bool | fixed_point () const |
| virtual void | enter () |
| Invoked just before Pass::rewriteing PassMan::curr_mut's body. | |
| virtual void | prepare () |
| Invoked once before entering the main rewrite loop. | |
| const Proxy * | proxy (const Def *type, Defs ops, u32 tag=0) |
| const Proxy * | isa_proxy (const Def *def, u32 tag=0) |
Check whether given def is a Proxy whose Proxy::pass matches this Pass's IPass::index. | |
| const Proxy * | as_proxy (const Def *def, u32 tag=0) |
| Public Member Functions inherited from mim::Stage | |
| Stage (World &world, std::string name) | |
| Stage (World &world, flags_t annex) | |
| virtual | ~Stage ()=default |
| virtual std::unique_ptr< Stage > | recreate () |
| Creates a new instance; needed by a fixed-point PhaseMan. | |
| virtual void | apply (const App *) |
| Invoked if your Stage has additional args. | |
| virtual void | apply (Stage &) |
| Dito, but invoked by Stage::recreate. | |
| virtual bool | redirects () const |
| If true, Stage::create uses take_resolved(). | |
| virtual std::unique_ptr< Stage > | take_resolved () |
| The Stage to use instead; nullptr means elide. | |
| World & | world () |
| Driver & | driver () |
| Log & | log () const |
| std::string_view | name () const |
| flags_t | annex () const |
Additional Inherited Members | |
| Static Public Member Functions inherited from mim::Stage | |
| static std::unique_ptr< Stage > | create (const Flags2Stages &stages, const Def *def) |
| template<class A, class P> | |
| static void | hook (Flags2Stages &stages) |
| Protected Attributes inherited from mim::Stage | |
| std::string | name_ |
Definition at line 12 of file eval.h.
References mim::Stage::annex(), mim::RWPass< Eval, Lam >::RWPass(), and mim::Stage::world().
Applies to (open) expressions in a functional context.
Returns the rewritten expressions and augments the partial and modular pullbacks. The rewrite is identity on the term up to renaming of variables. Otherwise, only pullbacks are added. To do so, some calls (e.g. axms) are replaced by their derivatives. This transformation can be seen as an augmentation with a dual computation that generates the derivatives.
Definition at line 10 of file eval.cpp.
References augment_().
Referenced by augment_app(), augment_extract(), augment_lam(), augment_pack(), augment_tuple(), and derive_().
Rewrites the given definition in a lambda environment.
Definition at line 300 of file autodiff_rewrite_inner.cpp.
References augment_app(), augment_extract(), augment_lam(), augment_lit(), augment_pack(), augment_tuple(), augment_var(), mim::plug::autodiff::autodiff_type_fun(), DLOG, ELOG, mim::World::externals(), mim::find_and_replace(), mim::Pass::index(), mim::Def::isa_mut(), mim::Def::node_name(), mim::World::sym(), mim::Def::type(), and mim::Stage::world().
Referenced by augment(), and augment_pack().
Definition at line 194 of file autodiff_rewrite_inner.cpp.
References mim::World::app(), mim::App::arg(), augment(), mim::App::callee(), mim::compose_cn(), mim::World::debug_dump(), DLOG, mim::Pi::isa_basicblock(), mim::Pi::isa_cn(), mim::World::mut_lam(), mim::plug::direct::op_cps2ds_dep(), mim::Lam::set(), mim::Def::type(), mim::Def::var(), and mim::Stage::world().
Referenced by augment_().
| const Def * mim::plug::autodiff::Eval::augment_extract | ( | const Extract * | ext, |
| Lam * | f, | ||
| Lam * | f_diff ) |
Definition at line 80 of file autodiff_rewrite_inner.cpp.
References augment(), DLOG, mim::World::extract(), mim::Extract::index(), mim::Pass::index(), mim::World::insert(), mim::World::mut_lam(), mim::plug::autodiff::pullback_type(), mim::Def::set(), mim::Lam::set(), mim::Extract::tuple(), mim::Def::type(), and mim::Stage::world().
Referenced by augment_().
Definition at line 23 of file autodiff_rewrite_inner.cpp.
References augment(), mim::plug::autodiff::autodiff_type_fun(), mim::Lam::body(), mim::World::call(), DLOG, mim::Pi::dom(), mim::Lam::filter(), mim::Lam::isa_basicblock(), mim::World::mut_con(), mim::plug::autodiff::pullback_type(), mim::Def::sym(), mim::Def::type(), mim::Lam::type(), mim::Def::var(), mim::Stage::world(), and mim::plug::autodiff::zero_pullback().
Referenced by augment_().
Definition at line 10 of file autodiff_rewrite_inner.cpp.
References mim::plug::autodiff::zero_pullback().
Referenced by augment_().
Definition at line 152 of file autodiff_rewrite_inner.cpp.
References mim::Stage::annex(), mim::World::app(), mim::Pack::arity(), augment(), augment_(), mim::Seq::body(), DLOG, mim::World::mut_lam(), mim::World::mut_pack(), mim::plug::direct::op_cps2ds_dep(), mim::World::pack(), mim::plug::autodiff::pullback_type(), mim::Lam::set(), mim::Pack::set(), mim::plug::autodiff::tangent_type_fun(), mim::Def::type(), and mim::Stage::world().
Referenced by augment_().
Definition at line 117 of file autodiff_rewrite_inner.cpp.
References augment(), DLOG, mim::World::mut_lam(), mim::plug::autodiff::op_sum(), mim::Def::projs(), mim::plug::autodiff::pullback_type(), mim::Lam::set(), mim::plug::autodiff::tangent_type_fun(), mim::World::tuple(), mim::Def::type(), and mim::Stage::world().
Referenced by augment_().
helper functions for augment
Definition at line 16 of file autodiff_rewrite_inner.cpp.
Referenced by augment_().
Additionally to the derivation, the pullback is registered and the maps are initialized.
Definition at line 7 of file autodiff_rewrite_toplevel.cpp.
References mim::Def::as_mut(), augment(), mim::plug::autodiff::autodiff_type_fun_pi(), DLOG, mim::plug::autodiff::id_pullback(), mim::World::mut_lam(), mim::Def::set(), mim::Lam::set(), mim::World::tuple(), mim::Def::type(), mim::Lam::type(), mim::Def::var(), mim::Stage::world(), and mim::plug::autodiff::zero_pullback().
Referenced by derive().