diff --git a/torchrec/distributed/planner/enumerators.py b/torchrec/distributed/planner/enumerators.py index 472e3d6e6..7d0be906d 100644 --- a/torchrec/distributed/planner/enumerators.py +++ b/torchrec/distributed/planner/enumerators.py @@ -7,6 +7,7 @@ # pyre-strict +import copy import logging from typing import Dict, List, Optional, Set, Tuple, Union @@ -102,6 +103,11 @@ def __init__( EmbeddingStorageEstimator(topology=topology, constraints=constraints), ] + # Initializing caching for enumerate + self._last_stored_search_space: Optional[List[ShardingOption]] = None + self._last_stored_module: Optional[nn.Module] = None + self._last_stored_sharders: Optional[List[ModuleSharder[nn.Module]]] = None + def enumerate( self, module: nn.Module, @@ -118,6 +124,12 @@ def enumerate( List[ShardingOption]: valid sharding options with values populated. """ + if ( + self._last_stored_module == module + and self._last_stored_sharders == sharders + ): + return copy.deepcopy(self._last_stored_search_space) # pyre-ignore + self._sharder_map = { sharder_name(sharder.module_type): sharder for sharder in sharders } @@ -230,8 +242,20 @@ def enumerate( self.populate_estimates(sharding_options) + self._last_stored_module = module + self._last_stored_sharders = sharders + + # Caching the search space with a copy of sharding options, to avoid unexpected modifications to list + self._last_stored_search_space = copy.deepcopy(sharding_options) return sharding_options + @property + def last_stored_search_space(self) -> Optional[List[ShardingOption]]: + # NOTE: This is the last search space stored by enumerate(...), do not use + # this field in place of actually calling enumerate(...) as it will varie for each + # module/sharders passed in. + return self._last_stored_search_space + def populate_estimates(self, sharding_options: List[ShardingOption]) -> None: for estimator in self._estimators: estimator.estimate(sharding_options, self._sharder_map) diff --git a/torchrec/distributed/planner/planners.py b/torchrec/distributed/planner/planners.py index 8d7f40af3..e360382a3 100644 --- a/torchrec/distributed/planner/planners.py +++ b/torchrec/distributed/planner/planners.py @@ -39,6 +39,7 @@ ) from torchrec.distributed.planner.types import ( Enumerator, + hash_planner_context_inputs, ParameterConstraints, Partitioner, PerfModel, @@ -280,7 +281,7 @@ def collective_plan( sharders, ) - def hash_planner_context_inputs(self) -> str: + def hash_planner_context_inputs(self) -> int: """ Generates a hash for all planner inputs except for partitioner, proposer, performance model, and stats. These are all the inputs needed to verify whether a previously generated sharding plan is still valid in a new context. @@ -288,17 +289,13 @@ def hash_planner_context_inputs(self) -> str: Returns: Generates a hash capturing topology, batch size, enumerator, storage reservation, stats and constraints. """ - hashable_list = [ + return hash_planner_context_inputs( self._topology, self._batch_size, self._enumerator, self._storage_reservation, - frozenset(self._constraints.items()) if self._constraints else None, - ] - serialized_list = str(hashable_list).encode("utf-8") - hash_object = hashlib.sha256(serialized_list) - hash_digest = hash_object.hexdigest() - return hash_digest + self._constraints, + ) def plan( self, @@ -499,6 +496,7 @@ def plan( best_plan=last_proposal, constraints=self._constraints, sharders=sharders, + enumerator=self._enumerator, debug=self._debug, ) diff --git a/torchrec/distributed/planner/storage_reservations.py b/torchrec/distributed/planner/storage_reservations.py index 5a2077db3..6e7d38300 100644 --- a/torchrec/distributed/planner/storage_reservations.py +++ b/torchrec/distributed/planner/storage_reservations.py @@ -163,6 +163,7 @@ class FixedPercentageStorageReservation(StorageReservation): def __init__(self, percentage: float) -> None: assert percentage >= 0 and percentage <= 1 self._percentage: float = percentage + self._last_reserved_topology: Optional[Topology] = None def reserve( self, @@ -174,8 +175,14 @@ def reserve( ) -> Topology: reserved_topology = copy.deepcopy(topology) _reserve_storage_percentage(reserved_topology, self._percentage) + self._last_reserved_topology = reserved_topology return reserved_topology + @property + def last_reserved_topology(self) -> Optional[Topology]: + "Returns a copy of the cached value of the most recent output from the reserve() method." + return copy.deepcopy(self._last_reserved_topology) + class HeuristicalStorageReservation(StorageReservation): """ @@ -206,6 +213,7 @@ def __init__( self._dense_storage: Optional[Storage] = None self._kjt_storage: Optional[Storage] = None + self._last_reserved_topology: Optional[Topology] = None def reserve( self, @@ -215,6 +223,7 @@ def reserve( sharders: List[ModuleSharder[nn.Module]], constraints: Optional[Dict[str, ParameterConstraints]] = None, ) -> Topology: + # TODO: enable proper caching of topology values through _last_reserved_topology reserved_topology = copy.deepcopy(topology) batch_inputs, shardable_modules = _get_batch_inputs_and_shardable_parameters( @@ -262,8 +271,14 @@ def reserve( message=negative_storage_solution, ) + self._last_reserved_topology = copy.deepcopy(reserved_topology) return reserved_topology + @property + def last_reserved_topology(self) -> Optional[Topology]: + "Cached value of the most recent output from the reserve() method." + return self._last_reserved_topology + class InferenceStorageReservation(StorageReservation): """ @@ -291,6 +306,7 @@ def __init__( self._dense_storage: Optional[Storage] = None self._kjt_storage: Optional[Storage] = None + self._last_reserved_topology: Optional[Topology] = None def reserve( self, @@ -324,4 +340,9 @@ def reserve( multiplier=1, ) + self._last_reserved_topology = copy.deepcopy(reserved_topology) + return reserved_topology + + def last_reserved_topology(self) -> Optional[Topology]: + return copy.deepcopy(self._last_reserved_topology) diff --git a/torchrec/distributed/planner/tests/test_planners.py b/torchrec/distributed/planner/tests/test_planners.py index 60c1ffdc6..64f96c4d0 100644 --- a/torchrec/distributed/planner/tests/test_planners.py +++ b/torchrec/distributed/planner/tests/test_planners.py @@ -12,7 +12,7 @@ import torch from torch import nn -from torchrec import EmbeddingConfig +from torchrec import EmbeddingBagCollection, EmbeddingConfig from torchrec.distributed.embedding import EmbeddingCollectionSharder from torchrec.distributed.embedding_types import EmbeddingComputeKernel from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder @@ -306,6 +306,22 @@ def test_passing_info_through_constraints(self) -> None: class TestEmbeddingShardingHashPlannerContextInputs(unittest.TestCase): def setUp(self) -> None: + eb_config = EmbeddingBagConfig( + name="table_0", + embedding_dim=160, + num_embeddings=10000, + feature_names=["f1"], + data_type=DataType.FP16, + ) + module = EmbeddingBagCollection( + tables=[eb_config], + is_weighted=False, + device=torch.device( + "meta" + ), # Using meta device for now since only getting search space + ) + sharders = [EmbeddingBagCollectionSharder()] + self.topology = Topology( local_world_size=8, world_size=1, @@ -315,10 +331,20 @@ def setUp(self) -> None: self.enumerator = EmbeddingEnumerator( topology=self.topology, batch_size=self.batch_size ) + self.enumerator.enumerate(module, sharders) # pyre-ignore + self.storage_reservation = HeuristicalStorageReservation(percentage=0.15) self.perf_model = NoopPerfModel(topology=self.topology) self.constraints = {"table1": ParameterConstraints()} + self.storage_reservation.reserve( + topology=self.topology, + batch_size=self.batch_size, + module=module, + sharders=sharders, # pyre-ignore + constraints=self.constraints, + ) + def test_hash_equality(self) -> None: planner1 = EmbeddingShardingPlanner( topology=self.topology, diff --git a/torchrec/distributed/planner/tests/test_types.py b/torchrec/distributed/planner/tests/test_types.py index cd9d2a650..c598ef72a 100644 --- a/torchrec/distributed/planner/tests/test_types.py +++ b/torchrec/distributed/planner/tests/test_types.py @@ -8,11 +8,19 @@ # pyre-strict import unittest -from typing import cast +from typing import cast, Dict, Optional from unittest.mock import MagicMock import torch +from torch import multiprocessing from torchrec.distributed.embedding_types import EmbeddingComputeKernel +from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder +from torchrec.distributed.planner import EmbeddingShardingPlanner +from torchrec.distributed.planner.enumerators import EmbeddingEnumerator +from torchrec.distributed.planner.perf_models import NoopPerfModel +from torchrec.distributed.planner.storage_reservations import ( + HeuristicalStorageReservation, +) from torchrec.distributed.planner.types import ( ParameterConstraints, @@ -20,6 +28,10 @@ ShardingOption, Topology, ) +from torchrec.distributed.test_utils.multi_process import ( + MultiProcessContext, + MultiProcessTestBase, +) from torchrec.distributed.types import ( BoundsCheckMode, CacheAlgorithm, @@ -348,3 +360,75 @@ def test_hash_inequality(self) -> None: self.assertNotEqual( hash(pc1), hash(pc2), "Hashes should be different for different instances" ) + + +def _test_hashing_consistency( + rank: int, + world_size: int, + backend: str, + return_hash_dict: Dict[str, int], + local_size: Optional[int] = None, +) -> None: + with MultiProcessContext(rank, world_size, backend, local_size) as ctx: + topology = Topology( + local_world_size=8, + world_size=1, + compute_device="cuda", + ) + batch_size = 128 + enumerator = EmbeddingEnumerator(topology=topology, batch_size=batch_size) + eb_config = EmbeddingBagConfig( + name="table_0", + embedding_dim=160, + num_embeddings=10000, + feature_names=["f1"], + data_type=DataType.FP16, + ) + module = EmbeddingBagCollection( + tables=[eb_config], + is_weighted=False, + device=torch.device( + "meta" + ), # Using meta device for now since only getting search space + ) + sharders = [EmbeddingBagCollectionSharder()] + enumerator.enumerate(module, sharders) # pyre-ignore + storage_reservation = HeuristicalStorageReservation(percentage=0.15) + constraints = {"table1": ParameterConstraints()} + + storage_reservation.reserve( + topology=topology, + batch_size=batch_size, + module=module, + sharders=sharders, # pyre-ignore + constraints=constraints, + ) + perf_model = NoopPerfModel(topology=topology) + + planner1 = EmbeddingShardingPlanner( + topology=topology, + batch_size=batch_size, + enumerator=enumerator, + storage_reservation=storage_reservation, + performance_model=perf_model, + constraints=constraints, + ) + + h = planner1.hash_planner_context_inputs() + return_hash_dict[str(rank)] = h + + +class TestConsistentHashingBetweenProcesses(MultiProcessTestBase): + + def test_hash_consistency(self) -> None: + # planner + world_size = 2 + return_hash_dict = multiprocessing.Manager().dict() + self._run_multi_process_test( + callable=_test_hashing_consistency, + world_size=world_size, + backend="nccl" if torch.cuda.is_available() else "gloo", + return_hash_dict=return_hash_dict, + ) + hashes = return_hash_dict.values() + assert hashes[0] == hashes[1], "hash values are different." diff --git a/torchrec/distributed/planner/types.py b/torchrec/distributed/planner/types.py index 4e0498666..e51547dbf 100644 --- a/torchrec/distributed/planner/types.py +++ b/torchrec/distributed/planner/types.py @@ -12,7 +12,7 @@ from copy import deepcopy from dataclasses import dataclass, field from enum import Enum -from typing import cast, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union import torch from torch import nn @@ -401,12 +401,15 @@ def __repr__(self) -> str: topology_repr += str(self._comms_bandwidths) + "\n" return topology_repr - def _hash(self) -> str: + def _hash(self) -> int: """ Compute a consistent hash value for this Topology instance. Returns: str: A hash value for this Topology instance. + + NOTE: Not overriding the __hash__ method here to account for other + potential variables that may be unchecked by the following list """ # Compute hbms and ddrs from the decives @@ -430,10 +433,7 @@ def _hash(self) -> str: self._uneven_sharding_perf_multiplier, ] - serialized_list = str(hashable_list).encode("utf-8") - hash_object = hashlib.sha256(serialized_list) - hash_digest = hash_object.hexdigest() - return hash_digest + return hash_sha256_to_int(hashable_list) # ---- INPUT / OUTPUT ----- # @@ -743,25 +743,25 @@ class ParameterConstraints: key_value_params: Optional[KeyValueParams] = None def __hash__(self) -> int: - return hash( - ( - tuple(self.sharding_types) if self.sharding_types else None, - tuple(self.compute_kernels) if self.compute_kernels else None, - self.min_partition, - tuple(self.pooling_factors), - tuple(self.num_poolings) if self.num_poolings else None, - tuple(self.batch_sizes) if self.batch_sizes else None, - self.is_weighted, - self.cache_params, - self.enforce_hbm, - self.stochastic_rounding, - self.bounds_check_mode, - tuple(self.feature_names) if self.feature_names else None, - self.output_dtype, - self.device_group, - self.key_value_params, - ) - ) + hashable_list = [ + tuple(self.sharding_types) if self.sharding_types else None, + tuple(self.compute_kernels) if self.compute_kernels else None, + self.min_partition, + tuple(self.pooling_factors), + tuple(self.num_poolings) if self.num_poolings else None, + tuple(self.batch_sizes) if self.batch_sizes else None, + self.is_weighted, + self.cache_params, + self.enforce_hbm, + self.stochastic_rounding, + self.bounds_check_mode, + tuple(self.feature_names) if self.feature_names else None, + self.output_dtype, + self.device_group, + self.key_value_params, + ] + + return hash_sha256_to_int(hashable_list) class PlannerErrorType(Enum): @@ -803,6 +803,10 @@ def reserve( constraints: Optional[Dict[str, ParameterConstraints]] = None, ) -> Topology: ... + @property + @abc.abstractmethod + def last_reserved_topology(self) -> Optional[Topology]: ... + class PerfModel(abc.ABC): @abc.abstractmethod @@ -967,3 +971,56 @@ class CriticalPathEstimate: def total(self) -> float: return self.comms_estimate + self.comp_estimate + + +# ---- Types Utils ---- # +def hash_sha256_to_int(hashable_list: List[Any]) -> int: # pyre-ignore + """ + Hashes the given data using SHA256 and returns the hash as an integer + """ + serialized_list = str(hashable_list).encode("utf-8") + hash_object = hashlib.sha256(serialized_list) + hash_digest = hash_object.hexdigest() + return int(hash_digest, 16) + + +def hash_planner_context_inputs( + topology: Topology, + batch_size: int, + enumerator: Enumerator, + storage_reservation: StorageReservation, + constraints: Optional[Dict[str, ParameterConstraints]], + # pyre-ignore + hash_function: Callable[[List[Any]], int] = hash_sha256_to_int, +) -> int: + assert hasattr( + enumerator, "last_stored_search_space" + ), "This enumerator is not compatible with hashing" + assert ( + enumerator.last_stored_search_space is not None # pyre-ignore + ), "Unable to hash planner context without an enumerator that has a precomputed search space" + search_space = enumerator.last_stored_search_space + storage_reservation_policy = type(storage_reservation).__name__ + + assert ( + storage_reservation._last_reserved_topology is not None # pyre-ignore + ), "Unable to hash planner context without a storage reservation that has a precomputed topology" + + hashable_list = [ + topology, + batch_size, + [ + [ + shard_option.fqn, + shard_option.sharding_type, + shard_option.compute_kernel, + tuple(shard_option.shards), + shard_option.cache_params, + ] + for shard_option in search_space + ], + storage_reservation_policy, + storage_reservation._last_reserved_topology, + constraints.items() if constraints else None, + ] + return hash_function(hashable_list)