-
Notifications
You must be signed in to change notification settings - Fork 17
This adds a file containing Adapt functions for a number of structs. This allows Adapt.adapt to be called with structs to change the underlying data type. #195
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
…This allows Adapt.adapt to be called with structs to change the underlieng data type. For example Adapt.adapt(CUArray, Split_matrix_A) adapts the split matrix to store the values on an equivalent CUDA storage format while maintaing the SplitMatrix wrapper. Not all structs of the PartititionedArrays module are implemented yet.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for this PR. It looks good for the moment but we need some tests.
Please add some tests. Use the FakeCuVector use discussed on the notebook. You can skip the matrix if needed.
Follow the test structure for primitives_tests.jl
function Adapt.adapt_structure(to,v::SplitMatrixBlocks) | ||
own_own = Adapt.adapt(to,v.own_own) | ||
own_ghost = Adapt.adapt(to,v.own_ghost) | ||
ghost_ghost = Adapt.adapt(to,v.ghost_ghost) | ||
ghost_own = Adapt.adapt(to,v.ghost_own) | ||
split_matrix_blocks(own_own,own_ghost,ghost_own,ghost_ghost) | ||
end | ||
|
||
function Adapt.adapt_structure(to,v::SplitVectorBlocks) | ||
own = Adapt.adapt(to,v.own) | ||
ghost = Adapt.adapt(to,v.ghost) | ||
split_vector_blocks(own,ghost) | ||
end | ||
|
||
function Adapt.adapt_structure(to,v::SplitVector) | ||
blocks = Adapt.adapt(to,v.blocks) | ||
perm = Adapt.adapt(to,v.permutation) | ||
split_vector(blocks,perm) | ||
end | ||
|
||
function Adapt.adapt_structure(to,v::JaggedArray) | ||
data = Adapt.adapt_structure(to,v.data) | ||
ptrs = Adapt.adapt_structure(to,v.ptrs) | ||
jagged_array(data, ptrs) | ||
end | ||
|
||
function Adapt.adapt_structure(to,v::SplitMatrix) | ||
blocks = Adapt.adapt_structure(to,v.blocks) | ||
col_per = v.col_permutation | ||
row_per = v.row_permutation | ||
split_matrix(blocks,row_per,col_per) | ||
end | ||
|
||
function Adapt.adapt_structure(to,v::PSparseMatrix) | ||
matrix_partition = Adapt.adapt_structure(to,v.matrix_partition) | ||
col_par = v.col_partition | ||
row_par = v.row_partition | ||
PSparseMatrix(matrix_partition,row_par,col_par,v.assembled) | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can just use Adapt.@adapt_structure T
for these and the generic counterparts.
Not all structs of the PartititionedArrays module are implemented yet.