Skip to content

[Quantization] Add compressed-tensors NVFP4 MoE Support #19990

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion vllm/model_executor/layers/fused_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1178,6 +1178,7 @@ def weight_loader(self, param: torch.nn.Parameter,
param.materialize(final_shape, dtype=loaded_weight.dtype)

expert_data = param.data if full_load else param.data[expert_id]

# Case input scale: input_scale loading is only supported for fp8
if "input_scale" in weight_name:
# this is needed for compressed-tensors only
Expand Down Expand Up @@ -1205,6 +1206,7 @@ def weight_loader(self, param: torch.nn.Parameter,
tp_rank=self.tp_rank)
return

# TODO @dsikka: ModelOpt should follow the proper MoE loading pattern
if "ModelOpt" in quant_method_name:
if ('weight_scale_2' in weight_name
or 'input_scale' in weight_name):
Expand All @@ -1221,7 +1223,7 @@ def weight_loader(self, param: torch.nn.Parameter,
tp_rank=self.tp_rank)
return

# Case weight scales, zero_points and offset
# Case weight scales, zero_points and offset, weight/input global scales
if ("scale" in weight_name or "zero" in weight_name
or "offset" in weight_name):
# load the weight scales and zp based on the quantization scheme
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
check_moe_marlin_supports_layer, marlin_make_workspace_new,
marlin_moe_permute_scales)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
prepare_moe_fp4_layer_for_marlin)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
prepare_moe_fp8_layer_for_marlin)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
Expand All @@ -48,12 +50,11 @@ class GPTQMarlinState(Enum):


__all__ = [
"CompressedTensorsMoEMethod",
"CompressedTensorsW8A8Fp8MoEMethod",
"CompressedTensorsMoEMethod", "CompressedTensorsW8A8Fp8MoEMethod",
"CompressedTensorsW8A8Fp8MoECutlassMethod",
"CompressedTensorsW8A8Int8MoEMethod",
"CompressedTensorsWNA16MarlinMoEMethod",
"CompressedTensorsWNA16MoEMethod",
"CompressedTensorsWNA16MarlinMoEMethod", "CompressedTensorsWNA16MoEMethod",
"CompressedTensorsW4A4MoeMethod"
]


Expand Down Expand Up @@ -86,6 +87,8 @@ def get_moe_method(
else:
logger.info_once("Using CompressedTensorsWNA16MarlinMoEMethod")
return CompressedTensorsWNA16MarlinMoEMethod(quant_config)
elif quant_config._is_fp4a4_nvfp4(weight_quant, input_quant):
return CompressedTensorsW4A4MoeMethod()
elif quant_config._is_fp8_w8a8_sm90(weight_quant, input_quant):
return CompressedTensorsW8A8Fp8MoECutlassMethod(quant_config)
elif quant_config._is_fp8_w8a8(weight_quant, input_quant):
Expand All @@ -97,6 +100,261 @@ def get_moe_method(
f"Unsupported FusedMoe scheme: {weight_quant}, {input_quant}")


class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):

def __init__(self):
self.use_marlin = True
self.group_size = 16

def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size_per_partition: int,
params_dtype: torch.dtype, **extra_weight_attrs):

layer.num_experts = num_experts
layer.params_dtype = params_dtype

w13_weight = torch.nn.Parameter(
torch.empty(
num_experts,
2 * intermediate_size_per_partition,
# 2 fp4 items are packed in the input dimension
hidden_size // 2,
requires_grad=False,
dtype=torch.uint8),
requires_grad=False)
layer.register_parameter("w13_weight_packed", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)

w2_weight = torch.nn.Parameter(
torch.empty(
num_experts,
hidden_size,
# 2 fp4 items are packed in the input dimension
intermediate_size_per_partition // 2,
dtype=torch.uint8),
requires_grad=False)
layer.register_parameter("w2_weight_packed", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)

# Weight Scales
w13_weight_scale = torch.nn.Parameter(
torch.empty(
num_experts,
2 * intermediate_size_per_partition,
# 2 fp4 items are packed in the input dimension
hidden_size // self.group_size,
dtype=torch.float8_e4m3fn),
requires_grad=False)
layer.register_parameter("w13_weight_scale", w13_weight_scale)
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.GROUP.value})
set_weight_attrs(w13_weight_scale, extra_weight_attrs)

w2_weight_scale = torch.nn.Parameter(
torch.empty(
num_experts,
hidden_size,
# 2 fp4 items are packed in the input dimension
intermediate_size_per_partition // self.group_size,
dtype=torch.float8_e4m3fn),
requires_grad=False)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.GROUP.value})
set_weight_attrs(w2_weight_scale, extra_weight_attrs)

# Weight Global Scales
w13_weight_scale_2 = torch.nn.Parameter(torch.empty(
num_experts, 2, dtype=torch.float32),
requires_grad=False)
layer.register_parameter("w13_weight_global_scale", w13_weight_scale_2)
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value})
set_weight_attrs(w13_weight_scale_2, extra_weight_attrs)

w2_weight_scale_2 = torch.nn.Parameter(torch.empty(
num_experts, dtype=torch.float32),
requires_grad=False)
layer.register_parameter("w2_weight_global_scale", w2_weight_scale_2)
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value})
set_weight_attrs(w2_weight_scale_2, extra_weight_attrs)

# Input Global Scales
w13_input_scale = torch.nn.Parameter(torch.empty(num_experts,
2,
dtype=torch.float32),
requires_grad=False)
layer.register_parameter("w13_input_global_scale", w13_input_scale)
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value})
set_weight_attrs(w13_input_scale, extra_weight_attrs)

w2_input_scale = torch.nn.Parameter(torch.empty(num_experts,
dtype=torch.float32),
requires_grad=False)
layer.register_parameter("w2_input_global_scale", w2_input_scale)
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value})
set_weight_attrs(w2_input_scale, extra_weight_attrs)

def swizzle_blockscale(self, scale: torch.tensor):
assert (scale.dtype == torch.float8_e4m3fn)
# Pad and blockwise interleave weight_scale
scale_ndim = scale.ndim
if scale.ndim == 2:
scale = scale.unsqueeze(0)
assert scale.ndim == 3
B, M, K = scale.shape
round_up_multiple = lambda x, m: (x + m - 1) // m * m
M_padded = round_up_multiple(M, 128)
K_padded = round_up_multiple(K, 4)
padded_scale = torch.zeros((B, M_padded, K_padded), dtype=scale.dtype)
padded_scale[:B, :M, :K] = scale
batches, rows, cols = padded_scale.shape
assert rows % 128 == 0
assert cols % 4 == 0
padded_scale = padded_scale.reshape(batches, rows // 128, 4, 32,
cols // 4, 4)
swizzled_scale = padded_scale.permute((0, 1, 4, 3, 2, 5))
swizzled_scale = swizzled_scale.contiguous().cuda()
return (swizzled_scale.reshape(M, K)
if scale_ndim == 2 else swizzled_scale.reshape(B, M, K))

def process_weights_after_loading(self, layer: torch.nn.Module) -> None:

# From packed to weight
layer.w13_weight = torch.nn.Parameter(layer.w13_weight_packed.data,
requires_grad=False)

layer.w2_weight = torch.nn.Parameter(layer.w2_weight_packed.data,
requires_grad=False)

if not torch.allclose(layer.w13_weight_global_scale[:, 0],
layer.w13_weight_global_scale[:, 1]):
logger.warning_once(
"w1_weight_global_scale must match w3_weight_global_scale. "
"Accuracy may be affected.")

# Take inverse of global scale saved to disk
layer.w13_weight_scale_2 = torch.nn.Parameter(
1 / layer.w13_weight_global_scale[:, 0], requires_grad=False)

layer.w2_weight_scale_2 = torch.nn.Parameter(
1 / layer.w2_weight_global_scale.data, requires_grad=False)

if not self.use_marlin:
# swizzle weight scales
layer.w13_blockscale_swizzled = torch.nn.Parameter(
self.swizzle_blockscale(layer.w13_weight_scale),
requires_grad=False)

layer.w2_blockscale_swizzled = torch.nn.Parameter(
self.swizzle_blockscale(layer.w2_weight_scale),
requires_grad=False)

# w13
layer.g1_alphas = torch.nn.Parameter(
((1 / (layer.w13_input_global_scale.max(dim=1).values.to(
torch.float32))) * layer.w13_weight_scale_2),
requires_grad=False)

# w2
layer.g2_alphas = torch.nn.Parameter(
((1 / layer.w2_input_global_scale) *
layer.w2_weight_scale_2).to(torch.float32),
requires_grad=False)

# Inverse Input Quant Scales
layer.w13_input_scale_quant = torch.nn.Parameter(
(layer.w13_input_global_scale.max(dim=1).values.to(
torch.float32)),
requires_grad=False)

layer.w2_input_scale_quant = torch.nn.Parameter(
(layer.w2_input_global_scale), requires_grad=False)

if self.use_marlin:
prepare_moe_fp4_layer_for_marlin(layer)

def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
):
topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
)

if self.use_marlin:
return torch.ops.vllm.fused_marlin_moe(
x,
layer.w13_weight,
layer.w2_weight,
layer.w13_weight_scale,
layer.w2_weight_scale,
router_logits,
topk_weights,
topk_ids,
global_scale1=layer.w13_weight_scale_2,
global_scale2=layer.w2_weight_scale_2,
quant_type_id=scalar_types.float4_e2m1f.id,
global_num_experts=global_num_experts,
expert_map=expert_map)

assert activation == "silu", "Only SiLU activation is supported."
assert not apply_router_weight_on_input, (
"Router weight on input is not "
"supported for ModelOptNvFp4FusedMoE.")
assert expert_map is None, ("Expert Parallelism / expert_map "
"is currently not supported for "
"ModelOptNvFp4FusedMoE.")

from vllm.model_executor.layers.fused_moe.cutlass_moe import (
cutlass_moe_fp4)

# Cutlass moe takes in activations in BF16/Half precision
# and fp4 quantized weights loaded from the checkpoint
return cutlass_moe_fp4(a=x,
w1_fp4=layer.w13_weight,
w1_blockscale=layer.w13_blockscale_swizzled,
w1_alphas=layer.g1_alphas,
w2_fp4=layer.w2_weight,
w2_blockscale=layer.w2_blockscale_swizzled,
w2_alphas=layer.g2_alphas,
topk_weights=topk_weights,
topk_ids=topk_ids,
m=x.shape[0],
n=layer.w2_weight.shape[2] * 2,
k=x.shape[1],
e=layer.w13_weight.shape[0],
a1_gscale=layer.w13_input_scale_quant,
a2_gscale=layer.w2_input_scale_quant,
device=x.device).to(x.dtype)


class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):

def __init__(
Expand Down