-
Notifications
You must be signed in to change notification settings - Fork 24
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Allow sharding fine tuned misaligned models (#126)
* fix(Makefile): re-add style target * feat(jetstream): pad weights to support unaligned sharding When loading large models, weights are sharded across a mesh of TPUs, splitting the original weights into smaller tensors, each one with the same shape. This is not possible, however, if the original weights shape is not divisible across the number of TPUs, because it results in a smaller tensor for the last TPU. This change pads the tensor with zeros, making it splittable across the TPUs.
- Loading branch information
1 parent
dadc0a6
commit e2c5ac2
Showing
3 changed files
with
33 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters