Skip to content

Commit

Permalink
Allow sharding fine tuned misaligned models (#126)
Browse files Browse the repository at this point in the history
* fix(Makefile): re-add style target

* feat(jetstream): pad weights to support unaligned sharding

When loading large models, weights are sharded across a mesh of TPUs,
splitting the original weights into smaller tensors, each one with the
same shape.
This is not possible, however, if the original weights shape is not
divisible across the number of TPUs, because it results in a smaller
tensor for the last TPU.
This change pads the tensor with zeros, making it splittable across the
TPUs.
  • Loading branch information
tengomucho authored Dec 9, 2024
1 parent dadc0a6 commit e2c5ac2
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 2 deletions.
1 change: 1 addition & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ style_check:
ruff check .

style:
ruff check . --fix

# Utilities to release to PyPi
build_dist_install_tools:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import TYPE_CHECKING, Any

import jax
import jax.numpy as jnp
import torch
from jetstream_pt import fetch_models, quantize_model, torchjax
from jetstream_pt.engine import PyTorchEngine
Expand Down Expand Up @@ -140,16 +141,40 @@ def instantiate_model_from_repo_id(
return model


def _get_needed_padding(value: int, multiple: int) -> int:
return (multiple - value % multiple) % multiple

def _pad_array_up_to(v: jnp.ndarray, axis: int, multiple: int) -> jnp.ndarray:
a = [(0, 0) for _ in range(len(v.shape))]
a[axis] = (0, _get_needed_padding(v.shape[axis], multiple))
return jnp.pad(v, a)

def pad_to_shard(env, val, axis: int):
# if axis is -1, then no sharding is done, everything is replicated
if axis == -1 or axis is None:
return val
sharding = env.sharding_by_axis(axis)
axis_name = sharding.spec[axis]
size_to_pad = env.mesh.shape[axis_name]
padded_val = _pad_array_up_to(val, axis, size_to_pad)
# Note that, even if the the value is padded, the model will only use the part that is needed.
if val.shape != padded_val.shape:
logger.debug(f"Sharding resulting in padding weights from {val.shape} to {padded_val.shape}")
return padded_val


def shard_weights(env, weights, weight_shardings):
"""Shard weights according to weight_shardings"""
for k, v in weight_shardings.items():
logger.debug(f"SHARDING {k} {v}")
sharded = {}
for key, val in weights.items():
sharding = env.sharding_by_axis(weight_shardings.get(key, -1))
axis = weight_shardings.get(key, -1)
sharding = env.sharding_by_axis(axis)
with jax.default_device(jax.devices("cpu")[0]):
# Note we clone to avoid a core-dump that might happen otherwise when calling device_put
arr = torch_xla2.tensor.t2j(val.clone())
arr = pad_to_shard(env, arr, axis)
arr = jax.device_put(arr, sharding)
sharded[key] = torchjax.to_torch(arr)
return sharded
Expand Down
7 changes: 6 additions & 1 deletion text-generation-inference/tests/test_decode_jetstream.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,13 @@ def test_decode_single_jetstream_pytorch_slow(params, do_sample):
sequence_length=512,
expected_text="манaminationVariableßer Rog malesazine longふ Toy Champions enero Facereverse▲verbose prosecut literally disappearedअ",
),
DecodeTestParams(
model_id="Trendyol/Trendyol-LLM-7b-base-v0.1",
sequence_length=512,
expected_text="\nThe clocks were striking thirteen, and the clocks were striking thirteen.",
),
],
ids=["TinyLLama-v0", "gemma-2b", "Mixtral-tiny"],
ids=["TinyLLama-v0", "gemma-2b", "Mixtral-tiny", "Trendyol-LLM-7b-base-v0.1"],
)
def test_decode_single_jetstream_pytorch(params, do_sample):
params.do_sample = do_sample
Expand Down

0 comments on commit e2c5ac2

Please sign in to comment.