From 4e286b4258598a2d3b2df95313beb0c3c87efff4 Mon Sep 17 00:00:00 2001 From: d8ahazard Date: Sat, 1 Jun 2024 17:47:27 -0500 Subject: [PATCH 01/11] Make it more habitable --- gradio_app.py | 219 +++++++++++++++++++++++++------- models/checkpoints/.placeholder | 0 models/llm/.placeholder | 0 3 files changed, 172 insertions(+), 47 deletions(-) create mode 100644 models/checkpoints/.placeholder create mode 100644 models/llm/.placeholder diff --git a/gradio_app.py b/gradio_app.py index de27039..3aaa9a0 100644 --- a/gradio_app.py +++ b/gradio_app.py @@ -1,8 +1,5 @@ import os -os.environ['HF_HOME'] = os.path.join(os.path.dirname(__file__), 'hf_download') -HF_TOKEN = None - import lib_omost.memory_management as memory_management import uuid @@ -11,19 +8,14 @@ import gradio as gr import tempfile -gradio_temp_dir = os.path.join(tempfile.gettempdir(), 'gradio') -os.makedirs(gradio_temp_dir, exist_ok=True) - from threading import Thread # Phi3 Hijack from transformers.models.phi3.modeling_phi3 import Phi3PreTrainedModel -Phi3PreTrainedModel._supports_sdpa = True - from PIL import Image from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer -from diffusers import AutoencoderKL, UNet2DConditionModel +from diffusers import AutoencoderKL, UNet2DConditionModel, StableDiffusionXLPipeline from diffusers.models.attention_processor import AttnProcessor2_0 from transformers import CLIPTextModel, CLIPTokenizer from lib_omost.pipeline import StableDiffusionXLOmostPipeline @@ -32,39 +24,110 @@ import lib_omost.canvas as omost_canvas +os.environ['HF_HOME'] = os.path.join(os.path.dirname(__file__), 'hf_download') +HF_TOKEN = None + +gradio_temp_dir = os.path.join(tempfile.gettempdir(), 'gradio') +os.makedirs(gradio_temp_dir, exist_ok=True) + +Phi3PreTrainedModel._supports_sdpa = True # SDXL -sdxl_name = 'SG161222/RealVisXL_V4.0' -# sdxl_name = 'stabilityai/stable-diffusion-xl-base-1.0' - -tokenizer = CLIPTokenizer.from_pretrained( - sdxl_name, subfolder="tokenizer") -tokenizer_2 = CLIPTokenizer.from_pretrained( - sdxl_name, subfolder="tokenizer_2") -text_encoder = CLIPTextModel.from_pretrained( - sdxl_name, subfolder="text_encoder", torch_dtype=torch.float16, variant="fp16") -text_encoder_2 = CLIPTextModel.from_pretrained( - sdxl_name, subfolder="text_encoder_2", torch_dtype=torch.float16, variant="fp16") -vae = AutoencoderKL.from_pretrained( - sdxl_name, subfolder="vae", torch_dtype=torch.bfloat16, variant="fp16") # bfloat16 vae -unet = UNet2DConditionModel.from_pretrained( - sdxl_name, subfolder="unet", torch_dtype=torch.float16, variant="fp16") - -unet.set_attn_processor(AttnProcessor2_0()) -vae.set_attn_processor(AttnProcessor2_0()) - -pipeline = StableDiffusionXLOmostPipeline( - vae=vae, - text_encoder=text_encoder, - tokenizer=tokenizer, - text_encoder_2=text_encoder_2, - tokenizer_2=tokenizer_2, - unet=unet, - scheduler=None, # We completely give up diffusers sampling system and use A1111's method -) +sdxl_name = 'RunDiffusion/Juggernaut-X-v10' +sdxl_names = { + 'RunDiffusion/Juggernaut-X-v10': 'Juggernaut-X-v10', + 'SG161222/RealVisXL_V4.0': 'RealVisXL_V4.0', + 'stabilityai/stable-diffusion-xl-base-1.0': 'stable-diffusion-xl-base-1.0' +} + +tokenizer = None +tokenizer_2 = None +text_encoder = None +text_encoder_2 = None +vae = None +unet = None +pipeline = None + + +def list_models(folder_path, file_extension=None): + models = [] + if file_extension == ".safetensors": + models = [(key, name) for key, name in sdxl_names.items()] + + if os.path.exists(folder_path): + for root, dirs, files in os.walk(folder_path): + # If we want to list only files with a specific extension + if file_extension: + for file in files: + if file.endswith(file_extension): + full_path = os.path.join(root, file) + models.append((file, full_path)) + else: + # Append dirs + for model_dir in dirs: + full_path = os.path.join(root, model_dir) + models.append((model_dir, full_path)) + return models + + +def load_pipeline(model_path): + global tokenizer, tokenizer_2, text_encoder, text_encoder_2, vae, unet, pipeline + + if model_path.endswith('.safetensors'): + temp_pipeline = StableDiffusionXLPipeline.from_single_file(model_path) + + tokenizer = temp_pipeline.tokenizer + tokenizer_2 = temp_pipeline.tokenizer_2 + text_encoder = temp_pipeline.text_encoder + text_encoder_2 = temp_pipeline.text_encoder_2 + vae = temp_pipeline.vae + unet = temp_pipeline.unet + else: + tokenizer = CLIPTokenizer.from_pretrained(sdxl_name, subfolder="tokenizer") + tokenizer_2 = CLIPTokenizer.from_pretrained(sdxl_name, subfolder="tokenizer_2") + text_encoder = CLIPTextModel.from_pretrained( + sdxl_name, subfolder="text_encoder", torch_dtype=torch.float16, variant="fp16") + text_encoder_2 = CLIPTextModel.from_pretrained( + sdxl_name, subfolder="text_encoder_2", torch_dtype=torch.float16, variant="fp16") + vae = AutoencoderKL.from_pretrained( + sdxl_name, subfolder="vae", torch_dtype=torch.bfloat16, variant="fp16") + unet = UNet2DConditionModel.from_pretrained( + sdxl_name, subfolder="unet", torch_dtype=torch.float16, variant="fp16") + + unet.set_attn_processor(AttnProcessor2_0()) + vae.set_attn_processor(AttnProcessor2_0()) + + pipeline = StableDiffusionXLOmostPipeline( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + text_encoder_2=text_encoder_2, + tokenizer_2=tokenizer_2, + unet=unet, + scheduler=None, # We completely give up diffusers sampling system and use A1111's method + ) + + memory_management.unload_all_models([text_encoder, text_encoder_2, vae, unet]) + + +def load_llm_model(model_name): + global llm_model, llm_tokenizer + + model_path = os.path.join('./models/llm', model_name) + llm_model = AutoModelForCausalLM.from_pretrained( + model_path, + torch_dtype=torch.bfloat16, + token=HF_TOKEN, + device_map="auto" + ) + llm_tokenizer = AutoTokenizer.from_pretrained( + model_path, + token=HF_TOKEN + ) + + memory_management.unload_all_models(llm_model) -memory_management.unload_all_models([text_encoder, text_encoder_2, vae, unet]) # LLM @@ -74,7 +137,8 @@ llm_model = AutoModelForCausalLM.from_pretrained( llm_name, - torch_dtype=torch.bfloat16, # This is computation type, not load/memory type. The loading quant type is baked in config. + torch_dtype=torch.bfloat16, + # This is computation type, not load/memory type. The loading quant type is baked in config. token=HF_TOKEN, device_map="auto" # This will load model to gpu with an offload system ) @@ -112,7 +176,7 @@ def resize_without_crop(image, target_width, target_height): @torch.inference_mode() -def chat_fn(message: str, history: list, seed:int, temperature: float, top_p: float, max_new_tokens: int) -> str: +def chat_fn(message: str, history: list, seed: int, temperature: float, top_p: float, max_new_tokens: int) -> str: np.random.seed(int(seed)) torch.manual_seed(int(seed)) @@ -175,7 +239,8 @@ def post_chat(history): try: if history: - history = [(user, assistant) for user, assistant in history if isinstance(user, str) and isinstance(assistant, str)] + history = [(user, assistant) for user, assistant in history if + isinstance(user, str) and isinstance(assistant, str)] last_assistant = history[-1][1] if len(history) > 0 else None canvas = omost_canvas.Canvas.from_bot_response(last_assistant) canvas_outputs = canvas.process() @@ -188,9 +253,16 @@ def post_chat(history): @torch.inference_mode() def diffusion_fn(chatbot, canvas_outputs, num_samples, seed, image_width, image_height, highres_scale, steps, cfg, highres_steps, highres_denoise, negative_prompt): - use_initial_latent = False eps = 0.05 + if not isinstance(pipeline, StableDiffusionXLOmostPipeline): + raise ValueError("Pipeline is not StableDiffusionXLOmostPipeline") + + if not isinstance(vae, AutoencoderKL): + raise ValueError("VAE is not AutoencoderKL") + + if not isinstance(unet, UNet2DConditionModel): + raise ValueError("UNet is not UNet2DConditionModel") image_width, image_height = int(image_width // 64) * 64, int(image_height // 64) * 64 @@ -198,7 +270,8 @@ def diffusion_fn(chatbot, canvas_outputs, num_samples, seed, image_width, image_ memory_management.load_models_to_gpu([text_encoder, text_encoder_2]) - positive_cond, positive_pooler, negative_cond, negative_pooler = pipeline.all_conds_from_canvas(canvas_outputs, negative_prompt) + positive_cond, positive_pooler, negative_cond, negative_pooler = pipeline.all_conds_from_canvas(canvas_outputs, + negative_prompt) if use_initial_latent: memory_management.load_models_to_gpu([vae]) @@ -209,7 +282,14 @@ def diffusion_fn(chatbot, canvas_outputs, num_samples, seed, image_width, image_ kernel_size=(initial_latent_blur * 2 + 1,) * 2, stride=(1, 1)) initial_latent = torch.nn.functional.interpolate(initial_latent, (image_height, image_width)) initial_latent = initial_latent.to(dtype=vae.dtype, device=vae.device) - initial_latent = vae.encode(initial_latent).latent_dist.mode() * vae.config.scaling_factor + try: + if isinstance(vae.config, dict): + initial_latent = vae.encode(initial_latent).latent_dist.mode() * vae.config['scaling_factor'] + else: + initial_latent = vae.encode(initial_latent).latent_dist.mode() * vae.config.scaling_factor + except Exception as e: + print('Failed to encode initial latent:', e) + initial_latent = torch.zeros(size=(1, 4, image_height // 8, image_width // 8), dtype=torch.float32) else: initial_latent = torch.zeros(size=(num_samples, 4, image_height // 8, image_width // 8), dtype=torch.float32) @@ -279,6 +359,20 @@ def diffusion_fn(chatbot, canvas_outputs, num_samples, seed, image_width, image_ return chatbot +def update_model_list(): + model_list = list_models('./models/checkpoints', '.safetensors') + options = [model for model, path in model_list] + paths = {model: path for model, path in model_list} + return gr.update(choices=options, value=options[0] if options else ""), paths + + +def update_llm_list(): + llm_list = list_models('./models/llm') + options = [os.path.basename(path) for name, path in llm_list] + paths = {os.path.basename(path): path for name, path in llm_list} + return gr.update(choices=options, value=options[0] if options else ""), paths + + css = ''' code {white-space: pre-wrap !important;} .gradio-container {max-width: none !important;} @@ -299,7 +393,8 @@ def diffusion_fn(chatbot, canvas_outputs, num_samples, seed, image_width, image_ with gr.Row(): clear_btn = gr.Button("➕ New Chat", variant="secondary", size="sm", min_width=60) retry_btn = gr.Button("Retry", variant="secondary", size="sm", min_width=60, visible=False) - undo_btn = gr.Button("✏️️ Edit Last Input", variant="secondary", size="sm", min_width=60, interactive=False) + undo_btn = gr.Button("✏️️ Edit Last Input", variant="secondary", size="sm", min_width=60, + interactive=False) seed = gr.Number(label="Random Seed", value=12345, precision=0) @@ -324,6 +419,8 @@ def diffusion_fn(chatbot, canvas_outputs, num_samples, seed, image_width, image_ step=1, value=4096, label="Max New Tokens") + llm_select = gr.Dropdown(label="Select LLM Model", interactive=True) + llm_refresh_btn = gr.Button("Refresh LLM List", variant="secondary", size="sm", min_width=60) with gr.Accordion(open=True, label='Image Diffusion Model'): with gr.Group(): with gr.Row(): @@ -333,13 +430,17 @@ def diffusion_fn(chatbot, canvas_outputs, num_samples, seed, image_width, image_ with gr.Row(): num_samples = gr.Slider(label="Image Number", minimum=1, maximum=12, value=1, step=1) steps = gr.Slider(label="Sampling Steps", minimum=1, maximum=100, value=25, step=1) + model_select = gr.Dropdown(label="Select Model", interactive=True) + model_refresh_btn = gr.Button("Refresh Model List", variant="secondary", size="sm", min_width=60) with gr.Accordion(open=False, label='Advanced'): cfg = gr.Slider(label="CFG Scale", minimum=1.0, maximum=32.0, value=5.0, step=0.01) - highres_scale = gr.Slider(label="HR-fix Scale (\"1\" is disabled)", minimum=1.0, maximum=2.0, value=1.0, step=0.01) + highres_scale = gr.Slider(label="HR-fix Scale (\"1\" is disabled)", minimum=1.0, maximum=2.0, value=1.0, + step=0.01) highres_steps = gr.Slider(label="Highres Fix Steps", minimum=1, maximum=100, value=20, step=1) highres_denoise = gr.Slider(label="Highres Fix Denoise", minimum=0.1, maximum=1.0, value=0.4, step=0.01) - n_prompt = gr.Textbox(label="Negative Prompt", value='lowres, bad anatomy, bad hands, cropped, worst quality') + n_prompt = gr.Textbox(label="Negative Prompt", + value='lowres, bad anatomy, bad hands, cropped, worst quality') render_button = gr.Button("Render the Image!", size='lg', variant="primary", visible=False) @@ -378,5 +479,29 @@ def diffusion_fn(chatbot, canvas_outputs, num_samples, seed, image_width, image_ chatInterface.chatbot ], outputs=[chatInterface.chatbot_state]) + model_refresh_btn.click( + fn=update_model_list, + inputs=[], + outputs=[model_select] + ) + + llm_refresh_btn.click( + fn=update_llm_list, + inputs=[], + outputs=[llm_select] + ) + + model_select.change( + fn=load_pipeline, + inputs=[model_select], + outputs=[] + ) + + llm_select.change( + fn=load_llm_model, + inputs=[llm_select], + outputs=[] + ) + if __name__ == "__main__": demo.queue().launch(inbrowser=True, server_name='0.0.0.0') diff --git a/models/checkpoints/.placeholder b/models/checkpoints/.placeholder new file mode 100644 index 0000000..e69de29 diff --git a/models/llm/.placeholder b/models/llm/.placeholder new file mode 100644 index 0000000..e69de29 From 404b798fbd8b9ad22dcd54da976c1fe75b06530e Mon Sep 17 00:00:00 2001 From: d8ahazard Date: Sat, 1 Jun 2024 18:54:39 -0500 Subject: [PATCH 02/11] Model Phun --- .gitignore | 4 ++++ gradio_app.py | 66 +++++++++++++++++++++++++-------------------------- 2 files changed, 36 insertions(+), 34 deletions(-) diff --git a/.gitignore b/.gitignore index 594e5e6..8f8054d 100644 --- a/.gitignore +++ b/.gitignore @@ -162,3 +162,7 @@ cython_debug/ # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. .idea/ + +# Ignore everything in ./models recursively if it's not a .placeholder file +/models/** +!/models/**/.placeholder diff --git a/gradio_app.py b/gradio_app.py index 3aaa9a0..d776a7d 100644 --- a/gradio_app.py +++ b/gradio_app.py @@ -68,10 +68,13 @@ def list_models(folder_path, file_extension=None): for model_dir in dirs: full_path = os.path.join(root, model_dir) models.append((model_dir, full_path)) + else: + print(f"Folder does not exist: {folder_path}") return models def load_pipeline(model_path): + print(f"Loading model from {model_path}") global tokenizer, tokenizer_2, text_encoder, text_encoder_2, vae, unet, pipeline if model_path.endswith('.safetensors'): @@ -86,14 +89,10 @@ def load_pipeline(model_path): else: tokenizer = CLIPTokenizer.from_pretrained(sdxl_name, subfolder="tokenizer") tokenizer_2 = CLIPTokenizer.from_pretrained(sdxl_name, subfolder="tokenizer_2") - text_encoder = CLIPTextModel.from_pretrained( - sdxl_name, subfolder="text_encoder", torch_dtype=torch.float16, variant="fp16") - text_encoder_2 = CLIPTextModel.from_pretrained( - sdxl_name, subfolder="text_encoder_2", torch_dtype=torch.float16, variant="fp16") - vae = AutoencoderKL.from_pretrained( - sdxl_name, subfolder="vae", torch_dtype=torch.bfloat16, variant="fp16") - unet = UNet2DConditionModel.from_pretrained( - sdxl_name, subfolder="unet", torch_dtype=torch.float16, variant="fp16") + text_encoder = CLIPTextModel.from_pretrained(sdxl_name, subfolder="text_encoder") + text_encoder_2 = CLIPTextModel.from_pretrained(sdxl_name, subfolder="text_encoder_2") + vae = AutoencoderKL.from_pretrained(sdxl_name, subfolder="vae") + unet = UNet2DConditionModel.from_pretrained(sdxl_name, subfolder="unet") unet.set_attn_processor(AttnProcessor2_0()) vae.set_attn_processor(AttnProcessor2_0()) @@ -112,9 +111,15 @@ def load_pipeline(model_path): def load_llm_model(model_name): - global llm_model, llm_tokenizer + global llm_model, llm_tokenizer, llm_model_name + if llm_model_name == model_name and llm_model: + return - model_path = os.path.join('./models/llm', model_name) + test_model_path = os.path.join('./models/llm', model_name) + if os.path.exists(test_model_path): + model_path = test_model_path + else: + model_path = model_name llm_model = AutoModelForCausalLM.from_pretrained( model_path, torch_dtype=torch.bfloat16, @@ -125,6 +130,7 @@ def load_llm_model(model_name): model_path, token=HF_TOKEN ) + llm_model_name = model_name memory_management.unload_all_models(llm_model) @@ -135,20 +141,9 @@ def load_llm_model(model_name): llm_name = 'lllyasviel/omost-llama-3-8b-4bits' # llm_name = 'lllyasviel/omost-dolphin-2.9-llama3-8b-4bits' -llm_model = AutoModelForCausalLM.from_pretrained( - llm_name, - torch_dtype=torch.bfloat16, - # This is computation type, not load/memory type. The loading quant type is baked in config. - token=HF_TOKEN, - device_map="auto" # This will load model to gpu with an offload system -) - -llm_tokenizer = AutoTokenizer.from_pretrained( - llm_name, - token=HF_TOKEN -) - -memory_management.unload_all_models(llm_model) +llm_model = None +llm_model_name = None +llm_tokenizer = None @torch.inference_mode() @@ -177,6 +172,7 @@ def resize_without_crop(image, target_width, target_height): @torch.inference_mode() def chat_fn(message: str, history: list, seed: int, temperature: float, top_p: float, max_new_tokens: int) -> str: + global llm_model, llm_tokenizer np.random.seed(int(seed)) torch.manual_seed(int(seed)) @@ -188,7 +184,9 @@ def chat_fn(message: str, history: list, seed: int, temperature: float, top_p: f conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}]) conversation.append({"role": "user", "content": message}) - + # Load the model if it is not loaded + if not llm_model: + load_llm_model(llm_name) memory_management.load_models_to_gpu(llm_model) input_ids = llm_tokenizer.apply_chat_template( @@ -227,7 +225,6 @@ def interrupter(): outputs = [] for text in streamer: outputs.append(text) - # print(outputs) yield "".join(outputs), interrupter return @@ -361,16 +358,16 @@ def diffusion_fn(chatbot, canvas_outputs, num_samples, seed, image_width, image_ def update_model_list(): model_list = list_models('./models/checkpoints', '.safetensors') - options = [model for model, path in model_list] + options = [(model, path) for model, path in model_list] paths = {model: path for model, path in model_list} - return gr.update(choices=options, value=options[0] if options else ""), paths + return gr.Dropdown.update(choices=options, value=options[0][0] if options else ""), paths def update_llm_list(): llm_list = list_models('./models/llm') - options = [os.path.basename(path) for name, path in llm_list] + options = [(os.path.basename(path), path) for name, path in llm_list] paths = {os.path.basename(path): path for name, path in llm_list} - return gr.update(choices=options, value=options[0] if options else ""), paths + return gr.Dropdown.update(choices=options, value=options[0][0] if options else ""), paths css = ''' @@ -419,7 +416,7 @@ def update_llm_list(): step=1, value=4096, label="Max New Tokens") - llm_select = gr.Dropdown(label="Select LLM Model", interactive=True) + llm_select = gr.Dropdown(label="Select LLM Model", value=llm_name, choices=[(llm_name, llm_name)], interactive=True) llm_refresh_btn = gr.Button("Refresh LLM List", variant="secondary", size="sm", min_width=60) with gr.Accordion(open=True, label='Image Diffusion Model'): with gr.Group(): @@ -430,7 +427,8 @@ def update_llm_list(): with gr.Row(): num_samples = gr.Slider(label="Image Number", minimum=1, maximum=12, value=1, step=1) steps = gr.Slider(label="Sampling Steps", minimum=1, maximum=100, value=25, step=1) - model_select = gr.Dropdown(label="Select Model", interactive=True) + model_select = gr.Dropdown(label="Select Model", choices=list_models('./models/checkpoints', '.safetensors'), + interactive=True, value=list_models('./models/checkpoints', '.safetensors')[0][0]) model_refresh_btn = gr.Button("Refresh Model List", variant="secondary", size="sm", min_width=60) with gr.Accordion(open=False, label='Advanced'): @@ -446,8 +444,8 @@ def update_llm_list(): examples = gr.Dataset( samples=[ - ['generate an image of the fierce battle of warriors and a dragon'], - ['change the dragon to a dinosaur'] + ['Generate an image of several squirrels in business suits having a meeting in a park'], + ['Add a dog in the background'] ], components=[gr.Textbox(visible=False)], label='Quick Prompts' From bdb8245d49b540237fb9562509d6591096e16ef8 Mon Sep 17 00:00:00 2001 From: d8ahazard Date: Sat, 1 Jun 2024 19:05:22 -0500 Subject: [PATCH 03/11] Fixes --- gradio_app.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/gradio_app.py b/gradio_app.py index d776a7d..4b8cb80 100644 --- a/gradio_app.py +++ b/gradio_app.py @@ -15,7 +15,7 @@ from PIL import Image from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer -from diffusers import AutoencoderKL, UNet2DConditionModel, StableDiffusionXLPipeline +from diffusers import AutoencoderKL, UNet2DConditionModel, StableDiffusionXLPipeline, StableDiffusionXLImg2ImgPipeline from diffusers.models.attention_processor import AttnProcessor2_0 from transformers import CLIPTextModel, CLIPTokenizer from lib_omost.pipeline import StableDiffusionXLOmostPipeline @@ -61,6 +61,7 @@ def list_models(folder_path, file_extension=None): if file_extension: for file in files: if file.endswith(file_extension): + print(f"Found model: {file}") full_path = os.path.join(root, file) models.append((file, full_path)) else: @@ -68,6 +69,7 @@ def list_models(folder_path, file_extension=None): for model_dir in dirs: full_path = os.path.join(root, model_dir) models.append((model_dir, full_path)) + print(f"Found model: {model_dir}") else: print(f"Folder does not exist: {folder_path}") return models @@ -78,7 +80,7 @@ def load_pipeline(model_path): global tokenizer, tokenizer_2, text_encoder, text_encoder_2, vae, unet, pipeline if model_path.endswith('.safetensors'): - temp_pipeline = StableDiffusionXLPipeline.from_single_file(model_path) + temp_pipeline = StableDiffusionXLImg2ImgPipeline.from_single_file(model_path) tokenizer = temp_pipeline.tokenizer tokenizer_2 = temp_pipeline.tokenizer_2 @@ -249,9 +251,11 @@ def post_chat(history): @torch.inference_mode() def diffusion_fn(chatbot, canvas_outputs, num_samples, seed, image_width, image_height, - highres_scale, steps, cfg, highres_steps, highres_denoise, negative_prompt): + highres_scale, steps, cfg, highres_steps, highres_denoise, negative_prompt, model_selection): use_initial_latent = False eps = 0.05 + # Load the model + load_pipeline(model_selection) if not isinstance(pipeline, StableDiffusionXLOmostPipeline): raise ValueError("Pipeline is not StableDiffusionXLOmostPipeline") @@ -471,7 +475,7 @@ def update_llm_list(): fn=diffusion_fn, inputs=[ chatInterface.chatbot, canvas_state, num_samples, seed, image_width, image_height, highres_scale, - steps, cfg, highres_steps, highres_denoise, n_prompt + steps, cfg, highres_steps, highres_denoise, n_prompt, model_select ], outputs=[chatInterface.chatbot]).then( fn=lambda x: x, inputs=[ chatInterface.chatbot From 5fe96e9452c314a05a8db35c8b6a81b1b5a8a9f5 Mon Sep 17 00:00:00 2001 From: d8ahazard Date: Sun, 2 Jun 2024 12:51:31 -0500 Subject: [PATCH 04/11] Add args, fix dropdowns/model selection. --- gradio_app.py | 148 +++++++++++++++++++++++++++++--------------------- 1 file changed, 86 insertions(+), 62 deletions(-) diff --git a/gradio_app.py b/gradio_app.py index 4b8cb80..41e0ee3 100644 --- a/gradio_app.py +++ b/gradio_app.py @@ -1,28 +1,25 @@ +import argparse import os - -import lib_omost.memory_management as memory_management -import uuid - -import torch -import numpy as np -import gradio as gr import tempfile - +import uuid from threading import Thread -# Phi3 Hijack -from transformers.models.phi3.modeling_phi3 import Phi3PreTrainedModel - +import gradio as gr +import numpy as np +import torch from PIL import Image -from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer -from diffusers import AutoencoderKL, UNet2DConditionModel, StableDiffusionXLPipeline, StableDiffusionXLImg2ImgPipeline +from diffusers import AutoencoderKL, UNet2DConditionModel, StableDiffusionXLImg2ImgPipeline from diffusers.models.attention_processor import AttnProcessor2_0 +from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer from transformers import CLIPTextModel, CLIPTokenizer -from lib_omost.pipeline import StableDiffusionXLOmostPipeline -from chat_interface import ChatInterface from transformers.generation.stopping_criteria import StoppingCriteriaList +# Phi3 Hijack +from transformers.models.phi3.modeling_phi3 import Phi3PreTrainedModel import lib_omost.canvas as omost_canvas +import lib_omost.memory_management as memory_management +from chat_interface import ChatInterface +from lib_omost.pipeline import StableDiffusionXLOmostPipeline os.environ['HF_HOME'] = os.path.join(os.path.dirname(__file__), 'hf_download') HF_TOKEN = None @@ -32,15 +29,30 @@ Phi3PreTrainedModel._supports_sdpa = True -# SDXL +parser = argparse.ArgumentParser() +parser.add_argument("--hf_token", type=str, default=None) +parser.add_argument("--sdxl_name", type=str, default='RunDiffusion/Juggernaut-X-v10') +parser.add_argument("--llm_name", type=str, default='lllyasviel/omost-llama-3-8b-4bits') +parser.add_argument("--checkpoints_folder", type=str, + default=os.path.join(os.path.dirname(__file__), "models", "checkpoints")) +parser.add_argument("--llm_folder", type=str, default=os.path.join(os.path.dirname(__file__), "models", "llm")) +args = parser.parse_args() -sdxl_name = 'RunDiffusion/Juggernaut-X-v10' -sdxl_names = { +DEFAULT_CHECKPOINTS = { 'RunDiffusion/Juggernaut-X-v10': 'Juggernaut-X-v10', 'SG161222/RealVisXL_V4.0': 'RealVisXL_V4.0', 'stabilityai/stable-diffusion-xl-base-1.0': 'stable-diffusion-xl-base-1.0' } +DEFAULT_LLMS = { + 'lllyasviel/omost-llama-3-8b-4bits': 'omost-llama-3-8b-4bits', + 'lllyasviel/omost-dolphin-2.9-llama3-8b-4bits': 'omost-dolphin-2.9-llama3-8b-4bits', + 'lllyasviel/omost-phi-3-mini-128k-8bits': 'omost-phi-3-mini-128k-8bits', + 'lllyasviel/omost-llama-3-8b': 'omost-llama-3-8b', + 'lllyasviel/omost-dolphin-2.9-llama3-8b': 'omost-dolphin-2.9-llama3-8b', + 'lllyasviel/omost-phi-3-mini-128k': 'omost-phi-3-mini-128k' +} + tokenizer = None tokenizer_2 = None text_encoder = None @@ -48,20 +60,28 @@ vae = None unet = None pipeline = None +loaded_pipeline = None +llm_model = None +llm_model_name = None +llm_tokenizer = None -def list_models(folder_path, file_extension=None): - models = [] - if file_extension == ".safetensors": - models = [(key, name) for key, name in sdxl_names.items()] +def list_models(llm: bool = False): + if not llm: + folder_path = args.checkpoints_folder + default_model = args.sdxl_name + models = [(name, key) for key, name in DEFAULT_CHECKPOINTS.items()] + else: + folder_path = args.llm_folder + default_model = args.llm_name + models = [(name, key) for key, name in DEFAULT_LLMS.items()] if os.path.exists(folder_path): for root, dirs, files in os.walk(folder_path): # If we want to list only files with a specific extension - if file_extension: + if not llm: for file in files: - if file.endswith(file_extension): - print(f"Found model: {file}") + if file.endswith(".safetensors"): full_path = os.path.join(root, file) models.append((file, full_path)) else: @@ -69,15 +89,21 @@ def list_models(folder_path, file_extension=None): for model_dir in dirs: full_path = os.path.join(root, model_dir) models.append((model_dir, full_path)) - print(f"Found model: {model_dir}") + if llm: + if llm_model and llm_model_name and llm_model_name in [name for key, name in models]: + default_model = llm_model_name else: - print(f"Folder does not exist: {folder_path}") - return models + if pipeline and loaded_pipeline and loaded_pipeline in [name for key, name in models]: + default_model = loaded_pipeline + + return models, default_model def load_pipeline(model_path): + global tokenizer, tokenizer_2, text_encoder, text_encoder_2, vae, unet, pipeline, loaded_pipeline + if pipeline is not None and loaded_pipeline == model_path: + return print(f"Loading model from {model_path}") - global tokenizer, tokenizer_2, text_encoder, text_encoder_2, vae, unet, pipeline if model_path.endswith('.safetensors'): temp_pipeline = StableDiffusionXLImg2ImgPipeline.from_single_file(model_path) @@ -86,15 +112,18 @@ def load_pipeline(model_path): tokenizer_2 = temp_pipeline.tokenizer_2 text_encoder = temp_pipeline.text_encoder text_encoder_2 = temp_pipeline.text_encoder_2 + # Convert text_encoder_2 to ClipTextModel + text_encoder_2 = CLIPTextModel(config=text_encoder_2.config) vae = temp_pipeline.vae unet = temp_pipeline.unet else: - tokenizer = CLIPTokenizer.from_pretrained(sdxl_name, subfolder="tokenizer") - tokenizer_2 = CLIPTokenizer.from_pretrained(sdxl_name, subfolder="tokenizer_2") - text_encoder = CLIPTextModel.from_pretrained(sdxl_name, subfolder="text_encoder") - text_encoder_2 = CLIPTextModel.from_pretrained(sdxl_name, subfolder="text_encoder_2") - vae = AutoencoderKL.from_pretrained(sdxl_name, subfolder="vae") - unet = UNet2DConditionModel.from_pretrained(sdxl_name, subfolder="unet") + tokenizer = CLIPTokenizer.from_pretrained(model_path, subfolder="tokenizer", torch_dtype=torch.float16) + tokenizer_2 = CLIPTokenizer.from_pretrained(model_path, subfolder="tokenizer_2", torch_dtype=torch.float16) + text_encoder = CLIPTextModel.from_pretrained(model_path, subfolder="text_encoder", torch_dtype=torch.float16) + text_encoder_2 = CLIPTextModel.from_pretrained(model_path, subfolder="text_encoder_2", + torch_dtype=torch.float16) + vae = AutoencoderKL.from_pretrained(model_path, subfolder="vae", torch_dtype=torch.float16) + unet = UNet2DConditionModel.from_pretrained(model_path, subfolder="unet", torch_dtype=torch.float16) unet.set_attn_processor(AttnProcessor2_0()) vae.set_attn_processor(AttnProcessor2_0()) @@ -108,6 +137,7 @@ def load_pipeline(model_path): unet=unet, scheduler=None, # We completely give up diffusers sampling system and use A1111's method ) + loaded_pipeline = model_path memory_management.unload_all_models([text_encoder, text_encoder_2, vae, unet]) @@ -116,20 +146,16 @@ def load_llm_model(model_name): global llm_model, llm_tokenizer, llm_model_name if llm_model_name == model_name and llm_model: return + print(f"Loading LLM model from {model_name}") - test_model_path = os.path.join('./models/llm', model_name) - if os.path.exists(test_model_path): - model_path = test_model_path - else: - model_path = model_name llm_model = AutoModelForCausalLM.from_pretrained( - model_path, + model_name, torch_dtype=torch.bfloat16, token=HF_TOKEN, device_map="auto" ) llm_tokenizer = AutoTokenizer.from_pretrained( - model_path, + model_name, token=HF_TOKEN ) llm_model_name = model_name @@ -137,17 +163,6 @@ def load_llm_model(model_name): memory_management.unload_all_models(llm_model) -# LLM - -# llm_name = 'lllyasviel/omost-phi-3-mini-128k-8bits' -llm_name = 'lllyasviel/omost-llama-3-8b-4bits' -# llm_name = 'lllyasviel/omost-dolphin-2.9-llama3-8b-4bits' - -llm_model = None -llm_model_name = None -llm_tokenizer = None - - @torch.inference_mode() def pytorch2numpy(imgs): results = [] @@ -174,7 +189,7 @@ def resize_without_crop(image, target_width, target_height): @torch.inference_mode() def chat_fn(message: str, history: list, seed: int, temperature: float, top_p: float, max_new_tokens: int) -> str: - global llm_model, llm_tokenizer + global llm_model, llm_tokenizer, llm_model_name np.random.seed(int(seed)) torch.manual_seed(int(seed)) @@ -188,7 +203,7 @@ def chat_fn(message: str, history: list, seed: int, temperature: float, top_p: f conversation.append({"role": "user", "content": message}) # Load the model if it is not loaded if not llm_model: - load_llm_model(llm_name) + load_llm_model(llm_model_name) memory_management.load_models_to_gpu(llm_model) input_ids = llm_tokenizer.apply_chat_template( @@ -361,17 +376,21 @@ def diffusion_fn(chatbot, canvas_outputs, num_samples, seed, image_width, image_ def update_model_list(): - model_list = list_models('./models/checkpoints', '.safetensors') + model_list, default_model = list_models(False) + if loaded_pipeline and loaded_pipeline in [path for name, path in model_list]: + default_model = loaded_pipeline options = [(model, path) for model, path in model_list] paths = {model: path for model, path in model_list} - return gr.Dropdown.update(choices=options, value=options[0][0] if options else ""), paths + return gr.update(choices=options, value=default_model if options else ""), paths def update_llm_list(): - llm_list = list_models('./models/llm') + llm_list, default_llm = list_models(True) + if llm_model and llm_model_name and llm_model_name in [name for name, path in llm_list]: + default_llm = llm_model_name options = [(os.path.basename(path), path) for name, path in llm_list] paths = {os.path.basename(path): path for name, path in llm_list} - return gr.Dropdown.update(choices=options, value=options[0][0] if options else ""), paths + return gr.update(choices=options, value=default_llm if options else ""), paths css = ''' @@ -420,7 +439,10 @@ def update_llm_list(): step=1, value=4096, label="Max New Tokens") - llm_select = gr.Dropdown(label="Select LLM Model", value=llm_name, choices=[(llm_name, llm_name)], interactive=True) + llm_models, selected = list_models(True) + llm_model_name = selected + llm_select = gr.Dropdown(label="Select LLM Model", value=selected, choices=llm_models, + interactive=True) llm_refresh_btn = gr.Button("Refresh LLM List", variant="secondary", size="sm", min_width=60) with gr.Accordion(open=True, label='Image Diffusion Model'): with gr.Group(): @@ -431,8 +453,10 @@ def update_llm_list(): with gr.Row(): num_samples = gr.Slider(label="Image Number", minimum=1, maximum=12, value=1, step=1) steps = gr.Slider(label="Sampling Steps", minimum=1, maximum=100, value=25, step=1) - model_select = gr.Dropdown(label="Select Model", choices=list_models('./models/checkpoints', '.safetensors'), - interactive=True, value=list_models('./models/checkpoints', '.safetensors')[0][0]) + checkpoint_list, selected = list_models(False) + + model_select = gr.Dropdown(label="Select Model", choices=checkpoint_list, interactive=True, + value=selected) model_refresh_btn = gr.Button("Refresh Model List", variant="secondary", size="sm", min_width=60) with gr.Accordion(open=False, label='Advanced'): From b330b1a38bca331579ff49631c3cf04858ebb507 Mon Sep 17 00:00:00 2001 From: d8ahazard Date: Sun, 2 Jun 2024 12:57:04 -0500 Subject: [PATCH 05/11] Use -1 for random seed, set outputs folder --- gradio_app.py | 12 +++++++++--- outputs/.placeholder | 0 2 files changed, 9 insertions(+), 3 deletions(-) create mode 100644 outputs/.placeholder diff --git a/gradio_app.py b/gradio_app.py index 41e0ee3..74a9dd2 100644 --- a/gradio_app.py +++ b/gradio_app.py @@ -36,6 +36,7 @@ parser.add_argument("--checkpoints_folder", type=str, default=os.path.join(os.path.dirname(__file__), "models", "checkpoints")) parser.add_argument("--llm_folder", type=str, default=os.path.join(os.path.dirname(__file__), "models", "llm")) +parser.add_argument("--outputs_folder", type=str, default=os.path.join(os.path.dirname(__file__), "outputs")) args = parser.parse_args() DEFAULT_CHECKPOINTS = { @@ -65,6 +66,8 @@ llm_model_name = None llm_tokenizer = None +os.makedirs(args.outputs_folder, exist_ok=True) + def list_models(llm: bool = False): if not llm: @@ -190,6 +193,8 @@ def resize_without_crop(image, target_width, target_height): @torch.inference_mode() def chat_fn(message: str, history: list, seed: int, temperature: float, top_p: float, max_new_tokens: int) -> str: global llm_model, llm_tokenizer, llm_model_name + if seed == -1: + seed = np.random.randint(0, 2 ** 32 - 1) np.random.seed(int(seed)) torch.manual_seed(int(seed)) @@ -281,7 +286,8 @@ def diffusion_fn(chatbot, canvas_outputs, num_samples, seed, image_width, image_ raise ValueError("UNet is not UNet2DConditionModel") image_width, image_height = int(image_width // 64) * 64, int(image_height // 64) * 64 - + if seed == -1: + seed = np.random.randint(0, 2 ** 32 - 1) rng = torch.Generator(device=memory_management.gpu).manual_seed(seed) memory_management.load_models_to_gpu([text_encoder, text_encoder_2]) @@ -367,7 +373,7 @@ def diffusion_fn(chatbot, canvas_outputs, num_samples, seed, image_width, image_ for i in range(len(pixels)): unique_hex = uuid.uuid4().hex - image_path = os.path.join(gradio_temp_dir, f"{unique_hex}_{i}.png") + image_path = os.path.join(args.outputs_folder, f"{unique_hex}_{i}.png") image = Image.fromarray(pixels[i]) image.save(image_path) chatbot = chatbot + [(None, (image_path, 'image'))] @@ -416,7 +422,7 @@ def update_llm_list(): undo_btn = gr.Button("✏️️ Edit Last Input", variant="secondary", size="sm", min_width=60, interactive=False) - seed = gr.Number(label="Random Seed", value=12345, precision=0) + seed = gr.Number(label="Random Seed", value=-1, precision=0) with gr.Accordion(open=True, label='Language Model'): with gr.Group(): diff --git a/outputs/.placeholder b/outputs/.placeholder new file mode 100644 index 0000000..e69de29 From 9db036810d173edcdcc5f58d2906c7cfd85a336d Mon Sep 17 00:00:00 2001 From: d8ahazard Date: Sun, 2 Jun 2024 12:57:16 -0500 Subject: [PATCH 06/11] Update .gitignore --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index 8f8054d..1cd710a 100644 --- a/.gitignore +++ b/.gitignore @@ -166,3 +166,4 @@ cython_debug/ # Ignore everything in ./models recursively if it's not a .placeholder file /models/** !/models/**/.placeholder +outputs/* From 456b9f880e063fbacbbe7c193d743a2f2de81d7a Mon Sep 17 00:00:00 2001 From: d8ahazard Date: Sun, 2 Jun 2024 13:08:39 -0500 Subject: [PATCH 07/11] Don't show .safetensors in model names in select --- gradio_app.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gradio_app.py b/gradio_app.py index 74a9dd2..b328bb1 100644 --- a/gradio_app.py +++ b/gradio_app.py @@ -86,7 +86,7 @@ def list_models(llm: bool = False): for file in files: if file.endswith(".safetensors"): full_path = os.path.join(root, file) - models.append((file, full_path)) + models.append((file.replace(".safetensors", ""), full_path)) else: # Append dirs for model_dir in dirs: From 58810ed0039947acff4c5fc0eabd616da7a67f0a Mon Sep 17 00:00:00 2001 From: d8ahazard Date: Sun, 2 Jun 2024 13:29:42 -0500 Subject: [PATCH 08/11] Delete loaded models when loading new models... --- gradio_app.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/gradio_app.py b/gradio_app.py index b328bb1..4713630 100644 --- a/gradio_app.py +++ b/gradio_app.py @@ -106,6 +106,12 @@ def load_pipeline(model_path): global tokenizer, tokenizer_2, text_encoder, text_encoder_2, vae, unet, pipeline, loaded_pipeline if pipeline is not None and loaded_pipeline == model_path: return + if pipeline: + models_to_delete = [tokenizer, tokenizer_2, text_encoder, text_encoder_2, vae, unet, pipeline] + for model in models_to_delete: + if model: + del model + torch.cuda.empty_cache() print(f"Loading model from {model_path}") if model_path.endswith('.safetensors'): @@ -149,6 +155,11 @@ def load_llm_model(model_name): global llm_model, llm_tokenizer, llm_model_name if llm_model_name == model_name and llm_model: return + if llm_model: + del llm_model + if llm_tokenizer: + del llm_tokenizer + torch.cuda.empty_cache() print(f"Loading LLM model from {model_name}") llm_model = AutoModelForCausalLM.from_pretrained( From 26282e64e77d90927983e7e117f2e51a0917c147 Mon Sep 17 00:00:00 2001 From: d8ahazard Date: Sun, 2 Jun 2024 13:39:20 -0500 Subject: [PATCH 09/11] Don't load models to CPU only to load to GPU again right away. --- gradio_app.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/gradio_app.py b/gradio_app.py index 4713630..9e390a6 100644 --- a/gradio_app.py +++ b/gradio_app.py @@ -102,7 +102,7 @@ def list_models(llm: bool = False): return models, default_model -def load_pipeline(model_path): +def load_pipeline(model_path, do_unload: bool = True): global tokenizer, tokenizer_2, text_encoder, text_encoder_2, vae, unet, pipeline, loaded_pipeline if pipeline is not None and loaded_pipeline == model_path: return @@ -147,11 +147,11 @@ def load_pipeline(model_path): scheduler=None, # We completely give up diffusers sampling system and use A1111's method ) loaded_pipeline = model_path + if do_unload: + memory_management.unload_all_models([text_encoder, text_encoder_2, vae, unet]) - memory_management.unload_all_models([text_encoder, text_encoder_2, vae, unet]) - -def load_llm_model(model_name): +def load_llm_model(model_name, do_unload: bool = True): global llm_model, llm_tokenizer, llm_model_name if llm_model_name == model_name and llm_model: return @@ -173,8 +173,8 @@ def load_llm_model(model_name): token=HF_TOKEN ) llm_model_name = model_name - - memory_management.unload_all_models(llm_model) + if do_unload: + memory_management.unload_all_models(llm_model) @torch.inference_mode() @@ -219,7 +219,7 @@ def chat_fn(message: str, history: list, seed: int, temperature: float, top_p: f conversation.append({"role": "user", "content": message}) # Load the model if it is not loaded if not llm_model: - load_llm_model(llm_model_name) + load_llm_model(llm_model_name, False) memory_management.load_models_to_gpu(llm_model) input_ids = llm_tokenizer.apply_chat_template( @@ -286,7 +286,7 @@ def diffusion_fn(chatbot, canvas_outputs, num_samples, seed, image_width, image_ use_initial_latent = False eps = 0.05 # Load the model - load_pipeline(model_selection) + load_pipeline(model_selection, False) if not isinstance(pipeline, StableDiffusionXLOmostPipeline): raise ValueError("Pipeline is not StableDiffusionXLOmostPipeline") From d089d29f3e6b1068e75f3796393ea1d7ff0d8662 Mon Sep 17 00:00:00 2001 From: d8ahazard Date: Sun, 2 Jun 2024 20:01:47 -0500 Subject: [PATCH 10/11] Fix precision on 32-bit OSes --- gradio_app.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/gradio_app.py b/gradio_app.py index 9e390a6..1e61969 100644 --- a/gradio_app.py +++ b/gradio_app.py @@ -1,5 +1,6 @@ import argparse import os +import sys import tempfile import uuid from threading import Thread @@ -205,7 +206,10 @@ def resize_without_crop(image, target_width, target_height): def chat_fn(message: str, history: list, seed: int, temperature: float, top_p: float, max_new_tokens: int) -> str: global llm_model, llm_tokenizer, llm_model_name if seed == -1: - seed = np.random.randint(0, 2 ** 32 - 1) + if sys.maxsize > 2 ** 32: + seed = np.random.randint(0, 2 ** 32 - 1) + else: + seed = np.random.randint(0, 2 ** 31 - 1) np.random.seed(int(seed)) torch.manual_seed(int(seed)) @@ -298,7 +302,10 @@ def diffusion_fn(chatbot, canvas_outputs, num_samples, seed, image_width, image_ image_width, image_height = int(image_width // 64) * 64, int(image_height // 64) * 64 if seed == -1: - seed = np.random.randint(0, 2 ** 32 - 1) + if sys.maxsize > 2 ** 32: + seed = np.random.randint(0, 2 ** 32 - 1) + else: + seed = np.random.randint(0, 2 ** 31 - 1) rng = torch.Generator(device=memory_management.gpu).manual_seed(seed) memory_management.load_models_to_gpu([text_encoder, text_encoder_2]) From 506aa63e86cfe6d842ad2d7ac5b4059424060794 Mon Sep 17 00:00:00 2001 From: d8ahazard Date: Mon, 3 Jun 2024 09:11:26 -0500 Subject: [PATCH 11/11] Random_seed function. --- gradio_app.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/gradio_app.py b/gradio_app.py index 1e61969..d777ee5 100644 --- a/gradio_app.py +++ b/gradio_app.py @@ -202,14 +202,21 @@ def resize_without_crop(image, target_width, target_height): return np.array(resized_image) +def random_seed(): + if sys.maxsize > 2 ** 32: + try: + return np.random.randint(0, 2 ** 32 - 1) + except: + return np.random.randint(0, 2 ** 31 - 1) + else: + return np.random.randint(0, 2 ** 31 - 1) + + @torch.inference_mode() def chat_fn(message: str, history: list, seed: int, temperature: float, top_p: float, max_new_tokens: int) -> str: global llm_model, llm_tokenizer, llm_model_name if seed == -1: - if sys.maxsize > 2 ** 32: - seed = np.random.randint(0, 2 ** 32 - 1) - else: - seed = np.random.randint(0, 2 ** 31 - 1) + seed = random_seed() np.random.seed(int(seed)) torch.manual_seed(int(seed)) @@ -302,10 +309,7 @@ def diffusion_fn(chatbot, canvas_outputs, num_samples, seed, image_width, image_ image_width, image_height = int(image_width // 64) * 64, int(image_height // 64) * 64 if seed == -1: - if sys.maxsize > 2 ** 32: - seed = np.random.randint(0, 2 ** 32 - 1) - else: - seed = np.random.randint(0, 2 ** 31 - 1) + seed = random_seed() rng = torch.Generator(device=memory_management.gpu).manual_seed(seed) memory_management.load_models_to_gpu([text_encoder, text_encoder_2])