Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add TP support for bitsandbytes #798

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 0 additions & 6 deletions aphrodite/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,12 +456,6 @@ def verify_with_parallel_config(
"Pipeline parallelism is only supported for the following "
f" architectures: {_PP_SUPPORTED_MODELS}.")

if self.quantization == "bitsandbytes" and (
parallel_config.tensor_parallel_size > 1
or parallel_config.pipeline_parallel_size > 1):
raise ValueError(
"BitsAndBytes quantization with TP/PP is not supported yet.")

if self.quantization == "bitsandbytes" and self.enforce_eager is False:
raise ValueError(
"BitsAndBytes with enforce_eager=False is not supported yet.")
Expand Down
23 changes: 18 additions & 5 deletions aphrodite/modeling/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,6 +523,8 @@ def weight_loader(self,
param, shard_size, shard_offset)

use_bitsandbytes = getattr(param, "use_bitsandbytes", False)
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit",
False)
if use_bitsandbytes:
shard_size = loaded_weight.shape[output_dim]
shard_offset = loaded_weight.shape[output_dim] * \
Expand All @@ -547,8 +549,11 @@ def weight_loader(self,
loaded_weight.shape[output_dim], tp_rank, tp_size)
else:
start_idx = tp_rank * shard_size
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
shard_size)
# bitsandbytes loads the weights of the specific portion
# no need to narrow here
if not use_bitsandbytes_4bit:
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
shard_size)
# Special case for AQLM codebooks.
elif is_metadata:
# metadata indicates fixed size concatenated along dim 0
Expand Down Expand Up @@ -894,6 +899,8 @@ def weight_loader(self,
param, shard_size, shard_offset)

use_bitsandbytes = getattr(param, "use_bitsandbytes", False)
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit",
False)
if use_bitsandbytes:
orig_qkv_offsets = {
"q": (0, self.num_heads * self.head_size),
Expand Down Expand Up @@ -934,8 +941,11 @@ def weight_loader(self,
else:
shard_id = tp_rank // self.num_kv_head_replicas
start_idx = shard_id * shard_size
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
shard_size)
# bitsandbytes loads the weights of the specific portion
# no need to narrow here
if not use_bitsandbytes_4bit:
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
shard_size)
# Special case for for AQLM codebooks.
elif is_metadata:
# metadata indicates fixed size concatenated along dim 0
Expand Down Expand Up @@ -1044,6 +1054,7 @@ def __init__(self,
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
tp_size = get_tensor_model_parallel_world_size()
input_dim = getattr(param, "input_dim", None)
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
# Special case for GGUF
is_gguf_weight = getattr(param, "is_gguf_weight", False)
is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
Expand All @@ -1058,7 +1069,9 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
param.materialize(tuple(weight_shape), dtype=loaded_weight.dtype)

param_data = param.data
if input_dim is not None:
# bitsandbytes loads the weights of the specific portion
# no need to narrow here
if input_dim is not None and not use_bitsandbytes_4bit:
shard_size = param_data.shape[input_dim]
if self.quant_config is None:
start_idx = get_current_tp_rank_partition_offset(
Expand Down
38 changes: 37 additions & 1 deletion aphrodite/modeling/model_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
LoRAConfig, ModelConfig, MultiModalConfig,
ParallelConfig, SchedulerConfig)
from aphrodite.common.utils import is_pin_memory_available, tensor_progress_bar
from aphrodite.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from aphrodite.modeling.model_loader.tensorizer import (
TensorizerConfig, is_aphrodite_tensorized, load_with_tensorizer,
serialize_aphrodite_model, tensorizer_weights_iterator)
Expand Down Expand Up @@ -661,6 +663,8 @@ def save_model(
class BitsAndBytesModelLoader(BaseModelLoader):
"""Model loader to load model weights with BitAndBytes quantization."""

# TODO: these module names are for Llama only,
# change so that it works with other models as well
default_target_modules = [
"gate_proj", "down_proj", "up_proj", "q_proj", "k_proj", "v_proj",
"o_proj"
Expand Down Expand Up @@ -846,13 +850,39 @@ def _parse_quant_state(param_name: str,
yield weight_name, weight_tensor

def generator() -> Generator:
tp_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()
for weight_name, weight_tensor in self._hf_weight_iter(
hf_weights_files, use_safetensors):
if any(target_module in weight_name
for target_module in self.target_modules):
weight_name = weight_name.replace(".weight", ".qweight")
# weight partitions of different modules occur at
# different dimensions
# TODO: these module names are for Llama only,
# change so that it works with other models as well
if 'down_proj' in weight_name or 'o_proj' in weight_name:
total_size = weight_tensor.size(-1)
start_index = total_size // tp_size * tp_rank
end_index = total_size // tp_size * (tp_rank + 1)
weight_sub_tensor = weight_tensor[...,
start_index:end_index]
else:
total_size = weight_tensor.size(0)
start_index = total_size // tp_size * tp_rank
end_index = total_size // tp_size * (tp_rank + 1)
weight_sub_tensor = weight_tensor[ \
start_index:end_index, ...]
# bitsandbytes requires data in GPU
loaded_weight = weight_tensor.cuda().data
if weight_sub_tensor.is_cuda:
loaded_weight = weight_sub_tensor
else:
loaded_weight = weight_sub_tensor.cuda()
# remove the following after the issue is fixed:
# https://github.com/bitsandbytes-foundation/bitsandbytes/issues/1342
if loaded_weight.is_contiguous() is False:
loaded_weight = loaded_weight.contiguous()

with set_default_torch_dtype(torch.float32):
processed_weight, quant_state = quantize_4bit(
loaded_weight,
Expand All @@ -867,6 +897,12 @@ def generator() -> Generator:

if pre_quant:
return quantized_checkpoint(), quant_state_dict

if pre_quant and get_tensor_model_parallel_world_size() > 1:
raise ValueError(
"Prequanted Bitsandbytes models are not supported with "
"Tensor Parallel. Please try Pipeline Parallel instead.")

return generator(), quant_state_dict

def _load_weights(self, model_config: ModelConfig,
Expand Down
Loading