Skip to content

Commit

Permalink
more on mixed precision
Browse files Browse the repository at this point in the history
  • Loading branch information
Rabab53 committed Sep 12, 2024
1 parent 751736e commit 3201258
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 7 deletions.
1 change: 1 addition & 0 deletions example/mixed_precision.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ This example shows how to compute kernel matrix and infer the precision per tile
to compute distance matrix based by using Euclidean distance
and then it calls GammaExponentialKernel for each resulted distance
"""
using Revise
using Dagger
using LinearAlgebra
using KernelFunctions
Expand Down
51 changes: 51 additions & 0 deletions src/array/adapt_precision.jl
Original file line number Diff line number Diff line change
Expand Up @@ -175,3 +175,54 @@ function adapt_precision(A::DArray{T,2}, tolerance::T) where {T}

return collect(DMP)
end


function tile_precision_and_convert(A, MP, global_norm, scalar_factor, tolerance)

tile_sqr = mapreduce(LinearAlgebra.norm_sqr, +, A)

tile_norm = sqrt(tile_sqr)

cal = tile_norm * scalar_factor / global_norm
decision_hp = tile_norm * scalar_factor / global_norm < tolerance / eps(Float16)
decision_sp = tile_norm * scalar_factor / global_norm < tolerance / eps(Float32)

#We are planning in near future to support fp8 E4M3 and E5M2
#decision_fp8 = tile_norm * scalar_factor / global_norm < tolerance / 0.0625
#if decision_fp8
# return Float8
if decision_hp
return Float16
elseif decision_sp
return Float32
else
return Float64
end
end


function adapt_precision_and_convert(A::DArray{T,2}, tolerance::T) where {T}

Ac = parent(A).chunks
mt, nt = size(Ac)

global_norm = LinearAlgebra.norm2(A)

MP = fill(T, mt, nt)
DMP = view(MP, Blocks(1, 1))
MPc = DMP.chunks


for m in range(1, mt)
for n in range(1, nt)
Dagger.@spawn tile_precision(
InOut(Ac[m, n]),
Out(MPc[m, n]),
global_norm,
max(mt, nt),
tolerance)
end
end

return collect(DMP)
end
28 changes: 21 additions & 7 deletions src/array/mixchol.jl
Original file line number Diff line number Diff line change
@@ -1,19 +1,23 @@
function mixedtrsm!(side, uplo, trans, diag, alpha, A, B, StoragePrecision)
@inline function mixedtrsm!(side, uplo, trans, diag, alpha, A, B, StoragePrecision)
T = StoragePrecision
m, n = size(B)
if typeof(B) != Matrix{T}
println("B is not of type $T but of type $(typeof(B))")
if typeof(A) != Matrix{T}
Acopy = convert(Matrix{T}, A)
else
Acopy = A
end
Bcopy = convert(Matrix{T}, B)
BLAS.trsm!(side, uplo, trans, diag, T(alpha), Acopy, Bcopy)
copyto!(B, Bcopy)
return B
end
BLAS.trsm!(side, uplo, trans, diag, alpha, A, B)
return B
end
function mixedgemm!(transa, transb, alpha, A, B, beta, C, StoragePrecision)
@inline function mixedgemm!(transa, transb, alpha, A, B, beta, C, StoragePrecision)
T = StoragePrecision
m, n = size(C)
if typeof(C) != Matrix{T}
if typeof(A) != Matrix{T}
Acopy = convert(Matrix{T}, A)
Expand All @@ -27,11 +31,15 @@ function mixedgemm!(transa, transb, alpha, A, B, beta, C, StoragePrecision)
end
Ccopy = convert(Matrix{T}, C)
BLAS.gemm!(transa, transb, T(alpha), Acopy, Bcopy, T(beta), Ccopy)
copyto!(C, Ccopy)
return C
end
BLAS.gemm!(transa, transb, alpha, A, B, beta, C)
return C
end
function mixedsyrk!(uplo, trans, alpha, A, beta, C, StoragePrecision)
@inline function mixedsyrk!(uplo, trans, alpha, A, beta, C, StoragePrecision)
T = StoragePrecision
m, n = size(C)
if typeof(C) != Matrix{T}
if typeof(A) != Matrix{T}
Acopy = convert(Matrix{T}, A)
Expand All @@ -40,10 +48,13 @@ function mixedsyrk!(uplo, trans, alpha, A, beta, C, StoragePrecision)
end
Ccopy = convert(Matrix{T}, C)
BLAS.syrk!(uplo, trans, T(alpha), Acopy, T(beta), Ccopy)
copyto!(C, Ccopy)
return C
end
BLAS.syrk!(uplo, trans, alpha, A, beta, C)
return C
end
function mixedherk!(uplo, trans, alpha, A, beta, C, StoragePrecision)
@inline function mixedherk!(uplo, trans, alpha, A, beta, C, StoragePrecision)
T = StoragePrecision
if typeof(C) != Matrix{T}
if typeof(A) != Matrix{T}
Expand All @@ -53,10 +64,13 @@ function mixedherk!(uplo, trans, alpha, A, beta, C, StoragePrecision)
end
Ccopy = convert(Matrix{T}, C)
BLAS.herk!(uplo, trans, T(alpha), Acopy, T(beta), Ccopy)
copyto!(C, Ccopy)
return C
end
BLAS.herk!(uplo, trans, alpha, A, beta, C)
return C
end
function MixedPrecisionChol!(A::DArray{T,2}, ::Type{LowerTriangular}, MP::Matrix{DataType}) where T
function MixedPrecisionChol!(A::DMatrix{T}, ::Type{LowerTriangular}, MP::Matrix{DataType}) where T
LinearAlgebra.checksquare(A)

zone = one(T)
Expand Down Expand Up @@ -124,7 +138,7 @@ function MixedPrecisionChol!(A::DArray{T,2}, ::Type{UpperTriangular}, MP::Matrix
if iscomplex
Dagger.@spawn mixedherk!(uplo, 'C', rmzone, In(Ac[k, m]), rzone, InOut(Ac[m, m]))
else
Dagger.@spawn mixedherk!(uplo, 'T', rmzone, In(Ac[k, m]), rzone, InOut(Ac[m, m]))
Dagger.@spawn mixedsyrk!(uplo, 'T', rmzone, In(Ac[k, m]), rzone, InOut(Ac[m, m]))
end
for n in range(m+1, nt)
Dagger.@spawn mixedgemm!(trans, 'N', mzone, In(Ac[k, m]), In(Ac[k, n]), zone, InOut(Ac[m, n]))
Expand Down

0 comments on commit 3201258

Please sign in to comment.