Skip to content

Commit fbfbcff

Browse files
Shuangping Liufacebook-github-bot
authored andcommitted
Add unit test for SSD TBE with VBE input (#3086)
Summary: Pull Request resolved: #3086 Add a new unit test in [`test_model_parallel_nccl_ssd_single_gpu.py`](https://www.internalfb.com/code/fbsource/[5f477259031a]/fbcode/torchrec/distributed/tests/test_model_parallel_nccl_ssd_single_gpu.py) for SSD TBE with VBE input. ### Context * This test is a prerequisite to test out the incoming FBGEMM & TorchRec changes to merge VBE output. * For SSD TBE, the tensor wrapped in a shard is a [`PartiallyMaterializedTensor`](https://www.internalfb.com/code/fbsource/fbcode/deeplearning/fbgemm/fbgemm_gpu/fbgemm_gpu/tbe/ssd/utils/partially_materialized_tensor.py) (PMT) which requires special handling when copying state dict from an unsharded tensor. Specifically: - It misses certain methods like `ndim`. - `copy_` method is a no-op. Writing should be done through the [wrapped C++ object](https://www.internalfb.com/code/fbsource/fbcode/deeplearning/fbgemm/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_table_batched_embeddings.cpp?lines=417) of PMT. - Only `ROW_WISE`, `TABLE_WISE` and `TABLE_ROW_WISE` sharding types are supported. NOTE: SSD TBE only support `RowWiseAdagrad` optimizer. For **FP16**, The learning rate and eps need to be carefully selected for avoid numerical instabilities for the unsharded model. Here we use `lr = 0.001` and `eps = 0.001` to pass the test. Reviewed By: TroyGarden Differential Revision: D76455104 fbshipit-source-id: 729fb91dae34f353eb6908e5ad78d401c6549d90
1 parent ddbe9e0 commit fbfbcff

File tree

2 files changed

+86
-3
lines changed

2 files changed

+86
-3
lines changed

torchrec/distributed/test_utils/test_sharding.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@
1616
import torch.distributed as dist
1717
import torch.nn as nn
1818
from fbgemm_gpu.split_embedding_configs import EmbOptimType
19+
from fbgemm_gpu.tbe.ssd.utils.partially_materialized_tensor import (
20+
PartiallyMaterializedTensor,
21+
)
1922
from torch.distributed._tensor import DeviceMesh, DTensor
2023
from torch.distributed.optim import (
2124
_apply_optimizer_in_backward as apply_optimizer_in_backward,
@@ -610,12 +613,16 @@ def copy_state_dict(
610613

611614
if isinstance(tensor, ShardedTensor):
612615
for local_shard in tensor.local_shards():
616+
# Tensors like `PartiallyMaterializedTensor` do not provide
617+
# `ndim` property, so use shape length here as a workaround
618+
ndim = len(local_shard.tensor.shape)
613619
assert (
614-
global_tensor.ndim == local_shard.tensor.ndim
615-
), f"global_tensor.ndim: {global_tensor.ndim}, local_shard.tensor.ndim: {local_shard.tensor.ndim}"
620+
global_tensor.ndim == ndim
621+
), f"global_tensor.ndim: {global_tensor.ndim}, local_shard.tensor.ndim: {ndim}"
616622
assert (
617623
global_tensor.dtype == local_shard.tensor.dtype
618624
), f"global tensor dtype: {global_tensor.dtype}, local tensor dtype: {local_shard.tensor.dtype}"
625+
619626
shard_meta = local_shard.metadata
620627
t = global_tensor.detach()
621628
if t.ndim == 1:
@@ -632,7 +639,13 @@ def copy_state_dict(
632639
]
633640
else:
634641
raise ValueError("Tensors with ndim > 2 are not supported")
635-
local_shard.tensor.copy_(t)
642+
643+
if isinstance(local_shard.tensor, PartiallyMaterializedTensor):
644+
local_shard.tensor.wrapped.set_range(
645+
0, 0, t.size(0), t.to(device="cpu")
646+
)
647+
else:
648+
local_shard.tensor.copy_(t)
636649
elif isinstance(tensor, DTensor):
637650
for local_shard, global_offset in zip(
638651
tensor.to_local().local_shards(),

torchrec/distributed/tests/test_model_parallel_nccl_ssd_single_gpu.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from typing import cast, List, OrderedDict, Union
1313

1414
import torch
15+
import torch.distributed as dist
1516
import torch.nn as nn
1617
from fbgemm_gpu.split_embedding_configs import EmbOptimType
1718
from hypothesis import given, settings, strategies as st, Verbosity
@@ -27,13 +28,15 @@
2728
from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder
2829
from torchrec.distributed.model_parallel import DistributedModelParallel
2930
from torchrec.distributed.planner import ParameterConstraints
31+
from torchrec.distributed.test_utils.test_model import TestSparseNN
3032
from torchrec.distributed.test_utils.test_model_parallel_base import (
3133
ModelParallelSingleRankBase,
3234
)
3335
from torchrec.distributed.test_utils.test_sharding import (
3436
copy_state_dict,
3537
create_test_sharder,
3638
SharderType,
39+
sharding_single_rank_test_single_process,
3740
)
3841
from torchrec.distributed.tests.test_sequence_model import (
3942
TestEmbeddingCollectionSharder,
@@ -45,6 +48,7 @@
4548
EmbeddingBagConfig,
4649
EmbeddingConfig,
4750
)
51+
from torchrec.optim import RowWiseAdagrad
4852

4953

5054
def _load_split_embedding_weights(
@@ -540,6 +544,72 @@ def test_ssd_mixed_kernels(
540544
self._eval_models(m1, m2, batch, is_deterministic=is_deterministic)
541545
self._compare_models(m1, m2, is_deterministic=is_deterministic)
542546

547+
@unittest.skipIf(
548+
not torch.cuda.is_available(),
549+
"Not enough GPUs, this test requires at least one GPU",
550+
)
551+
# pyre-fixme[56]
552+
@given(
553+
sharding_type=st.sampled_from(
554+
[
555+
ShardingType.TABLE_WISE.value,
556+
# TODO: uncomment when ssd ckpt support cw sharding
557+
# ShardingType.COLUMN_WISE.value,
558+
ShardingType.ROW_WISE.value,
559+
ShardingType.TABLE_ROW_WISE.value,
560+
# TODO: uncomment when ssd ckpt support cw sharding
561+
# ShardingType.TABLE_COLUMN_WISE.value,
562+
]
563+
),
564+
dtype=st.sampled_from([DataType.FP32, DataType.FP16]),
565+
)
566+
@settings(verbosity=Verbosity.verbose, max_examples=6, deadline=None)
567+
def test_ssd_mixed_kernels_with_vbe(
568+
self,
569+
sharding_type: str,
570+
dtype: DataType,
571+
) -> None:
572+
self._set_table_weights_precision(dtype)
573+
fused_params = {
574+
"prefetch_pipeline": True,
575+
}
576+
constraints = {
577+
table.name: ParameterConstraints(
578+
min_partition=4,
579+
compute_kernels=(
580+
[EmbeddingComputeKernel.FUSED.value]
581+
if i % 2 == 0
582+
else [EmbeddingComputeKernel.KEY_VALUE.value]
583+
),
584+
sharding_types=[sharding_type],
585+
)
586+
for i, table in enumerate(self.tables)
587+
}
588+
optimizer_config = (RowWiseAdagrad, {"lr": 0.001, "eps": 0.001})
589+
pg = dist.GroupMember.WORLD
590+
591+
assert pg is not None, "Process group is not initialized"
592+
sharding_single_rank_test_single_process(
593+
pg=pg,
594+
device=self.device,
595+
rank=0,
596+
world_size=1,
597+
# pyre-fixme[6]: The intake type should be `type[TestSparseNNBase]`
598+
model_class=TestSparseNN,
599+
embedding_groups={},
600+
tables=self.tables,
601+
# pyre-fixme[6]
602+
sharders=[EmbeddingBagCollectionSharder(fused_params=fused_params)],
603+
optim=EmbOptimType.EXACT_SGD,
604+
# The optimizer config here will overwrite the SGD optimizer above
605+
apply_optimizer_in_backward_config={
606+
"embedding_bags": optimizer_config,
607+
"embeddings": optimizer_config,
608+
},
609+
constraints=constraints,
610+
variable_batch_per_feature=True,
611+
)
612+
543613
@unittest.skipIf(
544614
not torch.cuda.is_available(),
545615
"Not enough GPUs, this test requires at least one GPU",

0 commit comments

Comments
 (0)