From cf16776063542948e5cec303abb8eef24d80b801 Mon Sep 17 00:00:00 2001 From: Li Xiang <54010254+lixiang007666@users.noreply.github.com> Date: Thu, 14 Mar 2024 22:10:56 +0800 Subject: [PATCH] [Fix] black images issues with diffusers SD2.1 (#725) This PR is done: - [x] Related issue: https://github.com/siliconflow/onediff/issues/722#issuecomment-1994110356 --- onediff_diffusers_extensions/examples/image_to_image.py | 8 +++++--- .../register_diffusers/attention_processor_oflow.py | 4 ++-- src/onediff/optimization/attention_processor.py | 4 ++-- 3 files changed, 9 insertions(+), 7 deletions(-) diff --git a/onediff_diffusers_extensions/examples/image_to_image.py b/onediff_diffusers_extensions/examples/image_to_image.py index d7a287181..2483a04de 100644 --- a/onediff_diffusers_extensions/examples/image_to_image.py +++ b/onediff_diffusers_extensions/examples/image_to_image.py @@ -1,11 +1,12 @@ +import argparse from PIL import Image +import oneflow as flow +import torch -import argparse from onediff.infer_compiler import oneflow_compile from diffusers import StableDiffusionImg2ImgPipeline -import oneflow as flow -import torch + prompt = "sea,beach,the waves crashed on the sand,blue sky whit white cloud" @@ -28,6 +29,7 @@ def parse_args(): pipe = pipe.to("cuda") pipe.unet = oneflow_compile(pipe.unet) +pipe.vae.decoder = oneflow_compile(pipe.vae.decoder) img = Image.new("RGB", (512, 512), "#1f80f0") diff --git a/src/infer_compiler_registry/register_diffusers/attention_processor_oflow.py b/src/infer_compiler_registry/register_diffusers/attention_processor_oflow.py index 99e8dd8d3..da3170341 100644 --- a/src/infer_compiler_registry/register_diffusers/attention_processor_oflow.py +++ b/src/infer_compiler_registry/register_diffusers/attention_processor_oflow.py @@ -390,9 +390,9 @@ def head_to_batch_dim(self, tensor, out_dim=3): def get_attention_scores(self, query, key, attention_mask=None): if self.upcast_attention and parse_boolean_from_env( - "ONEFLOW_KERENL_FMHA_ENABLE_TRT_FLASH_ATTN_IMPL", True + "ONEFLOW_ATTENTION_ALLOW_HALF_PRECISION_ACCUMULATION", True ): - set_boolean_env_var("ONEFLOW_KERENL_FMHA_ENABLE_TRT_FLASH_ATTN_IMPL", False) + set_boolean_env_var("ONEFLOW_ATTENTION_ALLOW_HALF_PRECISION_ACCUMULATION", False) dtype = query.dtype # if self.upcast_attention: # query = query.float() diff --git a/src/onediff/optimization/attention_processor.py b/src/onediff/optimization/attention_processor.py index b8ebd749b..22650ab62 100644 --- a/src/onediff/optimization/attention_processor.py +++ b/src/onediff/optimization/attention_processor.py @@ -90,10 +90,10 @@ def __call__( ) if attn.upcast_attention and parse_boolean_from_env( - "ONEFLOW_KERENL_FMHA_ENABLE_TRT_FLASH_ATTN_IMPL", True + "ONEFLOW_ATTENTION_ALLOW_HALF_PRECISION_ACCUMULATION", True ): set_boolean_env_var( - "ONEFLOW_KERENL_FMHA_ENABLE_TRT_FLASH_ATTN_IMPL", False + "ONEFLOW_ATTENTION_ALLOW_HALF_PRECISION_ACCUMULATION", False ) hidden_states = flow._C.fused_multi_head_attention_inference_v2( query=qkv,