-
Notifications
You must be signed in to change notification settings - Fork 146
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Faster Matrix{BlasFloat} * or \ VecOrMatrix{Dual} #589
base: master
Are you sure you want to change the base?
Conversation
src/dual.jl
Outdated
Base.:*(m::Union{LowerTriangular{<:LinearAlgebra.BlasFloat}, | ||
UpperTriangular{<:LinearAlgebra.BlasFloat}, | ||
StridedMatrix{<:LinearAlgebra.BlasFloat}}, | ||
x::StridedVecOrMat{<:Dual}) = | ||
_map_dual_components(Base.Fix1(lmul!, m), (x, _) -> lmul!(m, x), x) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's going to be very hard to avoid ambiguities here.
julia> rand(3,3) * Dual.(rand(3),0)
ERROR: MethodError: *(::Matrix{Float64}, ::Vector{Dual{Nothing, Float64, 1}}) is ambiguous.
Candidates:
*(A::StridedMatrix{T}, x::StridedVector{S}) where {T<:Union{Float32, Float64, ComplexF32, ComplexF64}, S<:Real}
@ LinearAlgebra ~/.julia/dev/julia/usr/share/julia/stdlib/v1.9/LinearAlgebra/src/matmul.jl:49
*(m::Union{LowerTriangular{var"#s87", S} where {var"#s87"<:Union{Float32, Float64, ComplexF32, ComplexF64}, S<:AbstractMatrix{var"#s87"}}, UpperTriangular{var"#s86", S} where {var"#s86"<:Union{Float32, Float64, ComplexF32, ComplexF64}, S<:AbstractMatrix{var"#s86"}}, StridedMatrix{<:Union{Float32, Float64, ComplexF32, ComplexF64}}}, x::StridedVecOrMat{<:Dual})
@ Main REPL[51]:4
Attaching the rule to mul!(Matrix{<:Dual}, ...)
seems less likely to trigger them.
Testing with Test.detect_ambiguities(ForwardDiff)
might be a good idea too.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Attaching the rule to
mul!(Matrix{<:Dual}, ...)
seems less likely to trigger them.
The problem is that promotion to Dual
already happens at \
(probably for *
too).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
BTW, I think LinearAlgebra does something like this trick for some mixed real/complex cases (where the strides work out correctly). Maybe mul!(ones(4,4).+im, rand(ComplexF64, 4,4), rand(4,4), true, false)
is one? Staying close to the signature used there is probably a way to avoid ambiguities.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
\
might be easier than *
, as it doesn't have so many methods.
For *
, promotion to make C
should happen correctly without this package doing anything, I think.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For
*
, promotion to make C should happen correctly without this package doing anything, I think.
This is *(m::Triangular, x::Vector)
Codecov ReportPatch coverage:
Additional details and impacted files@@ Coverage Diff @@
## master #589 +/- ##
==========================================
- Coverage 89.65% 87.11% -2.55%
==========================================
Files 11 9 -2
Lines 967 947 -20
==========================================
- Hits 867 825 -42
- Misses 100 122 +22
☔ View full report in Codecov by Sentry. |
Now these functions should be fixed and properly tested. |
792ee25
to
27cc509
Compare
@mcabbott @fredrikekre @devmotion I've refactored the code so that in-place ldiv/mul are also supported. These changes should now be covered by tests with the updated DiffTests package. There are some linalg differences between 1.0 and the current 1.x, so some of the tests are disabled on 1.0, I guess it's the most straightforward way to handle incompatibilities. |
src/dual.jl
Outdated
@eval LinearAlgebra.mul!(y::StridedMatrix{T}, m::$MT, x::StridedMatrix{T}) where T <: Dual = | ||
_map_dual_components!((y, x) -> mul!(y, m, x), (y, x, _) -> mul!(y, m, x), y, x) | ||
|
||
@eval Base.:*(m::$MT, x::StridedMatrix{<:Dual}) = mul!(similar(x, (size(m, 1), size(x, 2))), m, x) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've lost track a bit, but since I was passing by:
- Why does this case call
_map_dual_components!
and not just reinterprety
? - Why add a method to
*
here, won't it go tomul!
anyway? - Should there be any methods at all for 3-arg
mul!
, or can they all be on 5-argmul!
only, as that's eventually called?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mcabbott Unfortunately, I also start to loose track, but AFAIR I was exploring these possibilities:
reinterpret()
would only work for the dual vectors (translating it into normal matrix multiplication), but for dual matrices there's no representation as a normal matrix linalg operation*
and\
methods are required to overload Base Julia methods that would promote eltypes of all vectors and matrices to Dual, so we need to intercept that code path earlymul!()
: AFAIR there are some 3-arg methods that don't call 5-arg methods, plus it was evolving from Julia 1.0 to 1.8, so it is hard to come up with the set of methods that are optimal for all the versions. This is a part of the PR that could be potentially more polished, but as the whole infrastructure ofmul!()/ldiv!()
methods in LinearAlgrebra is nontrivial, I was waiting for the ForwardDiff devs feedback and approval of the PR in principle before going forward.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Eventually we'll get there!
I do think you can reinterpret dual-mat * matorvec. reinterpret(Float64, rand(3,4).+im) * rand(4, 5)
makes a 6×5 Matrix{Float64} via existing methods, I think you're suggesting that there is no method which makes a 2×3×5 array, but this is just reshape. I think that @less mul!((rand(3).+im), (rand(3,3).+im), rand(3), true, false)
works like this, without ever calling reinterpret
but just passing pointers.
Agree that some of the other methods here need to catch * or / directly.
Re 1.0, the only reason not to drop it is not being sure what version of ForwardDiff to call that... I mean it shouldn't break on 1.0, but it's totally fine for 1.0 to miss these fast paths.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mcabbott BLAS has complex matrix-matrix multiplication (zgemm()), so there's need for reinterpretation. The K×L matrix z of duals with N partials could be represented as (N+1)×L×M array Z of reals. When multiplied by the constant K×L matrix A, the result y = A×z could be reinterpreted as (N+1)×K×M array. My point is that K and L are the middle dimension of Y and Z arrays, resp. So any combination of reshaping and transposing operations would not make these dimensions the first or the last one (so that the resulting matrix is compatible with matrix A multiplication). One would need permutedims, which involves array elements reshuffling.
it moved to weakdeps
reduces the amount of noise in the output
To test the autodiff of `f(x)=M*x` or `f(x)=M\x`
I've rebased the PR and cleaned up commit history a bit. For tests to succeed, DiffTests 0.1.3 is required (JuliaDiff/DiffTests.jl#17). |
The calculation of the gradient from a multivariate normal prior involves left-division by a triangular matrix.
Currently, when the covariance matrix is fixed, and the random vector depends on model parameters (the typical use case), it is done via a fallback path in LinearAlgebra: the constant matrix gets promoted to Dual type (i.e. it is copied each time the gradient is calculated), and then the generic triangular left division implementation is called. For 100x100 and larger matrices this results in big CPU and memory overhead.
However, when the matrix is constant, we don't need to convert it to Dual. We just have to left divide the dual values vector as well as each partial by this matrix. Since it would be the operation on fixed vectors, LAPACK's trtrs() could be used for much faster division.
This is what this PR does. To avoid excessive copying, it relies on a hack: the array of duals could be accessed as a vector of N+1 floats.