Skip to content

Commit

Permalink
Reflect #1358, which enables replica_parallel saving. Uses cache to a…
Browse files Browse the repository at this point in the history
…void recomputing `devices_indices_map` for arrays with identical shape/sharding.

PiperOrigin-RevId: 700063349
  • Loading branch information
cpgaffney1 authored and Orbax Authors committed Nov 25, 2024
1 parent 3a32a23 commit 46d159e
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 29 deletions.
2 changes: 2 additions & 0 deletions checkpoint/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Tests and documentation for `abstract_arrays`.
- [Experimental Feature] Support `empty NamedTuple` leaf when
`PyTreeMetadataOptions.support_rich_types=true`.
- Enable `replica_parallel` saving. Uses cache to avoid recomputing
`devices_indices_map` for arrays with identical shape/sharding.

### Fixed
- Fix namedtuple empty value typestr when experimental support_rich_types is
Expand Down
59 changes: 40 additions & 19 deletions checkpoint/orbax/checkpoint/_src/serialization/replica_slices.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@

import collections
import dataclasses
import functools
import math
from typing import Callable, Optional, Sequence
from typing import Optional, Sequence

from absl import logging
import jax
Expand Down Expand Up @@ -141,26 +142,51 @@ def _hashable_slices(slices: Index) -> tuple[HashableSlice, ...]:
return tuple([(s.start, s.stop, s.step) for s in slices])


def _create_replica_counts_builder(
arr: jax.Array,
) -> Callable[[jax.Shard], int]:
"""Produces a mapping from addressable shards to global replication count."""
@functools.lru_cache(maxsize=4096)
def _sharding_num_replicas(
sharding: jax.sharding.Sharding, global_shape: Shape
) -> int:
"""Get the number of unique replicas for a sharding/shape.
Uses the devices_indices_map to get the mapping of devices to the slice of the
global array. This gives us the domains of every shard, which may be
non-unique. For any index (domain), we increment the count by one. When `n`
devices have the same index, this results in the replica count for that index
being `n`. We can assert that the number of replicas for each index should be
the same.
We can cache results because we typically expect `save` to be called
repeatedly on the same model (with changing array values).
The model shardings and shapes do not change during the course of a typical
training run.
Training typically occurs with stacked layers, so we expect the number of
model parameters to be significantly less than the cache size. Checkpoints
with unstacked layers may have thousands of parameters, but these are
typically used for inference, so saving is less relevant.
Args:
sharding: Array sharding.
global_shape: The global shape of the array.
Returns:
The number of unique replicas for the sharding/shape.
"""
counts = collections.defaultdict(int)
for index in arr.sharding.devices_indices_map(arr.shape).values():
for index in sharding.devices_indices_map(global_shape).values():
counts[_hashable_slices(index)] += 1

return lambda shard: counts[_hashable_slices(shard.index)]
num_replicas = next(iter(counts.values()))
assert all(count == num_replicas for count in counts.values())
return num_replicas


def calculate_replica_parallel_axis_and_local_shape(
arr: jax.Array, replica_count: int
arr: jax.Array,
) -> OptionalAxisAndShape:
"""Calculates a local shape for replica-parallel serialization."""
shard0 = arr.addressable_shards[0]
if shard0.data.size == 0 or replica_count == 1:
return None, None
if replica_count < 1:
replica_count = _sharding_num_replicas(arr.sharding, arr.shape)
if shard0.data.size == 0 or replica_count <= 1:
return None, None
try:
axis = next(
Expand Down Expand Up @@ -231,18 +257,13 @@ def maybe_pick_replica_parallel() -> Optional[Result]:
# Check whether replica-parallel applies: we are dealing with non-empty
# shards, we have more than one replica, and some dimension of the shards
# is evenly divisible across replicas.
get_replica_counts = _create_replica_counts_builder(arr)
replica_count = get_replica_counts(shard0)
axis, local_shape = calculate_replica_parallel_axis_and_local_shape(
arr, replica_count
)
axis, local_shape = calculate_replica_parallel_axis_and_local_shape(arr)
if axis is None or local_shape is None:
return None

rslices: list[ReplicaSlice] = []
for shard in arr.addressable_shards:
# Sanity check that all shards have the same number of replicas and shape.
assert get_replica_counts(shard) == replica_count
# Sanity check that all shards have the same shape.
assert shard.data.shape == shard0.data.shape

size = local_shape[axis]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -829,7 +829,7 @@ def __init__(
metadata_key: Optional[str] = None,
primary_host: Optional[int] = 0,
replica_id: Optional[int] = 0,
use_replica_parallel: bool = False,
use_replica_parallel: bool = True,
enable_write_sharding_file: bool = True,
):
"""Constructor.
Expand Down
13 changes: 4 additions & 9 deletions checkpoint/orbax/checkpoint/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
from orbax.checkpoint._src.multihost import multislice
from orbax.checkpoint._src.path import atomicity
from orbax.checkpoint._src.path import step as step_lib
from orbax.checkpoint._src.serialization import replica_slices
from orbax.checkpoint._src.serialization import serialization
from orbax.checkpoint._src.serialization import tensorstore_utils as ts_utils
from orbax.checkpoint._src.serialization import type_handlers
Expand Down Expand Up @@ -736,15 +737,9 @@ def assert_every_n_is_x_apart(testclass, values, n, x):

def get_expected_chunk_shape(arr: jax.Array) -> tuple[int, ...]:
"""Expected chunk shape for an array, accounting for replica-parallel."""
# TODO(cpgaffney): Enable once replica-parallel is enabled.
local_shape = None
# get_replica_counts = replica_slices._create_replica_counts_builder(arr)
# replica_count = get_replica_counts(arr.addressable_shards[0])
# _, local_shape = (
# replica_slices.calculate_replica_parallel_axis_and_local_shape(
# arr, replica_count
# )
# )
_, local_shape = (
replica_slices.calculate_replica_parallel_axis_and_local_shape(arr)
)
if local_shape is None:
local_shape = arr.sharding.shard_shape(arr.shape)
return local_shape

0 comments on commit 46d159e

Please sign in to comment.