Skip to content

Commit 0ee0774

Browse files
authored
Handle NCCL.avg correctly (#54)
1 parent 3a49978 commit 0ee0774

File tree

3 files changed

+44
-17
lines changed

3 files changed

+44
-17
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "NCCL"
22
uuid = "3fe64909-d7a1-4096-9b7d-7a0f12cf0f6b"
3-
version = "0.1.0"
3+
version = "0.1.1"
44

55
[deps]
66
CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82"

src/base.jl

+2
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ ncclRedOp_t(::typeof(+)) = ncclSum
3333
ncclRedOp_t(::typeof(*)) = ncclProd
3434
ncclRedOp_t(::typeof(max)) = ncclMax
3535
ncclRedOp_t(::typeof(min)) = ncclMin
36+
# Handles the case where user directly passed in the ncclRedOp_t (eg. `NCCL.avg`)
37+
ncclRedOp_t(x::ncclRedOp_t) = x
3638

3739
"""
3840
NCCl.avg

test/runtests.jl

+41-16
Original file line numberDiff line numberDiff line change
@@ -25,24 +25,49 @@ end
2525
@testset "Allreduce!" begin
2626
devs = CUDA.devices()
2727
comms = NCCL.Communicators(devs)
28-
recvbuf = Vector{CuVector{Float64}}(undef, length(devs))
29-
sendbuf = Vector{CuVector{Float64}}(undef, length(devs))
30-
N = 512
31-
for (ii, dev) in enumerate(devs)
32-
CUDA.device!(ii - 1)
33-
sendbuf[ii] = CuArray(fill(Float64(ii), N))
34-
recvbuf[ii] = CUDA.zeros(Float64, N)
35-
end
36-
NCCL.group() do
37-
for ii in 1:length(devs)
38-
NCCL.Allreduce!(sendbuf[ii], recvbuf[ii], +, comms[ii])
28+
29+
@testset "sum" begin
30+
recvbuf = Vector{CuVector{Float64}}(undef, length(devs))
31+
sendbuf = Vector{CuVector{Float64}}(undef, length(devs))
32+
N = 512
33+
for (ii, dev) in enumerate(devs)
34+
CUDA.device!(ii - 1)
35+
sendbuf[ii] = CuArray(fill(Float64(ii), N))
36+
recvbuf[ii] = CUDA.zeros(Float64, N)
37+
end
38+
NCCL.group() do
39+
for ii in 1:length(devs)
40+
NCCL.Allreduce!(sendbuf[ii], recvbuf[ii], +, comms[ii])
41+
end
42+
end
43+
answer = sum(1:length(devs))
44+
for (ii, dev) in enumerate(devs)
45+
device!(ii - 1)
46+
crecv = collect(recvbuf[ii])
47+
@test all(crecv .== answer)
3948
end
4049
end
41-
answer = sum(1:length(devs))
42-
for (ii, dev) in enumerate(devs)
43-
device!(ii - 1)
44-
crecv = collect(recvbuf[ii])
45-
@test all(crecv .== answer)
50+
51+
@testset "NCCL.avg" begin
52+
recvbuf = Vector{CuVector{Float64}}(undef, length(devs))
53+
sendbuf = Vector{CuVector{Float64}}(undef, length(devs))
54+
N = 512
55+
for (ii, dev) in enumerate(devs)
56+
CUDA.device!(ii - 1)
57+
sendbuf[ii] = CuArray(fill(Float64(ii), N))
58+
recvbuf[ii] = CUDA.zeros(Float64, N)
59+
end
60+
NCCL.group() do
61+
for ii in 1:length(devs)
62+
NCCL.Allreduce!(sendbuf[ii], recvbuf[ii], NCCL.avg, comms[ii])
63+
end
64+
end
65+
answer = sum(1:length(devs)) / length(devs)
66+
for (ii, dev) in enumerate(devs)
67+
device!(ii - 1)
68+
crecv = collect(recvbuf[ii])
69+
@test all(crecv .≈ answer)
70+
end
4671
end
4772
end
4873

0 commit comments

Comments
 (0)