Skip to content

Commit c4f17e4

Browse files
kausvfacebook-github-bot
authored andcommitted
Simplify fqns_to_feature_names in Delta Tracker (#3081)
Summary: Each table in collection will have FQN. We can simplify the logic with this assumption to avoid two iterations. Existing tests passed with this logic. Differential Revision: D76432354
1 parent f659b6a commit c4f17e4

File tree

1 file changed

+25
-51
lines changed

1 file changed

+25
-51
lines changed

torchrec/distributed/model_tracker/model_delta_tracker.py

Lines changed: 25 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
# pyre-strict
99
import logging as logger
10-
from collections import Counter, OrderedDict
10+
from collections import OrderedDict
1111
from typing import Dict, Iterable, List, Optional
1212

1313
import torch
@@ -33,7 +33,10 @@
3333
}
3434

3535
# Tracking is current only supported for ShardedEmbeddingCollection and ShardedEmbeddingBagCollection.
36-
SUPPORTED_MODULES = (ShardedEmbeddingCollection, ShardedEmbeddingBagCollection)
36+
SUPPORTED_MODULES_TO_PREFIX = {
37+
ShardedEmbeddingCollection: ".embeddings",
38+
ShardedEmbeddingBagCollection: ".embedding_bags",
39+
}
3740

3841

3942
class ModelDeltaTracker:
@@ -277,61 +280,32 @@ def fqn_to_feature_names(self) -> Dict[str, List[str]]:
277280
if (self._fqn_to_feature_map is not None) and len(self._fqn_to_feature_map) > 0:
278281
return self._fqn_to_feature_map
279282

280-
table_to_feature_names: Dict[str, List[str]] = OrderedDict()
281-
table_to_fqn: Dict[str, str] = OrderedDict()
283+
fqn_to_feature_names: Dict[str, List[str]] = OrderedDict()
282284
for fqn, named_module in self._model.named_modules():
283-
split_fqn = fqn.split(".")
284285
# Skipping partial FQNs present in fqns_to_skip
285286
# TODO: Validate if we need to support more complex patterns for skipping fqns
286-
should_skip = False
287-
for fqn_to_skip in self._fqns_to_skip:
288-
if fqn_to_skip in split_fqn:
289-
logger.info(f"Skipping {fqn} because it is part of fqns_to_skip")
290-
should_skip = True
291-
break
292-
if should_skip:
293-
continue
294-
# Using FQNs of the embedding and mapping them to features as state_dict() API uses these to key states.
295-
if isinstance(named_module, SUPPORTED_MODULES):
287+
if type(named_module) in SUPPORTED_MODULES_TO_PREFIX:
296288
for table_name, config in named_module._table_name_to_config.items():
297-
logger.info(
298-
f"Found {table_name} for {fqn} with features {config.feature_names}"
289+
embedding_fqn = (
290+
self._clean_fqn_fn(fqn)
291+
+ SUPPORTED_MODULES_TO_PREFIX[type(named_module)]
292+
+ f".{table_name}"
299293
)
300-
table_to_feature_names[table_name] = config.feature_names
301-
self.tracked_modules[self._clean_fqn_fn(fqn)] = named_module
302-
for table_name in table_to_feature_names:
303-
# Using the split FQN to get the exact table name matching. Otherwise, checking "table_name in fqn"
304-
# will incorrectly match fqn with all the table names that have the same prefix
305-
if table_name in split_fqn:
306-
embedding_fqn = self._clean_fqn_fn(fqn)
307-
if table_name in table_to_fqn:
308-
# Sanity check for validating that we don't have more then one table mapping to same fqn.
309-
logger.warning(
310-
f"Override {table_to_fqn[table_name]} with {embedding_fqn} for entry {table_name}"
311-
)
312-
table_to_fqn[table_name] = embedding_fqn
313-
logger.info(f"Table to fqn: {table_to_fqn}")
314-
flatten_names = [
315-
name for names in table_to_feature_names.values() for name in names
316-
]
317-
# TODO: Validate if there is a better way to handle duplicate feature names.
318-
# Logging a warning if duplicate feature names are found across tables, but continue execution as this could be a valid case.
319-
if len(set(flatten_names)) != len(flatten_names):
320-
counts = Counter(flatten_names)
321-
duplicates = [item for item, count in counts.items() if count > 1]
322-
logger.warning(f"duplicate feature names found: {duplicates}")
323294

324-
fqn_to_feature_names: Dict[str, List[str]] = OrderedDict()
325-
for table_name in table_to_feature_names:
326-
if table_name not in table_to_fqn:
327-
# This is likely unexpected, where we can't locate the FQN associated with this table.
328-
logger.warning(
329-
f"Table {table_name} not found in {table_to_fqn}, skipping"
330-
)
331-
continue
332-
fqn_to_feature_names[table_to_fqn[table_name]] = table_to_feature_names[
333-
table_name
334-
]
295+
should_skip = False
296+
for fqn_to_skip in self._fqns_to_skip:
297+
if fqn_to_skip in embedding_fqn:
298+
logger.info(
299+
f"Skipping {fqn} because it is part of fqns_to_skip"
300+
)
301+
should_skip = True
302+
break
303+
if should_skip:
304+
continue
305+
if embedding_fqn not in fqn_to_feature_names:
306+
fqn_to_feature_names[embedding_fqn] = []
307+
fqn_to_feature_names[embedding_fqn].extend(config.feature_names)
308+
self.tracked_modules[embedding_fqn] = named_module
335309
self._fqn_to_feature_map = fqn_to_feature_names
336310
return fqn_to_feature_names
337311

0 commit comments

Comments
 (0)