-
Notifications
You must be signed in to change notification settings - Fork 301
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enable Domain Parallelism with ShardTensor #784
Merged
Merged
Conversation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
…simple DDP sharding
…ieces are WIP but this has basic functionality supported for creation and forward usage.
…t of the ops have been validated, all that remains is to wrap the na2d function call to ensure it will dispatch properly.
…ng in unbind op rules.
….ops.aten.convolution.default.
…s also a minor bug in the backward pass that got more pronounced with smaller data: grad inputs were failing to properly collect haloed gradients and add them on the edges. Now fixed.
…gnificant overhead. I'm implementing here an option to switch to peer to peer message passing, since it might benefit from stream utilization in layers like natten.na2d. It's a developer choice currently, not a user choice.
…gnificant functionality changes in this commit.
Add `scatter_tensor` function to enable more easy transition to shard tensor. This function allows users to maintain data pipelines (on one rank) and easily scatter that data to a domain mesh.
But also, this adjusts the shard tensor mechanism for tracking shard info to use a dict instead of a list of tuples.
No real code changes applied here.
/blossom-ci |
1 similar comment
/blossom-ci |
Reminder to change the target to |
- First, the ability to transpose the sharding dimensions is supported. For square submeshs, 2x2 for example, the output sharding will match the input sharding if it's uneven. This can only be supported if the number of devices in the output mesh dimension is equal to the input dimension, hence the restriction on square submeshes. Other scenarios will apply dtensor-like chunk syntax, but return a shard tensor tracking that split. Comprehensive tests on 1D and 2D meshes are included here. No testing is done at this time on 3D sharding / meshes. - Second, the issues with torch.mean are intercepted and fixed. This uses a new dispatch intercept (below) and applies a weight to the mean, and converts the Partial placement to a Partial(sum) with the weight applied. This has a bug that appears to be present in DTensor too: reductions over non-sharded dimensions appear to falter. To be fixed in a future release. - Third, ShardTensor has a new class attribute to accomodate operator interceptions. The only applied function at this time are variants of aten.mean, however, it is expected to convert all monkey patching to this syntax.
…don't require them to trigger elsewhere. If ShardTensor is used, the patches get applied. Also, minor updates to docs.
Target updated to 0.1.0.0-rc. All smaller comments from @akshaysubr resolved. Most residual bugs fixed. This version of the release supports ShardTensor for select models to enable domain parallelism, not yet as a general tool. |
/multi-gpu-ci |
1 similar comment
/multi-gpu-ci |
…Tests are coming in the next commit immediately after.
…sharding. Further, fixed an annoying bug in other distributed tests where OS environs weren't cleared after testing, and tsome tests would fail but only if others ran first. Now, all distributed tests use a context manager to change OS environment variables locally only.
/multi-gpu-ci |
- Enable dynamic (off by default) wrapping of layers by shard tensor. they get turned on automatically when a shard tensor is created. - Rename the utils to manage env variables. Tests are failing with unusual CPU errors on ORD. Moving to github runners ...
/multi-gpu-ci |
/blossom-ci |
/blossom-ci |
akshaysubr
approved these changes
Feb 20, 2025
pzharrington
approved these changes
Feb 20, 2025
ktangsali
pushed a commit
that referenced
this pull request
Mar 18, 2025
* Enable mesh-based parallelism as the configuration backend, even for simple DDP sharding * Fix small typo in docstring * Remove unnecessary functions with new interface * Adding first implementation of ShardTensor prototype. Still several pieces are WIP but this has basic functionality supported for creation and forward usage. * Working implementation of ShardTensor, though still somewhate incomplete. * Adding work-in-progress examples. Be careful of sharp edges! * A few more example pieces before natten will work out of the box. Most of the ops have been validated, all that remains is to wrap the na2d function call to ensure it will dispatch properly. * Fix naming scheme * Minor name change * Add monkey patching for na2d operation with shard tensors * Fix bug in shard tensor inference of globla size. CHeck agains sharding in unbind op rules. * Enable backwards gradients for halo sharding and natten patch * Convolution 2d backwards works, though would be better to catch torch.ops.aten.convolution.default. * Fix missing import and ensure tensors are contiguous before allgather_v * Clean up and remove unnecessary noise and printouts for debugging * Unify (and correct!) the sharded convolution implementation. There was also a minor bug in the backward pass that got more pronounced with smaller data: grad inputs were failing to properly collect haloed gradients and add them on the edges. Now fixed. * Remove noise from sharding utils. * For smaller tensors, the alltoall step of halo reductions might be significant overhead. I'm implementing here an option to switch to peer to peer message passing, since it might benefit from stream utilization in layers like natten.na2d. It's a developer choice currently, not a user choice. * Remove shard_utils file, it is a subfolder. * Add modulus ShardTensor api documentation * Clean up doc strings, type annotations and mesh implementation. No significant functionality changes in this commit. * Add significant docstring / type annotation cleanup to ShardTensor. Add `scatter_tensor` function to enable more easy transition to shard tensor. This function allows users to maintain data pipelines (on one rank) and easily scatter that data to a domain mesh. * Remove neighborhood attention prototypes * Remove the rest of these examples since they are outdated and unnecessary * Mostly, this commit is adding type annotations and doc strings. But also, this adjusts the shard tensor mechanism for tracking shard info to use a dict instead of a list of tuples. * Clean up and document conv patches. No real code changes applied here. * clean up and improve documentation and type hints for shard utils worker functions * Adding basic tests for shard tensor initialization and redistribution. There appears to be one corner case in redistribute to fix. TBD. Tests for grad propogation are coming. * Add full working example of multilevel parallelism with pytorch FSDP and modulus ShardTensor * Add missing type annotations * Ensure scatter_tensor is available to import from modulus.distributed * Update changelog and ensure wrapt is a optional dependency * Update fsdp_and_shard_tensor.rst Update tutorial based on feedback from @pzharrington * Update __init__.py Remove wildcard import. * Update shard_tensor.py fix spacing * This is an essential bug fix for a missing import * Update branch to pass CI tests. * This commit provides several pieces: - First, the ability to transpose the sharding dimensions is supported. For square submeshs, 2x2 for example, the output sharding will match the input sharding if it's uneven. This can only be supported if the number of devices in the output mesh dimension is equal to the input dimension, hence the restriction on square submeshes. Other scenarios will apply dtensor-like chunk syntax, but return a shard tensor tracking that split. Comprehensive tests on 1D and 2D meshes are included here. No testing is done at this time on 3D sharding / meshes. - Second, the issues with torch.mean are intercepted and fixed. This uses a new dispatch intercept (below) and applies a weight to the mean, and converts the Partial placement to a Partial(sum) with the weight applied. This has a bug that appears to be present in DTensor too: reductions over non-sharded dimensions appear to falter. To be fixed in a future release. - Third, ShardTensor has a new class attribute to accomodate operator interceptions. The only applied function at this time are variants of aten.mean, however, it is expected to convert all monkey patching to this syntax. * Update monkey patching to ensure patches get applied by modulus, and don't require them to trigger elsewhere. If ShardTensor is used, the patches get applied. Also, minor updates to docs. * Codify ShardTensor and FSDP in tutorials. * Apparently, codify'ing in rst requires double ticks. * This commit fixes gradient propagation for unevenly sharded tensors. Tests are coming in the next commit immediately after. * Add tests for shard tensor: initialization, resharding, and gradient sharding. Further, fixed an annoying bug in other distributed tests where OS environs weren't cleared after testing, and tsome tests would fail but only if others ran first. Now, all distributed tests use a context manager to change OS environment variables locally only. * Two things done here: - Enable dynamic (off by default) wrapping of layers by shard tensor. they get turned on automatically when a shard tensor is created. - Rename the utils to manage env variables. Tests are failing with unusual CPU errors on ORD. Moving to github runners ... * Disable patched operations by default.
ktangsali
pushed a commit
that referenced
this pull request
Mar 18, 2025
* Enable mesh-based parallelism as the configuration backend, even for simple DDP sharding * Fix small typo in docstring * Remove unnecessary functions with new interface * Adding first implementation of ShardTensor prototype. Still several pieces are WIP but this has basic functionality supported for creation and forward usage. * Working implementation of ShardTensor, though still somewhate incomplete. * Adding work-in-progress examples. Be careful of sharp edges! * A few more example pieces before natten will work out of the box. Most of the ops have been validated, all that remains is to wrap the na2d function call to ensure it will dispatch properly. * Fix naming scheme * Minor name change * Add monkey patching for na2d operation with shard tensors * Fix bug in shard tensor inference of globla size. CHeck agains sharding in unbind op rules. * Enable backwards gradients for halo sharding and natten patch * Convolution 2d backwards works, though would be better to catch torch.ops.aten.convolution.default. * Fix missing import and ensure tensors are contiguous before allgather_v * Clean up and remove unnecessary noise and printouts for debugging * Unify (and correct!) the sharded convolution implementation. There was also a minor bug in the backward pass that got more pronounced with smaller data: grad inputs were failing to properly collect haloed gradients and add them on the edges. Now fixed. * Remove noise from sharding utils. * For smaller tensors, the alltoall step of halo reductions might be significant overhead. I'm implementing here an option to switch to peer to peer message passing, since it might benefit from stream utilization in layers like natten.na2d. It's a developer choice currently, not a user choice. * Remove shard_utils file, it is a subfolder. * Add modulus ShardTensor api documentation * Clean up doc strings, type annotations and mesh implementation. No significant functionality changes in this commit. * Add significant docstring / type annotation cleanup to ShardTensor. Add `scatter_tensor` function to enable more easy transition to shard tensor. This function allows users to maintain data pipelines (on one rank) and easily scatter that data to a domain mesh. * Remove neighborhood attention prototypes * Remove the rest of these examples since they are outdated and unnecessary * Mostly, this commit is adding type annotations and doc strings. But also, this adjusts the shard tensor mechanism for tracking shard info to use a dict instead of a list of tuples. * Clean up and document conv patches. No real code changes applied here. * clean up and improve documentation and type hints for shard utils worker functions * Adding basic tests for shard tensor initialization and redistribution. There appears to be one corner case in redistribute to fix. TBD. Tests for grad propogation are coming. * Add full working example of multilevel parallelism with pytorch FSDP and modulus ShardTensor * Add missing type annotations * Ensure scatter_tensor is available to import from modulus.distributed * Update changelog and ensure wrapt is a optional dependency * Update fsdp_and_shard_tensor.rst Update tutorial based on feedback from @pzharrington * Update __init__.py Remove wildcard import. * Update shard_tensor.py fix spacing * This is an essential bug fix for a missing import * Update branch to pass CI tests. * This commit provides several pieces: - First, the ability to transpose the sharding dimensions is supported. For square submeshs, 2x2 for example, the output sharding will match the input sharding if it's uneven. This can only be supported if the number of devices in the output mesh dimension is equal to the input dimension, hence the restriction on square submeshes. Other scenarios will apply dtensor-like chunk syntax, but return a shard tensor tracking that split. Comprehensive tests on 1D and 2D meshes are included here. No testing is done at this time on 3D sharding / meshes. - Second, the issues with torch.mean are intercepted and fixed. This uses a new dispatch intercept (below) and applies a weight to the mean, and converts the Partial placement to a Partial(sum) with the weight applied. This has a bug that appears to be present in DTensor too: reductions over non-sharded dimensions appear to falter. To be fixed in a future release. - Third, ShardTensor has a new class attribute to accomodate operator interceptions. The only applied function at this time are variants of aten.mean, however, it is expected to convert all monkey patching to this syntax. * Update monkey patching to ensure patches get applied by modulus, and don't require them to trigger elsewhere. If ShardTensor is used, the patches get applied. Also, minor updates to docs. * Codify ShardTensor and FSDP in tutorials. * Apparently, codify'ing in rst requires double ticks. * This commit fixes gradient propagation for unevenly sharded tensors. Tests are coming in the next commit immediately after. * Add tests for shard tensor: initialization, resharding, and gradient sharding. Further, fixed an annoying bug in other distributed tests where OS environs weren't cleared after testing, and tsome tests would fail but only if others ran first. Now, all distributed tests use a context manager to change OS environment variables locally only. * Two things done here: - Enable dynamic (off by default) wrapping of layers by shard tensor. they get turned on automatically when a shard tensor is created. - Rename the utils to manage env variables. Tests are failing with unusual CPU errors on ORD. Moving to github runners ... * Disable patched operations by default.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Labels
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Modulus Pull Request
Description
This PR adds new capabilities to Modulus:
ShardTensor
is an extension to pytorchDTensor
that enables uneven sharding of tensors across DeviceMesh objects. While some logical sharding constraints remain, this allows more dynamic and flexible operation on distributed input data, especially in cases where the input data shape and output data shape differ.ShardTensor
also enables an ecosystem of operation extensions. Two major ones are included in this PR: convolutions (1D/2D/3D) and neighborhood attention. When the right components of modulus are imported, these operations (when performed on sharded tensors) will automatically compute halo regions and perform data transfers to enable results consistent with single device outputs.ShardTensor
, as well as an example of integrating multiple levels of parallelism by combining shard tensor and pytorchFSDP
.Checklist
Dependencies
Adds a dependency on
wrapt
for monkey-patching operations on sharded inputs..