Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Poc elastic training #1310

Draft
wants to merge 10 commits into
base: main
Choose a base branch
from
71 changes: 71 additions & 0 deletions MaxText/elastic/reshard.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Resharding API for elastic training."""

from typing import Any
from typing import Callable, Optional, Sequence
import jax


def default_put_array(
arr: jax.Array,
dst_sharding: jax.sharding.Sharding,
donate_input: bool,
):
if not isinstance(dst_sharding, jax.sharding.Sharding):
raise ValueError("`sharding` must contain only `Sharding` instances.")
return jax.device_put(arr, dst_sharding, donate=donate_input)


def reshard(
x: Any,
sharding: jax.sharding.Sharding | Any,
*,
donate_input: bool = True,
put_array: Optional[
Callable[[jax.Array, Sequence[jax.sharding.Sharding], bool], jax.Array]
] = None,
) -> Any:
"""Reshards `x` to the specified `sharding`.

Args:
x: An array, scalar, or a nested Python container thereof.
sharding: A `Sharding` or a nested `Sharding` in a Python container (must
match the structure of `x`), specifying the target sharding.
donate_input: If `True`, donates the input arrays to reduce memory needed
for resharding. Donated buffers should not be reused.
put_array: A function that takes an array, a sharding, and a boolean
indicating whether to donate the input, and returns a copy of the array
with the specified sharding.

Returns:
A copy of `x` with the specified `sharding`.
"""
if put_array is None:
put_array = default_put_array

flat_x, tree_def = jax.tree_util.tree_flatten(x)
flat_sharding = jax.api_util.flatten_axes(
"reshard sharding", tree_def, sharding
)

if len(flat_x) != len(flat_sharding):
raise ValueError("Mismatched length between `x` and `sharding`.")

arrays = [
put_array(arr, dst_sharding, donate_input)
for arr, dst_sharding in zip(flat_x, flat_sharding)
]
return jax.tree_util.tree_unflatten(tree_def, arrays)

70 changes: 70 additions & 0 deletions MaxText/elastic/simulator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for elastic training."""

import logging
from typing import Any, Optional, Sequence

import jax
from pathwaysutils.google_internal.elastic import utils


PyTree = Any

logger = logging.getLogger(__name__)

logger.setLevel(logging.INFO)

# pylint: disable=logging-fstring-interpolation


class ElasticUtilsSimulator(utils.ElasticUtils):
"""Utility class for elastic training.

This class will simulate slices going down and coming back up.
"""
simulated_good_slice_indices: set[int]

def __init__(
self,
devices: Sequence[jax.Device],
total_slice_count: int,
save_period: Optional[int] = None,
reshard_check_period: Optional[int] = None,
max_failures: Optional[int] = None,
):
self.simulated_good_slice_indices = set(d.slice_index for d in devices)

super().__init__(
devices,
total_slice_count,
save_period,
reshard_check_period,
max_failures,
)

def update_good_slice_indices(self, good_slice_indices: set[int]):
"""Start step handler."""
self.simulated_good_slice_indices = good_slice_indices
logger.info(f"Updated: {self.simulated_good_slice_indices=}")

@utils.timeit
def get_slice_availability(self) -> set[int]:
"""Returns the set of good and bad slices."""
good_slice_indices = self.simulated_good_slice_indices

logger.info(f"{good_slice_indices=}")

return good_slice_indices

Loading
Loading