Skip to content

Commit 6ad50bb

Browse files
aliafzalfacebook-github-bot
authored andcommitted
Test commit (#3084)
Summary: Pull Request resolved: #3084 Rollback Plan: Differential Revision: D76457454
1 parent c8495ec commit 6ad50bb

File tree

2 files changed

+113
-8
lines changed

2 files changed

+113
-8
lines changed

torchrec/distributed/model_tracker/model_delta_tracker.py

Lines changed: 72 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@
66
# LICENSE file in the root directory of this source tree.
77

88
# pyre-strict
9-
from typing import Dict, List, Optional, Union
9+
import logging as logger
10+
from collections import Counter, OrderedDict
11+
from typing import Dict, Iterable, List, Optional, Union
1012

1113
import torch
1214

@@ -30,7 +32,7 @@
3032
}
3133

3234
# Tracking is current only supported for ShardedEmbeddingCollection and ShardedEmbeddingBagCollection.
33-
SUPPORTED_MODULES = Union[ShardedEmbeddingCollection, ShardedEmbeddingBagCollection]
35+
SUPPORTED_MODULES = (ShardedEmbeddingCollection, ShardedEmbeddingBagCollection)
3436

3537

3638
class ModelDeltaTracker:
@@ -49,6 +51,8 @@ class ModelDeltaTracker:
4951
call.
5052
delete_on_read (bool, optional): whether to delete the tracked ids after all consumers have read them.
5153
mode (TrackingMode, optional): tracking mode to use from supported tracking modes. Default: TrackingMode.ID_ONLY.
54+
fqns_to_skip (Iterable[str], optional): list of FQNs to skip tracking. Default: None.
55+
5256
"""
5357

5458
DEFAULT_CONSUMER: str = "default"
@@ -59,11 +63,15 @@ def __init__(
5963
consumers: Optional[List[str]] = None,
6064
delete_on_read: bool = True,
6165
mode: TrackingMode = TrackingMode.ID_ONLY,
66+
fqns_to_skip: Iterable[str] = (),
6267
) -> None:
6368
self._model = model
6469
self._consumers: List[str] = consumers or [self.DEFAULT_CONSUMER]
6570
self._delete_on_read = delete_on_read
6671
self._mode = mode
72+
self._fqn_to_feature_map: Dict[str, List[str]] = {}
73+
self._fqns_to_skip: Iterable[str] = fqns_to_skip
74+
self.fqn_to_feature_names()
6775
pass
6876

6977
def record_lookup(self, kjt: KeyedJaggedTensor, states: torch.Tensor) -> None:
@@ -85,14 +93,70 @@ def get_delta(self, consumer: Optional[str] = None) -> Dict[str, DeltaRows]:
8593
"""
8694
return {}
8795

