8std::optional<u64> dispatch_float_width(
nat_t width,
F&&
f) {
11 case i: return f.template operator()<i>();
19std::optional<u64> dispatch_int_width(
nat_t width,
F&&
f) {
22 case i: return f.template operator()<i>();
29template<nat_t w,
class F>
30std::optional<u64> fold_float_unary_bits(
u64 a,
F&&
f) {
36template<nat_t w,
class F>
37std::optional<u64> fold_float_binary_bits(
u64 a,
u64 b,
F&&
f) {
45constexpr long double signed_min() {
49 return static_cast<long double>(std::numeric_limits<w2s<w>>::min());
53constexpr long double signed_max() {
57 return static_cast<long double>(std::numeric_limits<w2s<w>>::max());
61constexpr long double unsigned_max() {
62 return static_cast<long double>(std::numeric_limits<w2u<w>>::max());
66std::optional<u64> encode_signed(
long double x) {
67 if constexpr (
w == 1) {
68 if (x == -1.0L)
return 1_u64;
69 if (x == 0.0L)
return 0_u64;
77std::optional<u64> encode_unsigned(
long double x) {
82long double decode_signed(
u64 a) {
90long double decode_unsigned(
u64 a) {
94template<nat_t w,
class F>
95std::optional<u64> fold_float_to_signed_bits(
F x) {
96 if (!std::isfinite(x))
return {};
98 auto truncated = std::trunc(
static_cast<long double>(x));
99 if (truncated < signed_min<w>() || truncated > signed_max<w>())
return {};
100 return encode_signed<w>(truncated);
103template<nat_t w,
class F>
104std::optional<u64> fold_float_to_unsigned_bits(
F x) {
105 if (!std::isfinite(x))
return {};
107 auto truncated = std::trunc(
static_cast<long double>(x));
108 if (truncated < 0.0L || truncated > unsigned_max<w>())
return {};
109 return encode_unsigned<w>(truncated);
113template<
class Id, Id
id, nat_t w>
114std::optional<u64> fold_unary_lit(
u64 a) {
115 if constexpr (std::is_same_v<Id, tri>) {
116 if constexpr (
false) {}
117 else if constexpr (
id ==
tri:: sin )
return fold_float_unary_bits<w>(
a, [](
auto x) {
return std:: sin (x); });
118 else if constexpr (
id ==
tri:: cos )
return fold_float_unary_bits<w>(
a, [](
auto x) {
return std:: cos (x); });
119 else if constexpr (
id ==
tri:: tan )
return fold_float_unary_bits<w>(
a, [](
auto x) {
return std:: tan (x); });
120 else if constexpr (
id ==
tri:: sinh)
return fold_float_unary_bits<w>(
a, [](
auto x) {
return std:: sinh(x); });
121 else if constexpr (
id ==
tri:: cosh)
return fold_float_unary_bits<w>(
a, [](
auto x) {
return std:: cosh(x); });
122 else if constexpr (
id ==
tri:: tanh)
return fold_float_unary_bits<w>(
a, [](
auto x) {
return std:: tanh(x); });
123 else if constexpr (
id ==
tri::asin )
return fold_float_unary_bits<w>(
a, [](
auto x) {
return std::asin (x); });
124 else if constexpr (
id ==
tri::acos )
return fold_float_unary_bits<w>(
a, [](
auto x) {
return std::acos (x); });
125 else if constexpr (
id ==
tri::atan )
return fold_float_unary_bits<w>(
a, [](
auto x) {
return std::atan (x); });
126 else if constexpr (
id ==
tri::asinh)
return fold_float_unary_bits<w>(
a, [](
auto x) {
return std::asinh(x); });
127 else if constexpr (
id ==
tri::acosh)
return fold_float_unary_bits<w>(
a, [](
auto x) {
return std::acosh(x); });
128 else if constexpr (
id ==
tri::atanh)
return fold_float_unary_bits<w>(
a, [](
auto x) {
return std::atanh(x); });
129 else fe::unreachable();
130 }
else if constexpr (std::is_same_v<Id, rt>) {
131 if constexpr (
false) {}
132 else if constexpr (
id ==
rt::sq)
return fold_float_unary_bits<w>(
a, [](
auto x) {
return std::sqrt(x); });
133 else if constexpr (
id ==
rt::cb)
return fold_float_unary_bits<w>(
a, [](
auto x) {
return std::cbrt(x); });
134 else static_assert(
false,
"missing sub tag");
135 }
else if constexpr (std::is_same_v<Id, exp>) {
136 if constexpr (
false) {}
137 else if constexpr (
id ==
exp::exp)
return fold_float_unary_bits<w>(
a, [](
auto x) {
return std::exp(x); });
138 else if constexpr (
id ==
exp::exp2)
return fold_float_unary_bits<w>(
a, [](
auto x) {
return std::exp2(x); });
139 else if constexpr (
id ==
exp::exp10)
return fold_float_unary_bits<w>(
a, [](
auto x) {
return std::pow(
decltype(x)(10), x); });
140 else if constexpr (
id ==
exp::log)
return fold_float_unary_bits<w>(
a, [](
auto x) {
return std::log(x); });
141 else if constexpr (
id ==
exp::log2)
return fold_float_unary_bits<w>(
a, [](
auto x) {
return std::log2(x); });
142 else if constexpr (
id ==
exp::log10)
return fold_float_unary_bits<w>(
a, [](
auto x) {
return std::log10(x); });
143 else fe::unreachable();
144 }
else if constexpr (std::is_same_v<Id, er>) {
145 if constexpr (
false) {}
146 else if constexpr (
id ==
er::f )
return fold_float_unary_bits<w>(
a, [](
auto x) {
return std::erf (x); });
147 else if constexpr (
id ==
er::fc)
return fold_float_unary_bits<w>(
a, [](
auto x) {
return std::erfc(x); });
148 else static_assert(
false,
"missing sub tag");
149 }
else if constexpr (std::is_same_v<Id, gamma>) {
150 if constexpr (
false) {}
151 else if constexpr (
id ==
gamma::t)
return fold_float_unary_bits<w>(
a, [](
auto x) {
return std::tgamma(x); });
152 else if constexpr (
id ==
gamma::l)
return fold_float_unary_bits<w>(
a, [](
auto x) {
return std::lgamma(x); });
153 else static_assert(
false,
"missing sub tag");
154 }
else if constexpr (std::is_same_v<Id, round>) {
155 if constexpr (
false) {}
156 else if constexpr (
id ==
round::f)
return fold_float_unary_bits<w>(
a, [](
auto x) {
return std::floor(x); });
157 else if constexpr (
id ==
round::c)
return fold_float_unary_bits<w>(
a, [](
auto x) {
return std::ceil(x); });
158 else if constexpr (
id ==
round::r)
return fold_float_unary_bits<w>(
a, [](
auto x) {
return std::round(x); });
159 else if constexpr (
id ==
round::t)
return fold_float_unary_bits<w>(
a, [](
auto x) {
return std::trunc(x); });
160 else static_assert(
false,
"missing sub tag");
162 static_assert(
false,
"missing tag");
167template<
class Id, nat_t w>
168std::optional<u64> fold_unary_lit(
u64 a) {
169 if constexpr (std::is_same_v<Id, abs>)
170 return fold_float_unary_bits<w>(
a, [](
auto x) {
return std::abs(x); });
172 static_assert(
false,
"missing tag");
176const Def*
fold(World& world,
const Def* type,
const Def*
a) {
177 if (
a->isa<
Bot>())
return world.bot(type);
180 if (
auto width =
isa_f(
a->type()))
181 if (
auto res = dispatch_float_width(*width, [&]<
nat_t w>() {
return fold_unary_lit<Id, w>(*la); }))
182 return world.lit(type, *res);
187template<
class Id, Id
id, nat_t w>
188std::optional<u64> fold_binary_lit(
u64 a,
u64 b) {
193 if constexpr (std::is_same_v<Id, arith>) {
195 if constexpr (
false) {}
201 else static_assert(
false,
"missing sub tag");
203 }
else if constexpr (std::is_same_v<Id, math::extrema>) {
206 if (x == T(-0.0) && y == T(+0.0)) {
208 }
else if (x == T(+0.0) && y == T(-0.0)) {
215 else if (std::isnan(y))
220 static_assert(
false,
"missing sub tag");
224 }
else if constexpr (std::is_same_v<Id, pow>) {
226 }
else if constexpr (std::is_same_v<Id, cmp>) {
227 using std::isunordered;
229 result |= ((
id &
cmp::u) !=
cmp::f) && isunordered(x, y);
235 static_assert(
false,
"missing tag");
239template<
class Id, Id
id>
240const Def*
fold(World& world,
const Def* type,
const Def*
a) {
241 if (
a->isa<
Bot>())
return world.bot(type);
244 if (
auto width =
isa_f(
a->type()))
245 if (
auto res = dispatch_float_width(*width, [&]<
nat_t w>() {
return fold_unary_lit<Id, id, w>(*la); }))
246 return world.lit(type, *res);
252template<
class Id, Id
id>
253const Def*
fold(World& world,
const Def* type,
const Def*&
a,
const Def*& b) {
254 if (
a->isa<
Bot>() || b->isa<
Bot>())
return world.bot(type);
258 if (
auto width =
isa_f(
a->type()))
260 = dispatch_float_width(*width, [&]<
nat_t w>() {
return fold_binary_lit<Id, id, w>(*la, *lb); }))
261 return world.lit(type, *res);
279const Def* reassociate(Id
id, World& world, [[maybe_unused]]
const App* ab,
const Def*
a,
const Def* b) {
284 auto la =
a->isa<
Lit>();
285 auto [x, y] = xy ? xy->template args<2>() : std::array<const Def*, 2>{
nullptr,
nullptr};
286 auto [z,
w] = zw ? zw->template args<2>() : std::array<const Def*, 2>{
nullptr,
nullptr};
292 auto check_mode = [&](
const App* app) {
294 if (!app_m || !fe::has_flag(
static_cast<Mode>(*app_m),
Mode::reassoc))
return false;
299 if (!check_mode(ab))
return nullptr;
300 if (lx && !check_mode(xy->decurry()))
return nullptr;
301 if (lz && !check_mode(zw->decurry()))
return nullptr;
303 auto make_op = [&](
const Def*
a,
const Def* b) {
return world.call(
id,
mode,
Defs{
a, b}); };
305 if (la && lz)
return make_op(make_op(
a, z), w);
306 if (lx && lz)
return make_op(make_op(x, z), make_op(y, w));
307 if (lz)
return make_op(z, make_op(
a, w));
308 if (lx)
return make_op(x, make_op(y, b));
312template<conv
id, nat_t sw, nat_t dw>
313std::optional<u64> fold_conv_lit(
u64 a) {
317 if constexpr (
false) {}
323 else static_assert(
false,
"missing sub tag");
327template<conv
id, nat_t sw>
328std::optional<u64> fold_conv_dst(
nat_t dw,
u64 a) {
330 return dispatch_float_width(dw, [&]<
nat_t d>() {
return fold_conv_lit<id, sw, d>(
a); });
332 return dispatch_int_width(dw, [&]<
nat_t d>() {
return fold_conv_lit<id, sw, d>(
a); });
338 return dispatch_int_width(sw, [&]<
nat_t s>() {
return fold_conv_dst<id, s>(dw,
a); });
340 return dispatch_float_width(sw, [&]<
nat_t s>() {
return fold_conv_dst<id, s>(dw,
a); });
347 auto& world = type->world();
348 auto callee =
c->as<
App>();
349 auto [
a, b] = arg->
projs<2>();
350 auto mode = callee->arg();
352 auto w =
isa_f(
a->type());
354 if (
auto result = fold<arith, id>(world, type,
a, b))
return result;
359 auto zero =
lit_f(world, *w, 0.0);
360 auto one =
lit_f(world, *w, 1.0);
361 auto two =
lit_f(world, *w, 2.0);
363 if (
auto la =
a->isa<
Lit>()) {
364 if (zero && la == zero) {
374 if (one && la == one) {
385 if (
auto lb = b->isa<
Lit>()) {
386 if (zero && lb == zero) {
391 default: fe::unreachable();
400 case arith::sub:
if (zero)
return zero;
break;
402 case arith::div:
if (one )
return one ;
break;
409 if (
auto res = reassociate<arith>(
id, world, callee,
a, b))
return res;
411 return world.raw_app(type, callee, {
a, b});
416 auto& world = type->world();
417 auto callee =
c->as<
App>();
418 auto [
a, b] = arg->
projs<2>();
419 auto m = callee->arg();
423 if (
auto lit = fold<extrema, id>(world, type,
a, b))
return lit;
433 return world.raw_app(type,
c, {
a, b});
438 auto& world = type->world();
439 if (
auto lit = fold<tri, id>(world, type, arg))
return lit;
444 auto& world = type->world();
445 auto [
a, b] = arg->
projs<2>();
446 if (
auto lit = fold<
pow,
pow(0)>(world, type,
a, b))
return lit;
452 auto& world = type->world();
453 if (
auto lit = fold<rt, id>(world, type, arg))
return lit;
459 auto& world = type->world();
460 if (
auto lit = fold<exp, id>(world, type, arg))
return lit;
466 auto& world = type->world();
467 if (
auto lit = fold<er, id>(world, type, arg))
return lit;
473 auto& world = type->world();
474 if (
auto lit = fold<gamma, id>(world, type, arg))
return lit;
480 auto& world = type->world();
481 auto callee =
c->as<
App>();
482 auto [
a, b] = arg->
projs<2>();
484 if (
auto result = fold<cmp, id>(world, type,
a, b))
return result;
485 if (
id ==
cmp::f)
return world.lit_ff();
486 if (
id ==
cmp::t)
return world.lit_tt();
488 return world.raw_app(type, callee, {
a, b});
493 auto& world = dst_t->
world();
494 auto s_t = x->
type()->as<
App>();
495 auto d_t = dst_t->as<
App>();
501 if (s_t == d_t)
return x;
502 if (x->isa<
Bot>())
return world.bot(d_t);
511 if (
auto res = fold_conv<id>(*sw, *dw, *
l))
return world.lit(d_t, *res);
517 auto& world = type->world();
518 if (
auto lit = fold<abs>(world, type, arg))
return lit;
524 auto& world = type->world();
525 if (
auto lit = fold<round, id>(world, type, arg))
return lit;
static auto isa(const Def *def)
World & world() const noexcept
auto projs(F f) const
Splits this Def via Def::projections into an Array (if A == std::dynamic_extent) or std::array (other...
const Def * type() const noexcept
Yields the "raw" type of this Def (maybe nullptr).
static bool greater(const Def *a, const Def *b)
static constexpr nat_t size2bitwidth(nat_t n)
static std::optional< T > isa(const Def *def)
#define MIM_math_NORMALIZER_IMPL
const Def * normalize_extrema(const Def *type, const Def *c, const Def *arg)
const Lit * lit_f(World &w, R val)
const Def * normalize_er(const Def *type, const Def *, const Def *arg)
const Def * normalize_cmp(const Def *type, const Def *c, const Def *arg)
const Def * normalize_abs(const Def *type, const Def *, const Def *arg)
const Def * normalize_gamma(const Def *type, const Def *, const Def *arg)
Mode
Allowed optimizations for a specific operation.
@ reassoc
Allow reassociation transformations for floating-point operations.
@ bot
Alias for Mode::fast.
const Def * mode(World &w, VMode m)
mim::plug::math::VMode -> const Def*.
const Def * normalize_arith(const Def *type, const Def *c, const Def *arg)
const Def * normalize_round(const Def *type, const Def *, const Def *arg)
std::optional< nat_t > isa_f(const Def *def)
const Def * normalize_tri(const Def *type, const Def *, const Def *arg)
const Def * normalize_exp(const Def *type, const Def *, const Def *arg)
const Def * normalize_rt(const Def *type, const Def *, const Def *arg)
const Def * normalize_pow(const Def *type, const Def *, const Def *arg)
const Def * normalize_conv(const Def *dst_t, const Def *, const Def *x)
typename detail::w2f_< w >::type w2f
constexpr bool is_commutative(Id)
typename detail::w2s_< w >::type w2s
constexpr bool is_associative(Id id)
typename detail::w2u_< w >::type w2u
constexpr D bitcast_resize(const S &src) noexcept
A bitcast from src of type S to D, supporting different sizes.
#define MIM_1_8_16_32_64(X)