Skip to content

set_adapters performance degrades with the number of inactive adapters #11816

Closed
@hrazjan

Description

@hrazjan

Describe the bug

Goal

Build an image-generation service with StableDiffusionXLPipeline that:

  1. Keeps ~50 LoRA adapters resident in GPU VRAM.
  2. For each request:
    • activate ≤ 5 specific LoRAs via pipeline.set_adapters(...)
    • run inference
    • deactivate them (ready for the next request).

Issue

pipeline.set_adapters() becomes progressively slower the more unique LoRAs have ever been loaded,
even though each call still enables only up to five adapters.

# LoRAs ever loaded set_adapters() time (s)
3 ~ 0.1031
6 ~ 0.1843
9 ~ 0.2614
12 ~ 0.3522
45 ~ 1.2470
57 ~ 1.5435

What I’ve tried

  1. Load LoRAs from disk for every request ~ 0.8 s/LoRA, too slow.
  2. Keep LoRAs in RAM (SpooledTemporaryFile) + pipeline.delete_adapter() – roughly as slow as (1).
  3. Keep all 50 LoRAs on the GPU and just switch with set_adapters() – fastest so far, but still shows the O(N)-style growth above.

Question

Is this increasing latency expected?
Is there a recommended pattern for caching many LoRAs on the GPU and switching between small subsets without paying an O(total LoRAs) cost every time?

Any guidance (or confirmation it’s a current limitation) would be greatly appreciated!

Reproduction

Code
import os
import time
from typing import List
from pydantic import BaseModel
from diffusers import StableDiffusionXLPipeline, AutoencoderTiny
import torch
from diffusers.utils import logging
logging.disable_progress_bar()
logging.set_verbosity_error() 

pipeline = None

class Lora(BaseModel):
    name: str
    strength: float

def timeit(func):
    def wrapper(*args, **kwargs):
        start = time.time()
        result = func(*args, **kwargs)
        end = time.time()
        duration = end - start
        print(f"{func.__name__} executed in {duration:.4f} seconds")
        return result
    return wrapper

@timeit
def load_model():
    pipeline = StableDiffusionXLPipeline.from_pretrained(
        "stabilityai/stable-diffusion-xl-base-1.0",
        torch_dtype=torch.float16,
        vae=AutoencoderTiny.from_pretrained(
            'madebyollin/taesdxl',
            use_safetensors=True,
            torch_dtype=torch.float16,
        )
    ).to("cuda")
    pipeline.set_progress_bar_config(disable=True)

    return pipeline

@timeit
def set_adapters(pipeline, adapter_names, adapter_weights):
    pipeline.set_adapters(
        adapter_names=adapter_names,
        adapter_weights=adapter_weights,
    )

@timeit
def fuse_lora(pipeline):
    pipeline.fuse_lora()

@timeit
def inference(pipeline, req, generator=None):
    return pipeline(
        prompt=req.prompt,
        negative_prompt=req.negative_prompt,
        width=req.width,
        height=req.height,
        num_inference_steps=req.steps,
        guidance_scale=req.guidance_scale,
        generator=generator,
    ).images

def apply_loras(pipeline, loras: list[Lora]) -> str:
    if not loras or len(loras) == 0:
        pipeline.disable_lora()
        return
    
    pipeline.enable_lora()
    for lora in loras:
        try:
            pipeline.load_lora_weights(
                "ostris/super-cereal-sdxl-lora",
                weight_name="cereal_box_sdxl_v1.safetensors",
                adapter_name=lora.name,
                token=os.getenv("HUGGINGFACE_HUB_TOKEN", None),
            )
        except ValueError:
            continue # LoRA already loaded, skip
        except Exception as e:
            print(f"Failed to load LoRA {lora}: {e}")
            continue
    set_adapters(
        pipeline,
        adapter_names=[lora.name for lora in loras],
        adapter_weights=[lora.strength for lora in loras],
    )
    fuse_lora(pipeline)
    

    return

def generate_images(req, pipeline):
    generator = torch.Generator(device="cuda").manual_seed(42)

    apply_loras(pipeline, req.loras)

    images = inference(
        pipeline,
        req,
        generator=generator,
    )

    pipeline.unfuse_lora()
    
    return images

