Skip to content

Commit 35958ec

Browse files
lizhouyufacebook-github-bot
authored andcommitted
OSS TorchRec MPZCH Modules (#3162)
Summary: Pull Request resolved: #3162 ### Major changes - Copy the following files from `fb` to corresponding location in the `torchrec` repository - `fb/distributed/hash_mc_embedding.py → torchrec/distributed/hash_mc_embedding.py` - `fb/modules/hash_mc_evictions.py → torchrec/modules/hash_mc_evictions.py` - `fb/modules/hash_mc_metrics.py → torchrec/modules/hash_mc_metrics.py` - `fb/modules/hash_mc_modules.py → torchrec/modules/hash_mc_modules.py` - `fb/modules/tests/test_hash_mc_evictions.py → torchrec/modules/tests/test_hash_mc_evictions.py` - `fb/modules/tests/test_hash_mc_modules.py → torchrec/modules/tests/test_hash_mc_modules.py` - Update `/modules/hash_mc_metrics.py` - Replace the tensorboard module with a local file logger in `hash_mc_metrics.py` module to avoid OSS CI test failures - The original tensorboard version is kept in the `torchrec/fb` folder. - Update the license declaration headers for the OSS files - Add `unittest.skipif` condition to `test_dynamically_switch_inference_training_mode` and `test_output_global_offset_tensor` to skip these tests when GPU is not available. - Update import packages in `torch/modules/tests/test_hash_mc_modules.py` and `torch/modules/tests/test_hash_mc_evictions.py` from `torch.fb.modules.hash_mc_*` to `torch.modules.hash_mc_*`. And update the BUCK file correspondingly from `/torchrec/fb/modules/hash_mc_*` to `torchrec/modules/hash_mc_*` ### Next step - Wait for `fbpkg` to pick up the Diff and update the existing dependencies on MPZCH modules from `torchrec/fb/module` to `torchrec/modules`. - Wait for all the dependencies being updated, then clean up the files in `torchrec/fb/module`. Differential Revision: D77825114
1 parent 538bfa4 commit 35958ec

File tree

5 files changed

+1852
-0
lines changed

5 files changed

+1852
-0
lines changed

torchrec/modules/hash_mc_evictions.py

Lines changed: 255 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,255 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
import logging
9+
import time
10+
from dataclasses import dataclass
11+
from enum import Enum, unique
12+
from typing import List, Optional, Tuple
13+
14+
import torch
15+
from pyre_extensions import none_throws
16+
17+
from torchrec.sparse.jagged_tensor import JaggedTensor
18+
19+
logger: logging.Logger = logging.getLogger(__name__)
20+
21+
22+
@unique
23+
class HashZchEvictionPolicyName(Enum):
24+
# eviction based on the time the ID is last seen during training,
25+
# and a single TTL
26+
SINGLE_TTL_EVICTION = "SINGLE_TTL_EVICTION"
27+
# eviction based on the time the ID is last seen during training,
28+
# and per-feature TTLs
29+
PER_FEATURE_TTL_EVICTION = "PER_FEATURE_TTL_EVICTION"
30+
# eviction based on least recently seen ID within the probe range
31+
LRU_EVICTION = "LRU_EVICTION"
32+
33+
34+
@torch.jit.script
35+
@dataclass
36+
class HashZchEvictionConfig:
37+
features: List[str]
38+
single_ttl: Optional[int] = None
39+
per_feature_ttl: Optional[List[int]] = None
40+
41+
42+
@torch.fx.wrap
43+
def get_kernel_from_policy(
44+
policy_name: Optional[HashZchEvictionPolicyName],
45+
) -> int:
46+
return (
47+
1
48+
if policy_name is not None
49+
and policy_name == HashZchEvictionPolicyName.LRU_EVICTION
50+
else 0
51+
)
52+
53+
54+
class HashZchEvictionScorer:
55+
def __init__(self, config: HashZchEvictionConfig) -> None:
56+
self._config: HashZchEvictionConfig = config
57+
58+
def gen_score(self, feature: JaggedTensor, device: torch.device) -> torch.Tensor:
59+
return torch.empty(0, device=device)
60+
61+
def gen_threshold(self) -> int:
62+
return -1
63+
64+
65+
class HashZchSingleTtlScorer(HashZchEvictionScorer):
66+
def gen_score(self, feature: JaggedTensor, device: torch.device) -> torch.Tensor:
67+
assert (
68+
self._config.single_ttl is not None and self._config.single_ttl > 0
69+
), "To use scorer HashZchSingleTtlScorer, a positive single_ttl is required."
70+
71+
return torch.full_like(
72+
feature.values(),
73+
# pyre-ignore [58]
74+
self._config.single_ttl + int(time.time() / 3600),
75+
dtype=torch.int32,
76+
device=device,
77+
)
78+
79+
def gen_threshold(self) -> int:
80+
return int(time.time() / 3600)
81+
82+
83+
class HashZchPerFeatureTtlScorer(HashZchEvictionScorer):
84+
def __init__(self, config: HashZchEvictionConfig) -> None:
85+
super().__init__(config)
86+
87+
assert self._config.per_feature_ttl is not None and len(
88+
self._config.features
89+
) == len(
90+
# pyre-ignore [6]
91+
self._config.per_feature_ttl
92+
), "To use scorer HashZchPerFeatureTtlScorer, a 1:1 mapping between features and per_feature_ttl is required."
93+
94+
self._per_feature_ttl = torch.IntTensor(self._config.per_feature_ttl)
95+
96+
def gen_score(self, feature: JaggedTensor, device: torch.device) -> torch.Tensor:
97+
feature_split = feature.weights()
98+
assert feature_split.size(0) == self._per_feature_ttl.size(0)
99+
100+
scores = self._per_feature_ttl.repeat_interleave(feature_split) + int(
101+
time.time() / 3600
102+
)
103+
104+
return scores.to(device=device)
105+
106+
def gen_threshold(self) -> int:
107+
return int(time.time() / 3600)
108+
109+
110+
@torch.fx.wrap
111+
def get_eviction_scorer(
112+
policy_name: str, config: HashZchEvictionConfig
113+
) -> HashZchEvictionScorer:
114+
if policy_name == HashZchEvictionPolicyName.SINGLE_TTL_EVICTION:
115+
return HashZchSingleTtlScorer(config)
116+
elif policy_name == HashZchEvictionPolicyName.PER_FEATURE_TTL_EVICTION:
117+
return HashZchPerFeatureTtlScorer(config)
118+
elif policy_name == HashZchEvictionPolicyName.LRU_EVICTION:
119+
return HashZchSingleTtlScorer(config)
120+
else:
121+
return HashZchEvictionScorer(config)
122+
123+
124+
class HashZchThresholdEvictionModule(torch.nn.Module):
125+
"""
126+
This module manages the computation of eviction score for input IDs. Based on the selected
127+
eviction policy, a scorer is initiated to generate a score for each ID. The kernel
128+
will use this score to make eviction decisions.
129+
130+
Args:
131+
policy_name: an enum value that indicates the eviction policy to use.
132+
config: a config that contains information needed to run the eviction policy.
133+
134+
Example::
135+
module = HashZchThresholdEvictionModule(...)
136+
score = module(feature)
137+
"""
138+
139+
_eviction_scorer: HashZchEvictionScorer
140+
141+
def __init__(
142+
self,
143+
policy_name: HashZchEvictionPolicyName,
144+
config: HashZchEvictionConfig,
145+
) -> None:
146+
super().__init__()
147+
148+
self._policy_name: HashZchEvictionPolicyName = policy_name
149+
self._config: HashZchEvictionConfig = config
150+
self._eviction_scorer = get_eviction_scorer(
151+
policy_name=self._policy_name,
152+
config=self._config,
153+
)
154+
155+
logger.info(
156+
f"HashZchThresholdEvictionModule: {self._policy_name=}, {self._config=}"
157+
)
158+
159+
def forward(
160+
self, feature: JaggedTensor, device: torch.device
161+
) -> Tuple[torch.Tensor, int]:
162+
"""
163+
Args:
164+
feature: a jagged tensor that contains the input IDs, and their lengths and
165+
weights (feature split).
166+
device: device of the tensor.
167+
168+
Returns:
169+
a tensor that contains the eviction score for each ID, plus an eviction threshold.
170+
"""
171+
return (
172+
self._eviction_scorer.gen_score(feature, device),
173+
self._eviction_scorer.gen_threshold(),
174+
)
175+
176+
177+
class HashZchOptEvictionModule(torch.nn.Module):
178+
"""
179+
This module manages the eviction of IDs from the ZCH table based on the selected eviction policy.
180+
Args:
181+
policy_name: an enum value that indicates the eviction policy to use.
182+
Example:
183+
module = HashZchOptEvictionModule(policy_name=HashZchEvictionPolicyName.LRU_EVICTION)
184+
"""
185+
186+
def __init__(
187+
self,
188+
policy_name: HashZchEvictionPolicyName,
189+
) -> None:
190+
super().__init__()
191+
192+
self._policy_name: HashZchEvictionPolicyName = policy_name
193+
194+
def forward(self, feature: JaggedTensor, device: torch.device) -> Tuple[None, int]:
195+
"""
196+
Does not apply to this Eviction Policy. Returns None and -1.
197+
Args:
198+
feature: No op
199+
Returns:
200+
None, -1
201+
"""
202+
return None, -1
203+
204+
205+
@torch.fx.wrap
206+
def get_eviction_module(
207+
policy_name: HashZchEvictionPolicyName, config: Optional[HashZchEvictionConfig]
208+
) -> torch.nn.Module:
209+
if policy_name in (
210+
HashZchEvictionPolicyName.SINGLE_TTL_EVICTION,
211+
HashZchEvictionPolicyName.PER_FEATURE_TTL_EVICTION,
212+
HashZchEvictionPolicyName.LRU_EVICTION,
213+
):
214+
return HashZchThresholdEvictionModule(policy_name, none_throws(config))
215+
else:
216+
return HashZchOptEvictionModule(policy_name)
217+
218+
219+
class HashZchEvictionModule(torch.nn.Module):
220+
"""
221+
This module manages the eviction of IDs from the ZCH table based on the selected eviction policy.
222+
Args:
223+
policy_name: an enum value that indicates the eviction policy to use.
224+
device: device of the tensor.
225+
config: an optional config required if threshold based eviction is selected.
226+
Example:
227+
module = HashZchEvictionModule(policy_name=HashZchEvictionPolicyName.LRU_EVICTION)
228+
"""
229+
230+
def __init__(
231+
self,
232+
policy_name: HashZchEvictionPolicyName,
233+
device: torch.device,
234+
config: Optional[HashZchEvictionConfig],
235+
) -> None:
236+
super().__init__()
237+
238+
self._policy_name: HashZchEvictionPolicyName = policy_name
239+
self._device: torch.device = device
240+
self._eviction_module: torch.nn.Module = get_eviction_module(
241+
self._policy_name, config
242+
)
243+
244+
logger.info(f"HashZchEvictionModule: {self._policy_name=}, {self._device=}")
245+
246+
def forward(self, feature: JaggedTensor) -> Tuple[Optional[torch.Tensor], int]:
247+
"""
248+
Args:
249+
feature: a jagged tensor that contains the input IDs, and their lengths and
250+
weights (feature split).
251+
252+
Returns:
253+
For threshold eviction, a tensor that contains the eviction score for each ID, plus an eviction threshold. Otherwise None and -1.
254+
"""
255+
return self._eviction_module(feature, self._device)

0 commit comments

Comments
 (0)