Skip to content

OSS TorchRec MPZCH Modules #3162

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
255 changes: 255 additions & 0 deletions torchrec/modules/hash_mc_evictions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,255 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import logging
import time
from dataclasses import dataclass
from enum import Enum, unique
from typing import List, Optional, Tuple

import torch
from pyre_extensions import none_throws

from torchrec.sparse.jagged_tensor import JaggedTensor

logger: logging.Logger = logging.getLogger(__name__)


@unique
class HashZchEvictionPolicyName(Enum):
# eviction based on the time the ID is last seen during training,
# and a single TTL
SINGLE_TTL_EVICTION = "SINGLE_TTL_EVICTION"
# eviction based on the time the ID is last seen during training,
# and per-feature TTLs
PER_FEATURE_TTL_EVICTION = "PER_FEATURE_TTL_EVICTION"
# eviction based on least recently seen ID within the probe range
LRU_EVICTION = "LRU_EVICTION"


@torch.jit.script
@dataclass
class HashZchEvictionConfig:
features: List[str]
single_ttl: Optional[int] = None
per_feature_ttl: Optional[List[int]] = None


@torch.fx.wrap
def get_kernel_from_policy(
policy_name: Optional[HashZchEvictionPolicyName],
) -> int:
return (
1
if policy_name is not None
and policy_name == HashZchEvictionPolicyName.LRU_EVICTION
else 0
)


class HashZchEvictionScorer:
def __init__(self, config: HashZchEvictionConfig) -> None:
self._config: HashZchEvictionConfig = config

def gen_score(self, feature: JaggedTensor, device: torch.device) -> torch.Tensor:
return torch.empty(0, device=device)

def gen_threshold(self) -> int:
return -1


class HashZchSingleTtlScorer(HashZchEvictionScorer):
def gen_score(self, feature: JaggedTensor, device: torch.device) -> torch.Tensor:
assert (
self._config.single_ttl is not None and self._config.single_ttl > 0
), "To use scorer HashZchSingleTtlScorer, a positive single_ttl is required."

return torch.full_like(
feature.values(),
# pyre-ignore [58]
self._config.single_ttl + int(time.time() / 3600),
dtype=torch.int32,
device=device,
)

def gen_threshold(self) -> int:
return int(time.time() / 3600)


class HashZchPerFeatureTtlScorer(HashZchEvictionScorer):
def __init__(self, config: HashZchEvictionConfig) -> None:
super().__init__(config)

assert self._config.per_feature_ttl is not None and len(
self._config.features
) == len(
# pyre-ignore [6]
self._config.per_feature_ttl
), "To use scorer HashZchPerFeatureTtlScorer, a 1:1 mapping between features and per_feature_ttl is required."

self._per_feature_ttl = torch.IntTensor(self._config.per_feature_ttl)

def gen_score(self, feature: JaggedTensor, device: torch.device) -> torch.Tensor:
feature_split = feature.weights()
assert feature_split.size(0) == self._per_feature_ttl.size(0)

scores = self._per_feature_ttl.repeat_interleave(feature_split) + int(
time.time() / 3600
)

return scores.to(device=device)

def gen_threshold(self) -> int:
return int(time.time() / 3600)


@torch.fx.wrap
def get_eviction_scorer(
policy_name: str, config: HashZchEvictionConfig
) -> HashZchEvictionScorer:
if policy_name == HashZchEvictionPolicyName.SINGLE_TTL_EVICTION:
return HashZchSingleTtlScorer(config)
elif policy_name == HashZchEvictionPolicyName.PER_FEATURE_TTL_EVICTION:
return HashZchPerFeatureTtlScorer(config)
elif policy_name == HashZchEvictionPolicyName.LRU_EVICTION:
return HashZchSingleTtlScorer(config)
else:
return HashZchEvictionScorer(config)