class GenerationRequest(BaseModel):
    prompt: str
    loras: List[Lora] = []
    negative_prompt: str = ""
    width: int = 512
    height: int = 512
    steps: int = 30
    guidance_scale: float = 7

def test_lora_group(pipeline, lora_group: List[Lora], group_number: int):    
    test_req = GenerationRequest(
        prompt="a simple test image",
        loras=[Lora(name=lora_name, strength=0.8) for lora_name in lora_group],
        width=256,
        height=256,
        steps=10,
    )
    
    try:
        generate_images(test_req, pipeline)
        return True, lora_group
    except Exception as e:
        return False, lora_group

def chunk_loras(lora_list: List[Lora], chunk_size: int = 5) -> List[List[Lora]]:
    """Split LoRAs into groups of specified size"""
    chunks = []
    for i in range(0, len(lora_list), chunk_size):
        chunks.append(lora_list[i:i + chunk_size])
    return chunks

def test():
    global pipeline
    # Load the pipeline if not already loaded
    if not pipeline:
        print("Loading pipeline for LoRA tests...")
        pipeline = load_model()
    
    # Split LoRAs into groups of 3
    all_loras = [f"cereal_{i}" for i in range(1, 50)]  # Example LoRAs
    lora_groups = chunk_loras(all_loras, 5)
    
    successful_groups = 0
    failed_loras = []
    
    print(f"Testing {len(all_loras)} LoRAs in {len(lora_groups)} groups of up to 5...")
    
    for i, lora_group in enumerate(lora_groups, 1):
        print(f"Testing group {i}/{len(lora_groups)}: {lora_group}")
        success, group = test_lora_group(pipeline, lora_group, i)
        if success:
            successful_groups += 1
        else:
            failed_loras.extend(group)
    
    if failed_loras:
        failed_names = [lora.value for lora in failed_loras]
        print(f"✗ Failed LoRAs: {', '.join(failed_names)}")

def test_all_loras():
    test()
    # run second test when all loras are loaded
    test()
    

if __name__ == "__main__":
    test_all_loras()

Logs

Above example gave me following results on RTX5090

