40
40
],
41
41
)
42
42
@pytest .mark .parametrize ("compile" , [False , True ])
43
- def test_moe_float8_training (target_fqns : list [str ], compile : bool ):
44
- # Set token group alignment size to 16. This is required so that
45
- # each logically distinct gemm in the grouped gemm `grad_weight = grad_output_t @ input`
46
- # has the contraction dim be divisible by 16. 16 byte alignment is required
47
- # for the slowest moving dim (stride 1), so 16 bytes / 1 byte per element in fp8 = 16 elements.
48
- set_token_group_alignment_size_m (16 )
49
- model_args = MoEArgs (
50
- num_experts = 8 ,
51
- )
52
- init_std = 0.02
53
- device = torch .device ("cuda" )
54
-
55
- # reference bf16 MoE
56
- dim , hidden_dim = 5120 , 8192
57
- ref_model = MoE (model_args , dim , hidden_dim ).to (torch .bfloat16 ).cuda ()
58
- torch .manual_seed (42 )
59
- ref_model .init_weights (init_std , device )
60
-
61
- # target MoE for testing conversion
62
- model = copy .deepcopy (ref_model )
63
-
64
- # assert starting params are identical for both models
65
- for param1 , param2 in zip (model .parameters (), ref_model .parameters ()):
66
- assert torch .equal (param1 , param2 )
67
-
68
- # convert MoE to float8 training
69
- def moe_module_filter_fn (mod : nn .Module , cur_fqn : str ) -> bool :
70
- for target_fqn in target_fqns :
71
- if target_fqn in cur_fqn :
72
- return True
73
- return False
74
-
75
- # quantize test model
76
- config = MoETrainingConfig ()
77
- quantize_ (model , config = config , filter_fn = moe_module_filter_fn )
78
-
79
- # validate that only the experts were converted
80
- _validate_model_conversion (
81
- model ,
82
- target_fqns = target_fqns ,
83
- )
84
- if compile :
85
- # TODO: compile with fullgraph=True when torchtitan llama4 moe supports it
86
- model = torch .compile (model , fullgraph = False )
87
- ref_model = torch .compile (ref_model , fullgraph = False )
88
-
89
- # inputs
90
- batch , seq = 8 , 2048
91
- ref_x = torch .randn (
92
- batch , seq , dim , dtype = torch .bfloat16 , requires_grad = True , device = device
93
- )
94
- x = ref_x .detach ().clone ().requires_grad_ (True )
95
-
96
- # forward pass
97
- ref_out = ref_model (ref_x )
98
- out = model (x )
99
-
100
- # validate output
101
- out_sqnr = compute_error (out , ref_out )
102
- min_out_sqnr = 29.0
103
- assert out_sqnr .item () >= min_out_sqnr , (
104
- f"SQNR must be >= { min_out_sqnr } , got { out_sqnr .item ()} ."
105
- )
106
-
107
- # compute loss
108
- labels = torch .ones_like (ref_out )
109
- ref_loss = F .mse_loss (ref_out , labels )
110
- out_loss = F .mse_loss (out , labels )
111
-
112
- # backward pass
113
- ref_loss .backward ()
114
- out_loss .backward ()
115
-
116
- # validate input gradient
117
- input_grad_sqnr = compute_error (x .grad , ref_x .grad )
118
- min_input_grad_sqnr = 29.0
119
- assert input_grad_sqnr .item () >= min_input_grad_sqnr , (
120
- f"SQNR must be >= { min_input_grad_sqnr } , got { input_grad_sqnr .item ()} ."
121
- )
122
-
123
- # validate param gradients
124
- min_param_grad_sqnr = 23.0
125
- for param1 , param2 in zip (model .parameters (), ref_model .parameters ()):
126
- param_grad_sqnr = compute_error (param1 .grad , param2 .grad )
127
- assert param_grad_sqnr .item () >= min_param_grad_sqnr , (
128
- f"SQNR must be >= { min_param_grad_sqnr } , got { param_grad_sqnr .item ()} ."
129
- )
130
-
131
-
132
43
@pytest .mark .parametrize (
133
- "target_fqns " ,
44
+ "recipe_config " ,
134
45
[
135
- ["experts" ],
136
- ["does.not.exist" ],
46
+ # {"recipe": MoEScalingType.FP8_ROWWISE, "group_alignment_size": 16, "min_out_sqnr": 29.0, "min_input_grad_sqnr": 29.0, "min_param_grad_sqnr": 23.0},
47
+ {
48
+ "recipe" : MoEScalingType .MXFP8 ,
49
+ "group_alignment_size" : 32 ,
50
+ "min_out_sqnr" : 28.0 ,
51
+ "min_input_grad_sqnr" : 29.0 ,
52
+ "min_param_grad_sqnr" : 21.0 ,
53
+ },
137
54
],
138
55
)
139
- @pytest .mark .parametrize ("compile" , [False , True ])
140
- def test_moe_mxfp8_training (target_fqns : list [str ], compile : bool ):
141
- block_size = 32
142
-
143
- # Token groups must be divisible by 32 for mxfp8
144
- set_token_group_alignment_size_m (block_size )
145
-
56
+ def test_moe_training (target_fqns : list [str ], compile : bool , recipe_config : dict ):
57
+ (
58
+ recipe ,
59
+ group_alignment_size ,
60
+ min_out_sqnr ,
61
+ min_input_grad_sqnr ,
62
+ min_param_grad_sqnr ,
63
+ ) = (
64
+ recipe_config ["recipe" ],
65
+ recipe_config ["group_alignment_size" ],
66
+ recipe_config ["min_out_sqnr" ],
67
+ recipe_config ["min_input_grad_sqnr" ],
68
+ recipe_config ["min_param_grad_sqnr" ],
69
+ )
70
+ # Set token group alignment size. This is required so that
71
+ # each logically distinct gemm in the grouped gemm `grad_weight = grad_output_t @ input`
72
+ # has the contraction dim be divisible by 16. 16 byte alignment is required
73
+ # for the slowest moving dim (stride 1).
74
+ set_token_group_alignment_size_m (group_alignment_size )
146
75
model_args = MoEArgs (
147
76
num_experts = 8 ,
148
77
)
@@ -170,15 +99,14 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:
170
99
return False
171
100
172
101
# quantize test model
173
- config = MoETrainingConfig (scaling_type = MoEScalingType . MXFP8 )
102
+ config = MoETrainingConfig (scaling_type = recipe )
174
103
quantize_ (model , config = config , filter_fn = moe_module_filter_fn )
175
104
176
105
# validate that only the experts were converted
177
106
_validate_model_conversion (
178
107
model ,
179
108
target_fqns = target_fqns ,
180
109
)
181
-
182
110
if compile :
183
111
# TODO: compile with fullgraph=True when torchtitan llama4 moe supports it
184
112
model = torch .compile (model , fullgraph = False )
@@ -197,7 +125,6 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:
197
125
198
126
# validate output
199
127
out_sqnr = compute_error (out , ref_out )
200
- min_out_sqnr = 28.0
201
128
assert out_sqnr .item () >= min_out_sqnr , (
202
129
f"SQNR must be >= { min_out_sqnr } , got { out_sqnr .item ()} ."
203
130
)
@@ -213,13 +140,11 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:
213
140
214
141
# validate input gradient
215
142
input_grad_sqnr = compute_error (x .grad , ref_x .grad )
216
- min_input_grad_sqnr = 30.0
217
143
assert input_grad_sqnr .item () >= min_input_grad_sqnr , (
218
144
f"SQNR must be >= { min_input_grad_sqnr } , got { input_grad_sqnr .item ()} ."
219
145
)
220
146
221
147
# validate param gradients
222
- min_param_grad_sqnr = 21.0
223
148
for param1 , param2 in zip (model .parameters (), ref_model .parameters ()):
224
149
param_grad_sqnr = compute_error (param1 .grad , param2 .grad )
225
150
assert param_grad_sqnr .item () >= min_param_grad_sqnr , (
0 commit comments