From 4686f274a5d14a472e9a1622f9908cae72971943 Mon Sep 17 00:00:00 2001 From: yernar Date: Thu, 12 Jun 2025 17:25:58 -0700 Subject: [PATCH 1/3] Enhance ParameterCosntraint configuration in the becnhmarking script Summary: Updated the `ParameterConstraints` in the TorchRec benchmarking script to include pooling factors, number of poolings, and batch sizes. This enhancement allows for more detailed configuration of embedding tables, improving the flexibility and precision of sharding strategies in distributed training scenarios. Differential Revision: D76440004 --- .../benchmark/benchmark_train_sparsenn.py | 39 +++++++++++++++++-- 1 file changed, 36 insertions(+), 3 deletions(-) diff --git a/torchrec/distributed/benchmark/benchmark_train_sparsenn.py b/torchrec/distributed/benchmark/benchmark_train_sparsenn.py index df6ebd917..09fc5f665 100644 --- a/torchrec/distributed/benchmark/benchmark_train_sparsenn.py +++ b/torchrec/distributed/benchmark/benchmark_train_sparsenn.py @@ -11,7 +11,7 @@ import copy -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import Any, cast, Dict, List, Optional, Tuple, Type, Union import click @@ -26,6 +26,7 @@ from torchrec.distributed.comm import get_local_size from torchrec.distributed.embedding_types import EmbeddingComputeKernel from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology +from torchrec.distributed.planner.constants import NUM_POOLINGS, POOLING_FACTOR from torchrec.distributed.planner.planners import HeteroEmbeddingShardingPlanner from torchrec.distributed.planner.types import ParameterConstraints @@ -80,6 +81,9 @@ class RunOptions: planner_type (str): Type of sharding planner to use. Options are: - "embedding": EmbeddingShardingPlanner (default) - "hetero": HeteroEmbeddingShardingPlanner + pooling_factors (Optional[List[float]]): Pooling factors for each feature of the table. + This is the average number of values each sample has for the feature. + num_poolings (Optional[List[float]]): Number of poolings for each feature of the table. """ world_size: int = 2 @@ -89,6 +93,8 @@ class RunOptions: input_type: str = "kjt" profile: str = "" planner_type: str = "embedding" + pooling_factors: Optional[List[float]] = None + num_poolings: Optional[List[float]] = None @dataclass @@ -111,7 +117,7 @@ class EmbeddingTablesConfig: num_unweighted_features: int = 100 num_weighted_features: int = 100 - embedding_feature_dim: int = 512 + embedding_feature_dim: int = 128 def generate_tables( self, @@ -286,17 +292,36 @@ def _generate_planner( tables: Optional[List[EmbeddingBagConfig]], weighted_tables: Optional[List[EmbeddingBagConfig]], sharding_type: ShardingType, - compute_kernel: EmbeddingComputeKernel = EmbeddingComputeKernel.FUSED, + compute_kernel: EmbeddingComputeKernel, + num_batches: int, + batch_size: int, + pooling_factors: Optional[List[float]], + num_poolings: Optional[List[float]], ) -> Union[EmbeddingShardingPlanner, HeteroEmbeddingShardingPlanner]: # Create parameter constraints for tables constraints = {} + if pooling_factors is None: + pooling_factors = [POOLING_FACTOR] * num_batches + + if num_poolings is None: + num_poolings = [NUM_POOLINGS] * num_batches + + batch_sizes = [batch_size] * num_batches + + assert ( + len(pooling_factors) == num_batches and len(num_poolings) == num_batches + ), "The length of pooling_factors and num_poolings must match the number of batches." + if tables is not None: for table in tables: constraints[table.name] = ParameterConstraints( sharding_types=[sharding_type.value], compute_kernels=[compute_kernel.value], device_group="cuda", + pooling_factors=pooling_factors, + num_poolings=num_poolings, + batch_sizes=batch_sizes, ) if weighted_tables is not None: @@ -305,6 +330,10 @@ def _generate_planner( sharding_types=[sharding_type.value], compute_kernels=[compute_kernel.value], device_group="cuda", + pooling_factors=pooling_factors, + num_poolings=num_poolings, + batch_sizes=batch_sizes, + is_weighted=True, ) if planner_type == "embedding": @@ -413,6 +442,10 @@ def runner( weighted_tables=weighted_tables, sharding_type=run_option.sharding_type, compute_kernel=run_option.compute_kernel, + num_batches=run_option.num_batches, + batch_size=input_config.batch_size, + pooling_factors=run_option.pooling_factors, + num_poolings=run_option.num_poolings, ) sharded_model, optimizer = _generate_sharded_model_and_optimizer( From 7d6306ab8903a724bddceab205fc9eb7723609f9 Mon Sep 17 00:00:00 2001 From: yernar Date: Thu, 12 Jun 2025 17:25:58 -0700 Subject: [PATCH 2/3] Added model config for supporting SparseNN, TowerSparseNN, TowerCollectionSparseNN models. Summary: Refactored the training benchmarking by moving generative helper functinos into separate util file and added a model configuration that supports SparseNN, TowerSparseNN, TowerCollectionSparseNN models. Future commits will add support for DeepFM and DLRM models. Differential Revision: D76539867 --- .../benchmark/benchmark_pipeline_utils.py | 391 ++++++++++++++ .../benchmark/benchmark_train_pipeline.py | 258 +++++++++ .../benchmark/benchmark_train_sparsenn.py | 504 ------------------ 3 files changed, 649 insertions(+), 504 deletions(-) create mode 100644 torchrec/distributed/benchmark/benchmark_pipeline_utils.py create mode 100644 torchrec/distributed/benchmark/benchmark_train_pipeline.py delete mode 100644 torchrec/distributed/benchmark/benchmark_train_sparsenn.py diff --git a/torchrec/distributed/benchmark/benchmark_pipeline_utils.py b/torchrec/distributed/benchmark/benchmark_pipeline_utils.py new file mode 100644 index 000000000..9ecfaa35a --- /dev/null +++ b/torchrec/distributed/benchmark/benchmark_pipeline_utils.py @@ -0,0 +1,391 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import copy +from dataclasses import dataclass +from typing import Any, cast, Dict, List, Optional, Tuple, Type, Union + +import torch +import torch.distributed as dist +from torch import nn, optim +from torch.optim import Optimizer +from torchrec.distributed import DistributedModelParallel +from torchrec.distributed.embedding_types import EmbeddingComputeKernel +from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology +from torchrec.distributed.planner.constants import NUM_POOLINGS, POOLING_FACTOR +from torchrec.distributed.planner.planners import HeteroEmbeddingShardingPlanner +from torchrec.distributed.planner.types import ParameterConstraints +from torchrec.distributed.test_utils.test_input import ModelInput +from torchrec.distributed.test_utils.test_model import ( + TestEBCSharder, + TestOverArchLarge, + TestSparseNN, + TestTowerCollectionSparseNN, + TestTowerSparseNN, +) +from torchrec.distributed.train_pipeline import ( + TrainPipelineBase, + TrainPipelineFusedSparseDist, + TrainPipelineSparseDist, +) +from torchrec.distributed.train_pipeline.train_pipelines import ( + PrefetchTrainPipelineSparseDist, + TrainPipelineSemiSync, +) +from torchrec.distributed.types import ModuleSharder, ShardingEnv, ShardingType +from torchrec.modules.embedding_configs import EmbeddingBagConfig + + +@dataclass +class ModelConfig: + model_name: str = "test_sparsenn" + + batch_size: int = 8192 + num_float_features: int = 10 + feature_pooling_avg: int = 10 + use_offsets: bool = False + dev_str: str = "" + long_kjt_indices: bool = True + long_kjt_offsets: bool = True + long_kjt_lengths: bool = True + pin_memory: bool = True + + def generate_model( + self, + tables: List[EmbeddingBagConfig], + weighted_tables: List[EmbeddingBagConfig], + dense_device: torch.device, + ) -> nn.Module: + if self.model_name == "test_sparsenn": + return TestSparseNN( + tables=tables, + weighted_tables=weighted_tables, + dense_device=dense_device, + sparse_device=torch.device("meta"), + over_arch_clazz=TestOverArchLarge, + ) + elif self.model_name == "test_tower_sparsenn": + return TestTowerSparseNN( + tables=tables, + weighted_tables=weighted_tables, + dense_device=dense_device, + sparse_device=torch.device("meta"), + num_float_features=self.num_float_features, + ) + elif self.model_name == "test_tower_collection_sparsenn": + return TestTowerCollectionSparseNN( + tables=tables, + weighted_tables=weighted_tables, + dense_device=dense_device, + sparse_device=torch.device("meta"), + num_float_features=self.num_float_features, + ) + else: + raise RuntimeError(f"Unknown model name: {self.model_name}") + + +def generate_tables( + num_unweighted_features: int, + num_weighted_features: int, + embedding_feature_dim: int, +) -> Tuple[ + List[EmbeddingBagConfig], + List[EmbeddingBagConfig], +]: + """ + Generate embedding bag configurations for both unweighted and weighted features. + + This function creates two lists of EmbeddingBagConfig objects: + 1. Unweighted tables: Named as "table_{i}" with feature names "feature_{i}" + 2. Weighted tables: Named as "weighted_table_{i}" with feature names "weighted_feature_{i}" + + For both types, the number of embeddings scales with the feature index, + calculated as max(i + 1, 100) * 1000. + + Args: + num_unweighted_features (int): Number of unweighted features to generate. + num_weighted_features (int): Number of weighted features to generate. + embedding_feature_dim (int): Dimension of the embedding vectors. + + Returns: + Tuple[List[EmbeddingBagConfig], List[EmbeddingBagConfig]]: A tuple containing + two lists - the first for unweighted embedding tables and the second for + weighted embedding tables. + """ + tables = [ + EmbeddingBagConfig( + num_embeddings=max(i + 1, 100) * 1000, + embedding_dim=embedding_feature_dim, + name="table_" + str(i), + feature_names=["feature_" + str(i)], + ) + for i in range(num_unweighted_features) + ] + weighted_tables = [ + EmbeddingBagConfig( + num_embeddings=max(i + 1, 100) * 1000, + embedding_dim=embedding_feature_dim, + name="weighted_table_" + str(i), + feature_names=["weighted_feature_" + str(i)], + ) + for i in range(num_weighted_features) + ] + return tables, weighted_tables + + +def generate_pipeline( + pipeline_type: str, + emb_lookup_stream: str, + model: nn.Module, + opt: torch.optim.Optimizer, + device: torch.device, +) -> Union[TrainPipelineBase, TrainPipelineSparseDist]: + """ + Generate a training pipeline instance based on the configuration. + + This function creates and returns the appropriate training pipeline object + based on the pipeline type specified. Different pipeline types are optimized + for different training scenarios. + + Args: + pipeline_type (str): The type of training pipeline to use. Options include: + - "base": Basic training pipeline + - "sparse": Pipeline optimized for sparse operations + - "fused": Pipeline with fused sparse distribution + - "semi": Semi-synchronous training pipeline + - "prefetch": Pipeline with prefetching for sparse distribution + emb_lookup_stream (str): The stream to use for embedding lookups. + Only used by certain pipeline types (e.g., "fused"). + model (nn.Module): The model to be trained. + opt (torch.optim.Optimizer): The optimizer to use for training. + device (torch.device): The device to run the training on. + + Returns: + Union[TrainPipelineBase, TrainPipelineSparseDist]: An instance of the + appropriate training pipeline class based on the configuration. + + Raises: + RuntimeError: If an unknown pipeline type is specified. + """ + + _pipeline_cls: Dict[ + str, Type[Union[TrainPipelineBase, TrainPipelineSparseDist]] + ] = { + "base": TrainPipelineBase, + "sparse": TrainPipelineSparseDist, + "fused": TrainPipelineFusedSparseDist, + "semi": TrainPipelineSemiSync, + "prefetch": PrefetchTrainPipelineSparseDist, + } + + if pipeline_type == "semi": + return TrainPipelineSemiSync( + model=model, optimizer=opt, device=device, start_batch=0 + ) + elif pipeline_type == "fused": + return TrainPipelineFusedSparseDist( + model=model, + optimizer=opt, + device=device, + emb_lookup_stream=emb_lookup_stream, + ) + elif pipeline_type in _pipeline_cls: + Pipeline = _pipeline_cls[pipeline_type] + return Pipeline(model=model, optimizer=opt, device=device) + else: + raise RuntimeError(f"unknown pipeline option {pipeline_type}") + + +def generate_planner( + planner_type: str, + topology: Topology, + tables: Optional[List[EmbeddingBagConfig]], + weighted_tables: Optional[List[EmbeddingBagConfig]], + sharding_type: ShardingType, + compute_kernel: EmbeddingComputeKernel, + num_batches: int, + batch_size: int, + pooling_factors: Optional[List[float]], + num_poolings: Optional[List[float]], +) -> Union[EmbeddingShardingPlanner, HeteroEmbeddingShardingPlanner]: + """ + Generate an embedding sharding planner based on the specified configuration. + + Args: + planner_type: Type of planner to use ("embedding" or "hetero") + topology: Network topology for distributed training + tables: List of unweighted embedding tables + weighted_tables: List of weighted embedding tables + sharding_type: Strategy for sharding embedding tables + compute_kernel: Compute kernel to use for embedding tables + num_batches: Number of batches to process + batch_size: Size of each batch + pooling_factors: Pooling factors for each feature of the table + num_poolings: Number of poolings for each feature of the table + + Returns: + An instance of EmbeddingShardingPlanner or HeteroEmbeddingShardingPlanner + + Raises: + RuntimeError: If an unknown planner type is specified + """ + # Create parameter constraints for tables + constraints = {} + + if pooling_factors is None: + pooling_factors = [POOLING_FACTOR] * num_batches + + if num_poolings is None: + num_poolings = [NUM_POOLINGS] * num_batches + + batch_sizes = [batch_size] * num_batches + + assert ( + len(pooling_factors) == num_batches and len(num_poolings) == num_batches + ), "The length of pooling_factors and num_poolings must match the number of batches." + + if tables is not None: + for table in tables: + constraints[table.name] = ParameterConstraints( + sharding_types=[sharding_type.value], + compute_kernels=[compute_kernel.value], + device_group="cuda", + pooling_factors=pooling_factors, + num_poolings=num_poolings, + batch_sizes=batch_sizes, + ) + + if weighted_tables is not None: + for table in weighted_tables: + constraints[table.name] = ParameterConstraints( + sharding_types=[sharding_type.value], + compute_kernels=[compute_kernel.value], + device_group="cuda", + pooling_factors=pooling_factors, + num_poolings=num_poolings, + batch_sizes=batch_sizes, + is_weighted=True, + ) + + if planner_type == "embedding": + return EmbeddingShardingPlanner( + topology=topology, + constraints=constraints if constraints else None, + ) + elif planner_type == "hetero": + topology_groups = {"cuda": topology} + return HeteroEmbeddingShardingPlanner( + topology_groups=topology_groups, + constraints=constraints if constraints else None, + ) + else: + raise RuntimeError(f"Unknown planner type: {planner_type}") + + +def generate_sharded_model_and_optimizer( + model: nn.Module, + sharding_type: str, + kernel_type: str, + pg: dist.ProcessGroup, + device: torch.device, + fused_params: Optional[Dict[str, Any]] = None, + planner: Optional[ + Union[ + EmbeddingShardingPlanner, + HeteroEmbeddingShardingPlanner, + ] + ] = None, +) -> Tuple[nn.Module, Optimizer]: + # Ensure fused_params is always a dictionary + fused_params_dict = {} if fused_params is None else fused_params + + sharder = TestEBCSharder( + sharding_type=sharding_type, + kernel_type=kernel_type, + fused_params=fused_params_dict, + ) + sharders = [cast(ModuleSharder[nn.Module], sharder)] + + # Use planner if provided + plan = None + if planner is not None: + if pg is not None: + plan = planner.collective_plan(model, sharders, pg) + else: + plan = planner.plan(model, sharders) + + sharded_model = DistributedModelParallel( + module=copy.deepcopy(model), + env=ShardingEnv.from_process_group(pg), + init_data_parallel=True, + device=device, + sharders=sharders, + plan=plan, + ).to(device) + optimizer = optim.SGD( + [ + param + for name, param in sharded_model.named_parameters() + if "sparse" not in name + ], + lr=0.1, + ) + return sharded_model, optimizer + + +def generate_data( + model_class_name: str, + tables: List[EmbeddingBagConfig], + weighted_tables: List[EmbeddingBagConfig], + model_config: ModelConfig, + num_batches: int, +) -> List[ModelInput]: + """ + Generate model input data for benchmarking. + + Args: + tables: List of unweighted embedding tables + weighted_tables: List of weighted embedding tables + model_config: Configuration for model generation + num_batches: Number of batches to generate + + Returns: + A list of ModelInput objects representing the generated batches + """ + device = torch.device(model_config.dev_str) if model_config.dev_str else None + + if ( + model_class_name == "TestSparseNN" + or model_class_name == "TestTowerSparseNN" + or model_class_name == "TestTowerCollectionSparseNN" + ): + return [ + ModelInput.generate( + batch_size=model_config.batch_size, + tables=tables, + weighted_tables=weighted_tables, + num_float_features=model_config.num_float_features, + pooling_avg=model_config.feature_pooling_avg, + use_offsets=model_config.use_offsets, + device=device, + indices_dtype=( + torch.int64 if model_config.long_kjt_indices else torch.int32 + ), + offsets_dtype=( + torch.int64 if model_config.long_kjt_offsets else torch.int32 + ), + lengths_dtype=( + torch.int64 if model_config.long_kjt_lengths else torch.int32 + ), + pin_memory=model_config.pin_memory, + ) + for _ in range(num_batches) + ] + else: + raise RuntimeError(f"Unknown model name: {model_config.model_name}") diff --git a/torchrec/distributed/benchmark/benchmark_train_pipeline.py b/torchrec/distributed/benchmark/benchmark_train_pipeline.py new file mode 100644 index 000000000..25a548696 --- /dev/null +++ b/torchrec/distributed/benchmark/benchmark_train_pipeline.py @@ -0,0 +1,258 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +#!/usr/bin/env python3 + +from dataclasses import dataclass +from typing import List, Optional + +import torch +from fbgemm_gpu.split_embedding_configs import EmbOptimType +from torch import nn +from torchrec.distributed.benchmark.benchmark_pipeline_utils import ( + generate_data, + generate_pipeline, + generate_planner, + generate_sharded_model_and_optimizer, + generate_tables, + ModelConfig, +) +from torchrec.distributed.benchmark.benchmark_utils import benchmark_func, cmd_conf +from torchrec.distributed.comm import get_local_size +from torchrec.distributed.embedding_types import EmbeddingComputeKernel +from torchrec.distributed.planner import Topology + +from torchrec.distributed.test_utils.multi_process import ( + MultiProcessContext, + run_multi_process_func, +) +from torchrec.distributed.test_utils.test_input import ModelInput +from torchrec.distributed.train_pipeline import TrainPipeline +from torchrec.distributed.types import ShardingType +from torchrec.modules.embedding_configs import EmbeddingBagConfig + + +@dataclass +class RunOptions: + """ + Configuration options for running sparse neural network benchmarks. + + This class defines the parameters that control how the benchmark is executed, + including distributed training settings, batch configuration, and profiling options. + + Args: + world_size (int): Number of processes/GPUs to use for distributed training. + Default is 2. + num_batches (int): Number of batches to process during the benchmark. + Default is 10. + sharding_type (ShardingType): Strategy for sharding embedding tables across devices. + Default is ShardingType.TABLE_WISE (entire tables are placed on single devices). + compute_kernel (EmbeddingComputeKernel): Compute kernel to use for embedding tables. + Default is EmbeddingComputeKernel.FUSED. + input_type (str): Type of input format to use for the model. + Default is "kjt" (KeyedJaggedTensor). + profile (str): Directory to save profiling results. If empty, profiling is disabled. + Default is "" (disabled). + planner_type (str): Type of sharding planner to use. Options are: + - "embedding": EmbeddingShardingPlanner (default) + - "hetero": HeteroEmbeddingShardingPlanner + pooling_factors (Optional[List[float]]): Pooling factors for each feature of the table. + This is the average number of values each sample has for the feature. + num_poolings (Optional[List[float]]): Number of poolings for each feature of the table. + """ + + world_size: int = 2 + num_batches: int = 10 + sharding_type: ShardingType = ShardingType.TABLE_WISE + compute_kernel: EmbeddingComputeKernel = EmbeddingComputeKernel.FUSED + input_type: str = "kjt" + profile: str = "" + planner_type: str = "embedding" + pooling_factors: Optional[List[float]] = None + num_poolings: Optional[List[float]] = None + + +@dataclass +class EmbeddingTablesConfig: + """ + Configuration for embedding tables. + + This class defines the parameters for generating embedding tables with both weighted + and unweighted features. + + Args: + num_unweighted_features (int): Number of unweighted features to generate. + Default is 100. + num_weighted_features (int): Number of weighted features to generate. + Default is 100. + embedding_feature_dim (int): Dimension of the embedding vectors. + Default is 128. + """ + + num_unweighted_features: int = 100 + num_weighted_features: int = 100 + embedding_feature_dim: int = 128 + + +@dataclass +class PipelineConfig: + """ + Configuration for training pipelines. + + This class defines the parameters for configuring the training pipeline. + + Args: + pipeline (str): The type of training pipeline to use. Options include: + - "base": Basic training pipeline + - "sparse": Pipeline optimized for sparse operations + - "fused": Pipeline with fused sparse distribution + - "semi": Semi-synchronous training pipeline + - "prefetch": Pipeline with prefetching for sparse distribution + Default is "base". + emb_lookup_stream (str): The stream to use for embedding lookups. + Only used by certain pipeline types (e.g., "fused"). + Default is "data_dist". + """ + + pipeline: str = "base" + emb_lookup_stream: str = "data_dist" + + +@cmd_conf +def main( + run_option: RunOptions, + table_config: EmbeddingTablesConfig, + model_config: ModelConfig, + pipeline_config: PipelineConfig, +) -> None: + tables, weighted_tables = generate_tables( + num_unweighted_features=table_config.num_unweighted_features, + num_weighted_features=table_config.num_weighted_features, + embedding_feature_dim=table_config.embedding_feature_dim, + ) + + # launch trainers + run_multi_process_func( + func=runner, + world_size=run_option.world_size, + tables=tables, + weighted_tables=weighted_tables, + run_option=run_option, + model_config=model_config, + pipeline_config=pipeline_config, + ) + + +def runner( + rank: int, + world_size: int, + tables: List[EmbeddingBagConfig], + weighted_tables: List[EmbeddingBagConfig], + run_option: RunOptions, + model_config: ModelConfig, + pipeline_config: PipelineConfig, +) -> None: + # Ensure GPUs are available and we have enough of them + assert ( + torch.cuda.is_available() and torch.cuda.device_count() >= world_size + ), "CUDA not available or insufficient GPUs for the requested world_size" + + torch.autograd.set_detect_anomaly(True) + with MultiProcessContext( + rank=rank, + world_size=world_size, + backend="nccl", + use_deterministic_algorithms=False, + ) as ctx: + unsharded_model = model_config.generate_model( + tables=tables, + weighted_tables=weighted_tables, + dense_device=ctx.device, + ) + + # Create a topology for sharding + topology = Topology( + local_world_size=get_local_size(world_size), + world_size=world_size, + compute_device=ctx.device.type, + ) + + # Create a planner for sharding based on the specified type + planner = generate_planner( + planner_type=run_option.planner_type, + topology=topology, + tables=tables, + weighted_tables=weighted_tables, + sharding_type=run_option.sharding_type, + compute_kernel=run_option.compute_kernel, + num_batches=run_option.num_batches, + batch_size=model_config.batch_size, + pooling_factors=run_option.pooling_factors, + num_poolings=run_option.num_poolings, + ) + bench_inputs = generate_data( + model_class_name=unsharded_model.__class__.__name__, + tables=tables, + weighted_tables=weighted_tables, + model_config=model_config, + num_batches=run_option.num_batches, + ) + + sharded_model, optimizer = generate_sharded_model_and_optimizer( + model=unsharded_model, + sharding_type=run_option.sharding_type.value, + kernel_type=run_option.compute_kernel.value, + # pyre-ignore + pg=ctx.pg, + device=ctx.device, + fused_params={ + "optimizer": EmbOptimType.EXACT_ADAGRAD, + "learning_rate": 0.1, + }, + planner=planner, + ) + pipeline = generate_pipeline( + pipeline_type=pipeline_config.pipeline, + emb_lookup_stream=pipeline_config.emb_lookup_stream, + model=sharded_model, + opt=optimizer, + device=ctx.device, + ) + pipeline.progress(iter(bench_inputs)) + + def _func_to_benchmark( + bench_inputs: List[ModelInput], + model: nn.Module, + pipeline: TrainPipeline, + ) -> None: + dataloader = iter(bench_inputs) + while True: + try: + pipeline.progress(dataloader) + except StopIteration: + break + + result = benchmark_func( + name=type(pipeline).__name__, + bench_inputs=bench_inputs, # pyre-ignore + prof_inputs=bench_inputs, # pyre-ignore + num_benchmarks=5, + num_profiles=2, + profile_dir=run_option.profile, + world_size=run_option.world_size, + func_to_benchmark=_func_to_benchmark, + benchmark_func_kwargs={"model": sharded_model, "pipeline": pipeline}, + rank=rank, + ) + if rank == 0: + print(result) + + +if __name__ == "__main__": + main() diff --git a/torchrec/distributed/benchmark/benchmark_train_sparsenn.py b/torchrec/distributed/benchmark/benchmark_train_sparsenn.py deleted file mode 100644 index 09fc5f665..000000000 --- a/torchrec/distributed/benchmark/benchmark_train_sparsenn.py +++ /dev/null @@ -1,504 +0,0 @@ -#!/usr/bin/env python3 -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-strict - -#!/usr/bin/env python3 - -import copy - -from dataclasses import dataclass, field -from typing import Any, cast, Dict, List, Optional, Tuple, Type, Union - -import click - -import torch -import torch.distributed as dist -from fbgemm_gpu.split_embedding_configs import EmbOptimType -from torch import nn, optim -from torch.optim import Optimizer -from torchrec.distributed import DistributedModelParallel -from torchrec.distributed.benchmark.benchmark_utils import benchmark_func, cmd_conf -from torchrec.distributed.comm import get_local_size -from torchrec.distributed.embedding_types import EmbeddingComputeKernel -from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology -from torchrec.distributed.planner.constants import NUM_POOLINGS, POOLING_FACTOR -from torchrec.distributed.planner.planners import HeteroEmbeddingShardingPlanner -from torchrec.distributed.planner.types import ParameterConstraints - -from torchrec.distributed.test_utils.multi_process import ( - MultiProcessContext, - run_multi_process_func, -) -from torchrec.distributed.test_utils.test_input import ( - ModelInput, - TestSparseNNInputConfig, -) -from torchrec.distributed.test_utils.test_model import ( - TestEBCSharder, - TestOverArchLarge, - TestSparseNN, -) -from torchrec.distributed.train_pipeline import ( - TrainPipeline, - TrainPipelineBase, - TrainPipelineFusedSparseDist, - TrainPipelineSparseDist, -) -from torchrec.distributed.train_pipeline.train_pipelines import ( - PrefetchTrainPipelineSparseDist, - TrainPipelineSemiSync, -) -from torchrec.distributed.types import ModuleSharder, ShardingEnv, ShardingType -from torchrec.modules.embedding_configs import EmbeddingBagConfig - - -@dataclass -class RunOptions: - """ - Configuration options for running sparse neural network benchmarks. - - This class defines the parameters that control how the benchmark is executed, - including distributed training settings, batch configuration, and profiling options. - - Args: - world_size (int): Number of processes/GPUs to use for distributed training. - Default is 2. - num_batches (int): Number of batches to process during the benchmark. - Default is 10. - sharding_type (ShardingType): Strategy for sharding embedding tables across devices. - Default is ShardingType.TABLE_WISE (entire tables are placed on single devices). - compute_kernel (EmbeddingComputeKernel): Compute kernel to use for embedding tables. - Default is EmbeddingComputeKernel.FUSED. - input_type (str): Type of input format to use for the model. - Default is "kjt" (KeyedJaggedTensor). - profile (str): Directory to save profiling results. If empty, profiling is disabled. - Default is "" (disabled). - planner_type (str): Type of sharding planner to use. Options are: - - "embedding": EmbeddingShardingPlanner (default) - - "hetero": HeteroEmbeddingShardingPlanner - pooling_factors (Optional[List[float]]): Pooling factors for each feature of the table. - This is the average number of values each sample has for the feature. - num_poolings (Optional[List[float]]): Number of poolings for each feature of the table. - """ - - world_size: int = 2 - num_batches: int = 10 - sharding_type: ShardingType = ShardingType.TABLE_WISE - compute_kernel: EmbeddingComputeKernel = EmbeddingComputeKernel.FUSED - input_type: str = "kjt" - profile: str = "" - planner_type: str = "embedding" - pooling_factors: Optional[List[float]] = None - num_poolings: Optional[List[float]] = None - - -@dataclass -class EmbeddingTablesConfig: - """ - Configuration for embedding tables used in sparse neural network benchmarks. - - This class defines the parameters for generating embedding tables with both weighted - and unweighted features. It provides a method to generate the actual embedding bag - configurations that can be used to create embedding tables. - - Args: - num_unweighted_features (int): Number of unweighted features to generate. - Default is 100. - num_weighted_features (int): Number of weighted features to generate. - Default is 100. - embedding_feature_dim (int): Dimension of the embedding vectors. - Default is 512. - """ - - num_unweighted_features: int = 100 - num_weighted_features: int = 100 - embedding_feature_dim: int = 128 - - def generate_tables( - self, - ) -> Tuple[ - List[EmbeddingBagConfig], - List[EmbeddingBagConfig], - ]: - """ - Generate embedding bag configurations for both unweighted and weighted features. - - This method creates two lists of EmbeddingBagConfig objects: - 1. Unweighted tables: Named as "table_{i}" with feature names "feature_{i}" - 2. Weighted tables: Named as "weighted_table_{i}" with feature names "weighted_feature_{i}" - - For both types, the number of embeddings scales with the feature index, - calculated as max(i + 1, 100) * 1000. - - Returns: - Tuple[List[EmbeddingBagConfig], List[EmbeddingBagConfig]]: A tuple containing - two lists - the first for unweighted embedding tables and the second for - weighted embedding tables. - """ - tables = [ - EmbeddingBagConfig( - num_embeddings=max(i + 1, 100) * 1000, - embedding_dim=self.embedding_feature_dim, - name="table_" + str(i), - feature_names=["feature_" + str(i)], - ) - for i in range(self.num_unweighted_features) - ] - weighted_tables = [ - EmbeddingBagConfig( - num_embeddings=max(i + 1, 100) * 1000, - embedding_dim=self.embedding_feature_dim, - name="weighted_table_" + str(i), - feature_names=["weighted_feature_" + str(i)], - ) - for i in range(self.num_weighted_features) - ] - return tables, weighted_tables - - -@dataclass -class PipelineConfig: - """ - Configuration for training pipelines used in sparse neural network benchmarks. - - This class defines the parameters for configuring the training pipeline and provides - a method to generate the appropriate pipeline instance based on the configuration. - - Args: - pipeline (str): The type of training pipeline to use. Options include: - - "base": Basic training pipeline - - "sparse": Pipeline optimized for sparse operations - - "fused": Pipeline with fused sparse distribution - - "semi": Semi-synchronous training pipeline - - "prefetch": Pipeline with prefetching for sparse distribution - Default is "base". - emb_lookup_stream (str): The stream to use for embedding lookups. - Only used by certain pipeline types (e.g., "fused"). - Default is "data_dist". - """ - - pipeline: str = "base" - emb_lookup_stream: str = "data_dist" - - def generate_pipeline( - self, model: nn.Module, opt: torch.optim.Optimizer, device: torch.device - ) -> Union[TrainPipelineBase, TrainPipelineSparseDist]: - """ - Generate a training pipeline instance based on the configuration. - - This method creates and returns the appropriate training pipeline object - based on the pipeline type specified in the configuration. Different - pipeline types are optimized for different training scenarios. - - Args: - model (nn.Module): The model to be trained. - opt (torch.optim.Optimizer): The optimizer to use for training. - device (torch.device): The device to run the training on. - - Returns: - Union[TrainPipelineBase, TrainPipelineSparseDist]: An instance of the - appropriate training pipeline class based on the configuration. - - Raises: - RuntimeError: If an unknown pipeline type is specified. - """ - _pipeline_cls: Dict[ - str, Type[Union[TrainPipelineBase, TrainPipelineSparseDist]] - ] = { - "base": TrainPipelineBase, - "sparse": TrainPipelineSparseDist, - "fused": TrainPipelineFusedSparseDist, - "semi": TrainPipelineSemiSync, - "prefetch": PrefetchTrainPipelineSparseDist, - } - - if self.pipeline == "semi": - return TrainPipelineSemiSync( - model=model, optimizer=opt, device=device, start_batch=0 - ) - elif self.pipeline == "fused": - return TrainPipelineFusedSparseDist( - model=model, - optimizer=opt, - device=device, - emb_lookup_stream=self.emb_lookup_stream, - ) - elif self.pipeline in _pipeline_cls: - Pipeline = _pipeline_cls[self.pipeline] - return Pipeline(model=model, optimizer=opt, device=device) - else: - raise RuntimeError(f"unknown pipeline option {self.pipeline}") - - -@cmd_conf -def main( - run_option: RunOptions, - table_config: EmbeddingTablesConfig, - input_config: TestSparseNNInputConfig, - pipeline_config: PipelineConfig, -) -> None: - # sparse table config is available on each trainer - tables, weighted_tables = table_config.generate_tables() - - # launch trainers - run_multi_process_func( - func=runner, - world_size=run_option.world_size, - tables=tables, - weighted_tables=weighted_tables, - run_option=run_option, - input_config=input_config, - pipeline_config=pipeline_config, - ) - - -def _generate_data( - tables: List[EmbeddingBagConfig], - weighted_tables: List[EmbeddingBagConfig], - input_config: TestSparseNNInputConfig, - num_batches: int, -) -> List[ModelInput]: - return [ - input_config.generate_model_input( - tables=tables, - weighted_tables=weighted_tables, - ) - for _ in range(num_batches) - ] - - -def _generate_model( - tables: List[EmbeddingBagConfig], - weighted_tables: List[EmbeddingBagConfig], - dense_device: torch.device, -) -> nn.Module: - return TestSparseNN( - tables=tables, - weighted_tables=weighted_tables, - dense_device=dense_device, - sparse_device=torch.device("meta"), - over_arch_clazz=TestOverArchLarge, - ) - - -def _generate_planner( - planner_type: str, - topology: Topology, - tables: Optional[List[EmbeddingBagConfig]], - weighted_tables: Optional[List[EmbeddingBagConfig]], - sharding_type: ShardingType, - compute_kernel: EmbeddingComputeKernel, - num_batches: int, - batch_size: int, - pooling_factors: Optional[List[float]], - num_poolings: Optional[List[float]], -) -> Union[EmbeddingShardingPlanner, HeteroEmbeddingShardingPlanner]: - # Create parameter constraints for tables - constraints = {} - - if pooling_factors is None: - pooling_factors = [POOLING_FACTOR] * num_batches - - if num_poolings is None: - num_poolings = [NUM_POOLINGS] * num_batches - - batch_sizes = [batch_size] * num_batches - - assert ( - len(pooling_factors) == num_batches and len(num_poolings) == num_batches - ), "The length of pooling_factors and num_poolings must match the number of batches." - - if tables is not None: - for table in tables: - constraints[table.name] = ParameterConstraints( - sharding_types=[sharding_type.value], - compute_kernels=[compute_kernel.value], - device_group="cuda", - pooling_factors=pooling_factors, - num_poolings=num_poolings, - batch_sizes=batch_sizes, - ) - - if weighted_tables is not None: - for table in weighted_tables: - constraints[table.name] = ParameterConstraints( - sharding_types=[sharding_type.value], - compute_kernels=[compute_kernel.value], - device_group="cuda", - pooling_factors=pooling_factors, - num_poolings=num_poolings, - batch_sizes=batch_sizes, - is_weighted=True, - ) - - if planner_type == "embedding": - return EmbeddingShardingPlanner( - topology=topology, - constraints=constraints if constraints else None, - ) - elif planner_type == "hetero": - topology_groups = {"cuda": topology} - return HeteroEmbeddingShardingPlanner( - topology_groups=topology_groups, - constraints=constraints if constraints else None, - ) - else: - raise RuntimeError(f"Unknown planner type: {planner_type}") - - -def _generate_sharded_model_and_optimizer( - model: nn.Module, - sharding_type: str, - kernel_type: str, - pg: dist.ProcessGroup, - device: torch.device, - fused_params: Optional[Dict[str, Any]] = None, - planner: Optional[ - Union[ - EmbeddingShardingPlanner, - HeteroEmbeddingShardingPlanner, - ] - ] = None, -) -> Tuple[nn.Module, Optimizer]: - sharder = TestEBCSharder( - sharding_type=sharding_type, - kernel_type=kernel_type, - fused_params=fused_params, - ) - - sharders = [cast(ModuleSharder[nn.Module], sharder)] - - # Use planner if provided - plan = None - if planner is not None: - if pg is not None: - plan = planner.collective_plan(model, sharders, pg) - else: - plan = planner.plan(model, sharders) - - sharded_model = DistributedModelParallel( - module=copy.deepcopy(model), - env=ShardingEnv.from_process_group(pg), - init_data_parallel=True, - device=device, - sharders=sharders, - plan=plan, - ).to(device) - optimizer = optim.SGD( - [ - param - for name, param in sharded_model.named_parameters() - if "sparse" not in name - ], - lr=0.1, - ) - return sharded_model, optimizer - - -def runner( - rank: int, - world_size: int, - tables: List[EmbeddingBagConfig], - weighted_tables: List[EmbeddingBagConfig], - run_option: RunOptions, - input_config: TestSparseNNInputConfig, - pipeline_config: PipelineConfig, -) -> None: - # Ensure GPUs are available and we have enough of them - assert ( - torch.cuda.is_available() and torch.cuda.device_count() >= world_size - ), "CUDA not available or insufficient GPUs for the requested world_size" - - torch.autograd.set_detect_anomaly(True) - with MultiProcessContext( - rank=rank, - world_size=world_size, - backend="nccl", - use_deterministic_algorithms=False, - ) as ctx: - unsharded_model = _generate_model( - tables=tables, - weighted_tables=weighted_tables, - dense_device=ctx.device, - ) - - # Create a topology for sharding - topology = Topology( - local_world_size=get_local_size(world_size), - world_size=world_size, - compute_device=ctx.device.type, - ) - - # Create a planner for sharding based on the specified type - planner = _generate_planner( - planner_type=run_option.planner_type, - topology=topology, - tables=tables, - weighted_tables=weighted_tables, - sharding_type=run_option.sharding_type, - compute_kernel=run_option.compute_kernel, - num_batches=run_option.num_batches, - batch_size=input_config.batch_size, - pooling_factors=run_option.pooling_factors, - num_poolings=run_option.num_poolings, - ) - - sharded_model, optimizer = _generate_sharded_model_and_optimizer( - model=unsharded_model, - sharding_type=run_option.sharding_type.value, - kernel_type=run_option.compute_kernel.value, - # pyre-ignore - pg=ctx.pg, - device=ctx.device, - fused_params={ - "optimizer": EmbOptimType.EXACT_ADAGRAD, - "learning_rate": 0.1, - }, - planner=planner, - ) - bench_inputs = _generate_data( - tables=tables, - weighted_tables=weighted_tables, - input_config=input_config, - num_batches=run_option.num_batches, - ) - pipeline = pipeline_config.generate_pipeline( - sharded_model, optimizer, ctx.device - ) - pipeline.progress(iter(bench_inputs)) - - def _func_to_benchmark( - bench_inputs: List[ModelInput], - model: nn.Module, - pipeline: TrainPipeline, - ) -> None: - dataloader = iter(bench_inputs) - while True: - try: - pipeline.progress(dataloader) - except StopIteration: - break - - result = benchmark_func( - name=type(pipeline).__name__, - bench_inputs=bench_inputs, # pyre-ignore - prof_inputs=bench_inputs, # pyre-ignore - num_benchmarks=5, - num_profiles=2, - profile_dir=run_option.profile, - world_size=run_option.world_size, - func_to_benchmark=_func_to_benchmark, - benchmark_func_kwargs={"model": sharded_model, "pipeline": pipeline}, - rank=rank, - ) - if rank == 0: - print(result) - - -if __name__ == "__main__": - main() From 26f60589930c417aade98294c7fe029a3d71bd8d Mon Sep 17 00:00:00 2001 From: Yernar Sadybekov Date: Thu, 12 Jun 2025 22:16:55 -0700 Subject: [PATCH 3/3] Added Optimizer configuration that supports optimizer type, learning rate, momentum, and weight decay configurations. (#3094) Summary: Pull Request resolved: https://github.com/pytorch/torchrec/pull/3094 This commit introduces enhancements to the optimizer configuration in TorchRec. It now supports specifying the optimizer type, learning rate, momentum, and weight decay. These changes provide more flexibility and control over the training process, allowing users to fine-tune their models with different optimization strategies and hyperparameters. Differential Revision: D76559261 --- .../benchmark/benchmark_pipeline_utils.py | 52 ++++++++++++++----- .../benchmark/benchmark_train_pipeline.py | 38 ++++++++++++-- 2 files changed, 74 insertions(+), 16 deletions(-) diff --git a/torchrec/distributed/benchmark/benchmark_pipeline_utils.py b/torchrec/distributed/benchmark/benchmark_pipeline_utils.py index 9ecfaa35a..acf30305d 100644 --- a/torchrec/distributed/benchmark/benchmark_pipeline_utils.py +++ b/torchrec/distributed/benchmark/benchmark_pipeline_utils.py @@ -294,7 +294,11 @@ def generate_sharded_model_and_optimizer( kernel_type: str, pg: dist.ProcessGroup, device: torch.device, - fused_params: Optional[Dict[str, Any]] = None, + fused_params: Dict[str, Any], + dense_optimizer: str = "SGD", + dense_lr: float = 0.1, + dense_momentum: Optional[float] = None, + dense_weight_decay: Optional[float] = None, planner: Optional[ Union[ EmbeddingShardingPlanner, @@ -302,13 +306,11 @@ def generate_sharded_model_and_optimizer( ] ] = None, ) -> Tuple[nn.Module, Optimizer]: - # Ensure fused_params is always a dictionary - fused_params_dict = {} if fused_params is None else fused_params sharder = TestEBCSharder( sharding_type=sharding_type, kernel_type=kernel_type, - fused_params=fused_params_dict, + fused_params=fused_params, ) sharders = [cast(ModuleSharder[nn.Module], sharder)] @@ -328,14 +330,40 @@ def generate_sharded_model_and_optimizer( sharders=sharders, plan=plan, ).to(device) - optimizer = optim.SGD( - [ - param - for name, param in sharded_model.named_parameters() - if "sparse" not in name - ], - lr=0.1, - ) + + # Get dense parameters + dense_params = [ + param + for name, param in sharded_model.named_parameters() + if "sparse" not in name + ] + + # Create optimizer based on the specified type + optimizer_classes = { + "sgd": optim.SGD, + "adam": optim.Adam, + "adagrad": optim.Adagrad, + "rmsprop": optim.RMSprop, + "adadelta": optim.Adadelta, + "adamw": optim.AdamW, + "adamax": optim.Adamax, + "nadam": optim.NAdam, + "asgd": optim.ASGD, + "lbfgs": optim.LBFGS, + } + optimizer_class = optimizer_classes.get(dense_optimizer.lower(), optim.SGD) + + # Create optimizer with momentum and/or weight_decay if provided + optimizer_kwargs = {"lr": dense_lr} + + if dense_momentum is not None: + optimizer_kwargs["momentum"] = dense_momentum + + if dense_weight_decay is not None: + optimizer_kwargs["weight_decay"] = dense_weight_decay + + optimizer = optimizer_class(dense_params, **optimizer_kwargs) # pyre-ignore[6] + return sharded_model, optimizer diff --git a/torchrec/distributed/benchmark/benchmark_train_pipeline.py b/torchrec/distributed/benchmark/benchmark_train_pipeline.py index 25a548696..7914949f9 100644 --- a/torchrec/distributed/benchmark/benchmark_train_pipeline.py +++ b/torchrec/distributed/benchmark/benchmark_train_pipeline.py @@ -65,6 +65,14 @@ class RunOptions: pooling_factors (Optional[List[float]]): Pooling factors for each feature of the table. This is the average number of values each sample has for the feature. num_poolings (Optional[List[float]]): Number of poolings for each feature of the table. + dense_optimizer (str): Optimizer to use for dense parameters. + Default is "SGD". + dense_lr (float): Learning rate for dense parameters. + Default is 0.1. + sparse_optimizer (str): Optimizer to use for sparse parameters. + Default is "EXACT_ADAGRAD". + sparse_lr (float): Learning rate for sparse parameters. + Default is 0.1. """ world_size: int = 2 @@ -76,6 +84,14 @@ class RunOptions: planner_type: str = "embedding" pooling_factors: Optional[List[float]] = None num_poolings: Optional[List[float]] = None + dense_optimizer: str = "SGD" + dense_lr: float = 0.1 + dense_momentum: Optional[float] = None + dense_weight_decay: Optional[float] = None + sparse_optimizer: str = "EXACT_ADAGRAD" + sparse_lr: float = 0.1 + sparse_momentum: Optional[float] = None + sparse_weight_decay: Optional[float] = None @dataclass @@ -204,6 +220,19 @@ def runner( num_batches=run_option.num_batches, ) + # Prepare fused_params for sparse optimizer + fused_params = { + "optimizer": getattr(EmbOptimType, run_option.sparse_optimizer.upper()), + "learning_rate": run_option.sparse_lr, + } + + # Add momentum and weight_decay to fused_params if provided + if run_option.sparse_momentum is not None: + fused_params["momentum"] = run_option.sparse_momentum + + if run_option.sparse_weight_decay is not None: + fused_params["weight_decay"] = run_option.sparse_weight_decay + sharded_model, optimizer = generate_sharded_model_and_optimizer( model=unsharded_model, sharding_type=run_option.sharding_type.value, @@ -211,10 +240,11 @@ def runner( # pyre-ignore pg=ctx.pg, device=ctx.device, - fused_params={ - "optimizer": EmbOptimType.EXACT_ADAGRAD, - "learning_rate": 0.1, - }, + fused_params=fused_params, + dense_optimizer=run_option.dense_optimizer, + dense_lr=run_option.dense_lr, + dense_momentum=run_option.dense_momentum, + dense_weight_decay=run_option.dense_weight_decay, planner=planner, ) pipeline = generate_pipeline(