Skip to content

Commit

Permalink
refactor: rearrange the tests
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Dec 11, 2024
1 parent 7c715d4 commit 2e891ac
Showing 1 changed file with 36 additions and 45 deletions.
81 changes: 36 additions & 45 deletions test/wrapped_arrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,11 +94,18 @@ function bypass_permutedims(x)
return view(x, 2:3, 1:2, :)
end

add_perm_dims(x) = x .+ PermutedDimsArray(x, (2, 1))

@testset "PermutedDimsArray" begin
x = rand(4, 4, 3)
x_ra = Reactant.to_rarray(x)
y_ra = @jit(bypass_permutedims(x_ra))
@test @allowscalar(Array(y_ra)) bypass_permutedims(x)

x = rand(4, 4)
x_ra = Reactant.to_rarray(x)

@test @jit(add_perm_dims(x_ra)) add_perm_dims(x)
end

function writeto_reshaped_array!(x)
Expand All @@ -108,76 +115,60 @@ function writeto_reshaped_array!(x)
return z1
end

@testset "writeto_reshaped_array!" begin
x = ConcreteRArray(rand(3, 2))
y = @jit writeto_reshaped_array!(x)
@test all(isone, Array(y))
end

function write_to_transposed_array!(x)
z1 = similar(x)
z2 = transpose(z1)
@. z2 = 1.0
return z1
end

@testset "write_to_transposed_array!" begin
x = ConcreteRArray(rand(3, 2))
y = @jit write_to_transposed_array!(x)
@test all(isone, Array(y))
end

function write_to_adjoint_array!(x)
z1 = similar(x)
z2 = adjoint(z1)
@. z2 = 1.0
return z1
end

@testset "write_to_adjoint_array!" begin
x = ConcreteRArray(rand(3, 2))
y = @jit write_to_adjoint_array!(x)
@test all(isone, Array(y))
end

add_perm_dims(x) = x .+ PermutedDimsArray(x, (2, 1))

@testset "add_perm_dims" begin
x = rand(4, 4)
x_ra = Reactant.to_rarray(x)

@test @jit(add_perm_dims(x_ra)) add_perm_dims(x)
end

function write_to_permuted_dims_array!(x)
z1 = similar(x)
z2 = PermutedDimsArray(z1, (2, 1))
@. z2 = 1.0
return z1
end

@testset "write_to_permuted_dims_array!" begin
x = rand(4, 4)
x_ra = Reactant.to_rarray(x)

@test @jit(write_to_permuted_dims_array!(x_ra)) write_to_permuted_dims_array!(x)
end

function write_to_diagonal_array!(x)
z = Diagonal(x)
@. z = 1.0
return z
end

@testset "write_to_diagonal_array!" begin
x = rand(4, 4)
x_ra = Reactant.to_rarray(x)
y_ra = copy(x_ra)

y = @jit(write_to_diagonal_array!(x_ra))
y_res = @allowscalar Array(y)
@test x_ra y_ra
@test all(isone, diag(y_res))
y_res[diagind(y_res)] .= 0
@test all(iszero, y_res)
@testset "Preserve Aliasing with Parent" begin
@testset "$(aType)" for (aType, fn) in [
("ReshapedArray", writeto_reshaped_array!),
("Transpose", write_to_transposed_array!),
("Adjoint", write_to_adjoint_array!),
]
x = ConcreteRArray(rand(3, 2))
y = @jit fn(x)
@test all(isone, Array(y))
end

@testset "PermutedDimsArray" begin
x = rand(4, 4)
x_ra = Reactant.to_rarray(x)
@test @jit(write_to_permuted_dims_array!(x_ra)) write_to_permuted_dims_array!(x)
end

@testset "Diagonal" begin
x = rand(4, 4)
x_ra = Reactant.to_rarray(x)
y_ra = copy(x_ra)

y = @jit(write_to_diagonal_array!(x_ra))
y_res = @allowscalar Array(y)
@test x_ra y_ra
@test all(isone, diag(y_res))
y_res[diagind(y_res)] .= 0
@test all(iszero, y_res)
end
end

0 comments on commit 2e891ac

Please sign in to comment.