88-
def fqn_to_feature_names(self, module: nn.Module) -> Dict[str, List[str]]:
96+
def fqn_to_feature_names(self) -> Dict[str, List[str]]:
8997
"""
90-
Returns a mapping from FQN to feature names for a given module.
91-
92-
Args:
93-
module (nn.Module): the module to retrieve feature names for.
98+
Returns a mapping of FQN to feature names from all Supported Modules [EmbeddingCollection and EmbeddingBagCollection] present in the given model.
9499
"""
95-
return {}
100+
if (self._fqn_to_feature_map is not None) and len(self._fqn_to_feature_map) > 0:
101+
return self._fqn_to_feature_map
102+
103+
table_to_feature_names: Dict[str, List[str]] = OrderedDict()
104+
table_to_fqn: Dict[str, str] = OrderedDict()
105+
for fqn, named_module in self._model.named_modules():
106+
split_fqn = fqn.split(".")
107+
# Skipping partial FQNs present in fqns_to_skip
108+
# TODO: Validate if we need to support more complex patterns for skipping fqns
109+
should_skip = False
110+
for fqn_to_skip in self._fqns_to_skip:
111+
if fqn_to_skip in split_fqn:
112+
logger.info(f"Skipping {fqn} because it is part of fqns_to_skip")
113+
should_skip = True
114+
break
115+
if should_skip:
116+
continue
117+
118+
# Using FQNs of the embedding and mapping them to features as state_dict() API uses these to key states.
119+
if isinstance(named_module, SUPPORTED_MODULES):
120+
for table_name, config in named_module._table_name_to_config.items():
121+
logger.info(
122+
f"Found {table_name} for {fqn} with features {config.feature_names}"
123+
)
124+
table_to_feature_names[table_name] = config.feature_names
125+
for table_name in table_to_feature_names:
126+
# Using the split FQN to get the exact table name matching. Otherwise, checking "table_name in fqn"
127+
# will incorrectly match fqn with all the table names that have the same prefix
128+
if table_name in split_fqn:
129+
embedding_fqn = fqn.replace("_dmp_wrapped_module.module.", "")
130+
if table_name in table_to_fqn:
131+
# Sanity check for validating that we don't have more then one table mapping to same fqn.
132+
logger.warning(
133+
f"Override {table_to_fqn[table_name]} with {embedding_fqn} for entry {table_name}"
134+
)
135+
table_to_fqn[table_name] = embedding_fqn
136+
logger.info(f"Table to fqn: {table_to_fqn}")
137+
flatten_names = [
138+
name for names in table_to_feature_names.values() for name in names
139+
]
140+
# TODO: Validate if there is a better way to handle duplicate feature names.
141+
# Logging a warning if duplicate feature names are found across tables, but continue execution as this could be a valid case.
142+
if len(set(flatten_names)) != len(flatten_names):
143+
counts = Counter(flatten_names)
144+
duplicates = [item for item, count in counts.items() if count > 1]
145+
logger.warning(f"duplicate feature names found: {duplicates}")
146+
147+
fqn_to_feature_names: Dict[str, List[str]] = OrderedDict()
148+
for table_name in table_to_feature_names:
149+
if table_name not in table_to_fqn:
150+
# This is likely unexpected, where we can't locate the FQN associated with this table.
151+
logger.warning(
152+
f"Table {table_name} not found in {table_to_fqn}, skipping"
153+
)
154+
continue
155+
fqn_to_feature_names[table_to_fqn[table_name]] = table_to_feature_names[
156+
table_name
157+
]
158+
self._fqn_to_feature_map = fqn_to_feature_names
159+
return fqn_to_feature_names
96160

97161
def clear(self, consumer: Optional[str] = None) -> None:
98162
"""
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
# pyre-strict
9+
10+
#!/usr/bin/env python3
11+
from dataclasses import dataclass
12+
from typing import cast, Dict, Iterable, List, Optional, Union
13+
14+
import torch
15+
16+
from torch import nn
17+
from torchrec.distributed.embedding_types import EmbeddingComputeKernel
18+
from torchrec.distributed.planner import ParameterConstraints
19+
from torchrec.distributed.types import ShardingType
20+
from torchrec.modules.embedding_configs import EmbeddingBagConfig, EmbeddingConfig
21+
from torchrec.modules.embedding_modules import (
22+
EmbeddingBagCollection,
23+
EmbeddingCollection,
24+
)
25+
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
26+
27+
28+
@dataclass
29+
class EmbeddingTableProps:
30+
"""
31+
Properties of an embedding table.
32+
33+
Args:
34+
embedding_table_config: Config of the embedding table of Union(EmbeddingConfig or EmbeddingBagConfig)
35+
sharding (ShardingType): sharding type of the table
36+
weight_type (WeightedType): weight
37+
"""
38+
39+
embedding_table_config: Union[EmbeddingConfig, EmbeddingBagConfig]
40+
sharding: ShardingType
41+
is_weighted: bool = False

0 commit comments

Comments
 (0)