MimIR
MimIR is my Intermediate Representation
Loading...
Searching...
No Matches
The tensor Plugin

See also
mim::plug::tensor

Dependencies

plugin tuple;
plugin refly;
plugin vec;
import affine;

A tensor plugin

Types

Represents an algebraic Ring.

let %tensor.Ring = [
T: *,
_0: T,
add: [T, T] → T,
mul: [T, T] → T,
];

let nat_ring = (Nat, 0, core.nat.add, core.nat.mul);

Operations

tensor.get

Extracts the element of index from arr. tensor.get arr (a, b, c) = arr#a#b#c;

axm %tensor.get: {T: *, r: Nat, s: «r; Nat»} → [arr: «s; T», index: «i: r; Idx (s#i)»] → T, normalize_get;

tensor.set

Inserts the element x in arr at index. tensor.set arr (a, b, c) x = Insert arr a (Insert (arr#a) b (Insert (arr#a#b) c x));

axm %tensor.set: {T: *, r: Nat, s: «r; Nat»} → [arr: «s; T», index: «i: r; Idx (s#i)», x: T] → «s; T», normalize_set;

%tensor.map_reduce

  • nis: number of inputs
  • T/R/S, is/o : respectively the type/rank/shape of the inputs/output
  • f : function to reduce over (takes an element of type To and one of each type in Tis, and returns a To)
  • init : accumulator to start f with
  • subs : for each input, for each dimension, an index to compute the output in Einstein notation
  • is : the inputs

Returns a tensor obtained by folding f following the indexes in subs

nis, Ro, Ris and subs must be literals, before entering the tensor::Lower phase, for the lowering to succeed.

axm %tensor.map_reduce: {nis: Nat}
→ {To: *, Ro: Nat}
→ [So: «Ro; Nat»]
→ {Tis: «nis; *», Ris: «i: nis; Nat», Sis: «i:nis; «Ris#i; Nat»»}
→ [f: Fn [To, «i: nis; Tis#i»] → To, init: To]
→ [subs: «i: nis; « Ris#i; Nat»»]
→ [is: «i: nis; «Sis#i; Tis#i» »]
→ «So; To», normalize_map_reduce;

%tensor.map_reduce_ds

Variant of tensor.map_reduce that takes a direct-style function as input.

axm %tensor.map_reduce_ds: {nis: Nat}
→ {To: *, Ro: Nat}
→ [So: «Ro; Nat»]
→ {Tis: «nis; *», Ris: «i: nis; Nat», Sis: «i:nis; «Ris#i; Nat»»}
→ [f: [To, «i: nis; Tis#i»] → To, init: To]
→ [subs: «i: nis; « Ris#i; Nat»»]
→ [is: «i: nis; «Sis#i; Tis#i» »]
→ «So; To»;
lam %tensor.map_reduce_ds_impl {nis: Nat}
{To: *, Ro: Nat}
(So: «Ro; Nat»)
{Tis: «nis; *», Ris: «i: nis; Nat», Sis: «i:nis; «Ris#i; Nat»»}
(f: [To, «i: nis; Tis#i»] → To, init: To)
(subs: «i: nis; « Ris#i; Nat»»)
(is: «i: nis; «Sis#i; Tis#i» »)
: «So; To»
= fun f_ (x: To, y: «i: nis; Tis#i»)@tt = return (f (x, y));
%tensor.map_reduce @nis @(To, Ro) So @(Tis, Ris, Sis) (f_, init) subs is;

%tensor.map_reduce_aff

Generalisation of tensor.map_reduce where the per-input subscripts are replaced by affine access functions.

The iteration domain and the result shape are stated explicitly:

  • Sr: the bounds of the full loop nest of length Ro + Rr. Its leading Ro dimensions are the parallel output loops; the trailing Rr dimensions are the reduction loops that are folded away.
  • So: the shape of the result tensor (rank Ro). It may differ from the leading Ro bounds of Sr, which is what lets acc_out write to a re-shaped output (e.g. a transpose).

The full loop iteration vector (o…, r…) has length Ro + Rr.

  • acc_out maps the full loop vector to the Ro write coordinates in the result «So». As the reduction is folded away before the write, it must depend only on the leading Ro output indices.
  • accs#i maps the full loop vector to the Ris#i read coordinates of input i. Transposes, slices and broadcasts are expressed here, by reading the inputs accordingly.

The literal subscript subs#i#j = k of tensor.map_reduce corresponds to the projection access map λ iters. iters#k, with acc_out the identity and the leading Ro bounds of Sr equal to So.

nis, Ro, Rr and the Ris must be literals before entering the tensor::Lower phase for the lowering to succeed; the bounds So and Sr may be symbolic.

axm %tensor.map_reduce_aff: {nis: Nat}
→ {To: *, Ro Rr: Nat}
→ [So: «Ro; Nat», Sr: «%core.nat.add (Ro, Rr); Nat»]
→ {Tis: «nis; *», Ris: «i: nis; Nat», Sis: «i: nis; «Ris#i; Nat»»}
→ [f: Fn [To, «i: nis; Tis#i»] → To, init: To]
→ [acc_out: [«%core.nat.add (Ro, Rr); %affine.index»] → «Ro; %affine.index»]
→ [accs: «i: nis; ([«%core.nat.add (Ro, Rr); %affine.index»] → «Ris#i; %affine.index»)»]
→ [is: «i: nis; «Sis#i; Tis#i»»]
→ «So; To», normalize_map_reduce_aff;

%tensor.dot_product

  • R: the ring in which the dot product is performed
  • r1/r2 : the ranks of the input matrices
  • nc/nb : the amount of contracting/batching dimensions
  • c/b``1/2 : respectively the contracting/batching dimensions of the left/right input
  • s``1/2 : the shape of the left/right input
  • a/b : the left/right input

Returns the generalised dot product of a and b.

fun dot_general_fun [R: %tensor.Ring] [x: R#T, [y: R#T, z: R#T]]@tt: R#T = return (R#add (x, (R#mul (y, z))));
lam dot_general_pick {n: Nat} {na nb nc: Nat} (a: «na; Idx n», b: «nb; Idx n», c: «nc; Idx n») (off_1 off_2: Nat) (i: Idx n): Nat =
let (ab, ai) = %vec.first (%core.icmp.e @n) (a, i);
let (bb, bi) = %vec.first (%core.icmp.e @n) (b, i);
let (cb, ci) = %vec.first (%core.icmp.e @n) (c, i);
let ai_nat = %core.bitcast Nat ai;
let ai_out = %core.nat.add (off_1, ai_nat);
let bi_nat = %core.bitcast Nat bi;
let ci_nat = %core.bitcast Nat ci;
let ci_out = %core.nat.add (off_2, ci_nat);
((ai_out, ci_out)#cb, bi_nat)#bb;
lam dot_general_shape {r1 r2: Nat} {nc nb: Nat}
(c1: «nc; Idx r1», c2: «nc; Idx r2», b1: «nb; Idx r1», b2: «nb; Idx r2»)
(s1: «r1; Nat», s2: «r2; Nat») =
let bs_check = ‹i: nb; %refly.check (%core.ncmp.e ((s1#(b1#i), s2#(b2#i))), s1#(b1#i), "batching dims don't match")›;
let cs_check = ‹i: nc; %refly.check (%core.ncmp.e ((s1#(c1#i), s2#(c2#i))), s1#(c1#i), "contracting dims don't match")›;
let bs = ‹i: nb; s1#(b1#i)›;
let bc_1 = %vec.cat (b1, c1);
let s1_res = %vec.diff (s1, bc_1);
let bc_2 = %vec.cat (b2, c2);
let s2_res = %vec.diff (s2, bc_2);
let s12_res = %vec.cat (s1_res, s2_res);
let s_out = %vec.cat (bs, s12_res);
s_out;
axm %tensor.dot_product: [R: %tensor.Ring]
→ {r1 r2: Nat}
→ {nc nb: Nat}
→ [c1: «nc; Idx r1», c2: «nc; Idx r2», b1: «nb; Idx r1», b2: «nb; Idx r2»]
→ {s1: «r1; Nat», s2: «r2; Nat»}
→ [a: «s1; R#T», b: «s2; R#T»]
→ «dot_general_shape (c1, c2, b1, b2) (s1, s2); R#T»;
lam %tensor.dot_product_impl
(R: %tensor.Ring) {r1 r2: Nat} {nc nb: Nat}
(c1: «nc; Idx r1», c2: «nc; Idx r2», b1: «nb; Idx r1», b2: «nb; Idx r2»)
{s1: «r1; Nat», s2: «r2; Nat»} (a: «s1; R#T», b: «s2; R#T»)
: let s_out = dot_general_shape (c1, c2, b1, b2) (s1, s2);
«s_out; R#T»
= let s_out = dot_general_shape (c1, c2, b1, b2) (s1, s2);
let bc_1 = %vec.cat (b1, c1);
let s1_res = %vec.diff (s1, bc_1);
let n_s1_res = %vec.len s1_res;
let bc_2 = %vec.cat (b2, c2);
let f = dot_general_fun R;
let r_out = %vec.len s_out;
let a1 = %vec.diff (‹i: r1; i›, bc_1);
let a2 = %vec.diff (‹i: r2; i›, bc_2);
let ein_1 = ‹i: r1; dot_general_pick (a1, b1, c1) (nb, r_out) i›;
let off_ein2 = %core.nat.add (nb, n_s1_res);
let ein_2 = ‹i: r2; dot_general_pick (a2, b2, c2) (off_ein2, r_out) i›;
%tensor.map_reduce @2 s_out (f, R#_0) (ein_1, ein_2) (a, b);

%tensor.product_2d

Computes the dot product of two 2-dimensional tensors.

axm %tensor.product_2d: [R: %tensor.Ring]
→ {m k l: Nat}
→ [t1: «m, k; R#T», t2: «k, l; R#T»]
→ «m, l; R#T»;
lam %tensor.product_2d_impl (R: %tensor.Ring) {m k l: Nat} (t1: «m, k; R#T», t2: «k, l; R#T»): «m, l; R#T»
= %tensor.dot_product_impl R @(2, 2) (1_2, 0_2, (), ()) @((m, k), (k, l)) (t1, t2);

%tensor.transpose

Permutes the dimensions of input according to permutation.

fun transpose_fun {T:*} [x: T, y: T]@tt: T = return y;
lam transpose_shape {r: Nat} (s: «r; Nat», permutation: «r; Idx r»): «r; Nat» =
let shape_permutation = ‹i: r; (%vec.first (%core.icmp.e @r) (permutation, i))#tt›;
‹i: r; s#(shape_permutation#i)›;
axm %tensor.transpose: {T: *, r: Nat, s: «r; Nat»}
→ [input: «s; T», permutation: «r; Idx r»]
→ «transpose_shape (s, permutation); T»;
lam %tensor.transpose_impl {T: *, r: Nat, s: «r; Nat»}
(input: «s;T», permutation: «r; Idx r»)
: «transpose_shape (s, permutation); T»
= let permutation_nat = <i: r; %core.bitcast Nat permutation#i>;
let out_s = transpose_shape (s, permutation);
%tensor.map_reduce @1 @(T, r) out_s @(T, r, s) (transpose_fun @T, ⊥: T) permutation_nat input;

%tensor.transpose_2d

Permutes the dimensions of a 2-dimensional tensor.

axm %tensor.transpose_2d: {T: *}
→ {s: «2; Nat»}
→ [input: «s; T»]
→ «s#tt, s#ff; T»;
lam %tensor.transpose_2d_impl {T: *} {s: «2; Nat»} (input: «s; T»): «s#tt, s#ff; T» =
%tensor.transpose_impl @(T, 2, s) (input, (tt, ff));

%tensor.broadcast

Expands the dimensions of input to fit s_out. For all i, s_in#i = s_out#i or 1 The dimensions of size 1 are expanded to fit s_out.

axm %tensor.broadcast: {T: *, r: Nat}
→ [s_in: «r; Nat», s_out: «r; Nat», input: «s_in; T»]
→ «s_out; T», normalize_broadcast;

%tensor.broadcast_in_dim

Transposes and expands the dimensions of input to fit s_out. Each dimensions is mapped to the corresponding index output dimension. The holes are filled by broadcasting. Todo: We could probably just implement this in terms of tensor.broadcast and tensor.transpose directly?

axm %tensor.broadcast_in_dim: {T: *, r_in r_out: Nat}
→ [s_in: «r_in; Nat», s_out: «r_out; Nat», input: «s_in; T», index: «r_in; Idx r_out»]
→ «s_out; T», normalize_broadcast_in_dim;

%tensor.map

Maps a function over a collection of tensors.

axm %tensor.map: {T: *, ni: Nat, Is: «ni; *»}
→ [app: «i: ni; Is#i» → T]
→ {r: Nat, s: «r; Nat»}
→ [is: «i: ni; «s; Is#i»»]
→ «s; T»;
lam %tensor.map_impl {T: *, ni: Nat, Is: «ni; *»}
(app: «i: ni; Is#i» → T)
{r: Nat, s: «r; Nat»}
(is: «i: ni; «s; Is#i» »)
: «s; T» =
fun app_mr [x: T, y: «i: ni; Is#i»]@tt = return (app y);
%tensor.map_reduce @ni @(T, r) s @(Is, ‹ni; r›, ‹ni; s›) (app_mr, ⊥: T) ‹ni; ‹i:r; %core.bitcast Nat i›› is;

%tensor.unary

Maps a unary function over a tensor.

axm %tensor.unary: {Ti To: *}
→ [app: Ti → To]
→ {r: Nat, s: «r; Nat»}
→ [i: «s; Ti»]
→ «s; To»;
lam %tensor.unary_impl {Ti To : *} [app: Ti → To] {r: Nat, s: «r; Nat»} (i: «s; Ti»): «s; To» =
%tensor.map_impl @(To, 1, Ti) app @(r, s) i;

%tensor.binary

Maps a binary function over a pair of tensors.

axm %tensor.binary: {Ti1 Ti2 To: *}
→ [app: [Ti1, Ti2] → To]
→ {r: Nat, s: «r; Nat»}
→ [is: [«s; Ti1», «s; Ti2»]]
→ «s; To»;
lam %tensor.binary_impl {Ti1 Ti2 To: *} (app: [Ti1, Ti2] → To) {r: Nat, s: «r; Nat»} (is: [«s; Ti1», «s; Ti2»]): «s; To» =
%tensor.map_impl @(To, 2, (Ti1, Ti2)) app @(r, s) is;

%tensor.select

Maps core.select over tensors.

axm %tensor.select: {T: *}
→ {r: Nat, s: «r; Nat»}
→ [is: [«s; Bool», «s; T», «s; T»]]
→ «s; T»;
lam %tensor.select_impl {T: *} {r: Nat, s: «r; Nat»} (is: [«s; Bool», «s; T», «s; T»]): «s; T» =
%tensor.map_impl @(T, 3, (Bool, T, T)) (%core.select @T) @(r, s) is;

Phases

axm %tensor.lower_tensor: %compile.Phase;
axm %tensor.lower_map_reduce: %compile.Phase;
axm %tensor.fuse_tensor: %compile.Phase;