Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

option to use local SDXL model file #80

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 39 additions & 14 deletions gradio_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

from PIL import Image
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from diffusers import AutoencoderKL, UNet2DConditionModel
from diffusers import AutoencoderKL, UNet2DConditionModel, StableDiffusionXLPipeline
from diffusers.models.attention_processor import AttnProcessor2_0
from transformers import CLIPTextModel, CLIPTokenizer
from lib_omost.pipeline import StableDiffusionXLOmostPipeline
Expand All @@ -35,21 +35,46 @@

# SDXL

sdxl_name = 'SG161222/RealVisXL_V4.0'
# sdxl_name = 'SG161222/RealVisXL_V4.0'

use_local_model = True
sdxl_name = "base/sd_xl_base_1.0"

if use_local_model:
try:
base_model_dir = os.environ["SDXL_MODELS_DIR"]
except KeyError:
print("Please set the SDXL_MODELS_DIR environment variable, e.g. /path/to/ComfyUI/models/checkpoints/")
model_file = f"{base_model_dir}{sdxl_name}.safetensors"
print("using local model file: ", model_file)

pipe = StableDiffusionXLPipeline.from_single_file(
model_file,
torch_dtype=torch.float16,
variant="fp16"
)
tokenizer = pipe.tokenizer
tokenizer_2 = pipe.tokenizer_2
text_encoder = pipe.text_encoder
text_encoder_2 = pipe.text_encoder_2
vae = pipe.vae
unet = pipe.unet

else: # HF diffusers format
# sdxl_name = 'stabilityai/stable-diffusion-xl-base-1.0'

tokenizer = CLIPTokenizer.from_pretrained(
sdxl_name, subfolder="tokenizer")
tokenizer_2 = CLIPTokenizer.from_pretrained(
sdxl_name, subfolder="tokenizer_2")
text_encoder = CLIPTextModel.from_pretrained(
sdxl_name, subfolder="text_encoder", torch_dtype=torch.float16, variant="fp16")
text_encoder_2 = CLIPTextModel.from_pretrained(
sdxl_name, subfolder="text_encoder_2", torch_dtype=torch.float16, variant="fp16")
vae = AutoencoderKL.from_pretrained(
sdxl_name, subfolder="vae", torch_dtype=torch.bfloat16, variant="fp16") # bfloat16 vae
unet = UNet2DConditionModel.from_pretrained(
sdxl_name, subfolder="unet", torch_dtype=torch.float16, variant="fp16")
tokenizer = CLIPTokenizer.from_pretrained(
sdxl_name, subfolder="tokenizer")
tokenizer_2 = CLIPTokenizer.from_pretrained(
sdxl_name, subfolder="tokenizer_2")
text_encoder = CLIPTextModel.from_pretrained(
sdxl_name, subfolder="text_encoder", torch_dtype=torch.float16, variant="fp16")
text_encoder_2 = CLIPTextModel.from_pretrained(
sdxl_name, subfolder="text_encoder_2", torch_dtype=torch.float16, variant="fp16")
vae = AutoencoderKL.from_pretrained(
sdxl_name, subfolder="vae", torch_dtype=torch.bfloat16, variant="fp16") # bfloat16 vae
unet = UNet2DConditionModel.from_pretrained(
sdxl_name, subfolder="unet", torch_dtype=torch.float16, variant="fp16")

unet.set_attn_processor(AttnProcessor2_0())
vae.set_attn_processor(AttnProcessor2_0())
Expand Down