Skip to content

Commit

Permalink
[model] support vision language model llava. (vectorch-ai#178)
Browse files Browse the repository at this point in the history
  • Loading branch information
liutongxuan authored Jun 28, 2024
1 parent e087247 commit 437be3f
Show file tree
Hide file tree
Showing 32 changed files with 2,933 additions and 3 deletions.
28 changes: 28 additions & 0 deletions python/tests/llava_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
#!/usr/bin/env python3

import torch
from scalellm import VLM, SamplingParameter, StoppingCriteria

def test_pixel_value_llava_generate():
vlm = VLM(
model="llava-hf/llava-1.5-7b-hf",
image_input_type="pixel_values",
image_token_id=32000,
image_input_shape="1,3,336,336",
image_feature_size=576,
)

prompt = "<image>" * 576 + (
"\nUSER: What is the content of this image?\nASSISTANT:")

# This should be provided by another online or offline component.
image = torch.load("images/stop_sign_pixel_values.pt")

output = vlm.generate(images, prompt)
print(o.outputs[0].text)

def main():
test_pixel_value_llava_generate()

if __name__ == "__main__":
main()
2 changes: 2 additions & 0 deletions scalellm/_C/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ from scalellm._C.llm_handler import LLMHandler, Message, Priority
from scalellm._C.output import (LogProb, LogProbData, RequestOutput,
SequenceOutput, Status, StatusCode, Usage)
from scalellm._C.sampling_params import SamplingParams
from scalellm._C.vlm_handler import VLMHandler

# Defined in scalellm/csrc/module.cpp
def get_metrics() -> str: ...
Expand All @@ -18,5 +19,6 @@ __all__ = [
"StatusCode",
"Usage",
"LLMHandler",
"VLMHandler",
"get_metrics",
]
47 changes: 47 additions & 0 deletions scalellm/_C/vlm_handler.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from typing import Callable, List, Optional

import torch

from scalellm._C.llm_handler import Future, Priority
from scalellm._C.output import RequestOutput
from scalellm._C.sampling_params import SamplingParams

class VLMHandler:
class Options:
def __init__(self) -> None: ...
def __repr__(self) -> str: ...
model_path: str
devices: Optional[str]
block_size: int
max_cache_size: int
max_memory_utilization: float
enable_prefix_cache: bool
enable_cuda_graph: bool
cuda_graph_max_seq_len: int
cuda_graph_batch_sizes: Optional[List[int]]
max_tokens_per_batch: int
max_seqs_per_batch: int
num_handling_threads: int
image_input_type: str
image_token_id: int
image_input_shape: str
image_feature_size: int

def __init__(self, options: Options) -> None: ...
def __repr__(self) -> str: ...
def schedule_async(
self,
image: torch.Tensor,
prompt: str,
sp: SamplingParams,
priority: Priority,
stream: bool,
callback: Callable[[RequestOutput], bool],
) -> Future: ...
def start(self) -> None: ...
def stop(self) -> None: ...
def run_until_complete(self) -> None: ...
def reset(self) -> None: ...
# helper functions
def encode(self, text: str) -> List[int]: ...
def decode(self, tokens: List[int], skip_special_tokens: bool) -> str: ...
3 changes: 2 additions & 1 deletion scalellm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from scalellm._C import (LLMHandler, LogProb, LogProbData, Message, Priority,
RequestOutput, SamplingParams, SequenceOutput, Status,
StatusCode, Usage, get_metrics)
StatusCode, Usage, VLMHandler, get_metrics)
from scalellm.errors import ValidationError
from scalellm.llm import LLM
from scalellm.llm_engine import AsyncLLMEngine, OutputAsyncStream, OutputStream
Expand All @@ -34,5 +34,6 @@
"StatusCode",
"Usage",
"LLMHandler",
"VLMHandler",
"get_metrics",
]
4 changes: 3 additions & 1 deletion scalellm/csrc/module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ namespace py = pybind11;
extern void init_sampling_params(py::module_& m);
extern void init_output(py::module_& m);
extern void init_llm_handler(py::module_& m);
extern void init_vlm_handler(py::module_& m);

// NOLINTNEXTLINE
static std::string get_metrics() { return Metrics::Instance().GetString(); }
Expand All @@ -26,6 +27,7 @@ PYBIND11_MODULE(PY_MODULE_NAME, m) {
init_sampling_params(m);
init_output(m);
init_llm_handler(m);
init_vlm_handler(m);
}

} // namespace llm::csrc
} // namespace llm::csrc
115 changes: 115 additions & 0 deletions scalellm/csrc/vlm_handler.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
#include "handlers/vlm_handler.h"

