-
Notifications
You must be signed in to change notification settings - Fork 363
Added flux demo #3418
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
cehongwang
wants to merge
32
commits into
main
Choose a base branch
from
flux-demo
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+1,628
−169
Open
Added flux demo #3418
Changes from all commits
Commits
Show all changes
32 commits
Select commit
Hold shift + click to select a range
431cc4d
Added CPU offloading
cehongwang 8168f92
Chagned CPU offload to default
cehongwang 23ca669
Added support to module with graph break
cehongwang 953f339
Added back the control flag and fixed the CI
cehongwang 797c670
Chagned CPU offload to default
cehongwang 024992d
Added flux demo
cehongwang 5b4beab
changed the file place and deleted unnecessary code
cehongwang c9573a1
Fixed memory overhead and enabled Flux with Mutable Module
cehongwang 2e90e73
Supported LoRA
cehongwang ef5bca8
Refined Flux demo, solved a bug of device mismatch, and prototyped Cu…
cehongwang a34d25c
Enabled Cuda Graph
cehongwang b8fafae
Enabled weight streaming and CudaGraph. Supported MTTM saving with dy…
cehongwang b6a96d8
Changed the Refitting test to disable CPU offload
cehongwang 53d06f3
Fixed Cuda Error
cehongwang 51c3a90
Fixed the bug of SDXL Cuda Error
cehongwang 3920a63
Changed the way to enable CudaGraph for MTTM
cehongwang 0cb1dc2
Finalize the refit revision
cehongwang 6066d51
Fixed the comments
cehongwang d23853d
Correct the flux export example
cehongwang b7b433a
Added a textbox to display time the generation process takes
cehongwang 7d2e1c3
Added perf script
cehongwang b941b75
added back control flag
cehongwang 13bd604
trying to add quantization to Flux
cehongwang e6e817a
Enable int8 and fp8 quantization for FLUX
cehongwang 41f1f80
Optimized FLUX compilation memory usage
cehongwang 1346fd4
Optimized lowering and decomposition to benchmark quantization again
cehongwang 084724e
Fixed the benchmark typo
cehongwang fb373a0
Use MutableTorchTensorRTModule to do quantization
cehongwang c67ee2f
Added quantization debug script
cehongwang 9c7edb2
Fixed fp16 quantization error
cehongwang f536ac6
Added converter registration
cehongwang 27a2001
Deleted unnecessary files
cehongwang File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
Binary file not shown.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,217 @@ | ||
import argparse | ||
import re | ||
import time | ||
|
||
import gradio as gr | ||
import modelopt.torch.quantization as mtq | ||
import register_sdpa | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can avoid copying the import sys
import os
# Register SDPA as a standalone operator. Converter and lowering pass are defined in register_sdpa.py
sys.path.append(os.path.join(os.path.dirname(__file__), '../dynamo'))
from register_sdpa import * |
||
import torch | ||
import torch_tensorrt | ||
from diffusers import FluxPipeline | ||
|
||
parser = argparse.ArgumentParser( | ||
description="Run Flux quantization with different dtypes" | ||
) | ||
|
||
parser.add_argument( | ||
"--dtype", | ||
choices=["fp8", "int8", "fp16"], | ||
default="fp16", | ||
help="Select the data type to use (fp8 or int8 or fp16)", | ||
) | ||
args = parser.parse_args() | ||
# Update enabled precisions based on dtype argument | ||
|
||
if args.dtype == "fp8": | ||
enabled_precisions = {torch.float8_e4m3fn, torch.float16} | ||
ptq_config = mtq.FP8_DEFAULT_CFG | ||
elif args.dtype == "int8": | ||
enabled_precisions = {torch.int8, torch.float16} | ||
ptq_config = mtq.INT8_DEFAULT_CFG | ||
ptq_config["quant_cfg"]["*weight_quantizer"]["axis"] = None | ||
elif args.dtype == "fp16": | ||
enabled_precisions = {torch.float16} | ||
print(f"\nUsing {args.dtype}") | ||
|
||
|
||
DEVICE = "cuda:0" | ||
pipe = FluxPipeline.from_pretrained( | ||
"black-forest-labs/FLUX.1-dev", | ||
torch_dtype=torch.float16, | ||
) | ||
|
||
|
||
pipe.to(DEVICE).to(torch.float16) | ||
backbone = pipe.transformer | ||
backbone.eval() | ||
|
||
|
||
def filter_func(name): | ||
pattern = re.compile( | ||
r".*(time_emb_proj|time_embedding|conv_in|conv_out|conv_shortcut|add_embedding|pos_embed|time_text_embed|context_embedder|norm_out|x_embedder).*" | ||
) | ||
return pattern.match(name) is not None | ||
|
||
|
||
def do_calibrate( | ||
pipe, | ||
prompt: str, | ||
) -> None: | ||
""" | ||
Run calibration steps on the pipeline using the given prompts. | ||
""" | ||
image = pipe( | ||
prompt, | ||
output_type="pil", | ||
num_inference_steps=20, | ||
generator=torch.Generator("cuda").manual_seed(0), | ||
).images[0] | ||
|
||
|
||
def forward_loop(mod): | ||
# Switch the pipeline's backbone, run calibration | ||
pipe.transformer = mod | ||
do_calibrate( | ||
pipe=pipe, | ||
prompt="test", | ||
) | ||
|
||
|
||
if args.dtype != "fp16": | ||
backbone = mtq.quantize(backbone, ptq_config, forward_loop) | ||
mtq.disable_quantizer(backbone, filter_func) | ||
|
||
batch_size = 2 | ||
|
||
BATCH = torch.export.Dim("batch", min=1, max=8) | ||
dynamic_shapes = { | ||
"hidden_states": {0: BATCH}, | ||
"encoder_hidden_states": {0: BATCH}, | ||
"pooled_projections": {0: BATCH}, | ||
"timestep": {0: BATCH}, | ||
"txt_ids": {}, | ||
"img_ids": {}, | ||
"guidance": {0: BATCH}, | ||
"joint_attention_kwargs": {}, | ||
"return_dict": None, | ||
} | ||
|
||
settings = { | ||
"strict": False, | ||
"allow_complex_guards_as_runtime_asserts": True, | ||
"enabled_precisions": enabled_precisions, | ||
"truncate_double": True, | ||
"min_block_size": 1, | ||
"debug": False, | ||
"use_python_runtime": True, | ||
"immutable_weights": False, | ||
"offload_module_to_cpu": True, | ||
} | ||
|
||
trt_gm = torch_tensorrt.MutableTorchTensorRTModule(backbone, **settings) | ||
trt_gm.set_expected_dynamic_shape_range((), dynamic_shapes) | ||
cehongwang marked this conversation as resolved.
Show resolved
Hide resolved
|
||
pipe.transformer = trt_gm | ||
|
||
|
||
def generate_image(prompt, inference_step, batch_size=2): | ||
start_time = time.time() | ||
image = pipe( | ||
prompt, | ||
output_type="pil", | ||
num_inference_steps=inference_step, | ||
num_images_per_prompt=batch_size, | ||
).images | ||
end_time = time.time() | ||
return image, end_time - start_time | ||
|
||
|
||
generate_image(["Test"], 2) | ||
torch.cuda.empty_cache() | ||
|
||
|
||
def model_change(model): | ||
if model == "Torch Model": | ||
pipe.transformer = backbone | ||
backbone.to(DEVICE) | ||
else: | ||
backbone.to("cpu") | ||
pipe.transformer = trt_gm | ||
torch.cuda.empty_cache() | ||
|
||
|
||
def load_lora(path): | ||
|
||
pipe.load_lora_weights( | ||
path, | ||
adapter_name="lora1", | ||
) | ||
pipe.set_adapters(["lora1"], adapter_weights=[1]) | ||
pipe.fuse_lora() | ||
pipe.unload_lora_weights() | ||
print("LoRA loaded! Begin refitting") | ||
generate_image(["Test"], 2) | ||
print("Refitting Finished!") | ||
|
||
|
||
# Create Gradio interface | ||
with gr.Blocks(title="Flux Demo with Torch-TensorRT") as demo: | ||
gr.Markdown("# Flux Image Generation Demo Accelerated by Torch-TensorRT") | ||
|
||
with gr.Row(): | ||
with gr.Column(): | ||
# Input components | ||
prompt_input = gr.Textbox( | ||
label="Prompt", placeholder="Enter your prompt here...", lines=3 | ||
) | ||
model_dropdown = gr.Dropdown( | ||
choices=["Torch Model", "Torch-TensorRT Accelerated Model"], | ||
value="Torch-TensorRT Accelerated Model", | ||
label="Model Variant", | ||
) | ||
|
||
lora_upload_path = gr.Textbox( | ||
label="LoRA Path", | ||
placeholder="Enter the LoRA checkpoint path here", | ||
value="/home/TensorRT/examples/apps/NGRVNG.safetensors", | ||
lines=2, | ||
) | ||
num_steps = gr.Slider( | ||
minimum=20, maximum=100, value=20, step=1, label="Inference Steps" | ||
) | ||
batch_size = gr.Slider( | ||
minimum=1, maximum=8, value=1, step=1, label="Batch Size" | ||
) | ||
|
||
generate_btn = gr.Button("Generate Image") | ||
load_lora_btn = gr.Button("Load LoRA") | ||
|
||
with gr.Column(): | ||
# Output component | ||
output_image = gr.Gallery(label="Generated Image") | ||
time_taken = gr.Textbox( | ||
label="Generation Time (seconds)", interactive=False | ||
) | ||
|
||
# Connect the button to the generation function | ||
model_dropdown.change(model_change, inputs=[model_dropdown]) | ||
load_lora_btn.click( | ||
fn=load_lora, | ||
inputs=[ | ||
lora_upload_path, | ||
], | ||
) | ||
|
||
# Update generate button click to include time output | ||
generate_btn.click( | ||
fn=generate_image, | ||
inputs=[ | ||
prompt_input, | ||
num_steps, | ||
batch_size, | ||
], | ||
outputs=[output_image, time_taken], | ||
) | ||
|
||
# Launch the interface | ||
if __name__ == "__main__": | ||
demo.launch() |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what does this file do ?