Skip to content

MurrellGroup/Einops.jl

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Logo

Einops.jl

Stable Dev Build Status Coverage

Einops.jl brings the readable and concise tensor operations of einops to Julia, reliably expanding to existing primitives like reshape, permutedims, and repeat.

Einops vs Base primitives

Einops uses patterns with explicitly named dimensions, which can be constructed with the einops string macro, e.g. einops"a b -> (b a)" expands to the form (:a, :b) --> ((:b, :a),), where --> is a custom operator that puts the left and right operands as type parameters of a special pattern type, allowing generated functions to compose clean expressions.

The snippets below show identical transformations expressed first with Einops (one readable line) and then with "hand-rolled" Julia primitives. Notice how Einops collapses multiple e.g. reshape / permutedims / dropdims / repeat calls into a single, declarative statement, while still expanding to such primitives under the hood and avoiding no-ops.

Description Einops Base primitives

Flatten first two dimensions

rearrange(x, einops"a b c -> (a b) c")
reshape(x, :, size(x, 3))

Permute first two dimensions

rearrange(x, einops"a b c -> b a c")
permutedims(x, (2, 1, 3))

Permute and flatten

rearrange(x, einops"a b -> (b a)")
vec(permutedims(x))

Remove first dimension singleton

rearrange(x, einops"1 ... -> ...")
dropdims(x, dims=1)

Funky repeat

repeat(x, einops"... -> 2 ... 3")


repeat(
  reshape(x, 1, size(x)...),
  2, ntuple(Returns(1), ndims(x))..., 3)

Multi-Head Attention

rearrange(q,
  einops"(d h) l b -> d l (h b)";
  d=head_dim)


reshape(
  permutedims(
    reshape(q, head_dim, :, size(q)[2:3]...),
    (1, 3, 2, 4)),
  head_dim, size(q, 2), :)

Grouped-Query Attention

repeat(k,
  einops"(d h) l b -> d l (r h b)";
  d=head_dim, r=repeats)




reshape(
  repeat(
    permutedims(
      reshape(k, head_dim, :, size(k)[2:3]...),
      (1, 3, 2, 4)),
    inner=(1, 1, repeats, 1)),
  head_dim, size(k, 2), :)

Contributing

Contributions are welcome! Please feel free to open an issue to report a bug or start a discussion.