Closed
Description
Describe the bug
Goal
Build an image-generation service with StableDiffusionXLPipeline
that:
- Keeps ~50 LoRA adapters resident in GPU VRAM.
- For each request:
• activate ≤ 5 specific LoRAs viapipeline.set_adapters(...)
• run inference
• deactivate them (ready for the next request).
Issue
pipeline.set_adapters()
becomes progressively slower the more unique LoRAs have ever been loaded,
even though each call still enables only up to five adapters.
# LoRAs ever loaded | set_adapters() time (s) |
---|---|
3 | ~ 0.1031 |
6 | ~ 0.1843 |
9 | ~ 0.2614 |
12 | ~ 0.3522 |
45 | ~ 1.2470 |
57 | ~ 1.5435 |
What I’ve tried
- Load LoRAs from disk for every request ~ 0.8 s/LoRA, too slow.
- Keep LoRAs in RAM (
SpooledTemporaryFile
) +pipeline.delete_adapter()
– roughly as slow as (1). - Keep all 50 LoRAs on the GPU and just switch with
set_adapters()
– fastest so far, but still shows the O(N)-style growth above.
Question
Is this increasing latency expected?
Is there a recommended pattern for caching many LoRAs on the GPU and switching between small subsets without paying an O(total LoRAs) cost every time?
Any guidance (or confirmation it’s a current limitation) would be greatly appreciated!
Reproduction
Code
import os
import time
from typing import List
from pydantic import BaseModel
from diffusers import StableDiffusionXLPipeline, AutoencoderTiny
import torch
from diffusers.utils import logging
logging.disable_progress_bar()
logging.set_verbosity_error()
pipeline = None
class Lora(BaseModel):
name: str
strength: float
def timeit(func):
def wrapper(*args, **kwargs):
start = time.time()
result = func(*args, **kwargs)
end = time.time()
duration = end - start
print(f"{func.__name__} executed in {duration:.4f} seconds")
return result
return wrapper
@timeit
def load_model():
pipeline = StableDiffusionXLPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=torch.float16,
vae=AutoencoderTiny.from_pretrained(
'madebyollin/taesdxl',
use_safetensors=True,
torch_dtype=torch.float16,
)
).to("cuda")
pipeline.set_progress_bar_config(disable=True)
return pipeline
@timeit
def set_adapters(pipeline, adapter_names, adapter_weights):
pipeline.set_adapters(
adapter_names=adapter_names,
adapter_weights=adapter_weights,
)
@timeit
def fuse_lora(pipeline):
pipeline.fuse_lora()
@timeit
def inference(pipeline, req, generator=None):
return pipeline(
prompt=req.prompt,
negative_prompt=req.negative_prompt,
width=req.width,
height=req.height,
num_inference_steps=req.steps,
guidance_scale=req.guidance_scale,
generator=generator,
).images
def apply_loras(pipeline, loras: list[Lora]) -> str:
if not loras or len(loras) == 0:
pipeline.disable_lora()
return
pipeline.enable_lora()
for lora in loras:
try:
pipeline.load_lora_weights(
"ostris/super-cereal-sdxl-lora",
weight_name="cereal_box_sdxl_v1.safetensors",
adapter_name=lora.name,
token=os.getenv("HUGGINGFACE_HUB_TOKEN", None),
)
except ValueError:
continue # LoRA already loaded, skip
except Exception as e:
print(f"Failed to load LoRA {lora}: {e}")
continue
set_adapters(
pipeline,
adapter_names=[lora.name for lora in loras],
adapter_weights=[lora.strength for lora in loras],
)
fuse_lora(pipeline)
return
def generate_images(req, pipeline):
generator = torch.Generator(device="cuda").manual_seed(42)
apply_loras(pipeline, req.loras)
images = inference(
pipeline,
req,
generator=generator,
)
pipeline.unfuse_lora()
return images
class GenerationRequest(BaseModel):
prompt: str
loras: List[Lora] = []
negative_prompt: str = ""
width: int = 512
height: int = 512
steps: int = 30
guidance_scale: float = 7
def test_lora_group(pipeline, lora_group: List[Lora], group_number: int):
test_req = GenerationRequest(
prompt="a simple test image",
loras=[Lora(name=lora_name, strength=0.8) for lora_name in lora_group],
width=256,
height=256,
steps=10,
)
try:
generate_images(test_req, pipeline)
return True, lora_group
except Exception as e:
return False, lora_group
def chunk_loras(lora_list: List[Lora], chunk_size: int = 5) -> List[List[Lora]]:
"""Split LoRAs into groups of specified size"""
chunks = []
for i in range(0, len(lora_list), chunk_size):
chunks.append(lora_list[i:i + chunk_size])
return chunks
def test():
global pipeline
# Load the pipeline if not already loaded
if not pipeline:
print("Loading pipeline for LoRA tests...")
pipeline = load_model()
# Split LoRAs into groups of 3
all_loras = [f"cereal_{i}" for i in range(1, 50)] # Example LoRAs
lora_groups = chunk_loras(all_loras, 5)
successful_groups = 0
failed_loras = []
print(f"Testing {len(all_loras)} LoRAs in {len(lora_groups)} groups of up to 5...")
for i, lora_group in enumerate(lora_groups, 1):
print(f"Testing group {i}/{len(lora_groups)}: {lora_group}")
success, group = test_lora_group(pipeline, lora_group, i)
if success:
successful_groups += 1
else:
failed_loras.extend(group)
if failed_loras:
failed_names = [lora.value for lora in failed_loras]
print(f"✗ Failed LoRAs: {', '.join(failed_names)}")
def test_all_loras():
test()
# run second test when all loras are loaded
test()
if __name__ == "__main__":
test_all_loras()
Logs
Above example gave me following results on RTX5090
Loading pipeline for LoRA tests...
load_model executed in 3.3787 seconds
Testing 49 LoRAs in 10 groups of up to 5...
Testing group 1/10: ['cereal_1', 'cereal_2', 'cereal_3', 'cereal_4', 'cereal_5']
/aibabe-diffusers/worker-env/lib/python3.11/site-packages/peft/tuners/tuners_utils.py:167: UserWarning: Already found a `peft_config` attribute in the model. This will lead to having multiple adapters in the model. Make sure to know what you are doing!
warnings.warn(
set_adapters executed in 0.1883 seconds
fuse_lora executed in 0.1787 seconds
inference executed in 0.4306 seconds
Testing group 2/10: ['cereal_6', 'cereal_7', 'cereal_8', 'cereal_9', 'cereal_10']
set_adapters executed in 0.6675 seconds
fuse_lora executed in 0.1478 seconds
inference executed in 0.2558 seconds
Testing group 3/10: ['cereal_11', 'cereal_12', 'cereal_13', 'cereal_14', 'cereal_15']
set_adapters executed in 0.6900 seconds
fuse_lora executed in 0.2076 seconds
inference executed in 0.3285 seconds
Testing group 4/10: ['cereal_16', 'cereal_17', 'cereal_18', 'cereal_19', 'cereal_20']
set_adapters executed in 0.7407 seconds
fuse_lora executed in 0.1937 seconds
inference executed in 0.2802 seconds
Testing group 5/10: ['cereal_21', 'cereal_22', 'cereal_23', 'cereal_24', 'cereal_25']
set_adapters executed in 0.9166 seconds
fuse_lora executed in 0.2169 seconds
inference executed in 0.2973 seconds
Testing group 6/10: ['cereal_26', 'cereal_27', 'cereal_28', 'cereal_29', 'cereal_30']
set_adapters executed in 1.7071 seconds
fuse_lora executed in 0.2461 seconds
inference executed in 0.3163 seconds
Testing group 7/10: ['cereal_31', 'cereal_32', 'cereal_33', 'cereal_34', 'cereal_35']
set_adapters executed in 1.8331 seconds
fuse_lora executed in 0.2640 seconds
inference executed in 0.3335 seconds
Testing group 8/10: ['cereal_36', 'cereal_37', 'cereal_38', 'cereal_39', 'cereal_40']
set_adapters executed in 2.0735 seconds
fuse_lora executed in 0.2882 seconds
inference executed in 0.3425 seconds
Testing group 9/10: ['cereal_41', 'cereal_42', 'cereal_43', 'cereal_44', 'cereal_45']
set_adapters executed in 2.3036 seconds
fuse_lora executed in 0.3111 seconds
inference executed in 0.3606 seconds
Testing group 10/10: ['cereal_46', 'cereal_47', 'cereal_48', 'cereal_49']
set_adapters executed in 2.2296 seconds
fuse_lora executed in 0.3228 seconds
inference executed in 0.3789 seconds
Testing 49 LoRAs in 10 groups of up to 5...
Testing group 1/10: ['cereal_1', 'cereal_2', 'cereal_3', 'cereal_4', 'cereal_5']
set_adapters executed in 1.7317 seconds
fuse_lora executed in 0.3396 seconds
inference executed in 0.3807 seconds
Testing group 2/10: ['cereal_6', 'cereal_7', 'cereal_8', 'cereal_9', 'cereal_10']
set_adapters executed in 2.5145 seconds
fuse_lora executed in 0.3403 seconds
inference executed in 0.3808 seconds
Testing group 3/10: ['cereal_11', 'cereal_12', 'cereal_13', 'cereal_14', 'cereal_15']
set_adapters executed in 1.7341 seconds
fuse_lora executed in 0.3392 seconds
inference executed in 0.3820 seconds
Testing group 4/10: ['cereal_16', 'cereal_17', 'cereal_18', 'cereal_19', 'cereal_20']
set_adapters executed in 2.5293 seconds
fuse_lora executed in 0.3392 seconds
inference executed in 0.3761 seconds
Testing group 5/10: ['cereal_21', 'cereal_22', 'cereal_23', 'cereal_24', 'cereal_25']
set_adapters executed in 1.7418 seconds
fuse_lora executed in 0.3438 seconds
inference executed in 0.3790 seconds
Testing group 6/10: ['cereal_26', 'cereal_27', 'cereal_28', 'cereal_29', 'cereal_30']
set_adapters executed in 2.5821 seconds
fuse_lora executed in 0.3522 seconds
inference executed in 0.3876 seconds
Testing group 7/10: ['cereal_31', 'cereal_32', 'cereal_33', 'cereal_34', 'cereal_35']
set_adapters executed in 2.5645 seconds
fuse_lora executed in 0.3373 seconds
inference executed in 0.3804 seconds
Testing group 8/10: ['cereal_36', 'cereal_37', 'cereal_38', 'cereal_39', 'cereal_40']
set_adapters executed in 1.8011 seconds
fuse_lora executed in 0.3375 seconds
inference executed in 0.3843 seconds
Testing group 9/10: ['cereal_41', 'cereal_42', 'cereal_43', 'cereal_44', 'cereal_45']
set_adapters executed in 2.5058 seconds
fuse_lora executed in 0.3371 seconds
inference executed in 0.3810 seconds
Testing group 10/10: ['cereal_46', 'cereal_47', 'cereal_48', 'cereal_49']
set_adapters executed in 1.4389 seconds
fuse_lora executed in 0.3215 seconds
inference executed in 0.3849 seconds
System Info
- 🤗 Diffusers version: 0.34.0
- Platform: Linux-6.8.0-58-generic-x86_64-with-glibc2.35
- Running on Google Colab?: No
- Python version: 3.11.11
- PyTorch version (GPU?): 2.7.1+cu128 (True)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Huggingface_hub version: 0.33.1
- Transformers version: 4.53.0
- Accelerate version: 1.8.1
- PEFT version: 0.15.2
- Bitsandbytes version: not installed
- Safetensors version: 0.5.3
- xFormers version: not installed
- Accelerator: NVIDIA GeForce RTX 5090, 32607 MiB
- Using GPU in script?: yes
- Using distributed or parallel set-up in script?: no