@@ -152,9 +152,15 @@ def skip_forward(
152
152
if hasattr (module , 'skip_forward' ):
153
153
module .forward = module .skip_forward
154
154
remove_weights (module , ignore_modules )
155
+ else :
156
+ logger .warning (
157
+ f"Fail to skip forward since { module .__class__ .__name__ } "
158
+ f"does not have `skip_forward`." )
155
159
156
160
157
161
def forward_after_recv (forward_fn ):
162
+ if hasattr (forward_fn , "__wrapped_by_forward_after_recv__" ):
163
+ return forward_fn
158
164
159
165
def forward_after_recv_fn (
160
166
position_ids ,
@@ -176,10 +182,13 @@ def forward_after_recv_fn(
176
182
** kwargs ,
177
183
)
178
184
185
+ forward_after_recv_fn .__wrapped_by_forward_after_recv__ = True
179
186
return forward_after_recv_fn
180
187
181
188
182
189
def forward_before_send (forward_fn ):
190
+ if hasattr (forward_fn , "__wrapped_by_forward_before_send__" ):
191
+ return forward_fn
183
192
184
193
def forward_before_send_fn (
185
194
position_ids ,
@@ -204,6 +213,7 @@ def forward_before_send_fn(
204
213
pp_send (hidden_states )
205
214
return output
206
215
216
+ forward_before_send_fn .__wrapped_by_forward_before_send__ = True
207
217
return forward_before_send_fn
208
218
209
219
@@ -411,6 +421,8 @@ def __pp_init__(self):
411
421
for module in self .epilogue :
412
422
skip_forward (module )
413
423
424
+ self .model .__pp_init__ ()
425
+
414
426
def __post_init__ (self ):
415
427
# 1. mixed precision
416
428
quant_config_dict = self .model_config .quant_config_dict
0 commit comments