MimIR
MimIR is my Intermediate Representation
Loading...
Searching...
No Matches
normalizers.cpp
Go to the documentation of this file.
1#include <type_traits>
2
4#include <mim/plug/mem/mem.h>
5
7
8namespace mim::plug::core {
9
10namespace {
11
12constexpr nat_t idx_shift_width(u64 size) {
13 if (size == 0) return 64;
14 auto width = Idx::size2bitwidth(size);
15 return width == 0 ? 1 : width;
16}
17
18constexpr std::optional<unsigned> idx_shift_amount(u64 size, u64 b) {
19 auto width = idx_shift_width(size);
20 if (b >= width) return {};
21 return static_cast<unsigned>(b);
22}
23
24constexpr u64 idx_unsigned_max(u64 size) { return size == 0 ? std::numeric_limits<u64>::max() : size - 1; }
25
26constexpr u64 idx_signed_max(u64 size) {
27 return size == 0 ? static_cast<u64>(std::numeric_limits<s64>::max()) : (size - 1) / 2;
28}
29
30constexpr u64 idx_signed_min_abs(u64 size) {
31 return size == 0 ? static_cast<u64>(std::numeric_limits<s64>::max()) + 1_u64 : size / 2;
32}
33
34constexpr u64 idx_signed_abs(s64 x) { return x >= 0 ? static_cast<u64>(x) : static_cast<u64>(-(x + 1)) + 1_u64; }
35
36constexpr s64 idx_neg(u64 abs) {
37 if (abs == static_cast<u64>(std::numeric_limits<s64>::max()) + 1_u64) return std::numeric_limits<s64>::min();
38 return -static_cast<s64>(abs);
39}
40
41constexpr u64 idx_pow2(unsigned k) { return k == 0 ? 1_u64 : Idx::bitwidth2size(static_cast<nat_t>(k)); }
42
43constexpr bool idx_sign(u64 size, u64 x) {
44 // Pre: x is already in range.
45 if (size == 0) return x > static_cast<u64>(std::numeric_limits<s64>::max()); // Idx 0 encodes 2^64.
46
47 // signed representatives in [-floor(size/2), ceil(size/2)-1]
48 return x > (size - 1) / 2;
49}
50
51constexpr s64 idx_sext(u64 size, u64 x) {
52 // Pre: x is already in range.
53 if (size == 0) return static_cast<s64>(x);
54
55 const u64 max_pos = (size - 1) / 2;
56 if (x <= max_pos) return static_cast<s64>(x);
57
58 // Negative representative is -(size - x).
59 return -static_cast<s64>(size - x);
60}
61
62constexpr u64 idx_from_signed(u64 size, s64 x) {
63 if (size == 0) return static_cast<u64>(x);
64 return x >= 0 ? static_cast<u64>(x) : size - static_cast<u64>(-x);
65}
66
67constexpr u64 idx_from_signed_mod(u64 size, s64 x) {
68 if (size == 0) return static_cast<u64>(x);
69 if (x >= 0) return static_cast<u64>(x) % size;
70
71 auto rem = idx_signed_abs(x) % size;
72 return rem == 0 ? 0 : size - rem;
73}
74
75constexpr bool idx_add_nuw(u64 size, u64 a, u64 b) {
76 if (size == 0) return a + b < a;
77 return a > size - 1 - b;
78}
79
80constexpr bool idx_sub_nuw(u64, u64 a, u64 b) { return a < b; }
81
82constexpr bool idx_mul_nuw(u64 size, u64 a, u64 b) {
83 if (a == 0 || b == 0) return false;
84
85 if (size == 0) return b > std::numeric_limits<u64>::max() / a;
86 return b > (size - 1) / a;
87}
88
89constexpr u64 idx_add(u64 size, u64 a, u64 b) {
90 if (size == 0) return a + b;
91 return (a + b) % size;
92}
93
94constexpr u64 idx_mul_pow2(u64 size, u64 a, unsigned k) {
95 while (k--)
96 a = idx_add(size, a, a);
97 return a;
98}
99
100constexpr u64 idx_sub(u64 size, u64 a, u64 b) {
101 if (size == 0) return a - b;
102 return (a >= b) ? (a - b) : (size - (b - a));
103}
104
105constexpr u64 idx_mul(u64 size, u64 a, u64 b) {
106 if (size == 0) return a * b;
107
108 // Safe double-and-add modulo size, avoids overflow.
109 u64 r = 0;
110 while (b) {
111 if (b % 2_u64 != 0) r = idx_add(size, r, a);
112 b /= 2_u64;
113 if (b) a = idx_add(size, a, a);
114 }
115 return r;
116}
117
118constexpr bool idx_add_nsw(u64 size, u64 a, u64 b) {
119 const bool sa = idx_sign(size, a);
120 const bool sb = idx_sign(size, b);
121 const u64 r = idx_add(size, a, b);
122 const bool sr = idx_sign(size, r);
123 return (sa == sb) && (sr != sa);
124}
125
126constexpr bool idx_sub_nsw(u64 size, u64 a, u64 b) {
127 const bool sa = idx_sign(size, a);
128 const bool sb = idx_sign(size, b);
129 const u64 r = idx_sub(size, a, b);
130 const bool sr = idx_sign(size, r);
131 return (sa != sb) && (sr != sa);
132}
133
134constexpr bool idx_mul_nsw(u64 size, u64 a, u64 b) {
135 const s64 x = idx_sext(size, a);
136 const s64 y = idx_sext(size, b);
137
138 if (x == 0 || y == 0) return false;
139
140 const s64 min_val = size == 0 ? std::numeric_limits<s64>::min() : -static_cast<s64>(size / 2);
141 const s64 max_val = size == 0 ? std::numeric_limits<s64>::max() : static_cast<s64>((size - 1) / 2);
142
143 if (x == -1) return y == min_val;
144 if (y == -1) return x == min_val;
145
146 if (x > 0)
147 if (y > 0)
148 return x > max_val / y;
149 else
150 return y < min_val / x;
151 else if (y > 0)
152 return x < min_val / y;
153 else
154 return x < max_val / y;
155}
156
157constexpr std::optional<u64> idx_udiv([[maybe_unused]] u64 size, u64 a, u64 b) {
158 if (b == 0) return {};
159 return a / b;
160}
161
162constexpr std::optional<u64> idx_urem([[maybe_unused]] u64 size, u64 a, u64 b) {
163 if (b == 0) return {};
164 return a % b;
165}
166
167constexpr bool idx_slt(u64 size, u64 a, u64 b) {
168 const bool sa = idx_sign(size, a);
169 const bool sb = idx_sign(size, b);
170
171 if (a == b) return false;
172 if (!sa && sb) return false;
173 if (sa && !sb) return true;
174 return a < b;
175}
176
177constexpr bool idx_sgt(u64 size, u64 a, u64 b) { return idx_slt(size, b, a); }
178
179constexpr bool idx_sdivrem_ub(u64 size, u64 a, u64 b) {
180 const s64 x = idx_sext(size, a);
181 const s64 y = idx_sext(size, b);
182
183 if (y == 0) return true;
184
185 const s64 min_val = [&] {
186 if (size == 0) return std::numeric_limits<s64>::min();
187 return -static_cast<s64>(size / 2);
188 }();
189
190 return x == min_val && y == -1;
191}
192
193constexpr u64 idx_sdiv(u64 size, u64 a, u64 b) {
194 const s64 x = idx_sext(size, a);
195 const s64 y = idx_sext(size, b);
196 return idx_from_signed(size, x / y);
197}
198
199constexpr u64 idx_srem(u64 size, u64 a, u64 b) {
200 const s64 x = idx_sext(size, a);
201 const s64 y = idx_sext(size, b);
202 return idx_from_signed(size, x % y);
203}
204
205constexpr bool idx_shl_nuw(u64 size, u64 a, unsigned k) {
206 u64 x = a;
207 u64 max = idx_unsigned_max(size);
208
209 while (k--) {
210 if (x > max / 2_u64) return true;
211 x *= 2_u64;
212 }
213
214 return false;
215}
216
217constexpr bool idx_shl_nsw(u64 size, u64 a, unsigned k) {
218 const s64 x = idx_sext(size, a);
219 if (x >= 0) {
220 u64 y = static_cast<u64>(x);
221 u64 max = idx_signed_max(size);
222 while (k--) {
223 if (y > max / 2_u64) return true;
224 y *= 2_u64;
225 }
226 } else {
227 u64 y = idx_signed_abs(x);
228 u64 min = idx_signed_min_abs(size);
229 while (k--) {
230 if (y > min / 2_u64) return true;
231 y *= 2_u64;
232 }
233 }
234
235 return false;
236}
237
238constexpr std::optional<u64> idx_shl(u64 size, u64 a, u64 b, bool nsw, bool nuw) {
239 auto k = idx_shift_amount(size, b);
240 if (!k) return {};
241
242 if (nuw && idx_shl_nuw(size, a, *k)) return {};
243 if (nsw && idx_shl_nsw(size, a, *k)) return {};
244
245 return idx_mul_pow2(size, a, *k);
246}
247
248constexpr std::optional<u64> idx_lshr(u64 size, u64 a, u64 b) {
249 auto k = idx_shift_amount(size, b);
250 if (!k) return {};
251 return a / idx_pow2(*k);
252}
253
254constexpr std::optional<u64> idx_ashr(u64 size, u64 a, u64 b) {
255 auto k = idx_shift_amount(size, b);
256 if (!k) return {};
257
258 auto divisor = idx_pow2(*k);
259 auto x = idx_sext(size, a);
260 if (x >= 0) return idx_from_signed(size, static_cast<s64>(static_cast<u64>(x) / divisor));
261
262 auto q = (idx_signed_abs(x) + divisor - 1_u64) / divisor;
263 return idx_from_signed(size, idx_neg(q));
264}
265
266template<icmp id>
267constexpr bool fold_icmp_idx(u64 size, u64 a, u64 b) {
268 const bool su = idx_sign(size, a);
269 const bool sv = idx_sign(size, b);
270
271 flags_t rel = 0;
272 // clang-format off
273 if (false) {}
274 else if (a == b) rel = icmp_mask & flags_t(icmp::xyglE); // equal
275 else if (!su && sv) rel = icmp_mask & flags_t(icmp::Xygle); // plus, minus
276 else if ( su && !sv) rel = icmp_mask & flags_t(icmp::xYgle); // minus, plus
277 else if (a > b) rel = icmp_mask & flags_t(icmp::xyGle); // greater (same sign)
278 else rel = icmp_mask & flags_t(icmp::xygLe); // less (same sign)
279 // clang-format on
280
281 return (flags_t(id) & rel) != 0;
282}
283
284template<class Id, Id id>
285std::optional<u64> fold_idx(u64 size, u64 a, u64 b, [[maybe_unused]] bool nsw, [[maybe_unused]] bool nuw) {
286 // Pre: a, b already in range for Idx size.
287
288 if constexpr (std::is_same_v<Id, wrap>) {
289 if constexpr (id == wrap::add) {
290 if (nuw && idx_add_nuw(size, a, b)) return {};
291 if (nsw && idx_add_nsw(size, a, b)) return {};
292 return idx_add(size, a, b);
293
294 } else if constexpr (id == wrap::sub) {
295 if (nuw && idx_sub_nuw(size, a, b)) return {};
296 if (nsw && idx_sub_nsw(size, a, b)) return {};
297 return idx_sub(size, a, b);
298
299 } else if constexpr (id == wrap::mul) {
300 if (nuw && idx_mul_nuw(size, a, b)) return {};
301 if (nsw && idx_mul_nsw(size, a, b)) return {};
302 return idx_mul(size, a, b);
303
304 } else if constexpr (id == wrap::shl) {
305 return idx_shl(size, a, b, nsw, nuw);
306
307 } else {
308 static_assert(false, "missing wrap subtag");
309 }
310
311 } else if constexpr (std::is_same_v<Id, shr>) {
312 if constexpr (id == shr::a)
313 return idx_ashr(size, a, b);
314 else if constexpr (id == shr::l)
315 return idx_lshr(size, a, b);
316 else
317 static_assert(false, "missing shr subtag");
318
319 } else if constexpr (std::is_same_v<Id, div>) {
320 if constexpr (id == div::udiv) {
321 return idx_udiv(size, a, b);
322
323 } else if constexpr (id == div::urem) {
324 return idx_urem(size, a, b);
325
326 } else if constexpr (id == div::sdiv) {
327 if (idx_sdivrem_ub(size, a, b)) return {};
328 return idx_sdiv(size, a, b);
329
330 } else if constexpr (id == div::srem) {
331 if (idx_sdivrem_ub(size, a, b)) return {};
332 return idx_srem(size, a, b);
333
334 } else {
335 static_assert(false, "missing div subtag");
336 }
337
338 } else if constexpr (std::is_same_v<Id, icmp>) {
339 return u64(fold_icmp_idx<id>(size, a, b));
340
341 } else if constexpr (std::is_same_v<Id, extrema>) {
342 if constexpr (id == extrema::sm)
343 return std::min(a, b);
344
345 else if constexpr (id == extrema::sM)
346 return std::max(a, b);
347
348 else if constexpr (id == extrema::Sm)
349 return idx_slt(size, a, b) ? a : b;
350
351 else if constexpr (id == extrema::SM)
352 return idx_sgt(size, a, b) ? a : b;
353
354 else
355 static_assert(false, "missing extrema subtag");
356
357 } else {
358 static_assert(false, "missing tag");
359 }
360}
361
362template<class Id, Id id>
363const Def* fold(World& world, const Def* type, const Def*& a, const Def*& b, const Def* mode = {}) {
364 if (a->isa<Bot>() || b->isa<Bot>()) return world.bot(type);
365
366 if (auto la = Lit::isa(a)) {
367 if (auto lb = Lit::isa(b)) {
368 assert(a->type() == b->type());
369
370 auto size = Lit::as(Idx::isa(a->type()));
371
372 bool nsw = false, nuw = false;
373 if constexpr (std::is_same_v<Id, wrap>) {
374 auto m = mode ? static_cast<Mode>(Lit::as(mode)) : Mode::none;
375 nsw = fe::has_flag(m, Mode::nsw);
376 nuw = fe::has_flag(m, Mode::nuw);
377 }
378
379 if (size == 1) {
380 if constexpr (std::is_same_v<Id, div>) {
381 if (*lb == 0) return world.bot(type);
382 }
383 if constexpr (std::is_same_v<Id, icmp>)
384 return world.lit(type, u64(fold_icmp_idx<id>(1, 0, 0)));
385 else
386 return world.lit(type, 0);
387 }
388
389 auto res = fold_idx<Id, id>(size, *la, *lb, nsw, nuw);
390 return res ? world.lit(type, *res) : world.bot(type);
391 }
392 }
393
394 if (::mim::is_commutative(id) && Def::greater(a, b)) std::swap(a, b);
395 return nullptr;
396}
397
398template<class Id>
399const Def* fold(World& world, const Def* type, const Def*& a) {
400 if (a->isa<Bot>()) return world.bot(type);
401
402 if (auto la = Lit::isa(a)) {
403 auto size = Lit::as(Idx::isa(a->type()));
404
405 if constexpr (std::is_same_v<Id, abs>) {
406 auto x = idx_sext(size, *la);
407 if (x >= 0) return world.lit(type, static_cast<u64>(x));
408
409 auto y = idx_signed_abs(x);
410 if ((size == 0 && x == std::numeric_limits<s64>::min())
411 || (size % 2_u64 == 0 && y == idx_signed_min_abs(size)))
412 return world.lit(type, *la);
413
414 return world.lit(type, y);
415 } else {
416 static_assert(false, "missing tag");
417 }
418 }
419
420 return nullptr;
421}
422
423/// Reassociates @p a and @p b according to following rules.
424/// We use the following naming convention while literals are prefixed with an `l`:
425/// ```
426/// a op b
427/// (x op y) op (z op w)
428///
429/// (1) la op (lz op w) -> (la op lz) op w
430/// (2) (lx op y) op (lz op w) -> (lx op lz) op (y op w)
431/// (3) a op (lz op w) -> lz op (a op w)
432/// (4) (lx op y) op b -> lx op (y op b)
433/// ```
434template<class Id>
435const Def* reassociate(Id id, World& world, [[maybe_unused]] const App* ab, const Def* a, const Def* b) {
436 if (!is_associative(id)) return nullptr;
437
438 auto xy = Axm::isa<Id>(id, a);
439 auto zw = Axm::isa<Id>(id, b);
440 auto la = a->isa<Lit>();
441 auto [x, y] = xy ? xy->template args<2>() : std::array<const Def*, 2>{nullptr, nullptr};
442 auto [z, w] = zw ? zw->template args<2>() : std::array<const Def*, 2>{nullptr, nullptr};
443 auto lx = Lit::isa(x);
444 auto lz = Lit::isa(z);
445
446 // if we reassociate, we have to forget about nsw/nuw
447 auto make_op = [&world, id](const Def* a, const Def* b) { return world.call(id, Mode::none, Defs{a, b}); };
448
449 if (la && lz) return make_op(make_op(a, z), w); // (1)
450 if (lx && lz) return make_op(make_op(x, z), make_op(y, w)); // (2)
451 if (lz) return make_op(z, make_op(a, w)); // (3)
452 if (lx) return make_op(x, make_op(y, b)); // (4)
453 return nullptr;
454}
455
456template<class Id>
457const Def* merge_cmps(std::array<std::array<u64, 2>, 2> tab, const Def* a, const Def* b) {
458 static_assert(sizeof(sub_t) == 1, "if this ever changes, please adjust the logic below");
459 static constexpr size_t num_bits = std::bit_width(Annex::num<Id>() - 1_u64);
460
461 auto& world = a->world();
462 auto a_cmp = Axm::isa<Id>(a);
463 auto b_cmp = Axm::isa<Id>(b);
464
465 if (a_cmp && b_cmp && a_cmp->arg() == b_cmp->arg()) {
466 // push sub bits of a_cmp and b_cmp through truth table
467 sub_t res = 0;
468 sub_t a_sub = a_cmp.sub();
469 sub_t b_sub = b_cmp.sub();
470 for (size_t i = 0; i != num_bits; ++i, res >>= 1, a_sub >>= 1, b_sub >>= 1)
471 res |= tab[a_sub & 1][b_sub & 1] << 7_u8;
472 res >>= (7_u8 - u8(num_bits));
473
474 if constexpr (std::is_same_v<Id, math::cmp>)
475 return world.call(math::cmp(res), /*mode*/ a_cmp->decurry()->arg(), a_cmp->arg());
476 else
477 return world.call(icmp(Annex::base<icmp>() | res), a_cmp->arg());
478 }
479
480 return nullptr;
481}
482
483} // namespace
484
485template<nat id>
486const Def* normalize_nat(const Def* type, const Def* callee, const Def* arg) {
487 auto& world = type->world();
488 auto [a, b] = arg->projs<2>();
489 if (is_commutative(id) && Def::greater(a, b)) std::swap(a, b);
490 auto la = Lit::isa(a);
491 auto lb = Lit::isa(b);
492
493 if (la) {
494 if (lb) {
495 switch (id) {
496 case nat::add: return world.lit_nat(*la + *lb);
497 case nat::sub: return *la < *lb ? world.lit_nat_0() : world.lit_nat(*la - *lb);
498 case nat::mul: return world.lit_nat(*la * *lb);
499 }
500 }
501
502 if (*la == 0) {
503 switch (id) {
504 case nat::add: return b;
505 case nat::sub: return a; // 0 - b = 0
506 case nat::mul: return a; // 0 * b = 0
507 }
508 }
509
510 if (*la == 1 && id == nat::mul) return b; // 1 * b = b
511 }
512
513 if (lb && *lb == 0 && id == nat::sub) return a; // a - 0 = a
514
515 if (a == b) {
516 switch (id) {
517 case nat::add: return world.call(nat::mul, Defs{world.lit_nat(2), a}); // a + a = 2 * a
518 case nat::sub: return world.lit_nat(0); // a - a = 0
519 case nat::mul: break;
520 }
521 }
522
523 return world.raw_app(type, callee, {a, b});
524}
525
526template<ncmp id>
527const Def* normalize_ncmp(const Def* type, const Def* callee, const Def* arg) {
528 auto& world = type->world();
529
530 if (id == ncmp::t) return world.lit_tt();
531 if (id == ncmp::f) return world.lit_ff();
532
533 auto [a, b] = arg->projs<2>();
534 if (is_commutative(id) && Def::greater(a, b)) std::swap(a, b);
535
536 if (a == b) {
537 constexpr auto eq_mask = fe::to_underlying(ncmp::e) & 0xff;
538 if ((fe::to_underlying(id) & eq_mask) != 0) return world.lit_tt();
539 if (id == ncmp::ne) return world.lit_ff();
540 }
541
542 if (auto la = Lit::isa(a)) {
543 if (auto lb = Lit::isa(b)) {
544 // clang-format off
545 switch (id) {
546 case ncmp:: e: return world.lit_bool(*la == *lb);
547 case ncmp::ne: return world.lit_bool(*la != *lb);
548 case ncmp::l : return world.lit_bool(*la < *lb);
549 case ncmp::le: return world.lit_bool(*la <= *lb);
550 case ncmp::g : return world.lit_bool(*la > *lb);
551 case ncmp::ge: return world.lit_bool(*la >= *lb);
552 default: fe::unreachable();
553 }
554 // clang-format on
555 }
556 }
557
558 return world.raw_app(type, callee, {a, b});
559}
560
561template<icmp id>
562const Def* normalize_icmp(const Def* type, const Def* c, const Def* arg) {
563 auto& world = type->world();
564 auto callee = c->as<App>();
565 auto [a, b] = arg->projs<2>();
566
567 if (auto result = fold<icmp, id>(world, type, a, b)) return result;
568 if (id == icmp::f) return world.lit_ff();
569 if (id == icmp::t) return world.lit_tt();
570 if (a == b) {
571 constexpr auto eq_mask = fe::to_underlying(icmp::e) & 0xff;
572 if ((fe::to_underlying(id) & eq_mask) != 0) return world.lit_tt();
573 if (id == icmp::ne) return world.lit_ff();
574 }
575
576 return world.raw_app(type, callee, {a, b});
577}
578
579template<extrema id>
580const Def* normalize_extrema(const Def* type, const Def* c, const Def* arg) {
581 auto& world = type->world();
582 auto callee = c->as<App>();
583 auto [a, b] = arg->projs<2>();
584 if (auto result = fold<extrema, id>(world, type, a, b)) return result;
585 return world.raw_app(type, callee, {a, b});
586}
587
588const Def* normalize_abs(const Def* type, const Def*, const Def* arg) {
589 auto& world = type->world();
590 auto [mem, a] = arg->projs<2>();
591 auto [_, actual_type] = type->projs<2>();
592 auto make_res = [&, mem = mem](const Def* res) { return world.tuple({mem, res}); };
593
594 if (auto result = fold<abs>(world, actual_type, a)) return make_res(result);
595 return {};
596}
597
598template<bit1 id>
599const Def* normalize_bit1(const Def* type, const Def* c, const Def* a) {
600 auto& world = type->world();
601 auto callee = c->as<App>();
602 auto s = callee->decurry()->arg();
603 // TODO cope with wrap around
604
605 if constexpr (id == bit1::id) return a;
606
607 if (auto ls = Lit::isa(s)) {
608 switch (id) {
609 case bit1::f: return world.lit_idx(*ls, 0);
610 case bit1::t: return world.lit_idx(*ls, *ls - 1_u64);
611 case bit1::id: fe::unreachable();
612 default: break;
613 }
614
615 assert(id == bit1::neg);
616 if (auto la = Lit::isa(a)) return world.lit_idx_mod(*ls, ~*la);
617 }
618
619 return {};
620}
621
622template<bit2 id>
623const Def* normalize_bit2(const Def* type, const Def* c, const Def* arg) {
624 auto& world = type->world();
625 auto callee = c->as<App>();
626 auto [a, b] = arg->projs<2>();
627 auto s = callee->decurry()->arg();
628 auto ls = Lit::isa(s);
629 // TODO cope with wrap around
630
631 if (is_commutative(id) && Def::greater(a, b)) std::swap(a, b);
632
633 auto tab = make_truth_table(id);
634 if (auto res = merge_cmps<icmp>(tab, a, b)) return res;
635 if (auto res = merge_cmps<math::cmp>(tab, a, b)) return res;
636
637 auto la = Lit::isa(a);
638 auto lb = Lit::isa(b);
639
640 // clang-format off
641 switch (id) {
642 case bit2:: f: return world.lit(type, 0);
643 case bit2:: t: if (ls) return world.lit(type, *ls-1_u64); break;
644 case bit2:: fst: return a;
645 case bit2:: snd: return b;
646 case bit2:: nfst: return world.call(bit1::neg, s, a);
647 case bit2:: nsnd: return world.call(bit1::neg, s, b);
648 case bit2:: ciff: return world.call(bit2:: iff, s, Defs{b, a});
649 case bit2::nciff: return world.call(bit2::niff, s, Defs{b, a});
650 default: break;
651 }
652
653 if (la && lb && ls) {
654 switch (id) {
655 case bit2::and_: return world.lit_idx (*ls, *la & *lb);
656 case bit2:: or_: return world.lit_idx (*ls, *la | *lb);
657 case bit2::xor_: return world.lit_idx (*ls, *la ^ *lb);
658 case bit2::nand: return world.lit_idx_mod(*ls, ~(*la & *lb));
659 case bit2:: nor: return world.lit_idx_mod(*ls, ~(*la | *lb));
660 case bit2::nxor: return world.lit_idx_mod(*ls, ~(*la ^ *lb));
661 case bit2:: iff: return world.lit_idx_mod(*ls, ~ *la | *lb);
662 case bit2::niff: return world.lit_idx (*ls, *la & ~*lb);
663 default: fe::unreachable();
664 }
665 }
666
667 // TODO rewrite using bit2
668 auto unary = [&](bool x, bool y, const Def* a) -> const Def* {
669 if (!x && !y) return world.lit(type, 0);
670 if ( x && y) return ls ? world.lit(type, *ls-1_u64) : nullptr;
671 if (!x && y) return a;
672 if ( x && !y && id != bit2::xor_) return world.call(bit1::neg, s, a);
673 return nullptr;
674 };
675 // clang-format on
676
677 if (is_commutative(id) && a == b) {
678 if (auto res = unary(tab[0][0], tab[1][1], a)) return res;
679 }
680
681 if (la) {
682 if (*la == 0) {
683 if (auto res = unary(tab[0][0], tab[0][1], b)) return res;
684 } else if (ls && *la == *ls - 1_u64) {
685 if (auto res = unary(tab[1][0], tab[1][1], b)) return res;
686 }
687 }
688
689 if (lb) {
690 if (*lb == 0) {
691 if (auto res = unary(tab[0][0], tab[1][0], a)) return res;
692 } else if (ls && *lb == *ls - 1_u64) {
693 if (auto res = unary(tab[0][1], tab[1][1], a)) return res;
694 }
695 }
696
697 if (auto res = reassociate<bit2>(id, world, callee, a, b)) return res;
698
699 return world.raw_app(type, callee, {a, b});
700}
701
702const Def* normalize_idx(const Def* type, const Def* c, const Def* arg) {
703 auto& world = type->world();
704 auto callee = c->as<App>();
705 if (auto i = Lit::isa(arg)) {
706 if (auto s = Lit::isa(Idx::isa(type))) {
707 if (*i < *s) return world.lit_idx(*s, *i);
708 if (auto m = Lit::isa(callee->arg())) return *m ? world.bot(type) : world.lit_idx_mod(*s, *i);
709 }
710 }
711
712 return {};
713}
714
715const Def* normalize_idx_unsafe(const Def*, const Def*, const Def* arg) {
716 auto& world = arg->world();
717 if (auto i = Lit::isa(arg)) return world.lit_idx_unsafe(*i);
718 return {};
719}
720
721template<shr id>
722const Def* normalize_shr(const Def* type, const Def* c, const Def* arg) {
723 auto& world = type->world();
724 auto callee = c->as<App>();
725 auto [a, b] = arg->projs<2>();
726 auto s = Idx::isa(a->type());
727 auto ls = Lit::isa(s);
728 auto width = ls ? std::optional<nat_t>(idx_shift_width(*ls)) : std::optional<nat_t>();
729
730 if (auto result = fold<shr, id>(world, type, a, b)) return result;
731
732 if (auto la = Lit::isa(a); la && *la == 0) {
733 switch (id) {
734 case shr::a: return a;
735 case shr::l: return a;
736 }
737 }
738
739 if (auto lb = Lit::isa(b)) {
740 if (width && *lb >= *width) return world.bot(type);
741
742 if (*lb == 0) {
743 switch (id) {
744 case shr::a: return a;
745 case shr::l: return a;
746 }
747 }
748 }
749
750 return world.raw_app(type, callee, {a, b});
751}
752
753template<wrap id>
754const Def* normalize_wrap(const Def* type, const Def* c, const Def* arg) {
755 auto& world = type->world();
756 auto callee = c->as<App>();
757 auto [a, b] = arg->projs<2>();
758 auto mode = callee->arg();
759 auto s = Idx::isa(a->type());
760 auto ls = Lit::isa(s);
761 auto width = ls.transform(idx_shift_width);
762
763 if (auto result = fold<wrap, id>(world, type, a, b, mode)) return result;
764
765 // clang-format off
766 if (auto la = Lit::isa(a)) {
767 if (*la == 0) {
768 switch (id) {
769 case wrap::add: return b; // 0 + b -> b
770 case wrap::sub: break;
771 case wrap::mul: return a; // 0 * b -> 0
772 case wrap::shl: return a; // 0 << b -> 0
773 }
774 } else if (*la == 1) {
775 switch (id) {
776 case wrap::add: break;
777 case wrap::sub: break;
778 case wrap::mul: return b; // 1 * b -> b
779 case wrap::shl: break;
780 }
781 }
782 }
783
784 if (auto lb = Lit::isa(b)) {
785 if (*lb == 0) {
786 switch (id) {
787 case wrap::sub: return a; // a - 0 -> a
788 case wrap::shl: return a; // a >> 0 -> a
789 default: fe::unreachable();
790 // add, mul are commutative, the literal has been normalized to the left
791 }
792 }
793
794 if (auto lm = Lit::isa(mode); lm && ls && *lm == 0 && id == wrap::sub)
795 return world.call(wrap::add, mode, Defs{a, world.lit_idx_mod(*ls, ~*lb + 1_u64)}); // a - lb -> a + (~lb + 1)
796 else if (id == wrap::shl && width && *lb >= *width)
797 return world.bot(type);
798 }
799
800 if (a == b) {
801 switch (id) {
802 case wrap::add: return world.call(wrap::mul, mode, Defs{world.lit(type, 2), a}); // a + a -> 2 * a
803 case wrap::sub: return world.lit(type, 0); // a - a -> 0
804 case wrap::mul: break;
805 case wrap::shl: break;
806 }
807 }
808 // clang-format on
809
810 if (auto res = reassociate<wrap>(id, world, callee, a, b)) return res;
811
812 return world.raw_app(type, callee, {a, b});
813}
814
815template<div id>
816const Def* normalize_div(const Def* full_type, const Def*, const Def* arg) {
817 auto& world = full_type->world();
818 auto [mem, ab] = arg->projs<2>();
819 auto [a, b] = ab->projs<2>();
820 auto [_, type] = full_type->projs<2>(); // peel off actual type
821 auto make_res = [&, mem = mem](const Def* res) { return world.tuple({mem, res}); };
822
823 if (auto result = fold<div, id>(world, type, a, b)) return make_res(result);
824
825 if (auto la = Lit::isa(a)) {
826 if (*la == 0) return make_res(a); // 0 / b -> 0 and 0 % b -> 0
827 }
828
829 if (auto lb = Lit::isa(b)) {
830 if (*lb == 0) return make_res(world.bot(type)); // a / 0 -> ⊥ and a % 0 -> ⊥
831
832 if (*lb == 1) {
833 switch (id) {
834 case div::sdiv: return make_res(a); // a / 1 -> a
835 case div::udiv: return make_res(a); // a / 1 -> a
836 case div::srem: return make_res(world.lit(type, 0)); // a % 1 -> 0
837 case div::urem: return make_res(world.lit(type, 0)); // a % 1 -> 0
838 }
839 }
840 }
841
842 if (a == b) {
843 switch (id) {
844 case div::sdiv: return make_res(world.lit(type, 1)); // a / a -> 1
845 case div::udiv: return make_res(world.lit(type, 1)); // a / a -> 1
846 case div::srem: return make_res(world.lit(type, 0)); // a % a -> 0
847 case div::urem: return make_res(world.lit(type, 0)); // a % a -> 0
848 }
849 }
850
851 return {};
852}
853
854template<conv id>
855const Def* normalize_conv(const Def* dst_t, const Def*, const Def* x) {
856 auto& world = dst_t->world();
857 auto s_t = x->type()->as<App>();
858 auto d_t = dst_t->as<App>();
859 auto s = s_t->arg();
860 auto d = d_t->arg();
861 auto ls = Lit::isa(s);
862 auto ld = Lit::isa(d);
863
864 if (s_t == d_t) return x;
865 if (x->isa<Bot>()) return world.bot(d_t);
866
867 if (auto l = Lit::isa(x); l && ls && ld) {
868 if constexpr (id == conv::u) {
869 if (*ld == 0) return world.lit(d_t, *l); // I64
870 return world.lit(d_t, *l % *ld);
871 }
872
873 return world.lit(d_t, idx_from_signed_mod(*ld, idx_sext(*ls, *l)));
874 }
875
876 if (ls && ld)
877 if (auto c1 = Axm::isa(id, x)) {
878 auto x1 = c1->arg();
879 if (auto ls1 = Lit::isa(x1->type()->as<App>()->arg()))
880 if (*ls > *ls1 || *ls == 0) // the intermediate conv is widening
881 if (*ld == *ls1) return x1; // conv(conv(x)) -> x
882 }
883
884 return {};
885}
886
887const Def* normalize_bitcast(const Def* dst_t, const Def*, const Def* src) {
888 auto& world = dst_t->world();
889 auto src_t = src->type();
890
891 if (src->isa<Bot>()) return world.bot(dst_t);
892 if (src_t == dst_t) return src;
893
894 if (auto other = Axm::isa<bitcast>(src))
895 return other->arg()->type() == dst_t ? other->arg() : world.call<bitcast>(dst_t, other->arg());
896
897 if (auto l = Lit::isa(src)) {
898 if (dst_t->isa<Nat>()) return world.lit(dst_t, *l);
899 if (Idx::isa(dst_t)) return world.lit(dst_t, *l);
900 }
901
902 return {};
903}
904
905// TODO this currently hard-codes x86_64 ABI
906// TODO in contrast to C, we might want to give singleton types like 'Idx 1' or '[]' a size of 0 and simply nuke each
907// and every occurance of these types in a later phase
908// TODO Pi and others
909template<trait id>
910const Def* normalize_trait(const Def*, const Def*, const Def* type) {
911 auto& world = type->world();
912 if (auto ptr = Axm::isa<mem::Ptr>(type)) {
913 return world.lit_nat(8);
914 } else if (type->isa<Pi>()) {
915 return world.lit_nat(8); // Gets lowered to function ptr
916 } else if (auto size = Idx::isa(type)) {
917 if (auto w = Idx::size2bitwidth(size)) return world.lit_nat(std::max(1_n, std::bit_ceil(*w) / 8_n));
918 } else if (auto w = math::isa_f(type)) {
919 switch (*w) {
920 case 16: return world.lit_nat(2);
921 case 32: return world.lit_nat(4);
922 case 64: return world.lit_nat(8);
923 default: fe::unreachable();
924 }
925 } else if (type->isa<Sigma>() || type->isa<Meet>()) {
926 u64 offset = 0;
927 u64 align = 1;
928 for (auto t : type->ops()) {
929 auto a = Lit::isa(core::op(trait::align, t));
930 auto s = Lit::isa(core::op(trait::size, t));
931 if (!a || !s) return {};
932
933 align = std::max(align, *a);
934 offset = pad(offset, *a) + *s;
935 }
936
937 offset = pad(offset, align);
938 u64 size = std::max(1_u64, offset);
939
940 switch (id) {
941 case trait::align: return world.lit_nat(align);
942 case trait::size: return world.lit_nat(size);
943 }
944 } else if (auto arr = type->isa_imm<Arr>()) {
945 auto align = op(trait::align, arr->body());
946 if constexpr (id == trait::align) return align;
947 auto b = op(trait::size, arr->body());
948 if (b->isa<Lit>()) return world.call(nat::mul, Defs{arr->arity(), b});
949 } else if (auto join = type->isa<Join>()) {
950 if (auto sigma = convert(join)) return core::op(id, sigma);
951 }
952
953 return {};
954}
955
956template<pe id>
957const Def* normalize_pe(const Def* type, const Def*, const Def* arg) {
958 auto& world = type->world();
959
960 if constexpr (id == pe::is_closed) {
961 if (Axm::isa(pe::hlt, arg)) return world.lit_ff();
962 if (arg->is_closed()) return world.lit_tt();
963 }
964
965 return {};
966}
967
969
970} // namespace mim::plug::core
const App * decurry() const
Returns App::callee again as App.
Definition lam.h:277
const Def * arg() const
Definition lam.h:285
A (possibly paramterized) Array.
Definition tuple.h:117
static auto isa(const Def *def)
Definition axm.h:107
Base class for all Defs.
Definition def.h:246
World & world() const noexcept
Definition def.cpp:444
auto projs(F f) const
Splits this Def via Def::projections into an Array (if A == std::dynamic_extent) or std::array (other...
Definition def.h:386
const Def * type() const noexcept
Yields the "raw" type of this Def (maybe nullptr).
Definition def.cpp:452
static bool greater(const Def *a, const Def *b)
Definition def.cpp:555
bool is_closed() const
Has no free_vars()?
Definition def.cpp:418
static constexpr nat_t size2bitwidth(nat_t n)
Definition def.h:905
static constexpr nat_t bitwidth2size(nat_t n)
Definition def.h:904
static const Def * isa(const Def *def)
Checks if def is a Idx s and returns s or nullptr otherwise.
Definition def.cpp:616
static std::optional< T > isa(const Def *def)
Definition def.h:838
static T as(const Def *def)
Definition def.h:844
A dependent function type.
Definition lam.h:14
A dependent tuple type.
Definition tuple.h:20
const Lit * lit_idx_unsafe(u64 val)
Definition world.h:554
#define MIM_core_NORMALIZER_IMPL
Definition autogen.h:302
The core Plugin
Definition core.h:8
const Def * normalize_nat(const Def *type, const Def *callee, const Def *arg)
const Def * normalize_idx_unsafe(const Def *, const Def *, const Def *arg)
const Sigma * convert(const TBound< up > *b)
Definition core.cpp:14
const Def * normalize_div(const Def *full_type, const Def *, const Def *arg)
const Def * normalize_pe(const Def *type, const Def *, const Def *arg)
const Def * normalize_extrema(const Def *type, const Def *c, const Def *arg)
const Def * normalize_icmp(const Def *type, const Def *c, const Def *arg)
const Def * normalize_bit1(const Def *type, const Def *c, const Def *a)
const Def * normalize_conv(const Def *dst_t, const Def *, const Def *x)
const Def * normalize_bit2(const Def *type, const Def *c, const Def *arg)
const Def * normalize_wrap(const Def *type, const Def *c, const Def *arg)
const Def * normalize_trait(const Def *, const Def *, const Def *type)
const Def * op(trait o, const Def *type)
Definition core.h:35
const Def * normalize_abs(const Def *type, const Def *, const Def *arg)
const Def * normalize_idx(const Def *type, const Def *c, const Def *arg)
constexpr std::array< std::array< u64, 2 >, 2 > make_truth_table(bit2 id)
Definition core.h:52
const Def * normalize_bitcast(const Def *dst_t, const Def *, const Def *src)
const Def * normalize_ncmp(const Def *type, const Def *callee, const Def *arg)
constexpr flags_t icmp_mask
Definition core.h:10
@ nuw
No Unsigned Wrap around.
Definition core.h:18
@ none
Wrap around.
Definition core.h:16
@ nsw
No Signed Wrap around.
Definition core.h:17
const Def * normalize_shr(const Def *type, const Def *c, const Def *arg)
std::optional< nat_t > isa_f(const Def *def)
Definition math.h:77
The mem Plugin
Definition mem.h:11
View< const Def * > Defs
Definition def.h:78
u64 nat_t
Definition types.h:37
u8 sub_t
Definition types.h:42
u64 flags_t
Definition types.h:39
float rem(float a, float b)
Definition types.h:87
TBound< true > Join
AKA union.
Definition lattice.h:174
constexpr bool is_commutative(Id)
Definition axm.h:153
constexpr std::uint64_t pad(std::uint64_t offset, std::uint64_t align) noexcept
Definition util.h:44
int64_t s64
Definition types.h:27
constexpr bool is_associative(Id id)
Definition axm.h:159
TExt< false > Bot
Definition lattice.h:171
uint64_t u64
Definition types.h:27
uint8_t u8
Definition types.h:27
TBound< false > Meet
AKA intersection.
Definition lattice.h:173
@ App
Definition def.h:109
@ Lit
Definition def.h:109
static consteval size_t num()
Definition plugin.h:149
static consteval flags_t base()
Definition plugin.h:150