Skip to content

Commit

Permalink
Add ISQ topology feature (EricLBuehler#701)
Browse files Browse the repository at this point in the history
* Add topology for isq

* Support single layer

* Add apis and connect to some public apis

* Use topology in isq quantization

* Works now

* Add demo topography

* Fixes

* Sorting a bit

* Add example

* Some error checking

* Add example and docs, add default

* Typos

* Update deps
  • Loading branch information
EricLBuehler authored Aug 20, 2024
1 parent 3d84a05 commit 754bb6a
Show file tree
Hide file tree
Showing 62 changed files with 874 additions and 278 deletions.
260 changes: 159 additions & 101 deletions Cargo.lock

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ Mistal.rs supports several model categories:
- [ISQ](docs/ISQ.md) (In situ quantization): run `.safetensors` models directly from Hugging Face Hub by quantizing them after loading instead of creating a GGUF file.
- This loads the ISQ-able weights on CPU before quantizing with ISQ and then moving to the device to avoid memory spikes.
- Extremely fast due to working in parallel
- Use a [model topology](docs/TOPOLOGY.md) to configure ISQ types *per layer* with a single [YAML file](topologies/isq.yml)
**Easy**:
- Lightweight OpenAI API compatible HTTP server.
Expand Down
12 changes: 8 additions & 4 deletions docs/ISQ.md
Original file line number Diff line number Diff line change
@@ -1,20 +1,24 @@
# In situ quantization

In situ quantization works by quantizing non GGUF or GGML models in-place. This allows you to take advantage of flash attention, and reduces memory footprint when running the model. Currently, all layers which would be `Linear` are able to be quantized. An API is exposed on the Python and Rust APIs which provide the ability to dynamically re-ISQ models.
In situ quantization works by quantizing non GGUF or GGML models in-place. This allows you to take advantage of flash attention, and reduces memory footprint when running the model. Currently, all layers which would be `Linear` are able to be quantized.

Possible values for ISQ quantization:
An API is exposed on the Python and Rust APIs which provide the ability to dynamically re-ISQ models at runtime.

To set the ISQ type for individual layers, use a model [`topology`](TOPOLOGY.md).

## ISQ quantization types
- Q4_0
- Q4_1
- Q5_0
- Q5_1
- Q8_0
- Q8_1
- Q8_1 (*not available on CUDA*)
- Q2K
- Q3K
- Q4K
- Q5K
- Q6K
- Q8K
- Q8K (*not available on CUDA*)
- HQQ4
- HQQ8

Expand Down
50 changes: 50 additions & 0 deletions docs/TOPOLOGY.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Model topology configuration

To support per-layer mix of ISQ, Mistral.rs supports loading a model topology YAML file. This YAML file is formatted as follows:

1) Top-level keys are either:
- A range of layers (`start-end`) where `start < end`. `start` is inclusive and `end` is inclusive
- A single layer number
2) The topology for the range or layer:
- A single key (`isq`) which mapps to a single value, which can be any [ISQ type](ISQ.md#isq-quantization-types)

Note that:
- The topology for the range is expanded to fill the range
- If ranges overlap, the range with the higher end layer takes precedence and will overwrite
- Any layers which are not covered will have no topology mapping. They will inherit any other ISQ (e.g. with `--isq`/`in_situ_quant`) set.
- Unless the layer is not covered by the topology, the topology value will override any other ISQ (e.g. with `--isq`/`in_situ_quant`).


```yml
0-8:
isq: Q3K
8-16:
isq: Q4K
16-24:
isq: Q6K
# Skip 24-28
28-32:
isq: Q8_0
```
Model topologies may be applied to the following model types:
- `plain`/`Plain`
- `xlora`/`XLora`
- `lora`/`Lora`
- `vision-plain`/`VisionPlain`

## CLI example
```
cargo run --features ... -- -i plain -m microsoft/Phi-3-mini-128k-instruct -a phi3 --topology topologies/isq.yml
```
## HTTP server example
```
cargo run --features ... -- --port 1234 plain -m microsoft/Phi-3-mini-128k-instruct -a phi3 --topology topologies/isq.yml
```
## Rust example
Example [here](../mistralrs/examples/topology/main.rs).
## Python example
Example [here](../examples/python/topology.py).
25 changes: 25 additions & 0 deletions examples/python/topology.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from mistralrs import Runner, Which, ChatCompletionRequest, Architecture

runner = Runner(
which=Which.Plain(
model_id="mistralai/Mistral-7B-Instruct-v0.1",
arch=Architecture.Mistral,
topology="topologies/isq.yml",
),
in_situ_quant="Q4K",
)

res = runner.send_chat_completion_request(
ChatCompletionRequest(
model="mistral",
messages=[
{"role": "user", "content": "Tell me a story about the Rust type system."}
],
max_tokens=256,
presence_penalty=1.0,
top_p=0.1,
temperature=0.1,
)
)
print(res.choices[0].message.content)
print(res.usage)
1 change: 1 addition & 0 deletions mistralrs-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ mistralrs-paged-attn = { version = "0.2.5", path = "../mistralrs-paged-attn", op
mistralrs-quant = { version = "0.2.0", path = "../mistralrs-quant" }
uuid = { version = "1.10.0", features = ["v4"] }
schemars = "0.8.21"
serde_yaml = "0.9.34"

[features]
default = ["plotly"]
Expand Down
2 changes: 2 additions & 0 deletions mistralrs-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ mod scheduler;
mod sequence;
mod toml_selector;
mod tools;
mod topology;
mod utils;
mod vision_models;
mod xlora_models;
Expand Down Expand Up @@ -83,6 +84,7 @@ use toml_selector::{TomlLoaderArgs, TomlSelector};
pub use tools::{
CalledFunction, Function, Tool, ToolCallResponse, ToolCallType, ToolChoice, ToolType,
};
pub use topology::{LayerTopology, Topology};
pub use utils::debug::initialize_logging;
pub use utils::memory_usage::MemoryUsage;
pub use utils::normal::{ModelDType, TryIntoDType};
Expand Down
10 changes: 9 additions & 1 deletion mistralrs-core/src/model_loader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use std::{
use crate::{
get_toml_selected_model_dtype,
pipeline::{GGMLLoaderBuilder, GGMLSpecificConfig, GGUFLoaderBuilder, NormalSpecificConfig},
Loader, ModelDType, ModelSelected, NormalLoaderBuilder, TomlLoaderArgs, TomlSelector,
Loader, ModelDType, ModelSelected, NormalLoaderBuilder, TomlLoaderArgs, TomlSelector, Topology,
VisionLoaderBuilder, VisionSpecificConfig, GGUF_MULTI_FILE_DELIMITER,
};

Expand Down Expand Up @@ -120,10 +120,12 @@ fn loader_from_model_selected(args: LoaderBuilder) -> anyhow::Result<Box<dyn Loa
tokenizer_json,
arch,
dtype: _,
topology,
} => NormalLoaderBuilder::new(
NormalSpecificConfig {
use_flash_attn,
prompt_batchsize: args.prompt_batchsize,
topology: Topology::from_option_path(topology)?,
},
args.chat_template,
tokenizer_json,
Expand All @@ -138,10 +140,12 @@ fn loader_from_model_selected(args: LoaderBuilder) -> anyhow::Result<Box<dyn Loa
tgt_non_granular_index,
arch,
dtype: _,
topology,
} => NormalLoaderBuilder::new(
NormalSpecificConfig {
use_flash_attn,
prompt_batchsize: args.prompt_batchsize,
topology: Topology::from_option_path(topology)?,
},
args.chat_template,
tokenizer_json,
Expand All @@ -164,10 +168,12 @@ fn loader_from_model_selected(args: LoaderBuilder) -> anyhow::Result<Box<dyn Loa
order,
arch,
dtype: _,
topology,
} => NormalLoaderBuilder::new(
NormalSpecificConfig {
use_flash_attn,
prompt_batchsize: args.prompt_batchsize,
topology: Topology::from_option_path(topology)?,
},
args.chat_template,
tokenizer_json,
Expand Down Expand Up @@ -327,10 +333,12 @@ fn loader_from_model_selected(args: LoaderBuilder) -> anyhow::Result<Box<dyn Loa
tokenizer_json,
arch,
dtype: _,
topology,
} => VisionLoaderBuilder::new(
VisionSpecificConfig {
use_flash_attn,
prompt_batchsize: args.prompt_batchsize,
topology: Topology::from_option_path(topology)?,
},
args.chat_template,
tokenizer_json,
Expand Down
16 changes: 16 additions & 0 deletions mistralrs-core/src/model_selected.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ pub enum ModelSelected {
/// Model data type. Defaults to `auto`.
#[arg(short, long, default_value_t = ModelDType::Auto, value_parser = parse_model_dtype)]
dtype: ModelDType,

/// Path to a topology YAML file.
#[arg(long)]
topology: Option<String>,
},

/// Select an X-LoRA architecture
Expand Down Expand Up @@ -75,6 +79,10 @@ pub enum ModelSelected {
/// Model data type. Defaults to `auto`.
#[arg(short, long, default_value_t = ModelDType::Auto, value_parser = parse_model_dtype)]
dtype: ModelDType,

/// Path to a topology YAML file.
#[arg(long)]
topology: Option<String>,
},

/// Select a LoRA architecture
Expand Down Expand Up @@ -102,6 +110,10 @@ pub enum ModelSelected {
/// Model data type. Defaults to `auto`.
#[arg(short, long, default_value_t = ModelDType::Auto, value_parser = parse_model_dtype)]
dtype: ModelDType,

/// Path to a topology YAML file.
#[arg(long)]
topology: Option<String>,
},

/// Select a GGUF model.
Expand Down Expand Up @@ -292,5 +304,9 @@ pub enum ModelSelected {
/// Model data type. Defaults to `auto`.
#[arg(short, long, default_value_t = ModelDType::Auto, value_parser = parse_model_dtype)]
dtype: ModelDType,

/// Path to a topology YAML file.
#[arg(long)]
topology: Option<String>,
},
}
4 changes: 1 addition & 3 deletions mistralrs-core/src/models/gemma.rs
Original file line number Diff line number Diff line change
Expand Up @@ -470,9 +470,7 @@ impl Model {
quant_cfg.bits
);
}
let mapper = normal_loading_metadata
.mapper
.into_mapper(cfg.num_hidden_layers, &normal_loading_metadata.real_device)?;
let mapper = normal_loading_metadata.mapper;

let vb_m = vb.pp("model");
let embed_tokens = candle_nn::embedding(
Expand Down
4 changes: 1 addition & 3 deletions mistralrs-core/src/models/gemma2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -508,9 +508,7 @@ impl Model {
quant_cfg.bits
);
}
let mapper = normal_loading_metadata
.mapper
.into_mapper(cfg.num_hidden_layers, &normal_loading_metadata.real_device)?;
let mapper = normal_loading_metadata.mapper;

let vb_m = vb.pp("model");
let embed_tokens = candle_nn::embedding(
Expand Down
4 changes: 1 addition & 3 deletions mistralrs-core/src/models/llama.rs
Original file line number Diff line number Diff line change
Expand Up @@ -450,9 +450,7 @@ impl Llama {
quant_cfg.bits
);
}
let mapper = normal_loading_metadata
.mapper
.into_mapper(cfg.num_hidden_layers, &normal_loading_metadata.real_device)?;
let mapper = normal_loading_metadata.mapper;

let wte = embedding(
cfg.vocab_size,
Expand Down
4 changes: 1 addition & 3 deletions mistralrs-core/src/models/mistral.rs
Original file line number Diff line number Diff line change
Expand Up @@ -450,9 +450,7 @@ impl Model {
quant_cfg.bits
);
}
let mapper = normal_loading_metadata
.mapper
.into_mapper(cfg.num_hidden_layers, &normal_loading_metadata.real_device)?;
let mapper = normal_loading_metadata.mapper;

let embed_tokens = candle_nn::embedding(
cfg.vocab_size,
Expand Down
4 changes: 1 addition & 3 deletions mistralrs-core/src/models/mixtral.rs
Original file line number Diff line number Diff line change
Expand Up @@ -477,9 +477,7 @@ impl Model {
quant_cfg.bits
);
}
let mapper = normal_loading_metadata
.mapper
.into_mapper(cfg.num_hidden_layers, &normal_loading_metadata.real_device)?;
let mapper = normal_loading_metadata.mapper;
let vb_m = vb.pp("model");

let embed_tokens = candle_nn::embedding(
Expand Down
4 changes: 1 addition & 3 deletions mistralrs-core/src/models/phi2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -426,9 +426,7 @@ impl Model {
quant_cfg.bits
);
}
let mapper = normal_loading_metadata
.mapper
.into_mapper(cfg.num_hidden_layers, &normal_loading_metadata.real_device)?;
let mapper = normal_loading_metadata.mapper;
let vb_m = vb.pp("model");

let embed_tokens = embedding(
Expand Down
4 changes: 1 addition & 3 deletions mistralrs-core/src/models/phi3.rs
Original file line number Diff line number Diff line change
Expand Up @@ -418,9 +418,7 @@ impl Model {
quant_cfg.bits
);
}
let mapper = normal_loading_metadata
.mapper
.into_mapper(cfg.num_hidden_layers, &normal_loading_metadata.real_device)?;
let mapper = normal_loading_metadata.mapper;
let vb_m = vb.pp("model");

let embed_tokens = candle_nn::embedding(
Expand Down
4 changes: 1 addition & 3 deletions mistralrs-core/src/models/qwen2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -410,9 +410,7 @@ impl Model {
quant_cfg.bits
);
}
let mapper = normal_loading_metadata
.mapper
.into_mapper(cfg.num_hidden_layers, &normal_loading_metadata.real_device)?;
let mapper = normal_loading_metadata.mapper;
let vb_m = vb.pp("model");

let embed_tokens = candle_nn::embedding(
Expand Down
4 changes: 1 addition & 3 deletions mistralrs-core/src/models/starcoder2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -410,9 +410,7 @@ impl Model {
quant_cfg.bits
);
}
let mapper = normal_loading_metadata
.mapper
.into_mapper(cfg.num_hidden_layers, &normal_loading_metadata.real_device)?;
let mapper = normal_loading_metadata.mapper;
let vb_m = vb.pp("model");

let embed_tokens = candle_nn::embedding(
Expand Down
1 change: 0 additions & 1 deletion mistralrs-core/src/paged_attention/scheduler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,6 @@ impl PagedAttentionScheduler {

impl Scheduler for PagedAttentionScheduler {
fn add_seq(&mut self, seq: Sequence) {
println!("Adding sequence {}", seq.id());
self.waiting.push_back(Arc::new(Mutex::new(seq)));
}
fn schedule(&mut self) -> SchedulerOutput<'_> {
Expand Down
Loading

0 comments on commit 754bb6a

Please sign in to comment.