Skip to content

Commit

Permalink
FEAT: support CogVideoX-5b (#2197)
Browse files Browse the repository at this point in the history
  • Loading branch information
qinxuye authored Aug 30, 2024
1 parent 2f8c3ce commit f3d510e
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 41 deletions.
4 changes: 3 additions & 1 deletion xinference/model/video/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import logging
import os
from collections import defaultdict
from typing import Dict, List, Literal, Optional, Tuple
from typing import Any, Dict, List, Literal, Optional, Tuple

from ...constants import XINFERENCE_CACHE_DIR
from ..core import CacheableModelSpec, ModelDescription
Expand Down Expand Up @@ -44,6 +44,8 @@ class VideoModelFamilyV1(CacheableModelSpec):
model_revision: str
model_hub: str = "huggingface"
model_ability: Optional[List[str]]
default_model_config: Optional[Dict[str, Any]]
default_generate_config: Optional[Dict[str, Any]]


class VideoModelDescription(ModelDescription):
Expand Down
79 changes: 41 additions & 38 deletions xinference/model/video/diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import base64
import logging
import os
import sys
import time
import uuid
from concurrent.futures import ThreadPoolExecutor
Expand All @@ -24,10 +23,9 @@

import numpy as np
import PIL.Image
import torch

from ...constants import XINFERENCE_VIDEO_DIR
from ...device_utils import move_model_to_available_device
from ...device_utils import gpu_count, move_model_to_available_device
from ...types import Video, VideoList

if TYPE_CHECKING:
Expand Down Expand Up @@ -76,41 +74,58 @@ def model_spec(self):
def load(self):
import torch

torch_dtype = self._kwargs.get("torch_dtype")
if sys.platform != "darwin" and torch_dtype is None:
# The following params crashes on Mac M2
self._kwargs["torch_dtype"] = torch.float16
self._kwargs["variant"] = "fp16"
self._kwargs["use_safetensors"] = True
kwargs = self._model_spec.default_model_config.copy()
kwargs.update(self._kwargs)

scheduler_cls_name = kwargs.pop("scheduler", None)

torch_dtype = kwargs.get("torch_dtype")
if isinstance(torch_dtype, str):
self._kwargs["torch_dtype"] = getattr(torch, torch_dtype)
kwargs["torch_dtype"] = getattr(torch, torch_dtype)
logger.debug("Loading video model with kwargs: %s", kwargs)

if self._model_spec.model_family == "CogVideoX":
import diffusers
from diffusers import CogVideoXPipeline

self._model = CogVideoXPipeline.from_pretrained(
self._model_path, **self._kwargs
pipeline = self._model = CogVideoXPipeline.from_pretrained(
self._model_path, **kwargs
)
else:
raise Exception(
f"Unsupported model family: {self._model_spec.model_family}"
)

if self._kwargs.get("cpu_offload", False):
if scheduler_cls_name:
logger.debug("Using scheduler: %s", scheduler_cls_name)
pipeline.scheduler = getattr(diffusers, scheduler_cls_name).from_config(
pipeline.scheduler.config, timestep_spacing="trailing"
)
if kwargs.get("compile_graph", False):
pipeline.transformer = torch.compile(
pipeline.transformer, mode="max-autotune", fullgraph=True
)
if kwargs.get("cpu_offload", False):
logger.debug("CPU offloading model")
self._model.enable_model_cpu_offload()
elif not self._kwargs.get("device_map"):
pipeline.enable_model_cpu_offload()
if kwargs.get("sequential_cpu_offload", True):
pipeline.enable_sequential_cpu_offload()
pipeline.vae.enable_slicing()
pipeline.vae.enable_tiling()
elif not kwargs.get("device_map"):
logger.debug("Loading model to available device")
self._model = move_model_to_available_device(self._model)
if gpu_count() > 1:
kwargs["device_map"] = "balanced"
else:
pipeline = move_model_to_available_device(self._model)
# Recommended if your computer has < 64 GB of RAM
self._model.enable_attention_slicing()
pipeline.enable_attention_slicing()

def text_to_video(
self,
prompt: str,
n: int = 1,
num_inference_steps: int = 50,
guidance_scale: int = 6,
response_format: str = "b64_json",
**kwargs,
) -> VideoList:
Expand All @@ -121,31 +136,19 @@ def text_to_video(
# from diffusers.utils import export_to_video
from ...device_utils import empty_cache

assert self._model is not None
assert callable(self._model)
generate_kwargs = self._model_spec.default_generate_config.copy()
generate_kwargs.update(kwargs)
generate_kwargs["num_videos_per_prompt"] = n
logger.debug(
"diffusers text_to_video args: %s",
kwargs,
generate_kwargs,
)
assert self._model is not None
if self._kwargs.get("cpu_offload"):
# if enabled cpu offload,
# the model.device would be CPU
device = "cuda"
else:
device = self._model.device
prompt_embeds, _ = self._model.encode_prompt(
prompt=prompt,
do_classifier_free_guidance=True,
num_videos_per_prompt=n,
max_sequence_length=226,
device=device,
dtype=torch.float16,
)
assert callable(self._model)
output = self._model(
prompt=prompt,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
prompt_embeds=prompt_embeds,
**kwargs,
**generate_kwargs,
)

# clean cache
Expand Down
25 changes: 24 additions & 1 deletion xinference/model/video/model_spec.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,29 @@
"model_revision": "4bbfb1de622b80bc1b77b6e9aced75f816be0e38",
"model_ability": [
"text2video"
]
],
"default_model_config": {
"scheduler": "CogVideoXDDIMScheduler",
"torch_dtype": "float16"
},
"default_generate_config": {
"guidance_scale": 6
}
},
{
"model_name": "CogVideoX-5b",
"model_family": "CogVideoX",
"model_id": "THUDM/CogVideoX-5b",
"model_revision": "8d6ea3f817438460b25595a120f109b88d5fdfad",
"model_ability": [
"text2video"
],
"default_model_config": {
"scheduler": "CogVideoXDPMScheduler",
"torch_dtype": "bfloat16"
},
"default_generate_config": {
"guidance_scale": 7
}
}
]
26 changes: 25 additions & 1 deletion xinference/model/video/model_spec_modelscope.json
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,30 @@
"model_revision": "master",
"model_ability": [
"text2video"
]
],
"default_model_config": {
"scheduler": "CogVideoXDDIMScheduler",
"torch_dtype": "float16"
},
"default_generate_config": {
"guidance_scale": 6
}
},
{
"model_name": "CogVideoX-5b",
"model_family": "CogVideoX",
"model_hub": "modelscope",
"model_id": "ZhipuAI/CogVideoX-5b",
"model_revision": "master",
"model_ability": [
"text2video"
],
"default_model_config": {
"scheduler": "CogVideoXDPMScheduler",
"torch_dtype": "bfloat16"
},
"default_generate_config": {
"guidance_scale": 7
}
}
]

0 comments on commit f3d510e

Please sign in to comment.