|
25 | 25 | @testset "Allreduce!" begin
|
26 | 26 | devs = CUDA.devices()
|
27 | 27 | comms = NCCL.Communicators(devs)
|
28 |
| - recvbuf = Vector{CuVector{Float64}}(undef, length(devs)) |
29 |
| - sendbuf = Vector{CuVector{Float64}}(undef, length(devs)) |
30 |
| - N = 512 |
31 |
| - for (ii, dev) in enumerate(devs) |
32 |
| - CUDA.device!(ii - 1) |
33 |
| - sendbuf[ii] = CuArray(fill(Float64(ii), N)) |
34 |
| - recvbuf[ii] = CUDA.zeros(Float64, N) |
35 |
| - end |
36 |
| - NCCL.group() do |
37 |
| - for ii in 1:length(devs) |
38 |
| - NCCL.Allreduce!(sendbuf[ii], recvbuf[ii], +, comms[ii]) |
| 28 | + |
| 29 | + @testset "sum" begin |
| 30 | + recvbuf = Vector{CuVector{Float64}}(undef, length(devs)) |
| 31 | + sendbuf = Vector{CuVector{Float64}}(undef, length(devs)) |
| 32 | + N = 512 |
| 33 | + for (ii, dev) in enumerate(devs) |
| 34 | + CUDA.device!(ii - 1) |
| 35 | + sendbuf[ii] = CuArray(fill(Float64(ii), N)) |
| 36 | + recvbuf[ii] = CUDA.zeros(Float64, N) |
| 37 | + end |
| 38 | + NCCL.group() do |
| 39 | + for ii in 1:length(devs) |
| 40 | + NCCL.Allreduce!(sendbuf[ii], recvbuf[ii], +, comms[ii]) |
| 41 | + end |
| 42 | + end |
| 43 | + answer = sum(1:length(devs)) |
| 44 | + for (ii, dev) in enumerate(devs) |
| 45 | + device!(ii - 1) |
| 46 | + crecv = collect(recvbuf[ii]) |
| 47 | + @test all(crecv .== answer) |
39 | 48 | end
|
40 | 49 | end
|
41 |
| - answer = sum(1:length(devs)) |
42 |
| - for (ii, dev) in enumerate(devs) |
43 |
| - device!(ii - 1) |
44 |
| - crecv = collect(recvbuf[ii]) |
45 |
| - @test all(crecv .== answer) |
| 50 | + |
| 51 | + @testset "NCCL.avg" begin |
| 52 | + recvbuf = Vector{CuVector{Float64}}(undef, length(devs)) |
| 53 | + sendbuf = Vector{CuVector{Float64}}(undef, length(devs)) |
| 54 | + N = 512 |
| 55 | + for (ii, dev) in enumerate(devs) |
| 56 | + CUDA.device!(ii - 1) |
| 57 | + sendbuf[ii] = CuArray(fill(Float64(ii), N)) |
| 58 | + recvbuf[ii] = CUDA.zeros(Float64, N) |
| 59 | + end |
| 60 | + NCCL.group() do |
| 61 | + for ii in 1:length(devs) |
| 62 | + NCCL.Allreduce!(sendbuf[ii], recvbuf[ii], NCCL.avg, comms[ii]) |
| 63 | + end |
| 64 | + end |
| 65 | + answer = sum(1:length(devs)) / length(devs) |
| 66 | + for (ii, dev) in enumerate(devs) |
| 67 | + device!(ii - 1) |
| 68 | + crecv = collect(recvbuf[ii]) |
| 69 | + @test all(crecv .≈ answer) |
| 70 | + end |
46 | 71 | end
|
47 | 72 | end
|
48 | 73 |
|
|
0 commit comments