Skip to content

Change job name while uploading sharding plan to Manifold #3092

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions torchrec/distributed/planner/enumerators.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

# pyre-strict

import copy
import logging
from typing import Dict, List, Optional, Set, Tuple, Union

Expand Down Expand Up @@ -102,6 +103,11 @@ def __init__(
EmbeddingStorageEstimator(topology=topology, constraints=constraints),
]

# Initializing caching for enumerate
self._last_stored_search_space: Optional[List[ShardingOption]] = None
self._last_stored_module: Optional[nn.Module] = None
self._last_stored_sharders: Optional[List[ModuleSharder[nn.Module]]] = None

def enumerate(
self,
module: nn.Module,
Expand All @@ -118,6 +124,12 @@ def enumerate(
List[ShardingOption]: valid sharding options with values populated.
"""

if (
self._last_stored_module == module
and self._last_stored_sharders == sharders
):
return copy.deepcopy(self._last_stored_search_space) # pyre-ignore

self._sharder_map = {
sharder_name(sharder.module_type): sharder for sharder in sharders
}
Expand Down Expand Up @@ -230,8 +242,20 @@ def enumerate(

self.populate_estimates(sharding_options)

self._last_stored_module = module
self._last_stored_sharders = sharders

# Caching the search space with a copy of sharding options, to avoid unexpected modifications to list
self._last_stored_search_space = copy.deepcopy(sharding_options)
return sharding_options

@property
def last_stored_search_space(self) -> Optional[List[ShardingOption]]:
# NOTE: This is the last search space stored by enumerate(...), do not use
# this field in place of actually calling enumerate(...) as it will varie for each
# module/sharders passed in.
return self._last_stored_search_space

def populate_estimates(self, sharding_options: List[ShardingOption]) -> None:
for estimator in self._estimators:
estimator.estimate(sharding_options, self._sharder_map)
Expand Down
14 changes: 6 additions & 8 deletions torchrec/distributed/planner/planners.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
)
from torchrec.distributed.planner.types import (
Enumerator,
hash_planner_context_inputs,
ParameterConstraints,
Partitioner,
PerfModel,
Expand Down Expand Up @@ -280,25 +281,21 @@ def collective_plan(
sharders,
)

def hash_planner_context_inputs(self) -> str:
def hash_planner_context_inputs(self) -> int:
"""
Generates a hash for all planner inputs except for partitioner, proposer, performance model, and stats.
These are all the inputs needed to verify whether a previously generated sharding plan is still valid in a new context.

Returns:
Generates a hash capturing topology, batch size, enumerator, storage reservation, stats and constraints.
"""
hashable_list = [
return hash_planner_context_inputs(
self._topology,
self._batch_size,
self._enumerator,
self._storage_reservation,
frozenset(self._constraints.items()) if self._constraints else None,
]
serialized_list = str(hashable_list).encode("utf-8")
hash_object = hashlib.sha256(serialized_list)
hash_digest = hash_object.hexdigest()
return hash_digest
self._constraints,
)

def plan(
self,
Expand Down Expand Up @@ -499,6 +496,7 @@ def plan(
best_plan=last_proposal,
constraints=self._constraints,
sharders=sharders,
enumerator=self._enumerator,
debug=self._debug,
)

Expand Down
21 changes: 21 additions & 0 deletions torchrec/distributed/planner/storage_reservations.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ class FixedPercentageStorageReservation(StorageReservation):
def __init__(self, percentage: float) -> None:
assert percentage >= 0 and percentage <= 1
self._percentage: float = percentage
self._last_reserved_topology: Optional[Topology] = None

def reserve(
self,
Expand All @@ -174,8 +175,14 @@ def reserve(
) -> Topology:
reserved_topology = copy.deepcopy(topology)
_reserve_storage_percentage(reserved_topology, self._percentage)
self._last_reserved_topology = reserved_topology
return reserved_topology

@property
def last_reserved_topology(self) -> Optional[Topology]:
"Returns a copy of the cached value of the most recent output from the reserve() method."
return copy.deepcopy(self._last_reserved_topology)


class HeuristicalStorageReservation(StorageReservation):
"""
Expand Down Expand Up @@ -206,6 +213,7 @@ def __init__(

self._dense_storage: Optional[Storage] = None
self._kjt_storage: Optional[Storage] = None
self._last_reserved_topology: Optional[Topology] = None

def reserve(
self,
Expand All @@ -215,6 +223,7 @@ def reserve(
sharders: List[ModuleSharder[nn.Module]],
constraints: Optional[Dict[str, ParameterConstraints]] = None,
) -> Topology:
# TODO: enable proper caching of topology values through _last_reserved_topology
reserved_topology = copy.deepcopy(topology)

batch_inputs, shardable_modules = _get_batch_inputs_and_shardable_parameters(
Expand Down Expand Up @@ -262,8 +271,14 @@ def reserve(
message=negative_storage_solution,
)

self._last_reserved_topology = copy.deepcopy(reserved_topology)
return reserved_topology

@property
def last_reserved_topology(self) -> Optional[Topology]:
"Cached value of the most recent output from the reserve() method."
return self._last_reserved_topology


class InferenceStorageReservation(StorageReservation):
"""
Expand Down Expand Up @@ -291,6 +306,7 @@ def __init__(

self._dense_storage: Optional[Storage] = None
self._kjt_storage: Optional[Storage] = None
self._last_reserved_topology: Optional[Topology] = None

def reserve(
self,
Expand Down Expand Up @@ -324,4 +340,9 @@ def reserve(
multiplier=1,
)

self._last_reserved_topology = copy.deepcopy(reserved_topology)

return reserved_topology

def last_reserved_topology(self) -> Optional[Topology]:
return copy.deepcopy(self._last_reserved_topology)
28 changes: 27 additions & 1 deletion torchrec/distributed/planner/tests/test_planners.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

import torch
from torch import nn
from torchrec import EmbeddingConfig
from torchrec import EmbeddingBagCollection, EmbeddingConfig
from torchrec.distributed.embedding import EmbeddingCollectionSharder
from torchrec.distributed.embedding_types import EmbeddingComputeKernel
from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder
Expand Down Expand Up @@ -306,6 +306,22 @@ def test_passing_info_through_constraints(self) -> None:
class TestEmbeddingShardingHashPlannerContextInputs(unittest.TestCase):

def setUp(self) -> None:
eb_config = EmbeddingBagConfig(
name="table_0",
embedding_dim=160,
num_embeddings=10000,
feature_names=["f1"],
data_type=DataType.FP16,
)
module = EmbeddingBagCollection(
tables=[eb_config],
is_weighted=False,
device=torch.device(
"meta"
), # Using meta device for now since only getting search space
)
sharders = [EmbeddingBagCollectionSharder()]

self.topology = Topology(
local_world_size=8,
world_size=1,
Expand All @@ -315,10 +331,20 @@ def setUp(self) -> None:
self.enumerator = EmbeddingEnumerator(
topology=self.topology, batch_size=self.batch_size
)
self.enumerator.enumerate(module, sharders) # pyre-ignore

self.storage_reservation = HeuristicalStorageReservation(percentage=0.15)
self.perf_model = NoopPerfModel(topology=self.topology)
self.constraints = {"table1": ParameterConstraints()}

self.storage_reservation.reserve(
topology=self.topology,
batch_size=self.batch_size,
module=module,
sharders=sharders, # pyre-ignore
constraints=self.constraints,
)

def test_hash_equality(self) -> None:
planner1 = EmbeddingShardingPlanner(
topology=self.topology,
Expand Down
86 changes: 85 additions & 1 deletion torchrec/distributed/planner/tests/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,30 @@
# pyre-strict

import unittest
from typing import cast
from typing import cast, Dict, Optional
from unittest.mock import MagicMock

import torch
from torch import multiprocessing
from torchrec.distributed.embedding_types import EmbeddingComputeKernel
from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder
from torchrec.distributed.planner import EmbeddingShardingPlanner
from torchrec.distributed.planner.enumerators import EmbeddingEnumerator
from torchrec.distributed.planner.perf_models import NoopPerfModel
from torchrec.distributed.planner.storage_reservations import (
HeuristicalStorageReservation,
)

from torchrec.distributed.planner.types import (
ParameterConstraints,
Shard,
ShardingOption,
Topology,
)
from torchrec.distributed.test_utils.multi_process import (
MultiProcessContext,
MultiProcessTestBase,
)
from torchrec.distributed.types import (
BoundsCheckMode,
CacheAlgorithm,
Expand Down Expand Up @@ -348,3 +360,75 @@ def test_hash_inequality(self) -> None:
self.assertNotEqual(
hash(pc1), hash(pc2), "Hashes should be different for different instances"
)


def _test_hashing_consistency(
rank: int,
world_size: int,
backend: str,
return_hash_dict: Dict[str, int],
local_size: Optional[int] = None,
) -> None:
with MultiProcessContext(rank, world_size, backend, local_size) as ctx:
topology = Topology(
local_world_size=8,
world_size=1,
compute_device="cuda",
)
batch_size = 128
enumerator = EmbeddingEnumerator(topology=topology, batch_size=batch_size)
eb_config = EmbeddingBagConfig(
name="table_0",
embedding_dim=160,
num_embeddings=10000,
feature_names=["f1"],
data_type=DataType.FP16,
)
module = EmbeddingBagCollection(
tables=[eb_config],
is_weighted=False,
device=torch.device(
"meta"
), # Using meta device for now since only getting search space
)
sharders = [EmbeddingBagCollectionSharder()]
enumerator.enumerate(module, sharders) # pyre-ignore
storage_reservation = HeuristicalStorageReservation(percentage=0.15)
constraints = {"table1": ParameterConstraints()}

storage_reservation.reserve(
topology=topology,
batch_size=batch_size,
module=module,
sharders=sharders, # pyre-ignore
constraints=constraints,
)
perf_model = NoopPerfModel(topology=topology)

planner1 = EmbeddingShardingPlanner(
topology=topology,
batch_size=batch_size,
enumerator=enumerator,
storage_reservation=storage_reservation,
performance_model=perf_model,
constraints=constraints,
)

h = planner1.hash_planner_context_inputs()
return_hash_dict[str(rank)] = h


class TestConsistentHashingBetweenProcesses(MultiProcessTestBase):

def test_hash_consistency(self) -> None:
# planner
world_size = 2
return_hash_dict = multiprocessing.Manager().dict()
self._run_multi_process_test(
callable=_test_hashing_consistency,
world_size=world_size,
backend="nccl" if torch.cuda.is_available() else "gloo",
return_hash_dict=return_hash_dict,
)
hashes = return_hash_dict.values()
assert hashes[0] == hashes[1], "hash values are different."
Loading
Loading