Caching of `x`, `y = f(x)`, and `log|det(J)|`
In this release, we add caching of intermediate values for Bijector
s.
What this means is that you can often reduce computation by calculating log|det(J)|
at the same time as y = f(x)
. It's also useful for performing variational inference on Bijector
s that don't have an explicit inverse. The mechanism by which this is achieved is a subclass of torch.Tensor
called BijectiveTensor
that bundles together (x, y, context, bundle, log_det_J)
.
Special shout out to @vmoens for coming up with this neat solution and taking the implementation lead! Looking forward to your future contributions 🥳