class HashZchThresholdEvictionModule(torch.nn.Module):
"""
This module manages the computation of eviction score for input IDs. Based on the selected
eviction policy, a scorer is initiated to generate a score for each ID. The kernel
will use this score to make eviction decisions.

Args:
policy_name: an enum value that indicates the eviction policy to use.
config: a config that contains information needed to run the eviction policy.

Example::
module = HashZchThresholdEvictionModule(...)
score = module(feature)
"""

_eviction_scorer: HashZchEvictionScorer

def __init__(
self,
policy_name: HashZchEvictionPolicyName,
config: HashZchEvictionConfig,
) -> None:
super().__init__()

self._policy_name: HashZchEvictionPolicyName = policy_name
self._config: HashZchEvictionConfig = config
self._eviction_scorer = get_eviction_scorer(
policy_name=self._policy_name,
config=self._config,
)

logger.info(
f"HashZchThresholdEvictionModule: {self._policy_name=}, {self._config=}"
)

def forward(
self, feature: JaggedTensor, device: torch.device
) -> Tuple[torch.Tensor, int]:
"""
Args:
feature: a jagged tensor that contains the input IDs, and their lengths and
weights (feature split).
device: device of the tensor.

Returns:
a tensor that contains the eviction score for each ID, plus an eviction threshold.
"""
return (
self._eviction_scorer.gen_score(feature, device),
self._eviction_scorer.gen_threshold(),
)


class HashZchOptEvictionModule(torch.nn.Module):
"""
This module manages the eviction of IDs from the ZCH table based on the selected eviction policy.
Args:
policy_name: an enum value that indicates the eviction policy to use.
Example:
module = HashZchOptEvictionModule(policy_name=HashZchEvictionPolicyName.LRU_EVICTION)
"""

def __init__(
self,
policy_name: HashZchEvictionPolicyName,
) -> None:
super().__init__()

self._policy_name: HashZchEvictionPolicyName = policy_name

def forward(self, feature: JaggedTensor, device: torch.device) -> Tuple[None, int]:
"""
Does not apply to this Eviction Policy. Returns None and -1.
Args:
feature: No op
Returns:
None, -1
"""
return None, -1


@torch.fx.wrap
def get_eviction_module(
policy_name: HashZchEvictionPolicyName, config: Optional[HashZchEvictionConfig]
) -> torch.nn.Module:
if policy_name in (
HashZchEvictionPolicyName.SINGLE_TTL_EVICTION,
HashZchEvictionPolicyName.PER_FEATURE_TTL_EVICTION,
HashZchEvictionPolicyName.LRU_EVICTION,
):
return HashZchThresholdEvictionModule(policy_name, none_throws(config))
else:
return HashZchOptEvictionModule(policy_name)


class HashZchEvictionModule(torch.nn.Module):
"""
This module manages the eviction of IDs from the ZCH table based on the selected eviction policy.
Args:
policy_name: an enum value that indicates the eviction policy to use.
device: device of the tensor.
config: an optional config required if threshold based eviction is selected.
Example:
module = HashZchEvictionModule(policy_name=HashZchEvictionPolicyName.LRU_EVICTION)
"""

def __init__(
self,
policy_name: HashZchEvictionPolicyName,
device: torch.device,
config: Optional[HashZchEvictionConfig],
) -> None:
super().__init__()

self._policy_name: HashZchEvictionPolicyName = policy_name
self._device: torch.device = device
self._eviction_module: torch.nn.Module = get_eviction_module(
self._policy_name, config
)

logger.info(f"HashZchEvictionModule: {self._policy_name=}, {self._device=}")

def forward(self, feature: JaggedTensor) -> Tuple[Optional[torch.Tensor], int]:
"""
Args:
feature: a jagged tensor that contains the input IDs, and their lengths and
weights (feature split).

Returns:
For threshold eviction, a tensor that contains the eviction score for each ID, plus an eviction threshold. Otherwise None and -1.
"""
return self._eviction_module(feature, self._device)
Loading
Loading