From e9f076c29b40225fbad32ad51b0b08fee5c0e8b2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?David=20L=C3=B6wenfels?= <282+dfl@users.noreply.github.com> Date: Thu, 6 Jun 2024 17:08:04 +0100 Subject: [PATCH] allow to use local SDXL model file --- gradio_app.py | 53 +++++++++++++++++++++++++++++++++++++-------------- 1 file changed, 39 insertions(+), 14 deletions(-) diff --git a/gradio_app.py b/gradio_app.py index d0eecd8..bf67511 100644 --- a/gradio_app.py +++ b/gradio_app.py @@ -23,7 +23,7 @@ from PIL import Image from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer -from diffusers import AutoencoderKL, UNet2DConditionModel +from diffusers import AutoencoderKL, UNet2DConditionModel, StableDiffusionXLPipeline from diffusers.models.attention_processor import AttnProcessor2_0 from transformers import CLIPTextModel, CLIPTokenizer from lib_omost.pipeline import StableDiffusionXLOmostPipeline @@ -35,21 +35,46 @@ # SDXL -sdxl_name = 'SG161222/RealVisXL_V4.0' +# sdxl_name = 'SG161222/RealVisXL_V4.0' + +use_local_model = True +sdxl_name = "base/sd_xl_base_1.0" + +if use_local_model: + try: + base_model_dir = os.environ["SDXL_MODELS_DIR"] + except KeyError: + print("Please set the SDXL_MODELS_DIR environment variable, e.g. /path/to/ComfyUI/models/checkpoints/") + model_file = f"{base_model_dir}{sdxl_name}.safetensors" + print("using local model file: ", model_file) + + pipe = StableDiffusionXLPipeline.from_single_file( + model_file, + torch_dtype=torch.float16, + variant="fp16" + ) + tokenizer = pipe.tokenizer + tokenizer_2 = pipe.tokenizer_2 + text_encoder = pipe.text_encoder + text_encoder_2 = pipe.text_encoder_2 + vae = pipe.vae + unet = pipe.unet + +else: # HF diffusers format # sdxl_name = 'stabilityai/stable-diffusion-xl-base-1.0' -tokenizer = CLIPTokenizer.from_pretrained( - sdxl_name, subfolder="tokenizer") -tokenizer_2 = CLIPTokenizer.from_pretrained( - sdxl_name, subfolder="tokenizer_2") -text_encoder = CLIPTextModel.from_pretrained( - sdxl_name, subfolder="text_encoder", torch_dtype=torch.float16, variant="fp16") -text_encoder_2 = CLIPTextModel.from_pretrained( - sdxl_name, subfolder="text_encoder_2", torch_dtype=torch.float16, variant="fp16") -vae = AutoencoderKL.from_pretrained( - sdxl_name, subfolder="vae", torch_dtype=torch.bfloat16, variant="fp16") # bfloat16 vae -unet = UNet2DConditionModel.from_pretrained( - sdxl_name, subfolder="unet", torch_dtype=torch.float16, variant="fp16") + tokenizer = CLIPTokenizer.from_pretrained( + sdxl_name, subfolder="tokenizer") + tokenizer_2 = CLIPTokenizer.from_pretrained( + sdxl_name, subfolder="tokenizer_2") + text_encoder = CLIPTextModel.from_pretrained( + sdxl_name, subfolder="text_encoder", torch_dtype=torch.float16, variant="fp16") + text_encoder_2 = CLIPTextModel.from_pretrained( + sdxl_name, subfolder="text_encoder_2", torch_dtype=torch.float16, variant="fp16") + vae = AutoencoderKL.from_pretrained( + sdxl_name, subfolder="vae", torch_dtype=torch.bfloat16, variant="fp16") # bfloat16 vae + unet = UNet2DConditionModel.from_pretrained( + sdxl_name, subfolder="unet", torch_dtype=torch.float16, variant="fp16") unet.set_attn_processor(AttnProcessor2_0()) vae.set_attn_processor(AttnProcessor2_0())