|
12 | 12 | from typing import cast, List, OrderedDict, Union
|
13 | 13 |
|
14 | 14 | import torch
|
| 15 | +import torch.distributed as dist |
15 | 16 | import torch.nn as nn
|
16 | 17 | from fbgemm_gpu.split_embedding_configs import EmbOptimType
|
17 | 18 | from hypothesis import given, settings, strategies as st, Verbosity
|
|
27 | 28 | from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder
|
28 | 29 | from torchrec.distributed.model_parallel import DistributedModelParallel
|
29 | 30 | from torchrec.distributed.planner import ParameterConstraints
|
| 31 | +from torchrec.distributed.test_utils.test_model import TestSparseNN |
30 | 32 | from torchrec.distributed.test_utils.test_model_parallel_base import (
|
31 | 33 | ModelParallelSingleRankBase,
|
32 | 34 | )
|
33 | 35 | from torchrec.distributed.test_utils.test_sharding import (
|
34 | 36 | copy_state_dict,
|
35 | 37 | create_test_sharder,
|
36 | 38 | SharderType,
|
| 39 | + sharding_single_rank_test_single_process, |
37 | 40 | )
|
38 | 41 | from torchrec.distributed.tests.test_sequence_model import (
|
39 | 42 | TestEmbeddingCollectionSharder,
|
|
45 | 48 | EmbeddingBagConfig,
|
46 | 49 | EmbeddingConfig,
|
47 | 50 | )
|
| 51 | +from torchrec.optim import RowWiseAdagrad |
48 | 52 |
|
49 | 53 |
|
50 | 54 | def _load_split_embedding_weights(
|
@@ -540,6 +544,72 @@ def test_ssd_mixed_kernels(
|
540 | 544 | self._eval_models(m1, m2, batch, is_deterministic=is_deterministic)
|
541 | 545 | self._compare_models(m1, m2, is_deterministic=is_deterministic)
|
542 | 546 |
|
| 547 | + @unittest.skipIf( |
| 548 | + not torch.cuda.is_available(), |
| 549 | + "Not enough GPUs, this test requires at least one GPU", |
| 550 | + ) |
| 551 | + # pyre-fixme[56] |
| 552 | + @given( |
| 553 | + sharding_type=st.sampled_from( |
| 554 | + [ |
| 555 | + ShardingType.TABLE_WISE.value, |
| 556 | + # TODO: uncomment when ssd ckpt support cw sharding |
| 557 | + # ShardingType.COLUMN_WISE.value, |
| 558 | + ShardingType.ROW_WISE.value, |
| 559 | + ShardingType.TABLE_ROW_WISE.value, |
| 560 | + # TODO: uncomment when ssd ckpt support cw sharding |
| 561 | + # ShardingType.TABLE_COLUMN_WISE.value, |
| 562 | + ] |
| 563 | + ), |
| 564 | + dtype=st.sampled_from([DataType.FP32, DataType.FP16]), |
| 565 | + ) |
| 566 | + @settings(verbosity=Verbosity.verbose, max_examples=6, deadline=None) |
| 567 | + def test_ssd_mixed_kernels_with_vbe( |
| 568 | + self, |
| 569 | + sharding_type: str, |
| 570 | + dtype: DataType, |
| 571 | + ) -> None: |
| 572 | + self._set_table_weights_precision(dtype) |
| 573 | + fused_params = { |
| 574 | + "prefetch_pipeline": True, |
| 575 | + } |
| 576 | + constraints = { |
| 577 | + table.name: ParameterConstraints( |
| 578 | + min_partition=4, |
| 579 | + compute_kernels=( |
| 580 | + [EmbeddingComputeKernel.FUSED.value] |
| 581 | + if i % 2 == 0 |
| 582 | + else [EmbeddingComputeKernel.KEY_VALUE.value] |
| 583 | + ), |
| 584 | + sharding_types=[sharding_type], |
| 585 | + ) |
| 586 | + for i, table in enumerate(self.tables) |
| 587 | + } |
| 588 | + optimizer_config = (RowWiseAdagrad, {"lr": 0.001, "eps": 0.001}) |
| 589 | + pg = dist.GroupMember.WORLD |
| 590 | + |
| 591 | + assert pg is not None, "Process group is not initialized" |
| 592 | + sharding_single_rank_test_single_process( |
| 593 | + pg=pg, |
| 594 | + device=self.device, |
| 595 | + rank=0, |
| 596 | + world_size=1, |
| 597 | + # pyre-fixme[6]: The intake type should be `type[TestSparseNNBase]` |
| 598 | + model_class=TestSparseNN, |
| 599 | + embedding_groups={}, |
| 600 | + tables=self.tables, |
| 601 | + # pyre-fixme[6] |
| 602 | + sharders=[EmbeddingBagCollectionSharder(fused_params=fused_params)], |
| 603 | + optim=EmbOptimType.EXACT_SGD, |
| 604 | + # The optimizer config here will overwrite the SGD optimizer above |
| 605 | + apply_optimizer_in_backward_config={ |
| 606 | + "embedding_bags": optimizer_config, |
| 607 | + "embeddings": optimizer_config, |
| 608 | + }, |
| 609 | + constraints=constraints, |
| 610 | + variable_batch_per_feature=True, |
| 611 | + ) |
| 612 | + |
543 | 613 | @unittest.skipIf(
|
544 | 614 | not torch.cuda.is_available(),
|
545 | 615 | "Not enough GPUs, this test requires at least one GPU",
|
|
0 commit comments