Skip to content

Commit

Permalink
Merge pull request #64 from ayush1999/dev_gemm
Browse files Browse the repository at this point in the history
gemm! fix for julia0.7 and julia1.0
  • Loading branch information
MikeInnes authored Sep 10, 2018
2 parents d6aaa81 + a91deb6 commit 5954d19
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 3 deletions.
6 changes: 3 additions & 3 deletions src/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@ for (gemm, elty) in ((:dgemm_,:Float64), (:sgemm_,:Float32))
if transA=='N'; lda=M; else; lda=K; end
if transB=='N'; ldb=K; else; ldb=N; end
ldc = M;
ccall((@blasfunc(dgemm_), libblas), Nothing,
ccall((@blasfunc($(gemm)), libblas), Nothing,
(Ref{UInt8}, Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt},
Ref{BlasInt}, Ref{Float64}, Ptr{Float64}, Ref{BlasInt},
Ptr{Float64}, Ref{BlasInt}, Ref{Float64}, Ptr{Float64},
Ref{BlasInt}, Ref{$elty}, Ptr{$elty}, Ref{BlasInt},
Ptr{$elty}, Ref{BlasInt}, Ref{$elty}, Ptr{$elty},
Ref{BlasInt}),
transA, transB, M, N, K,
alpha, A, lda, B, ldb, beta, C, ldc)
Expand Down
32 changes: 32 additions & 0 deletions test/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,20 @@ using NNlib: conv, ∇conv_filter, ∇conv_data, ∇maxpool, maxpool, depthwisec
49 99 149;
59 109 159.]

@test dropdims(conv(Float32.(x), Float32.(w)), dims=(3,4)) == Float32.([
29 79 129;
39 89 139;
49 99 149;
59 109 159.])

@test dropdims(conv(x, w; stride=2), dims = (3,4)) == [
29 129;
49 149.]

@test dropdims(conv(Float32.(x), Float32.(w); stride=2), dims = (3,4)) == Float32.([
29 129;
49 149.])

@test dropdims(conv(x, w; pad=1), dims = (3,4)) == [
1.0 9.0 29.0 49.0 48.0;
4.0 29.0 79.0 129.0 115.0;
Expand All @@ -23,6 +33,15 @@ using NNlib: conv, ∇conv_filter, ∇conv_data, ∇maxpool, maxpool, depthwisec
10.0 40.0 70.0 100.0 80.0
]

@test dropdims(conv(Float32.(x), Float32.(w); pad=1), dims = (3,4)) == Float32.([
1.0 9.0 29.0 49.0 48.0;
4.0 29.0 79.0 129.0 115.0;
7.0 39.0 89.0 139.0 122.0;
10.0 49.0 99.0 149.0 129.0;
13.0 59.0 109.0 159.0 136.0;
10.0 40.0 70.0 100.0 80.0
])

@test dropdims(conv(x, w; dilation=2), dims = (3,4)) == [
48 98;
58 108;
Expand Down Expand Up @@ -151,10 +170,16 @@ end
1150.0 1330.0 1510.0]
@test dropdims(conv(x, w), dims = (4,5)) == res

@test dropdims(conv(Float32.(x), Float32.(w)), dims = (4,5)) == Float32.(res)

@test dropdims(conv(x, w; stride=2), dims = (3,4,5)) == [
322.0 682.0;
394.0 754.0]

@test dropdims(conv(Float32.(x), Float32.(w); stride=2), dims = (3,4,5)) == Float32.([
322.0 682.0;
394.0 754.0])

res = zeros(6,5,4)
res[:, :, 1] = [
1.0 9.0 29.0 49.0 48.0;
Expand Down Expand Up @@ -186,12 +211,19 @@ end
270.0 660.0 730.0 800.0 480.0]
@test dropdims(conv(x, w; pad=1), dims = (4,5)) == res

@test dropdims(conv(Float32.(x), Float32.(w); pad=1), dims = (4,5)) == Float32.(res)

@test dropdims(conv(x, w; dilation=2), dims = (3,4,5)) == [
608 788;
644 824;
680 860.
]

@test dropdims(conv(Float32.(x), Float32.(w); dilation=2), dims = (3,4,5)) == Float32.([
608 788;
644 824;
680 860.
])
# NaN tests for dilation forward pass

ys = []
Expand Down

0 comments on commit 5954d19

Please sign in to comment.