Loading pipeline for LoRA tests...
load_model executed in 3.3787 seconds
Testing 49 LoRAs in 10 groups of up to 5...
Testing group 1/10: ['cereal_1', 'cereal_2', 'cereal_3', 'cereal_4', 'cereal_5']
/aibabe-diffusers/worker-env/lib/python3.11/site-packages/peft/tuners/tuners_utils.py:167: UserWarning: Already found a `peft_config` attribute in the model. This will lead to having multiple adapters in the model. Make sure to know what you are doing!
  warnings.warn(
set_adapters executed in 0.1883 seconds
fuse_lora executed in 0.1787 seconds
inference executed in 0.4306 seconds
Testing group 2/10: ['cereal_6', 'cereal_7', 'cereal_8', 'cereal_9', 'cereal_10']
set_adapters executed in 0.6675 seconds
fuse_lora executed in 0.1478 seconds
inference executed in 0.2558 seconds
Testing group 3/10: ['cereal_11', 'cereal_12', 'cereal_13', 'cereal_14', 'cereal_15']
set_adapters executed in 0.6900 seconds
fuse_lora executed in 0.2076 seconds
inference executed in 0.3285 seconds
Testing group 4/10: ['cereal_16', 'cereal_17', 'cereal_18', 'cereal_19', 'cereal_20']
set_adapters executed in 0.7407 seconds
fuse_lora executed in 0.1937 seconds
inference executed in 0.2802 seconds
Testing group 5/10: ['cereal_21', 'cereal_22', 'cereal_23', 'cereal_24', 'cereal_25']
set_adapters executed in 0.9166 seconds
fuse_lora executed in 0.2169 seconds
inference executed in 0.2973 seconds
Testing group 6/10: ['cereal_26', 'cereal_27', 'cereal_28', 'cereal_29', 'cereal_30']
set_adapters executed in 1.7071 seconds
fuse_lora executed in 0.2461 seconds
inference executed in 0.3163 seconds
Testing group 7/10: ['cereal_31', 'cereal_32', 'cereal_33', 'cereal_34', 'cereal_35']
set_adapters executed in 1.8331 seconds
fuse_lora executed in 0.2640 seconds
inference executed in 0.3335 seconds
Testing group 8/10: ['cereal_36', 'cereal_37', 'cereal_38', 'cereal_39', 'cereal_40']
set_adapters executed in 2.0735 seconds
fuse_lora executed in 0.2882 seconds
inference executed in 0.3425 seconds
Testing group 9/10: ['cereal_41', 'cereal_42', 'cereal_43', 'cereal_44', 'cereal_45']
set_adapters executed in 2.3036 seconds
fuse_lora executed in 0.3111 seconds
inference executed in 0.3606 seconds
Testing group 10/10: ['cereal_46', 'cereal_47', 'cereal_48', 'cereal_49']
set_adapters executed in 2.2296 seconds
fuse_lora executed in 0.3228 seconds
inference executed in 0.3789 seconds
Testing 49 LoRAs in 10 groups of up to 5...
Testing group 1/10: ['cereal_1', 'cereal_2', 'cereal_3', 'cereal_4', 'cereal_5']
set_adapters executed in 1.7317 seconds
fuse_lora executed in 0.3396 seconds
inference executed in 0.3807 seconds
Testing group 2/10: ['cereal_6', 'cereal_7', 'cereal_8', 'cereal_9', 'cereal_10']
set_adapters executed in 2.5145 seconds
fuse_lora executed in 0.3403 seconds
inference executed in 0.3808 seconds
Testing group 3/10: ['cereal_11', 'cereal_12', 'cereal_13', 'cereal_14', 'cereal_15']
set_adapters executed in 1.7341 seconds
fuse_lora executed in 0.3392 seconds
inference executed in 0.3820 seconds
Testing group 4/10: ['cereal_16', 'cereal_17', 'cereal_18', 'cereal_19', 'cereal_20']
set_adapters executed in 2.5293 seconds
fuse_lora executed in 0.3392 seconds
inference executed in 0.3761 seconds
Testing group 5/10: ['cereal_21', 'cereal_22', 'cereal_23', 'cereal_24', 'cereal_25']
set_adapters executed in 1.7418 seconds
fuse_lora executed in 0.3438 seconds
inference executed in 0.3790 seconds
Testing group 6/10: ['cereal_26', 'cereal_27', 'cereal_28', 'cereal_29', 'cereal_30']
set_adapters executed in 2.5821 seconds
fuse_lora executed in 0.3522 seconds
inference executed in 0.3876 seconds
Testing group 7/10: ['cereal_31', 'cereal_32', 'cereal_33', 'cereal_34', 'cereal_35']
set_adapters executed in 2.5645 seconds
fuse_lora executed in 0.3373 seconds
inference executed in 0.3804 seconds
Testing group 8/10: ['cereal_36', 'cereal_37', 'cereal_38', 'cereal_39', 'cereal_40']
set_adapters executed in 1.8011 seconds
fuse_lora executed in 0.3375 seconds
inference executed in 0.3843 seconds
Testing group 9/10: ['cereal_41', 'cereal_42', 'cereal_43', 'cereal_44', 'cereal_45']
set_adapters executed in 2.5058 seconds
fuse_lora executed in 0.3371 seconds
inference executed in 0.3810 seconds
Testing group 10/10: ['cereal_46', 'cereal_47', 'cereal_48', 'cereal_49']
set_adapters executed in 1.4389 seconds
fuse_lora executed in 0.3215 seconds
inference executed in 0.3849 seconds

System Info

  • 🤗 Diffusers version: 0.34.0
  • Platform: Linux-6.8.0-58-generic-x86_64-with-glibc2.35
  • Running on Google Colab?: No
  • Python version: 3.11.11
  • PyTorch version (GPU?): 2.7.1+cu128 (True)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Huggingface_hub version: 0.33.1
  • Transformers version: 4.53.0
  • Accelerate version: 1.8.1
  • PEFT version: 0.15.2
  • Bitsandbytes version: not installed
  • Safetensors version: 0.5.3
  • xFormers version: not installed
  • Accelerator: NVIDIA GeForce RTX 5090, 32607 MiB
  • Using GPU in script?: yes
  • Using distributed or parallel set-up in script?: no

Who can help?

@sayakpaul

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions