Skip to content

Commit eb425d8

Browse files
committed
Enable Consistent SHA256 Hashing with reduced Planner Context
Summary: Even though SHA256 hashing is used, we're still not seeing the expected same hash generated from the original planner context inputs. This problem is due to Enumerator and Storage Reservation objects we were originally trying to hash containing attributes that differ between processes/instances. To resolve this we reduced the hashing context to only use the specific attributes we need from enumerator and storage reservation. Namely: * enumerator.enumerate(...)'s output - which is used as the `search_space` in both LP and OSS planner * We are storing the output of enumerate as an attribute `last_stored_search_space`. **This assumes enumerate will have been called before we hash the planner context inputs**. * StorageResveration's policy (aka whether `HeuristicalStorageReservation` is used or `FixedStorageReservation` * StorageResveration's initialization attributes: * _percentage * _parameter_multiplier for HeuristicalStorageReservation * _dense_tensor_estimate for HeuristicalStorageReservation Created helper functions: * `hash_planner_context_inputs` to be called in both planner.hash_planner_context_inputs and manifold loading call site (see D75723272) * `hash_sha256_to_int` to be passed in as the default hash function in hash_planner_context_inputs Also created a multiprocess unit test to quickly check if consistent hashes are being generated across different processes given the same input. Differential Revision: D76303748
1 parent 3b6b537 commit eb425d8

File tree

5 files changed

+213
-34
lines changed

5 files changed

+213
-34
lines changed

torchrec/distributed/planner/enumerators.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,11 @@ def __init__(
102102
EmbeddingStorageEstimator(topology=topology, constraints=constraints),
103103
]
104104

105+
# Initializing caching for enumerate
106+
self._last_stored_search_space: Optional[List[ShardingOption]] = None
107+
self._last_stored_module: Optional[nn.Module] = None
108+
self._last_stored_sharders: Optional[List[ModuleSharder[nn.Module]]] = None
109+
105110
def enumerate(
106111
self,
107112
module: nn.Module,
@@ -118,6 +123,12 @@ def enumerate(
118123
List[ShardingOption]: valid sharding options with values populated.
119124
"""
120125

126+
if (
127+
self._last_stored_module == module
128+
and self._last_stored_sharders == sharders
129+
):
130+
return self._last_stored_search_space # pyre-ignore
131+
121132
self._sharder_map = {
122133
sharder_name(sharder.module_type): sharder for sharder in sharders
123134
}
@@ -230,8 +241,18 @@ def enumerate(
230241

231242
self.populate_estimates(sharding_options)
232243

244+
self._last_stored_module = module
245+
self._last_stored_sharders = sharders
246+
self._last_stored_search_space = sharding_options
233247
return sharding_options
234248

249+
@property
250+
def last_stored_search_space(self) -> Optional[List[ShardingOption]]:
251+
# NOTE: This is the last search space stored by enumerate(...), do not use
252+
# this field in place of actually calling enumerate(...) as it will varie for each
253+
# module/sharders passed in.
254+
return self._last_stored_search_space
255+
235256
def populate_estimates(self, sharding_options: List[ShardingOption]) -> None:
236257
for estimator in self._estimators:
237258
estimator.estimate(sharding_options, self._sharder_map)

torchrec/distributed/planner/planners.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
)
4040
from torchrec.distributed.planner.types import (
4141
Enumerator,
42+
hash_planner_context_inputs,
4243
ParameterConstraints,
4344
Partitioner,
4445
PerfModel,
@@ -280,25 +281,21 @@ def collective_plan(
280281
sharders,
281282
)
282283

283-
def hash_planner_context_inputs(self) -> str:
284+
def hash_planner_context_inputs(self) -> int:
284285
"""
285286
Generates a hash for all planner inputs except for partitioner, proposer, performance model, and stats.
286287
These are all the inputs needed to verify whether a previously generated sharding plan is still valid in a new context.
287288
288289
Returns:
289290
Generates a hash capturing topology, batch size, enumerator, storage reservation, stats and constraints.
290291
"""
291-
hashable_list = [
292+
return hash_planner_context_inputs(
292293
self._topology,
293294
self._batch_size,
294295
self._enumerator,
295296
self._storage_reservation,
296-
frozenset(self._constraints.items()) if self._constraints else None,
297-
]
298-
serialized_list = str(hashable_list).encode("utf-8")
299-
hash_object = hashlib.sha256(serialized_list)
300-
hash_digest = hash_object.hexdigest()
301-
return hash_digest
297+
self._constraints,
298+
)
302299

303300
def plan(
304301
self,

torchrec/distributed/planner/storage_reservations.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,7 @@ class FixedPercentageStorageReservation(StorageReservation):
163163
def __init__(self, percentage: float) -> None:
164164
assert percentage >= 0 and percentage <= 1
165165
self._percentage: float = percentage
166+
self._last_reserved_toplogy: Optional[Topology] = None
166167

167168
def reserve(
168169
self,
@@ -174,8 +175,14 @@ def reserve(
174175
) -> Topology:
175176
reserved_topology = copy.deepcopy(topology)
176177
_reserve_storage_percentage(reserved_topology, self._percentage)
178+
self._last_reserved_toplogy = reserved_topology
177179
return reserved_topology
178180

181+
@property
182+
def last_reserved_toplogy(self) -> Optional[Topology]:
183+
"Cached value of the most recent output from the reserve() method."
184+
return self._last_reserved_toplogy
185+
179186

180187
class HeuristicalStorageReservation(StorageReservation):
181188
"""
@@ -206,6 +213,7 @@ def __init__(
206213

207214
self._dense_storage: Optional[Storage] = None
208215
self._kjt_storage: Optional[Storage] = None
216+
self._last_reserved_toplogy: Optional[Topology] = None
209217

210218
def reserve(
211219
self,
@@ -215,6 +223,7 @@ def reserve(
215223
sharders: List[ModuleSharder[nn.Module]],
216224
constraints: Optional[Dict[str, ParameterConstraints]] = None,
217225
) -> Topology:
226+
# TODO: enable proper caching of topology values through _last_reserved_toplogy
218227
reserved_topology = copy.deepcopy(topology)
219228

220229
batch_inputs, shardable_modules = _get_batch_inputs_and_shardable_parameters(
@@ -262,8 +271,14 @@ def reserve(
262271
message=negative_storage_solution,
263272
)
264273

274+
self._last_reserved_toplogy = reserved_topology
265275
return reserved_topology
266276

277+
@property
278+
def last_reserved_toplogy(self) -> Optional[Topology]:
279+
"Cached value of the most recent output from the reserve() method."
280+
return self._last_reserved_toplogy
281+
267282

268283
class InferenceStorageReservation(StorageReservation):
269284
"""

torchrec/distributed/planner/tests/test_types.py

Lines changed: 85 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,18 +8,30 @@
88
# pyre-strict
99

1010
import unittest
11-
from typing import cast
11+
from typing import cast, Dict, Optional
1212
from unittest.mock import MagicMock
1313

1414
import torch
15+
from torch import multiprocessing
1516
from torchrec.distributed.embedding_types import EmbeddingComputeKernel
17+
from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder
18+
from torchrec.distributed.planner import EmbeddingShardingPlanner
19+
from torchrec.distributed.planner.enumerators import EmbeddingEnumerator
20+
from torchrec.distributed.planner.perf_models import NoopPerfModel
21+
from torchrec.distributed.planner.storage_reservations import (
22+
HeuristicalStorageReservation,
23+
)
1624

1725
from torchrec.distributed.planner.types import (
1826
ParameterConstraints,
1927
Shard,
2028
ShardingOption,
2129
Topology,
2230
)
31+
from torchrec.distributed.test_utils.multi_process import (
32+
MultiProcessContext,
33+
MultiProcessTestBase,
34+
)
2335
from torchrec.distributed.types import (
2436
BoundsCheckMode,
2537
CacheAlgorithm,
@@ -348,3 +360,75 @@ def test_hash_inequality(self) -> None:
348360
self.assertNotEqual(
349361
hash(pc1), hash(pc2), "Hashes should be different for different instances"
350362
)
363+
364+
365+
def _test_hashing_consistency(
366+
rank: int,
367+
world_size: int,
368+
backend: str,
369+
return_hash_dict: Dict[str, int],
370+
local_size: Optional[int] = None,
371+
) -> None:
372+
with MultiProcessContext(rank, world_size, backend, local_size) as ctx:
373+
topology = Topology(
374+
local_world_size=8,
375+
world_size=1,
376+
compute_device="cuda",
377+
)
378+
batch_size = 128
379+
enumerator = EmbeddingEnumerator(topology=topology, batch_size=batch_size)
380+
eb_config = EmbeddingBagConfig(
381+
name="table_0",
382+
embedding_dim=160,
383+
num_embeddings=10000,
384+
feature_names=["f1"],
385+
data_type=DataType.FP16,
386+
)
387+
module = EmbeddingBagCollection(
388+
tables=[eb_config],
389+
is_weighted=False,
390+
device=torch.device(
391+
"meta"
392+
), # Using meta device for now since only getting search space
393+
)
394+
sharders = [EmbeddingBagCollectionSharder()]
395+
enumerator.enumerate(module, sharders) # pyre-ignore
396+
storage_reservation = HeuristicalStorageReservation(percentage=0.15)
397+
constraints = {"table1": ParameterConstraints()}
398+
399+
storage_reservation.reserve(
400+
topology=topology,
401+
batch_size=batch_size,
402+
module=module,
403+
sharders=sharders, # pyre-ignore
404+
constraints=constraints,
405+
)
406+
perf_model = NoopPerfModel(topology=topology)
407+
408+
planner1 = EmbeddingShardingPlanner(
409+
topology=topology,
410+
batch_size=batch_size,
411+
enumerator=enumerator,
412+
storage_reservation=storage_reservation,
413+
performance_model=perf_model,
414+
constraints=constraints,
415+
)
416+
417+
h = planner1.hash_planner_context_inputs()
418+
return_hash_dict[str(rank)] = h
419+
420+
421+
class TestConsistentHashingBetweenProcesses(MultiProcessTestBase):
422+
423+
def test_hash_consistency(self) -> None:
424+
# planner
425+
world_size = 2
426+
return_hash_dict = multiprocessing.Manager().dict()
427+
self._run_multi_process_test(
428+
callable=_test_hashing_consistency,
429+
world_size=world_size,
430+
backend="nccl" if torch.cuda.is_available() else "gloo",
431+
return_hash_dict=return_hash_dict,
432+
)
433+
hashes = return_hash_dict.values()
434+
assert hashes[0] == hashes[1], "hash values are different."

torchrec/distributed/planner/types.py

Lines changed: 87 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from copy import deepcopy
1313
from dataclasses import dataclass, field
1414
from enum import Enum
15-
from typing import cast, Dict, List, Optional, Tuple, Union
15+
from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union
1616

1717
import torch
1818
from torch import nn
@@ -401,12 +401,15 @@ def __repr__(self) -> str:
401401
topology_repr += str(self._comms_bandwidths) + "\n"
402402
return topology_repr
403403

404-
def _hash(self) -> str:
404+
def _hash(self) -> int:
405405
"""
406406
Compute a consistent hash value for this Topology instance.
407407
408408
Returns:
409409
str: A hash value for this Topology instance.
410+
411+
NOTE: Not overriding the __hash__ method here to account for other
412+
potential variables that may be unchecked by the following list
410413
"""
411414

412415
# Compute hbms and ddrs from the decives
@@ -430,10 +433,7 @@ def _hash(self) -> str:
430433
self._uneven_sharding_perf_multiplier,
431434
]
432435

433-
serialized_list = str(hashable_list).encode("utf-8")
434-
hash_object = hashlib.sha256(serialized_list)
435-
hash_digest = hash_object.hexdigest()
436-
return hash_digest
436+
return hash_sha256_to_int(hashable_list)
437437

438438

439439
# ---- INPUT / OUTPUT ----- #
@@ -743,25 +743,25 @@ class ParameterConstraints:
743743
key_value_params: Optional[KeyValueParams] = None
744744

745745
def __hash__(self) -> int:
746-
return hash(
747-
(
748-
tuple(self.sharding_types) if self.sharding_types else None,
749-
tuple(self.compute_kernels) if self.compute_kernels else None,
750-
self.min_partition,
751-
tuple(self.pooling_factors),
752-
tuple(self.num_poolings) if self.num_poolings else None,
753-
tuple(self.batch_sizes) if self.batch_sizes else None,
754-
self.is_weighted,
755-
self.cache_params,
756-
self.enforce_hbm,
757-
self.stochastic_rounding,
758-
self.bounds_check_mode,
759-
tuple(self.feature_names) if self.feature_names else None,
760-
self.output_dtype,
761-
self.device_group,
762-
self.key_value_params,
763-
)
764-
)
746+
hashable_list = [
747+
tuple(self.sharding_types) if self.sharding_types else None,
748+
tuple(self.compute_kernels) if self.compute_kernels else None,
749+
self.min_partition,
750+
tuple(self.pooling_factors),
751+
tuple(self.num_poolings) if self.num_poolings else None,
752+
tuple(self.batch_sizes) if self.batch_sizes else None,
753+
self.is_weighted,
754+
self.cache_params,
755+
self.enforce_hbm,
756+
self.stochastic_rounding,
757+
self.bounds_check_mode,
758+
tuple(self.feature_names) if self.feature_names else None,
759+
self.output_dtype,
760+
self.device_group,
761+
self.key_value_params,
762+
]
763+
764+
return hash_sha256_to_int(hashable_list)
765765

766766

767767
class PlannerErrorType(Enum):
@@ -967,3 +967,65 @@ class CriticalPathEstimate:
967967

968968
def total(self) -> float:
969969
return self.comms_estimate + self.comp_estimate
970+
971+
972+
# ---- Types Utils ---- #
973+
def hash_sha256_to_int(hashable_list: List[Any]) -> int: # pyre-ignore
974+
"""
975+
Hashes the given data using SHA256 and returns the hash as an integer
976+
"""
977+
serialized_list = str(hashable_list).encode("utf-8")
978+
hash_object = hashlib.sha256(serialized_list)
979+
hash_digest = hash_object.hexdigest()
980+
return int(hash_digest, 16)
981+
982+
983+
def hash_planner_context_inputs(
984+
topology: Topology,
985+
batch_size: int,
986+
enumerator: Enumerator,
987+
storage_reservation: StorageReservation,
988+
constraints: Optional[Dict[str, ParameterConstraints]],
989+
# pyre-ignore
990+
hash_function: Callable[[List[Any]], int] = hash_sha256_to_int,
991+
) -> int:
992+
assert hasattr(
993+
enumerator, "last_stored_search_space"
994+
), "This enumerator is not compatible with hashing"
995+
assert (
996+
enumerator.last_stored_search_space is not None # pyre-ignore
997+
), "Unable to hash planner context without an enumerator that has a precomputed search space"
998+
search_space = enumerator.last_stored_search_space
999+
storage_reservation_policy = type(storage_reservation).__name__
1000+
1001+
# TODO: Not the best code, but will result in circular dependency if we import the actaul
1002+
# storage reservation classes - will come back here to refactor more cleanly
1003+
assert storage_reservation_policy in [
1004+
"HeuristicalStorageReservation",
1005+
"FixedPercentageStorageReservation",
1006+
]
1007+
assert hasattr(
1008+
storage_reservation, "last_reserved_toplogy"
1009+
), "This storage reservation is not compatible with hashing"
1010+
assert (
1011+
storage_reservation.last_reserved_toplogy is not None # pyre-ignore
1012+
), "Unable to hash planner context without a storage reservation that has a precomputed topology"
1013+
1014+
hashable_list = [
1015+
topology,
1016+
batch_size,
1017+
[
1018+
[
1019+
shard_option.fqn,
1020+
shard_option.sharding_type,
1021+
shard_option.compute_kernel,
1022+
tuple(shard_option.shards),
1023+
shard_option.cache_params,
1024+
]
1025+
for shard_option in search_space
1026+
],
1027+
storage_reservation_policy,
1028+
storage_reservation.last_reserved_toplogy,
1029+
frozenset(constraints.items()) if constraints else None,
1030+
]
1031+
return hash_function(hashable_list)

0 commit comments

Comments
 (0)