forked from vectorch-ai/ScaleLLM
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[model] support vision language model llava. (vectorch-ai#178)
- Loading branch information
1 parent
e087247
commit 437be3f
Showing
32 changed files
with
2,933 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
#!/usr/bin/env python3 | ||
|
||
import torch | ||
from scalellm import VLM, SamplingParameter, StoppingCriteria | ||
|
||
def test_pixel_value_llava_generate(): | ||
vlm = VLM( | ||
model="llava-hf/llava-1.5-7b-hf", | ||
image_input_type="pixel_values", | ||
image_token_id=32000, | ||
image_input_shape="1,3,336,336", | ||
image_feature_size=576, | ||
) | ||
|
||
prompt = "<image>" * 576 + ( | ||
"\nUSER: What is the content of this image?\nASSISTANT:") | ||
|
||
# This should be provided by another online or offline component. | ||
image = torch.load("images/stop_sign_pixel_values.pt") | ||
|
||
output = vlm.generate(images, prompt) | ||
print(o.outputs[0].text) | ||
|
||
def main(): | ||
test_pixel_value_llava_generate() | ||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
from typing import Callable, List, Optional | ||
|
||
import torch | ||
|
||
from scalellm._C.llm_handler import Future, Priority | ||
from scalellm._C.output import RequestOutput | ||
from scalellm._C.sampling_params import SamplingParams | ||
|
||
class VLMHandler: | ||
class Options: | ||
def __init__(self) -> None: ... | ||
def __repr__(self) -> str: ... | ||
model_path: str | ||
devices: Optional[str] | ||
block_size: int | ||
max_cache_size: int | ||
max_memory_utilization: float | ||
enable_prefix_cache: bool | ||
enable_cuda_graph: bool | ||
cuda_graph_max_seq_len: int | ||
cuda_graph_batch_sizes: Optional[List[int]] | ||
max_tokens_per_batch: int | ||
max_seqs_per_batch: int | ||
num_handling_threads: int | ||
image_input_type: str | ||
image_token_id: int | ||
image_input_shape: str | ||
image_feature_size: int | ||
|
||
def __init__(self, options: Options) -> None: ... | ||
def __repr__(self) -> str: ... | ||
def schedule_async( | ||
self, | ||
image: torch.Tensor, | ||
prompt: str, | ||
sp: SamplingParams, | ||
priority: Priority, | ||
stream: bool, | ||
callback: Callable[[RequestOutput], bool], | ||
) -> Future: ... | ||
def start(self) -> None: ... | ||
def stop(self) -> None: ... | ||
def run_until_complete(self) -> None: ... | ||
def reset(self) -> None: ... | ||
# helper functions | ||
def encode(self, text: str) -> List[int]: ... | ||
def decode(self, tokens: List[int], skip_special_tokens: bool) -> str: ... |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
#include "handlers/vlm_handler.h" | ||
|
||
#include <pybind11/functional.h> | ||
#include <pybind11/pybind11.h> | ||
#include <pybind11/stl.h> | ||
#include <pybind11/stl_bind.h> | ||
|
||
namespace llm::csrc { | ||
namespace py = pybind11; | ||
using namespace pybind11::literals; | ||
|
||
void init_vlm_handler(py::module_& m) { | ||
py::enum_<Priority>(m, "Priority") | ||
.value("DEFAULT", Priority::NORMAL) | ||
.value("LOW", Priority::LOW) | ||
.value("NORMAL", Priority::NORMAL) | ||
.value("HIGH", Priority::HIGH) | ||
.export_values(); | ||
|
||
py::class_<std::future<bool>>(m, "Future") | ||
.def("wait", | ||
&std::future<bool>::wait, | ||
py::call_guard<py::gil_scoped_release>()) | ||
.def("get", | ||
&std::future<bool>::get, | ||
py::call_guard<py::gil_scoped_release>()); | ||
|
||
auto vlm_handler = | ||
py::class_<VLMHandler>(m, "VLMHandler") | ||
.def(py::init<const VLMHandler::Options&>(), py::arg("options")) | ||
.def("schedule_async", | ||
&VLMHandler::schedule_async, | ||
py::call_guard<py::gil_scoped_release>()) | ||
.def("start", | ||
&VLMHandler::start, | ||
py::call_guard<py::gil_scoped_release>()) | ||
.def("stop", | ||
&VLMHandler::stop, | ||
py::call_guard<py::gil_scoped_release>()) | ||
.def("run_until_complete", | ||
&VLMHandler::run_until_complete, | ||
py::call_guard<py::gil_scoped_release>()) | ||
.def("encode", | ||
&VLMHandler::encode, | ||
py::call_guard<py::gil_scoped_release>()) | ||
.def("decode", | ||
&VLMHandler::decode, | ||
py::call_guard<py::gil_scoped_release>()) | ||
.def("reset", | ||
&VLMHandler::reset, | ||
py::call_guard<py::gil_scoped_release>()) | ||
.def("__repr__", [](const VLMHandler& self) { | ||
return "VLMHandler({})"_s.format(self.options()); | ||
}); | ||
|
||
// VLMHandler::Options | ||
py::class_<VLMHandler::Options>(vlm_handler, "Options") | ||
.def(py::init()) | ||
.def_readwrite("model_path", &VLMHandler::Options::model_path_) | ||
.def_readwrite("devices", &VLMHandler::Options::devices_) | ||
.def_readwrite("block_size", &VLMHandler::Options::block_size_) | ||
.def_readwrite("max_cache_size", &VLMHandler::Options::max_cache_size_) | ||
.def_readwrite("max_memory_utilization", | ||
&VLMHandler::Options::max_memory_utilization_) | ||
.def_readwrite("enable_prefix_cache", | ||
&VLMHandler::Options::enable_prefix_cache_) | ||
.def_readwrite("enable_cuda_graph", | ||
&VLMHandler::Options::enable_cuda_graph_) | ||
.def_readwrite("cuda_graph_max_seq_len", | ||
&VLMHandler::Options::cuda_graph_max_seq_len_) | ||
.def_readwrite("cuda_graph_batch_sizes", | ||
&VLMHandler::Options::cuda_graph_batch_sizes_) | ||
.def_readwrite("max_tokens_per_batch", | ||
&VLMHandler::Options::max_tokens_per_batch_) | ||
.def_readwrite("max_seqs_per_batch", | ||
&VLMHandler::Options::max_seqs_per_batch_) | ||
.def_readwrite("num_handling_threads", | ||
&VLMHandler::Options::num_handling_threads_) | ||
.def_readwrite("image_input_type", | ||
&VLMHandler::Options::image_input_type_) | ||
.def_readwrite("image_token_id", &VLMHandler::Options::image_token_id_) | ||
.def_readwrite("image_input_shape", | ||
&VLMHandler::Options::image_input_shape_) | ||
.def_readwrite("image_feature_size", | ||
&VLMHandler::Options::image_feature_size_) | ||
.def("__repr__", [](const VLMHandler::Options& self) { | ||
return "Options(model_path={}, devices={}, " | ||
"block_size={}, max_cache_size={}, " | ||
"max_memory_utilization={}, enable_prefix_cache={}, " | ||
"enable_cuda_graph={}, cuda_graph_max_seq_len={}, " | ||
"cuda_graph_batch_sizes={}, " | ||
"max_tokens_per_batch={}, max_seqs_per_batch={}, " | ||
"num_handling_threads={}, " | ||
"image_input_type={}, image_token_id={}, | ||
"image_input_shape={}, image_feature_size={})"_s.format( | ||
self.model_path_, | ||
self.devices_, | ||
self.block_size_, | ||
self.max_cache_size_, | ||
self.max_memory_utilization_, | ||
self.enable_prefix_cache_, | ||
self.enable_cuda_graph_, | ||
self.cuda_graph_max_seq_len_, | ||
self.cuda_graph_batch_sizes_, | ||
self.max_tokens_per_batch_, | ||
self.max_seqs_per_batch_, | ||
self.num_handling_threads_, | ||
self.image_input_type_, | ||
self.image_token_id_, | ||
self.image_input_shape_, | ||
self.image_feature_size_); | ||
}); | ||
} | ||
} // namespace llm::csrc |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,127 @@ | ||
import os | ||
from typing import List, Optional | ||
|
||
import torch | ||
|
||
from scalellm._C import Priority, RequestOutput, SamplingParams, VLMHandler | ||
from scalellm.downloader import download_hf_model | ||
from scalellm.errors import ValidationError | ||
|
||
|
||
class VLM: | ||
def __init__( | ||
self, | ||
model: str, | ||
revision: Optional[str] = None, | ||
allow_patterns: Optional[str] = None, | ||
cache_dir: Optional[str] = None, | ||
convert_to_safetensors: bool = False, | ||
devices: Optional[str] = None, | ||
block_size: int = 16, | ||
max_cache_size: int = 20 * 1024 * 1024 * 1024, | ||
max_memory_utilization: float = 0.9, | ||
enable_prefix_cache: bool = True, | ||
enable_cuda_graph: bool = True, | ||
cuda_graph_max_seq_len: int = 2048, | ||
cuda_graph_batch_sizes: Optional[List[int]] = None, | ||
max_tokens_per_batch: int = 409600, # a big number to disable chunked prefill | ||
max_seqs_per_batch: int = 2048, # a big number for better throughput | ||
num_handling_threads: int = 4, | ||
# vision encoder configuration | ||
image_input_type: Optional[str] = None, | ||
image_token_id: Optional[int] = None, | ||
image_input_shape: Optional[str] = None, | ||
image_feature_size: Optional[int] = None, | ||
) -> None: | ||
# download hf model if it does not exist | ||
self._model = model | ||
model_path = model | ||
if not os.path.exists(model_path): | ||
model_path = download_hf_model( | ||
repo_id=model_path, | ||
revision=revision, | ||
allow_patterns=allow_patterns, | ||
cache_dir=cache_dir, | ||
convert_to_safetensors=convert_to_safetensors, | ||
) | ||
|
||
options = VLMHandler.Options() | ||
options.model_path = model_path | ||
options.devices = devices | ||
options.block_size = block_size | ||
options.max_cache_size = max_cache_size | ||
options.max_memory_utilization = max_memory_utilization | ||
options.enable_prefix_cache = enable_prefix_cache | ||
options.enable_cuda_graph = enable_cuda_graph | ||
options.cuda_graph_max_seq_len = cuda_graph_max_seq_len | ||
options.cuda_graph_batch_sizes = cuda_graph_batch_sizes | ||
options.max_tokens_per_batch = max_tokens_per_batch | ||
options.max_seqs_per_batch = max_seqs_per_batch | ||
options.num_handling_threads = num_handling_threads | ||
options.image_input_type = image_input_type | ||
options.image_token_id = image_token_id | ||
options.image_input_shape = image_input_shape | ||
options.image_feature_size = image_feature_size | ||
# create the LLM handler | ||
self._handler = VLMHandler(options) | ||
|
||
def generate( | ||
self, | ||
image: torch.Tensor = None, | ||
prompt: str = None, | ||
sampling_params: Optional[SamplingParams] = None, | ||
priority: Priority = Priority.NORMAL, | ||
wait_for_schedule: bool = True, | ||
) -> RequestOutput: | ||
# use default sampling parameters if not provided | ||
if sampling_params is None: | ||
sampling_params = SamplingParams() | ||
|
||
output = None | ||
def callback(async_output: RequestOutput) -> bool: | ||
#output = async_output | ||
return True | ||
|
||
# schedule the batch requests | ||
future = self._handler.schedule_async( | ||
image, prompt, sampling_params, priority, False, callback | ||
) | ||
|
||
# wait for batch request to be scheduled | ||
if wait_for_schedule: | ||
future.wait() | ||
|
||
# run until all scheduled requsts complete | ||
self._handler.run_until_complete() | ||
|
||
# throw an exception if there is any error | ||
if output is None: | ||
raise RuntimeError("Request failed, no output received") | ||
if output.status is not None and not output.status.ok: | ||
raise ValidationError(output.status.code, output.status.message) | ||
# carry over the prompt to the output | ||
output.prompt = prompt | ||
return output | ||
|
||
def encode(self, text: str) -> List[int]: | ||
return self._handler.encode(text) | ||
|
||
def decode( | ||
self, tokens: List[int], skip_special_tokens: bool = True | ||
) -> Optional[str]: | ||
return self._handler.decode(tokens, skip_special_tokens) | ||
|
||
def __del__(self): | ||
self._handler.reset() | ||
|
||
def __enter__(self): | ||
return self | ||
|
||
def __exit__(self, *args): | ||
self.__del__() | ||
return False | ||
|
||
def __repr__(self) -> str: | ||
if self._draft_model: | ||
return f"VLM(model={self._model}, draft_model={self._draft_model})" | ||
return f"VLM(model={self._model})" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.