Skip to content

Commit 62c16b6

Browse files
authored
fix: skip weights defined in create_weights for pp. (#4447)
Signed-off-by: Yuxian Qiu <[email protected]>
1 parent a030a89 commit 62c16b6

File tree

1 file changed

+12
-0
lines changed

1 file changed

+12
-0
lines changed

tensorrt_llm/_torch/models/modeling_utils.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,9 +152,15 @@ def skip_forward(
152152
if hasattr(module, 'skip_forward'):
153153
module.forward = module.skip_forward
154154
remove_weights(module, ignore_modules)
155+
else:
156+
logger.warning(
157+
f"Fail to skip forward since {module.__class__.__name__} "
158+
f"does not have `skip_forward`.")
155159

156160

157161
def forward_after_recv(forward_fn):
162+
if hasattr(forward_fn, "__wrapped_by_forward_after_recv__"):
163+
return forward_fn
158164

159165
def forward_after_recv_fn(
160166
position_ids,
@@ -176,10 +182,13 @@ def forward_after_recv_fn(
176182
**kwargs,
177183
)
178184

185+
forward_after_recv_fn.__wrapped_by_forward_after_recv__ = True
179186
return forward_after_recv_fn
180187

181188

182189
def forward_before_send(forward_fn):
190+
if hasattr(forward_fn, "__wrapped_by_forward_before_send__"):
191+
return forward_fn
183192

184193
def forward_before_send_fn(
185194
position_ids,
@@ -204,6 +213,7 @@ def forward_before_send_fn(
204213
pp_send(hidden_states)
205214
return output
206215

216+
forward_before_send_fn.__wrapped_by_forward_before_send__ = True
207217
return forward_before_send_fn
208218

209219

@@ -411,6 +421,8 @@ def __pp_init__(self):
411421
for module in self.epilogue:
412422
skip_forward(module)
413423

424+
self.model.__pp_init__()
425+
414426
def __post_init__(self):
415427
# 1. mixed precision
416428
quant_config_dict = self.model_config.quant_config_dict

0 commit comments

Comments
 (0)