#include <pybind11/functional.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <pybind11/stl_bind.h>

namespace llm::csrc {
namespace py = pybind11;
using namespace pybind11::literals;

void init_vlm_handler(py::module_& m) {
py::enum_<Priority>(m, "Priority")
.value("DEFAULT", Priority::NORMAL)
.value("LOW", Priority::LOW)
.value("NORMAL", Priority::NORMAL)
.value("HIGH", Priority::HIGH)
.export_values();

py::class_<std::future<bool>>(m, "Future")
.def("wait",
&std::future<bool>::wait,
py::call_guard<py::gil_scoped_release>())
.def("get",
&std::future<bool>::get,
py::call_guard<py::gil_scoped_release>());

auto vlm_handler =
py::class_<VLMHandler>(m, "VLMHandler")
.def(py::init<const VLMHandler::Options&>(), py::arg("options"))
.def("schedule_async",
&VLMHandler::schedule_async,
py::call_guard<py::gil_scoped_release>())
.def("start",
&VLMHandler::start,
py::call_guard<py::gil_scoped_release>())
.def("stop",
&VLMHandler::stop,
py::call_guard<py::gil_scoped_release>())
.def("run_until_complete",
&VLMHandler::run_until_complete,
py::call_guard<py::gil_scoped_release>())
.def("encode",
&VLMHandler::encode,
py::call_guard<py::gil_scoped_release>())
.def("decode",
&VLMHandler::decode,
py::call_guard<py::gil_scoped_release>())
.def("reset",
&VLMHandler::reset,
py::call_guard<py::gil_scoped_release>())
.def("__repr__", [](const VLMHandler& self) {
return "VLMHandler({})"_s.format(self.options());
});

// VLMHandler::Options
py::class_<VLMHandler::Options>(vlm_handler, "Options")
.def(py::init())
.def_readwrite("model_path", &VLMHandler::Options::model_path_)
.def_readwrite("devices", &VLMHandler::Options::devices_)
.def_readwrite("block_size", &VLMHandler::Options::block_size_)
.def_readwrite("max_cache_size", &VLMHandler::Options::max_cache_size_)
.def_readwrite("max_memory_utilization",
&VLMHandler::Options::max_memory_utilization_)
.def_readwrite("enable_prefix_cache",
&VLMHandler::Options::enable_prefix_cache_)
.def_readwrite("enable_cuda_graph",
&VLMHandler::Options::enable_cuda_graph_)
.def_readwrite("cuda_graph_max_seq_len",
&VLMHandler::Options::cuda_graph_max_seq_len_)
.def_readwrite("cuda_graph_batch_sizes",
&VLMHandler::Options::cuda_graph_batch_sizes_)
.def_readwrite("max_tokens_per_batch",
&VLMHandler::Options::max_tokens_per_batch_)
.def_readwrite("max_seqs_per_batch",
&VLMHandler::Options::max_seqs_per_batch_)
.def_readwrite("num_handling_threads",
&VLMHandler::Options::num_handling_threads_)
.def_readwrite("image_input_type",
&VLMHandler::Options::image_input_type_)
.def_readwrite("image_token_id", &VLMHandler::Options::image_token_id_)
.def_readwrite("image_input_shape",
&VLMHandler::Options::image_input_shape_)
.def_readwrite("image_feature_size",
&VLMHandler::Options::image_feature_size_)
.def("__repr__", [](const VLMHandler::Options& self) {
return "Options(model_path={}, devices={}, "
"block_size={}, max_cache_size={}, "
"max_memory_utilization={}, enable_prefix_cache={}, "
"enable_cuda_graph={}, cuda_graph_max_seq_len={}, "
"cuda_graph_batch_sizes={}, "
"max_tokens_per_batch={}, max_seqs_per_batch={}, "
"num_handling_threads={}, "
"image_input_type={}, image_token_id={},
"image_input_shape={}, image_feature_size={})"_s.format(
self.model_path_,
self.devices_,
self.block_size_,
self.max_cache_size_,
self.max_memory_utilization_,
self.enable_prefix_cache_,
self.enable_cuda_graph_,
self.cuda_graph_max_seq_len_,
self.cuda_graph_batch_sizes_,
self.max_tokens_per_batch_,
self.max_seqs_per_batch_,
self.num_handling_threads_,
self.image_input_type_,
self.image_token_id_,
self.image_input_shape_,
self.image_feature_size_);
});
}
} // namespace llm::csrc
127 changes: 127 additions & 0 deletions scalellm/vlm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
import os
from typing import List, Optional

