Skip to content

Enhance ParameterCosntraint configuration in the becnhmarking script #3082

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 36 additions & 3 deletions torchrec/distributed/benchmark/benchmark_train_sparsenn.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

import copy

from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import Any, cast, Dict, List, Optional, Tuple, Type, Union

import click
Expand All @@ -26,6 +26,7 @@
from torchrec.distributed.comm import get_local_size
from torchrec.distributed.embedding_types import EmbeddingComputeKernel
from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology
from torchrec.distributed.planner.constants import NUM_POOLINGS, POOLING_FACTOR
from torchrec.distributed.planner.planners import HeteroEmbeddingShardingPlanner
from torchrec.distributed.planner.types import ParameterConstraints

Expand Down Expand Up @@ -80,6 +81,9 @@ class RunOptions:
planner_type (str): Type of sharding planner to use. Options are:
- "embedding": EmbeddingShardingPlanner (default)
- "hetero": HeteroEmbeddingShardingPlanner
pooling_factors (Optional[List[float]]): Pooling factors for each feature of the table.
This is the average number of values each sample has for the feature.
num_poolings (Optional[List[float]]): Number of poolings for each feature of the table.
"""

world_size: int = 2
Expand All @@ -89,6 +93,8 @@ class RunOptions:
input_type: str = "kjt"
profile: str = ""
planner_type: str = "embedding"
pooling_factors: Optional[List[float]] = None
num_poolings: Optional[List[float]] = None


@dataclass
Expand All @@ -111,7 +117,7 @@ class EmbeddingTablesConfig:

num_unweighted_features: int = 100
num_weighted_features: int = 100
embedding_feature_dim: int = 512
embedding_feature_dim: int = 128

def generate_tables(
self,
Expand Down Expand Up @@ -286,17 +292,36 @@ def _generate_planner(
tables: Optional[List[EmbeddingBagConfig]],
weighted_tables: Optional[List[EmbeddingBagConfig]],
sharding_type: ShardingType,
compute_kernel: EmbeddingComputeKernel = EmbeddingComputeKernel.FUSED,
compute_kernel: EmbeddingComputeKernel,
num_batches: int,
batch_size: int,
pooling_factors: Optional[List[float]],
num_poolings: Optional[List[float]],
) -> Union[EmbeddingShardingPlanner, HeteroEmbeddingShardingPlanner]:
# Create parameter constraints for tables
constraints = {}

if pooling_factors is None:
pooling_factors = [POOLING_FACTOR] * num_batches

if num_poolings is None:
num_poolings = [NUM_POOLINGS] * num_batches

batch_sizes = [batch_size] * num_batches

assert (
len(pooling_factors) == num_batches and len(num_poolings) == num_batches
), "The length of pooling_factors and num_poolings must match the number of batches."

if tables is not None:
for table in tables:
constraints[table.name] = ParameterConstraints(
sharding_types=[sharding_type.value],
compute_kernels=[compute_kernel.value],
device_group="cuda",
pooling_factors=pooling_factors,
num_poolings=num_poolings,
batch_sizes=batch_sizes,
)

if weighted_tables is not None:
Expand All @@ -305,6 +330,10 @@ def _generate_planner(
sharding_types=[sharding_type.value],
compute_kernels=[compute_kernel.value],
device_group="cuda",
pooling_factors=pooling_factors,
num_poolings=num_poolings,
batch_sizes=batch_sizes,
is_weighted=True,
)

if planner_type == "embedding":
Expand Down Expand Up @@ -413,6 +442,10 @@ def runner(
weighted_tables=weighted_tables,
sharding_type=run_option.sharding_type,
compute_kernel=run_option.compute_kernel,
num_batches=run_option.num_batches,
batch_size=input_config.batch_size,
pooling_factors=run_option.pooling_factors,
num_poolings=run_option.num_poolings,
)

sharded_model, optimizer = _generate_sharded_model_and_optimizer(
Expand Down
Loading