This package is unique in that it can differentiate vector-valued expressions in Einstein notation. However, if you only need gradients of scalar-valued functions (which is typicial in machine learning), please use XGrad.jl instead. XGrad.jl is re-thought and stabilized version of this package, adding many useful featues in place of (not frequently used) derivatives of vector-valued functions. If nevertheless you want to continue using XDiff.jl, please pin Espresso.jl to version
v3.0.0
, which is the last supporting Einstein notation.
XDiff.jl is an expression differentiation package, supporting fully symbolic approach to finding tensor derivatives. Unlike automatic differentiation packages, XDiff.jl can output not only ready-to-use derivative functions, but also their symbolic expressions suitable for further optimization and code generation.
xdiff
takes an expression and a set of "example values" and returns another expression
that calculates the value together with derivatives of an output variable w.r.t each
argument. Example values are anything similar to expected data, i.e. with the same data type
and size.
In the example below we want w
and x
to be vectors of size (3,)
while b
to be a scalar.
# expressions after a semicolon are "example values" - something that looks like expected data
xdiff(:(y = sum(w .* x) + b); w=rand(3), x=rand(3), b=rand())
# quote
# dy!dx = @get_or_create(mem, :dy!dx, zeros(Float64, (3,)))
# dy!dw = @get_or_create(mem, :dy!dw, zeros(Float64, (3,)))
# y = @get_or_create(mem, :y, zero(Float64))
# tmp658 = @get_or_create(mem, :tmp658, zero(Float64))
# dy!dtmp658 = @get_or_create(mem, :dy!dtmp658, zero(Float64))
# tmp658 = sum(w .* x)
# y = tmp658 .+ b
# dy!dtmp658 = 1.0
# dy!dw .= x
# dy!dx .= w
# tmp677 = (y, dy!dw, dy!dx, dy!dtmp658)
# end
By default, xdiff
generates a highly-optimized code that uses a set of buffers stored in
a dictionary mem
. You may also generate slower, but more readable expression using VectorCodeGen
:
ctx = Dict(:codegen => VectorCodeGen())
xdiff(:(y = sum(w .* x) + b); ctx=ctx, w=rand(3), x=rand(3), b=rand())
# quote
# tmp691 = w' * x
# y = tmp691 + b
# dy!dtmp691 = 1.0
# dy!db = 1.0
# dy!dw = x
# dy!dx = w
# tmp698 = (y, dy!dw, dy!dx, dy!db)
# end
or in Einstein indexing notation using EinCodeGen
:
ctx = Dict(:codegen => EinCodeGen())
xdiff(:(y = sum(w .* x) + b); ctx=ctx, w=rand(3), x=rand(3), b=rand())
# quote
# tmp700 = w[i] .* x[i]
# y = tmp700 + b
# dy!dtmp700 = 1.0
# dy!db = 1.0
# dy!dw[j] = dy!dtmp700 .* x[j]
# dy!dx[j] = dy!dtmp700 .* w[j]
# tmp707 = (y, dy!dw, dy!dx, dy!db)
# end
xdiff
also provides a convenient interface for generating function derivatives:
# evaluate using `include("file_with_function.jl")`
f(w, x, b) = sum(w .* x) .+ b
df = xdiff(f; w=rand(3), x=rand(3), b=rand())
df(rand(3), rand(3), rand())
# (0.8922305671741435, [0.936149, 0.80665, 0.189789], [0.735201, 0.000282879, 0.605989], 1.0)
Note, that xdiff
will try to extract function body as it was written, but it doesn't always
work smoothly. One сommon case when function body isn't available is when function is defined
in REPL, so for better experience load functions using include(<filename>)
or using <module>
.
- loops are not supported
- conditional branching is not supported
Loops and conditional operators may introduce discontinuity points, potentially resulting in
very complex and heavy piecewise expressions, and thus are not supported.
However, many such expressions may be rewritten into analytical form. For example, many loops
may be rewritten into some aggregation function like sum()
(already supported), and
many conditions may be expressed as multiplication like f(x) * (x > 0) + g(x) * (x <= 0)
(planned). Please, submit an issue if you are interested in supporting some specific feature.
On the high level, scalar expressions are differentiated as follows:
- Expression is parsed into an
ExGraph
- a set of primitive expressions, mostly assignments and single function calls. - Resulting
ExGraph
is evaluated using example values to determine types and shape of all variables (forward pass). - Similar to reverse-mode automatic differentiation, derivatives are propagated backward from output to input variables. Unlike AD, however, derivatives aren't represented as values, but instead as symbolic exprssions.
Tensor expressions exploit very similar pipeline, but act in Einstein notation.
- Tensor expression is transformed into Einstein notation.
- Expression in Einstein notation is parsed into an
Einraph
(indexed variant ofExGraph
). - Resulting
EinGraph
is evaluated. - Partial derivatives are computed using tensor or element-wise rules for each element of each tensor, then propagated from output to input variables.
- Optionally, derivative expressions are converted back to vectorized notation.