Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 41 additions & 4 deletions test/prototype/moe_training/test_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
######################################################################
#
# To run these unit tests, use the following command:
#
# torchrun --nproc_per_node=${NUM_GPUS} -m pytest test_fsdp.py
# cd ao/test/prototype/moe_training
# torchrun --nproc_per_node=8 test_fsdp.py
#
#######################################################################

Expand All @@ -31,7 +31,40 @@
from torchao.prototype.moe_training.conversion_utils import MoETrainingConfig
from torchao.quantization.quant_api import quantize_

from .testing_utils import _validate_model_conversion
from torch import nn

from torchao.prototype.moe_training.tensor import ScaledGroupedMMTensor


def _validate_model_conversion(
root_module: nn.Module,
target_fqns: list[str],
):
def _recursive_validate(
module: nn.Module,
cur_fqn: str,
):
is_allowed_module = any([target_fqn in cur_fqn for target_fqn in target_fqns])

# check current module params
for param_name, param in module.named_parameters(recurse=False):
is_converted_type = isinstance(param, ScaledGroupedMMTensor)
if is_converted_type:
assert is_allowed_module, (
f"Module {cur_fqn} is not in target_fqns, but has converted param {param_name}."
)
if not is_allowed_module:
assert not is_converted_type, (
f"Module {cur_fqn} is not in target_fqns, but has converted param {param_name}."
)

# recursively check child modules
for child_name, child_module in module.named_children():
child_fqn = f"{cur_fqn}.{child_name}" if cur_fqn else child_name
_recursive_validate(child_module, child_fqn)

_recursive_validate(root_module, "")


# this test requires torchtitan
try:
Expand All @@ -55,7 +88,7 @@ def test_moe_float8_training_fsdp():
set_token_group_alignment_size_m(16)

# define model args
target_fqns = ["experts"]
target_fqns = ["experts", "shared_expert"]
model_args = MoEArgs(
num_experts=8,
)
Expand Down Expand Up @@ -146,3 +179,7 @@ def setup_distributed():
world_size = int(os.environ["WORLD_SIZE"])
dist.init_process_group("nccl", rank=rank, world_size=world_size)
torch.cuda.set_device(rank)


if __name__ == "__main__":
test_moe_float8_training_fsdp()
Loading