Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: correct handling of wrapped arrays functionalities #342

Merged
merged 18 commits into from
Dec 11, 2024
Merged

Conversation

avik-pal
Copy link
Collaborator

@avik-pal avik-pal commented Dec 8, 2024

fixes #339

The main trick here is to preserve ReshapedArray type. Applies the fix for the following types for the time being:

  1. ReshapedArray
  2. Adjoint
  3. Transpose
  4. Diagonal
  5. PermutedDimsArray

Adds the following LinearAlgebra functionality

  1. diag
  2. diagm

Updated NNlibExt to handle wrappers correctly

@avik-pal avik-pal force-pushed the ap/test_failures branch 2 times, most recently from b2d9a9f to 4a3c633 Compare December 8, 2024 16:12
@avik-pal avik-pal force-pushed the ap/upsampling branch 6 times, most recently from b8cd2d1 to f923616 Compare December 9, 2024 04:41
@avik-pal avik-pal changed the title fix: preserve parent array tracking for reshape feat: correct handling of wrapped arrays functionalities Dec 9, 2024
@avik-pal avik-pal marked this pull request as ready for review December 9, 2024 10:16
@avik-pal avik-pal linked an issue Dec 9, 2024 that may be closed by this pull request
@avik-pal avik-pal requested review from mofeing and wsmoses December 9, 2024 10:29
Reactant.MLIR.Dialects.stablehlo.maximum, T.(x), pdims; init=typemin(T)
).mlir_data
res = reduce_window(
Reactant.MLIR.Dialects.stablehlo.maximum, T.(x), pdims; init=typemin(T)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe this can already be replaced by Ops.maximum?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The first argument here is being used directly with block arguments. It would be an unnecessary in-direction to go down the Ops route here

@mofeing
Copy link
Collaborator

mofeing commented Dec 9, 2024

it fixes my code!

using Tenet#master with this

using Tenet, Reactant, Adapt
julia> a = Tensor([1 0; 1 0], [:j, :i])
2×2 Tensor{Int64, 2, Matrix{Int64}}:
 1  0
 1  0

julia> b = Tensor([0 0; 1 1], [:i, :j])
2×2 Tensor{Int64, 2, Matrix{Int64}}:
 0  0
 1  1

julia> a + b # this is correct
2×2 Tensor{Int64, 2, Matrix{Int64}}:
 1  1
 1  1

julia> parent(a) + parent(b) # this is wrong
2×2 Matrix{Int64}:
 1  0
 2  1

before this PR

julia> ar = adapt(ConcreteRArray, a)
2×2 Tensor{Int64, 2, ConcreteRArray{Int64, 2}}:
 1  0
 1  0

julia> br = adapt(ConcreteRArray, b)
2×2 Tensor{Int64, 2, ConcreteRArray{Int64, 2}}:
 0  0
 1  1

julia> @jit ar + br # it's taking the parent array, not transposed
2×2 Tensor{Int64, 2, ConcreteRArray{Int64, 2}}:
 1  0
 2  1

after this PR

julia> @jit ar + br
2×2 Tensor{Int64, 2, ConcreteRArray{Int64, 2}}:
 1  1
 1  1

Base automatically changed from ap/test_failures to main December 10, 2024 06:31
@avik-pal avik-pal requested a review from mofeing December 10, 2024 07:37
Copy link
Collaborator

@mofeing mofeing left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

everythings fine but would you mind adding some more tests for Ops.reshape? we probably missed the bug you're talking about

Comment on lines +119 to +123
res = MLIR.IR.result(
MLIR.Dialects.stablehlo.dynamic_gather(
get_mlir_data(y), idxs, slice_sizes; dimension_numbers
),
1,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ahhh i didn't think about using dynamic_gather in this way. i feared about using dynamic_slice for each diagonal element. nice!

@avik-pal
Copy link
Collaborator Author

The only additional CI failure will be resolved by #362

@avik-pal avik-pal merged commit 814e9c0 into main Dec 11, 2024
18 of 35 checks passed
@avik-pal avik-pal deleted the ap/upsampling branch December 11, 2024 03:46
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Incorrect traced code for upsampling functions
2 participants