|
7 | 7 |
|
8 | 8 | # pyre-strict
|
9 | 9 | import logging as logger
|
10 |
| -from collections import Counter, OrderedDict |
| 10 | +from collections import OrderedDict |
11 | 11 | from typing import Dict, Iterable, List, Optional
|
12 | 12 |
|
13 | 13 | import torch
|
|
33 | 33 | }
|
34 | 34 |
|
35 | 35 | # Tracking is current only supported for ShardedEmbeddingCollection and ShardedEmbeddingBagCollection.
|
36 |
| -SUPPORTED_MODULES = (ShardedEmbeddingCollection, ShardedEmbeddingBagCollection) |
| 36 | +SUPPORTED_MODULES_TO_PREFIX = { |
| 37 | + ShardedEmbeddingCollection: ".embeddings", |
| 38 | + ShardedEmbeddingBagCollection: ".embedding_bags", |
| 39 | +} |
37 | 40 |
|
38 | 41 |
|
39 | 42 | class ModelDeltaTracker:
|
@@ -277,61 +280,32 @@ def fqn_to_feature_names(self) -> Dict[str, List[str]]:
|
277 | 280 | if (self._fqn_to_feature_map is not None) and len(self._fqn_to_feature_map) > 0:
|
278 | 281 | return self._fqn_to_feature_map
|
279 | 282 |
|
280 |
| - table_to_feature_names: Dict[str, List[str]] = OrderedDict() |
281 |
| - table_to_fqn: Dict[str, str] = OrderedDict() |
| 283 | + fqn_to_feature_names: Dict[str, List[str]] = OrderedDict() |
282 | 284 | for fqn, named_module in self._model.named_modules():
|
283 |
| - split_fqn = fqn.split(".") |
284 | 285 | # Skipping partial FQNs present in fqns_to_skip
|
285 | 286 | # TODO: Validate if we need to support more complex patterns for skipping fqns
|
286 |
| - should_skip = False |
287 |
| - for fqn_to_skip in self._fqns_to_skip: |
288 |
| - if fqn_to_skip in split_fqn: |
289 |
| - logger.info(f"Skipping {fqn} because it is part of fqns_to_skip") |
290 |
| - should_skip = True |
291 |
| - break |
292 |
| - if should_skip: |
293 |
| - continue |
294 |
| - # Using FQNs of the embedding and mapping them to features as state_dict() API uses these to key states. |
295 |
| - if isinstance(named_module, SUPPORTED_MODULES): |
| 287 | + if type(named_module) in SUPPORTED_MODULES_TO_PREFIX: |
296 | 288 | for table_name, config in named_module._table_name_to_config.items():
|
297 |
| - logger.info( |
298 |
| - f"Found {table_name} for {fqn} with features {config.feature_names}" |
| 289 | + embedding_fqn = ( |
| 290 | + self._clean_fqn_fn(fqn) |
| 291 | + + SUPPORTED_MODULES_TO_PREFIX[type(named_module)] |
| 292 | + + f".{table_name}" |
299 | 293 | )
|
300 |
| - table_to_feature_names[table_name] = config.feature_names |
301 |
| - self.tracked_modules[self._clean_fqn_fn(fqn)] = named_module |
302 |
| - for table_name in table_to_feature_names: |
303 |
| - # Using the split FQN to get the exact table name matching. Otherwise, checking "table_name in fqn" |
304 |
| - # will incorrectly match fqn with all the table names that have the same prefix |
305 |
| - if table_name in split_fqn: |
306 |
| - embedding_fqn = self._clean_fqn_fn(fqn) |
307 |
| - if table_name in table_to_fqn: |
308 |
| - # Sanity check for validating that we don't have more then one table mapping to same fqn. |
309 |
| - logger.warning( |
310 |
| - f"Override {table_to_fqn[table_name]} with {embedding_fqn} for entry {table_name}" |
311 |
| - ) |
312 |
| - table_to_fqn[table_name] = embedding_fqn |
313 |
| - logger.info(f"Table to fqn: {table_to_fqn}") |
314 |
| - flatten_names = [ |
315 |
| - name for names in table_to_feature_names.values() for name in names |
316 |
| - ] |
317 |
| - # TODO: Validate if there is a better way to handle duplicate feature names. |
318 |
| - # Logging a warning if duplicate feature names are found across tables, but continue execution as this could be a valid case. |
319 |
| - if len(set(flatten_names)) != len(flatten_names): |
320 |
| - counts = Counter(flatten_names) |
321 |
| - duplicates = [item for item, count in counts.items() if count > 1] |
322 |
| - logger.warning(f"duplicate feature names found: {duplicates}") |
323 | 294 |
|
324 |
| - fqn_to_feature_names: Dict[str, List[str]] = OrderedDict() |
325 |
| - for table_name in table_to_feature_names: |
326 |
| - if table_name not in table_to_fqn: |
327 |
| - # This is likely unexpected, where we can't locate the FQN associated with this table. |
328 |
| - logger.warning( |
329 |
| - f"Table {table_name} not found in {table_to_fqn}, skipping" |
330 |
| - ) |
331 |
| - continue |
332 |
| - fqn_to_feature_names[table_to_fqn[table_name]] = table_to_feature_names[ |
333 |
| - table_name |
334 |
| - ] |
| 295 | + should_skip = False |
| 296 | + for fqn_to_skip in self._fqns_to_skip: |
| 297 | + if fqn_to_skip in embedding_fqn: |
| 298 | + logger.info( |
| 299 | + f"Skipping {fqn} because it is part of fqns_to_skip" |
| 300 | + ) |
| 301 | + should_skip = True |
| 302 | + break |
| 303 | + if should_skip: |
| 304 | + continue |
| 305 | + if embedding_fqn not in fqn_to_feature_names: |
| 306 | + fqn_to_feature_names[embedding_fqn] = [] |
| 307 | + fqn_to_feature_names[embedding_fqn].extend(config.feature_names) |
| 308 | + self.tracked_modules[embedding_fqn] = named_module |
335 | 309 | self._fqn_to_feature_map = fqn_to_feature_names
|
336 | 310 | return fqn_to_feature_names
|
337 | 311 |
|
|
0 commit comments