import torch

from scalellm._C import Priority, RequestOutput, SamplingParams, VLMHandler
from scalellm.downloader import download_hf_model
from scalellm.errors import ValidationError


class VLM:
def __init__(
self,
model: str,
revision: Optional[str] = None,
allow_patterns: Optional[str] = None,
cache_dir: Optional[str] = None,
convert_to_safetensors: bool = False,
devices: Optional[str] = None,
block_size: int = 16,
max_cache_size: int = 20 * 1024 * 1024 * 1024,
max_memory_utilization: float = 0.9,
enable_prefix_cache: bool = True,
enable_cuda_graph: bool = True,
cuda_graph_max_seq_len: int = 2048,
cuda_graph_batch_sizes: Optional[List[int]] = None,
max_tokens_per_batch: int = 409600, # a big number to disable chunked prefill
max_seqs_per_batch: int = 2048, # a big number for better throughput
num_handling_threads: int = 4,
# vision encoder configuration
image_input_type: Optional[str] = None,
image_token_id: Optional[int] = None,
image_input_shape: Optional[str] = None,
image_feature_size: Optional[int] = None,
) -> None:
# download hf model if it does not exist
self._model = model
model_path = model
if not os.path.exists(model_path):
model_path = download_hf_model(
repo_id=model_path,
revision=revision,
allow_patterns=allow_patterns,
cache_dir=cache_dir,
convert_to_safetensors=convert_to_safetensors,
)

options = VLMHandler.Options()
options.model_path = model_path
options.devices = devices
options.block_size = block_size
options.max_cache_size = max_cache_size
options.max_memory_utilization = max_memory_utilization
options.enable_prefix_cache = enable_prefix_cache
options.enable_cuda_graph = enable_cuda_graph
options.cuda_graph_max_seq_len = cuda_graph_max_seq_len
options.cuda_graph_batch_sizes = cuda_graph_batch_sizes
options.max_tokens_per_batch = max_tokens_per_batch
options.max_seqs_per_batch = max_seqs_per_batch
options.num_handling_threads = num_handling_threads
options.image_input_type = image_input_type
options.image_token_id = image_token_id
options.image_input_shape = image_input_shape
options.image_feature_size = image_feature_size
# create the LLM handler
self._handler = VLMHandler(options)

def generate(
self,
image: torch.Tensor = None,
prompt: str = None,
sampling_params: Optional[SamplingParams] = None,
priority: Priority = Priority.NORMAL,
wait_for_schedule: bool = True,
) -> RequestOutput:
# use default sampling parameters if not provided
if sampling_params is None:
sampling_params = SamplingParams()

output = None
def callback(async_output: RequestOutput) -> bool:
#output = async_output
return True

# schedule the batch requests
future = self._handler.schedule_async(
image, prompt, sampling_params, priority, False, callback
)

# wait for batch request to be scheduled
if wait_for_schedule:
future.wait()

# run until all scheduled requsts complete
self._handler.run_until_complete()

# throw an exception if there is any error
if output is None:
raise RuntimeError("Request failed, no output received")
if output.status is not None and not output.status.ok:
raise ValidationError(output.status.code, output.status.message)
# carry over the prompt to the output
output.prompt = prompt
return output

def encode(self, text: str) -> List[int]:
return self._handler.encode(text)

def decode(
self, tokens: List[int], skip_special_tokens: bool = True
) -> Optional[str]:
return self._handler.decode(tokens, skip_special_tokens)

def __del__(self):
self._handler.reset()

def __enter__(self):
return self

def __exit__(self, *args):
self.__del__()
return False

def __repr__(self) -> str:
if self._draft_model:
return f"VLM(model={self._model}, draft_model={self._draft_model})"
return f"VLM(model={self._model})"
4 changes: 4 additions & 0 deletions src/engine/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,18 @@ cc_library(
batch.h
model_runner.h
worker.h
vlm_worker.h
engine.h
llm_engine.h
vlm_engine.h
SRCS
utils.cpp
batch.cpp
model_runner.cpp
worker.cpp
vlm_worker.cpp
llm_engine.cpp
vlm_engine.cpp
DEPS
torch
:common
Expand Down
Loading

0 comments on commit 437be3f

Please sign in to comment.