A convenient way of applying functions to tensors. einfunc
is incredibly similar to einsum
, but with one key difference. The ability to apply your own custom function instead of multiplication. einfunc
also allows for the ability to choose how reductions occur within the operation.
einfunc
is a simple interface, which utilizes Pytorch's torchdim. I highly recommend checking out torchdim as it is by far the best way to do readable tensor operations in pytorch. einfunc
is just a convenient way of tapping into torchdim with a function similar to einsum
.
Table of Contents
- Installation
- API
- Why you shouldn't use einfunc
- Why you should use einfunc
- Additional Examples
- Planned Work
- Acknowledgements
einfunc
requires torch >= 2.0 and python >= 3.8.
pip install einfunc
Using einfunc
is similar to einsum
however you also pass a function and a mode of reduction. Take this math equation for example.
We can use einfunc to represent this math equation with 2 lines of code.
inner_exp = einfunc(x, y, 'b, b k -> k', lambda a, b : a ** 2 - torch.exp(b), reduce='prod')
final_exp = einfunc(z, inner_exp, 'k, k -> ', lambda a, b : torch.log(a) / b, reduce='mean')
While lambda functions are simple, any function can work as long as it takes the correct number of inputs. For example when looking at the following expression:
inner_exp = einfunc(x, y, 'i, i j -> j', lambda a, b : a ** 2 - torch.exp(b), reduce='prod')
x
maps to a
and y
maps to b
. This means that the order of the function variables is passed in the order that tensors are passed to einfunc.
Currently, einfunc supports 5 different types of reduction.
- Mean
result = einfunc(x, y, 'i j, k i -> j k', lambda a, b : a - b, reduce='mean')
- Sum
result = einfunc(x, y, 'i j, k i -> j k', lambda a, b : a - b, reduce='sum')
- Prod
result = einfunc(x, y, 'i j, k i -> j k', lambda a, b : a - b, reduce='prod')
- Max
result = einfunc(x, y, 'i j, k i -> j k', lambda a, b : a - b, reduce='max')
- Min
result = einfunc(x, y, 'i j, k i -> j k', lambda a, b : a - b, reduce='min')
One thing to note is that if reduce
is not passed then 'sum' is assumed by einfunc.
Einfunc is just a convenient way of interfacing with PyTorch and Torchdim. This creates some overhead when operating on tensors, compared to vanilla operations and torchdim operations. This means that it will be much faster to use vanilla pytorch operations if doing a simple operation, or just use torchdim if trying to do something more complex.
It's convenient and slightly more readable than torchdim IMO. Understanding exactly what is happening in an operation can be hard, and einfunc makes it a lot simpler by boiling operations down to a single expression, while using einstein notation to indicate indexing.
Coming Soon :)
Currently, einfunc does support parenthesis and ellipses. I will be working on implementing this as soon as I can.