-
Notifications
You must be signed in to change notification settings - Fork 6.1k
[fix] compile for groupoffloading #11960
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[fix] compile for groupoffloading #11960
Conversation
Hi @seed93, thanks for the PR! Could you share a minimal reproducible code snippet so I can verify the changes on my end? |
This could help to reproduce. Also, if I remove |
I investigated compilation a bit. There are some recompilations though of two types:
output with regional fullgraph=False compilationIteration 1
V0724 21:56:11.637000 2089972 torch/_dynamo/guards.py:3006] [0/1] [__recompiles] Recompiling function new_forward in /home/aryan/work/diffusers/src/diffusers/hooks/hooks.py:187
V0724 21:56:11.637000 2089972 torch/_dynamo/guards.py:3006] [0/1] [__recompiles] triggered by the following guard failure(s):
V0724 21:56:11.637000 2089972 torch/_dynamo/guards.py:3006] [0/1] [__recompiles] - 0/0: ___check_type_id(function_reference.forward, 94573921257568)
V0724 21:56:14.094000 2089972 torch/_dynamo/guards.py:3006] [3/1] [__recompiles] Recompiling function torch_dynamo_resume_in_new_forward_at_188 in /home/aryan/work/diffusers/src/diffusers/hooks/hooks.py:188
V0724 21:56:14.094000 2089972 torch/_dynamo/guards.py:3006] [3/1] [__recompiles] triggered by the following guard failure(s):
V0724 21:56:14.094000 2089972 torch/_dynamo/guards.py:3006] [3/1] [__recompiles] - 3/0: ___check_obj_id(function_reference.forward, 140026546073344)
V0724 21:56:14.114000 2089972 torch/_dynamo/guards.py:3006] [3/2] [__recompiles] Recompiling function torch_dynamo_resume_in_new_forward_at_188 in /home/aryan/work/diffusers/src/diffusers/hooks/hooks.py:188
V0724 21:56:14.114000 2089972 torch/_dynamo/guards.py:3006] [3/2] [__recompiles] triggered by the following guard failure(s):
V0724 21:56:14.114000 2089972 torch/_dynamo/guards.py:3006] [3/2] [__recompiles] - 3/1: ___check_type_id(function_reference.forward, 94573921257568)
V0724 21:56:14.114000 2089972 torch/_dynamo/guards.py:3006] [3/2] [__recompiles] - 3/0: ___check_obj_id(function_reference.forward, 140026546073344)
Iteration 2
V0724 21:56:14.390000 2089972 torch/_dynamo/guards.py:3006] [1/1] [__recompiles] Recompiling function pre_forward in /home/aryan/work/diffusers/src/diffusers/hooks/group_offloading.py:328
V0724 21:56:14.390000 2089972 torch/_dynamo/guards.py:3006] [1/1] [__recompiles] triggered by the following guard failure(s):
V0724 21:56:14.390000 2089972 torch/_dynamo/guards.py:3006] [1/1] [__recompiles] - 1/0: ___check_obj_id(self.group.onload_self, 140037214748608)
V0724 21:56:14.690000 2089972 torch/_dynamo/guards.py:3006] [1/2] [__recompiles] Recompiling function pre_forward in /home/aryan/work/diffusers/src/diffusers/hooks/group_offloading.py:328
V0724 21:56:14.690000 2089972 torch/_dynamo/guards.py:3006] [1/2] [__recompiles] triggered by the following guard failure(s):
V0724 21:56:14.690000 2089972 torch/_dynamo/guards.py:3006] [1/2] [__recompiles] - 1/1: ___check_type_id(self.next_group, 94574045200112)
V0724 21:56:14.690000 2089972 torch/_dynamo/guards.py:3006] [1/2] [__recompiles] - 1/0: ___check_obj_id(self.group.onload_self, 140037214748608)
Iteration 3
Iteration 4
Iteration 5
Iteration 6
Iteration 7
Iteration 8
Iteration 9
Iteration 10 output with full compilation fullgraph=False(nightly-venv) aryan@hf-dgx-01:~/work/diffusers$ TORCH_LOGS="recompiles" CUDA_VISIBLE_DEVICES=3 python3 dump12.py
Iteration 1
V0724 21:59:57.527000 2094194 torch/_dynamo/guards.py:3006] [0/1] [__recompiles] Recompiling function new_forward in /home/aryan/work/diffusers/src/diffusers/hooks/hooks.py:187
V0724 21:59:57.527000 2094194 torch/_dynamo/guards.py:3006] [0/1] [__recompiles] triggered by the following guard failure(s):
V0724 21:59:57.527000 2094194 torch/_dynamo/guards.py:3006] [0/1] [__recompiles] - 0/0: ___check_type_id(function_reference.forward, 94565963950176)
V0724 21:59:59.011000 2094194 torch/_dynamo/guards.py:3006] [0/2] [__recompiles] Recompiling function new_forward in /home/aryan/work/diffusers/src/diffusers/hooks/hooks.py:187
V0724 21:59:59.011000 2094194 torch/_dynamo/guards.py:3006] [0/2] [__recompiles] triggered by the following guard failure(s):
V0724 21:59:59.011000 2094194 torch/_dynamo/guards.py:3006] [0/2] [__recompiles] - 0/1: ___check_type_id(module, 94566115730144)
V0724 21:59:59.011000 2094194 torch/_dynamo/guards.py:3006] [0/2] [__recompiles] - 0/0: ___check_type_id(function_reference.forward.args[0], 94566115730144)
V0724 21:59:59.040000 2094194 torch/_dynamo/guards.py:3006] [0/3] [__recompiles] Recompiling function new_forward in /home/aryan/work/diffusers/src/diffusers/hooks/hooks.py:187
V0724 21:59:59.040000 2094194 torch/_dynamo/guards.py:3006] [0/3] [__recompiles] triggered by the following guard failure(s):
V0724 21:59:59.040000 2094194 torch/_dynamo/guards.py:3006] [0/3] [__recompiles] - 0/2: ___check_type_id(function_reference.forward, 94565963950176)
V0724 21:59:59.040000 2094194 torch/_dynamo/guards.py:3006] [0/3] [__recompiles] - 0/1: ___check_type_id(module, 94566115730144)
V0724 21:59:59.040000 2094194 torch/_dynamo/guards.py:3006] [0/3] [__recompiles] - 0/0: ___check_type_id(function_reference.forward, 94565963950176)
V0724 21:59:59.053000 2094194 torch/_dynamo/guards.py:3006] [1/1] [__recompiles] Recompiling function pre_forward in /home/aryan/work/diffusers/src/diffusers/hooks/group_offloading.py:328
V0724 21:59:59.053000 2094194 torch/_dynamo/guards.py:3006] [1/1] [__recompiles] triggered by the following guard failure(s):
V0724 21:59:59.053000 2094194 torch/_dynamo/guards.py:3006] [1/1] [__recompiles] - 1/0: ___check_type_id(self.group.onload_leader, 94566115730144)
V0724 21:59:59.070000 2094194 torch/_dynamo/guards.py:3006] [3/1] [__recompiles] Recompiling function torch_dynamo_resume_in_new_forward_at_188 in /home/aryan/work/diffusers/src/diffusers/hooks/hooks.py:188
V0724 21:59:59.070000 2094194 torch/_dynamo/guards.py:3006] [3/1] [__recompiles] triggered by the following guard failure(s):
V0724 21:59:59.070000 2094194 torch/_dynamo/guards.py:3006] [3/1] [__recompiles] - 3/0: ___check_obj_id(function_reference.forward, 140015984505344)
V0724 22:00:00.955000 2094194 torch/_dynamo/guards.py:3006] [3/2] [__recompiles] Recompiling function torch_dynamo_resume_in_new_forward_at_188 in /home/aryan/work/diffusers/src/diffusers/hooks/hooks.py:188
V0724 22:00:00.955000 2094194 torch/_dynamo/guards.py:3006] [3/2] [__recompiles] triggered by the following guard failure(s):
V0724 22:00:00.955000 2094194 torch/_dynamo/guards.py:3006] [3/2] [__recompiles] - 3/1: ___check_obj_id(function_reference.forward, 140015984503936)
V0724 22:00:00.955000 2094194 torch/_dynamo/guards.py:3006] [3/2] [__recompiles] - 3/0: ___check_obj_id(function_reference.forward, 140015984505344)
V0724 22:00:00.975000 2094194 torch/_dynamo/guards.py:3006] [3/3] [__recompiles] Recompiling function torch_dynamo_resume_in_new_forward_at_188 in /home/aryan/work/diffusers/src/diffusers/hooks/hooks.py:188
V0724 22:00:00.975000 2094194 torch/_dynamo/guards.py:3006] [3/3] [__recompiles] triggered by the following guard failure(s):
V0724 22:00:00.975000 2094194 torch/_dynamo/guards.py:3006] [3/3] [__recompiles] - 3/2: ___check_type_id(function_reference.forward, 94565963950176)
V0724 22:00:00.975000 2094194 torch/_dynamo/guards.py:3006] [3/3] [__recompiles] - 3/1: ___check_obj_id(function_reference.forward, 140015984503936)
V0724 22:00:00.975000 2094194 torch/_dynamo/guards.py:3006] [3/3] [__recompiles] - 3/0: ___check_obj_id(function_reference.forward, 140015984505344)
V0724 22:00:01.247000 2094194 torch/_dynamo/guards.py:3006] [5/1] [__recompiles] Recompiling function post_forward in /home/aryan/work/diffusers/src/diffusers/hooks/group_offloading.py:347
V0724 22:00:01.247000 2094194 torch/_dynamo/guards.py:3006] [5/1] [__recompiles] triggered by the following guard failure(s):
V0724 22:00:01.247000 2094194 torch/_dynamo/guards.py:3006] [5/1] [__recompiles] - 5/0: ___check_type_id(self.group.offload_leader, 94566115729184)
Iteration 2
V0724 22:00:01.330000 2094194 torch/_dynamo/guards.py:3006] [2/1] [__recompiles] Recompiling function torch_dynamo_resume_in_pre_forward_at_339 in /home/aryan/work/diffusers/src/diffusers/hooks/group_offloading.py:339
V0724 22:00:01.330000 2094194 torch/_dynamo/guards.py:3006] [2/1] [__recompiles] triggered by the following guard failure(s):
V0724 22:00:01.330000 2094194 torch/_dynamo/guards.py:3006] [2/1] [__recompiles] - 2/0: ___check_obj_id(self.next_group, 140026653035680)
V0724 22:00:01.401000 2094194 torch/_dynamo/guards.py:3006] [1/2] [__recompiles] Recompiling function pre_forward in /home/aryan/work/diffusers/src/diffusers/hooks/group_offloading.py:328
V0724 22:00:01.401000 2094194 torch/_dynamo/guards.py:3006] [1/2] [__recompiles] triggered by the following guard failure(s):
V0724 22:00:01.401000 2094194 torch/_dynamo/guards.py:3006] [1/2] [__recompiles] - 1/0: ___check_type_id(self.group.onload_leader, 94566115730144)
V0724 22:00:01.401000 2094194 torch/_dynamo/guards.py:3006] [1/2] [__recompiles] - 1/1: ___check_obj_id(self.group.onload_self, 140026652950464)
V0724 22:00:01.416000 2094194 torch/_dynamo/guards.py:3006] [12/1] [__recompiles] Recompiling function torch_dynamo_resume_in_pre_forward_at_341 in /home/aryan/work/diffusers/src/diffusers/hooks/group_offloading.py:341
V0724 22:00:01.416000 2094194 torch/_dynamo/guards.py:3006] [12/1] [__recompiles] triggered by the following guard failure(s):
V0724 22:00:01.416000 2094194 torch/_dynamo/guards.py:3006] [12/1] [__recompiles] - 12/0: ___check_obj_id(self.group.non_blocking, 140026652950496)
V0724 22:00:01.472000 2094194 torch/_dynamo/guards.py:3006] [1/3] [__recompiles] Recompiling function pre_forward in /home/aryan/work/diffusers/src/diffusers/hooks/group_offloading.py:328
V0724 22:00:01.472000 2094194 torch/_dynamo/guards.py:3006] [1/3] [__recompiles] triggered by the following guard failure(s):
V0724 22:00:01.472000 2094194 torch/_dynamo/guards.py:3006] [1/3] [__recompiles] - 1/2: ___check_type_id(self.next_group, 94566087736832)
V0724 22:00:01.472000 2094194 torch/_dynamo/guards.py:3006] [1/3] [__recompiles] - 1/0: ___check_type_id(self.group.onload_leader, 94566115730144)
V0724 22:00:01.472000 2094194 torch/_dynamo/guards.py:3006] [1/3] [__recompiles] - 1/1: ___check_obj_id(self.group.onload_self, 140026652950464)
Iteration 3
Iteration 4
Iteration 5
Iteration 6
Iteration 7
Iteration 8
Iteration 9
Iteration 10 Code used for testing is slightly different compared to the example you shared. We should always do the compilation after fully setting up the model -- in this case, we should apply group offloading before compilation. codeimport torch
import torch.nn as nn
import torch.nn.functional as F
from diffusers.hooks import apply_group_offloading
from diffusers.models import ModelMixin
class TransformerBlock(nn.Module):
def __init__(self, d_model=1024, n_heads=16, d_ff=4096, dropout=0.1):
super().__init__()
self.d_model = d_model
self.n_heads = n_heads
self.d_ff = d_ff
# Self-attention layer
self.self_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout, batch_first=True)
self.norm1 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
# Feed-forward layer
self.ffn = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(d_ff, d_model),
nn.Dropout(dropout)
)
self.norm2 = nn.LayerNorm(d_model)
def forward(self, x):
# Self-attention with residual connection
attn_out, _ = self.self_attn(x, x, x)
x = self.norm1(x + self.dropout1(attn_out))
# Feed-forward with residual connection
ffn_out = self.ffn(x)
x = self.norm2(x + ffn_out)
return x
class SimpleModel(ModelMixin):
_repeated_blocks = ["TransformerBlock"]
def __init__(self, d_model=1024, n_heads=16, d_ff=4096, dropout=0.1):
super().__init__()
self.d_model = d_model
# 2个transformer blocks
self.blocks = nn.ModuleList([
TransformerBlock(d_model, n_heads, d_ff, dropout),
TransformerBlock(d_model, n_heads, d_ff, dropout)
])
# Input projection (if needed)
self.input_proj = nn.Linear(1024, d_model)
# Output projection (if needed)
self.output_proj = nn.Linear(d_model, 1024)
def forward(self, x):
# Project input to d_model dimensions
x = self.input_proj(x)
# Apply transformer blocks
for block in self.blocks:
x = block(x)
# Project back to original dimensions
x = self.output_proj(x)
return x
model = SimpleModel()
model.eval()
apply_group_offloading(
model,
onload_device=torch.device("cuda:0"),
offload_device=torch.device("cpu"),
offload_type="block_level",
num_blocks_per_group=1,
non_blocking=True,
use_stream=True,
)
# model.compile_repeated_blocks(fullgraph=False, mode="default", dynamic=False)
model = torch.compile(model, fullgraph=False, mode="default", dynamic=False)
with torch.inference_mode():
x = torch.randn(2, 1024).to("cuda")
for i in range(10):
print(f"Iteration {i+1}")
x = model(x) |
I thought the default value of fullgraph is False according to https://docs.pytorch.org/docs/stable/generated/torch.compile.html |
What does this PR do?
For a compiled module, when it runs with groupoffloading, it will throw an error like this
Fixes # (issue)
Before submitting
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@a-r-r-o-w