-
-
Notifications
You must be signed in to change notification settings - Fork 122
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Reshape of a view breaks batched_mul on GPU (scalar indexing happens) #466
Comments
I think it's taking the generic path here: NNlib.jl/src/batched/batchedmul.jl Line 219 in 7f6ea50
because julia> summary(v6)
"4×3×8 reshape(view(::CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, :, 1, :, :), 4, 3, 8) with eltype Float32"
julia> NNlib.is_strided(v6) # NNlib's test for whether to use BLAS
false
julia> strides(v6) # maybe this is OK?
(1, 8, 24)
julia> strides(copy(v6))
(1, 4, 12) That test is intended to match what CUBLAS accepts here: https://github.com/FluxML/NNlibCUDA.jl/blob/master/src/batchedmul.jl#L3-L4 and in this case, it doesn't seem to accept julia> CUDA.CUBLAS.gemm_strided_batched!('N', 'N', true, v4, A3, false, v4 ⊠ A3) |> summary
"4×3×8 CuArray{Float32, 3, CUDA.Mem.DeviceBuffer}"
julia> CUDA.CUBLAS.gemm_strided_batched!('N', 'N', true, v6, A3, false, v4 ⊠ A3) |> summary
ERROR: conversion to pointer not defined for Base.ReshapedArray{Float32, 3, SubArray{Float32, 3, CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}}, false}, Tuple{Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}, Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}}}
Stacktrace:
[1] error(s::String)
@ Base ./error.jl:35
[2] unsafe_convert(#unused#::Type{CuPtr{Float32}}, a::Base.ReshapedArray{Float32, 3, SubArray{Float32, 3, CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}}, false}, Tuple{Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}, Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}}})
@ CUDA ~/.julia/packages/CUDA/Ey3w2/src/pointer.jl:64 It might be possible to widen this... a very quick hack seems to run: julia> Base.unsafe_convert(::Type{CuPtr{Float32}}, a::typeof(v6)) = Base.unsafe_convert(CuPtr{Float32}, parent(parent(v6)))
julia> CUDA.CUBLAS.gemm_strided_batched!('N', 'N', true, v6, A3, false, v4 ⊠ A3) |> summary
"4×3×8 CuArray{Float32, 3, CUDA.Mem.DeviceBuffer}" But I have not checked correctness, and in reality the method would need to work only for reshapes which are in fact sufficiently simple. After updating CUDA to allow this, the test used by NNlib could be widened to match. |
I think CUDA.jl is a little gun-shy about accommodating too many wrappers after JuliaGPU/CUDA.jl#453. Worth a shot though. |
This multi-wrapper cases is a little easier, in that creating a pointer must either succeed or fail. CUDA certainly owns The worse case for multiple wrappers is where Base has |
Useful Info:
NNlib v0.8.16
I actually have use of this to implement very efficient multi head attention. I am sharing the minimal example to replicate below:
using Flux # to get batched_mul
Error for reference:
The text was updated successfully, but these errors were encountered: