From d02387b6ac1029c56add853d6d742cd91a917a7f Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Mon, 23 Jun 2025 16:20:01 +0000 Subject: [PATCH 1/9] initial comit --- vllm/model_executor/layers/fused_moe/layer.py | 13 +- .../compressed_tensors_moe.py | 252 +++++++++++++++++- 2 files changed, 257 insertions(+), 8 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 133881fd049..39324506022 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -1176,7 +1176,7 @@ def weight_loader(self, param: torch.nn.Parameter, full_load = len(loaded_weight.shape) == 3 if full_load: shard_dim += 1 - + """ # Materialize GGUF UninitializedParameter if is_gguf_weight and isinstance(param, UninitializedParameter): final_shape = list(loaded_weight.shape) @@ -1212,10 +1212,10 @@ def weight_loader(self, param: torch.nn.Parameter, expert_data=expert_data, tp_rank=self.tp_rank) return - - if "ModelOpt" in quant_method_name: - if ('weight_scale_2' in weight_name - or 'input_scale' in weight_name): + """ + if "ModelOpt" in quant_method_name or "compressed" in quant_method_name: + if ('weight_scale_2' in weight_name or 'input_scale' in weight_name + or "global" in weight_name): self._load_per_tensor_weight_scale(shard_id=shard_id, param=param, loaded_weight=loaded_weight, @@ -1228,7 +1228,7 @@ def weight_loader(self, param: torch.nn.Parameter, expert_data=expert_data, tp_rank=self.tp_rank) return - + """ # Case weight scales, zero_points and offset if ("scale" in weight_name or "zero" in weight_name or "offset" in weight_name): @@ -1283,6 +1283,7 @@ def weight_loader(self, param: torch.nn.Parameter, expert_data=expert_data, tp_rank=self.tp_rank) return + """ @staticmethod def select_experts(hidden_states: torch.Tensor, diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index f14131c5f05..c01c764e47b 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -10,7 +10,8 @@ from compressed_tensors import CompressionFormat from compressed_tensors.quantization import (ActivationOrdering, QuantizationStrategy) - +from vllm.model_executor.parameter import (ModelWeightParameter, + PerTensorScaleParameter) import vllm.envs as envs from vllm import _custom_ops as ops from vllm.logger import init_logger @@ -29,7 +30,9 @@ from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform from vllm.scalar_type import scalar_types - +from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( + apply_fp4_marlin_linear, is_fp4_marlin_supported, + prepare_fp4_layer_for_marlin, prepare_moe_fp4_layer_for_marlin) has_pplx = importlib.util.find_spec("pplx_kernels") is not None if current_platform.is_cuda_alike(): @@ -54,6 +57,7 @@ class GPTQMarlinState(Enum): "CompressedTensorsW8A8Int8MoEMethod", "CompressedTensorsWNA16MarlinMoEMethod", "CompressedTensorsWNA16MoEMethod", + "CompressedTensorsW4A4MoeMethod" ] @@ -86,6 +90,8 @@ def get_moe_method( else: logger.info_once("Using CompressedTensorsWNA16MarlinMoEMethod") return CompressedTensorsWNA16MarlinMoEMethod(quant_config) + elif quant_config._is_fp4a4_nvfp4(weight_quant, input_quant): + return CompressedTensorsW4A4MoeMethod() elif quant_config._is_fp8_w8a8_sm90(weight_quant, input_quant): return CompressedTensorsW8A8Fp8MoECutlassMethod(quant_config) elif quant_config._is_fp8_w8a8(weight_quant, input_quant): @@ -97,6 +103,248 @@ def get_moe_method( f"Unsupported FusedMoe scheme: {weight_quant}, {input_quant}") +class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod): + def __init__(self): + self.use_marlin = True + + def create_weights(self, layer: torch.nn.Module, num_experts: int, + hidden_size: int, intermediate_size_per_partition: int, + params_dtype: torch.dtype, **extra_weight_attrs): + if not self.quant_config.is_checkpoint_nvfp4_serialized: + raise ValueError("NVFP4 quantization was selected, " + " dynamic quantization is not supported.") + + layer.num_experts = num_experts + layer.params_dtype = params_dtype + layer.quant_config = self.quant_config + weight_dtype = torch.uint8 + weight_scale_dtype = torch.float8_e4m3fn + weight_loader = extra_weight_attrs.get("weight_loader") + breakpoint() + # GEMM 1 + w13_weight = ModelWeightParameter( + data=torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + # 2 fp4 items are packed in the input dimension + hidden_size // 2, + dtype=weight_dtype), + input_dim=1, + output_dim=2, + weight_loader=weight_loader) + layer.register_parameter("w13_weight", w13_weight) + + # GEMM 2 + w2_weight = ModelWeightParameter( + data=torch.empty( + num_experts, + hidden_size, + # 2 fp4 items are packed in the input dimension + intermediate_size_per_partition // 2, + dtype=weight_dtype), + input_dim=1, + output_dim=2, + weight_loader=weight_loader) + layer.register_parameter("w2_weight", w2_weight) + + w13_weight_scale = ModelWeightParameter( + data=torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + # 2 fp4 items are packed in the input dimension + hidden_size // self.quant_config.group_size, + dtype=weight_scale_dtype), + input_dim=1, + output_dim=2, + weight_loader=weight_loader) + layer.register_parameter("w13_weight_scale", w13_weight_scale) + + w2_weight_scale = ModelWeightParameter( + data=torch.empty( + num_experts, + hidden_size, + # 2 fp4 items are packed in the input dimension + intermediate_size_per_partition // + self.quant_config.group_size, + dtype=weight_scale_dtype), + input_dim=1, + output_dim=2, + weight_loader=weight_loader) + layer.register_parameter("w2_weight_scale", w2_weight_scale) + + w13_weight_scale_2 = PerTensorScaleParameter( + data=torch.empty(num_experts, 2, dtype=torch.float32), + weight_loader=weight_loader) + layer.register_parameter("w13_weight_global_scale", w13_weight_scale_2) + + w2_weight_scale_2 = PerTensorScaleParameter( + data=torch.empty(num_experts, dtype=torch.float32), + weight_loader=weight_loader) + layer.register_parameter("w2_weight_global_scale", w2_weight_scale_2) + + w13_input_scale = PerTensorScaleParameter(data=torch.empty( + num_experts, 2, dtype=torch.float32), + weight_loader=weight_loader) + layer.register_parameter("w13_input_global_scale", w13_input_scale) + + w2_input_scale = PerTensorScaleParameter(data=torch.empty( + num_experts, dtype=torch.float32), + weight_loader=weight_loader) + layer.register_parameter("w2_input_global_scale", w2_input_scale) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + + # GEMM 1 + if not torch.allclose(layer.w13_weight_scale_2[:, 0], + layer.w13_weight_scale_2[:, 1]): + logger.warning_once( + "w1_weight_scale_2 must match w3_weight_scale_2. " + "Accuracy may be affected.") + + w13_weight_scale_2 = layer.w13_weight_scale_2[:, 0] + layer.w13_weight_scale_2 = torch.nn.Parameter(w13_weight_scale_2, + requires_grad=False) + + w13_input_scale = layer.w13_input_scale.max(dim=1).values.to( + torch.float32) + layer.g1_alphas = torch.nn.Parameter( + (w13_input_scale * w13_weight_scale_2).to(torch.float32), + requires_grad=False) + + assert (layer.w13_weight_scale.shape[2] % 16 == 0), ( + "Expected weight_scale.dim(1) to be divisible by 16") + assert (layer.w13_weight_scale.dtype == torch.float8_e4m3fn), ( + "Weight Blockscale must be represented as FP8-E4M3") + w13_blockscale_swizzled = self.swizzle_blockscale( + layer.w13_weight_scale) + + layer.w13_blockscale_swizzled = torch.nn.Parameter(w13_blockscale_swizzled, + requires_grad=False) + + # This is for quantization, so we need to invert it. + layer.w13_input_scale_quant = torch.nn.Parameter( + (1 / w13_input_scale).to(torch.float32), requires_grad=False) + + layer.w13_weight = torch.nn.Parameter(layer.w13_weight.data, + requires_grad=False) + + # GEMM 2 + layer.g2_alphas = torch.nn.Parameter( + (layer.w2_input_scale * layer.w2_weight_scale_2).to(torch.float32), + requires_grad=False) + + # This is for quantization, so we need to invert it. + layer.w2_input_scale_quant = torch.nn.Parameter( + (1 / layer.w2_input_scale).to(torch.float32), requires_grad=False) + + assert (layer.w2_weight_scale.shape[2] % 16 == 0), ( + "Expected weight_scale.dim(1) to be divisible by 16") + assert (layer.w2_weight_scale.dtype == torch.float8_e4m3fn), ( + "Weight Blockscale must be represented as FP8-E4M3") + w2_blockscale_swizzled = self.swizzle_blockscale(layer.w2_weight_scale) + + layer.w2_blockscale_swizzled = torch.nn.Parameter(w2_blockscale_swizzled, + requires_grad=False) + layer.w2_weight = torch.nn.Parameter(layer.w2_weight.data, requires_grad=False) + + if self.use_marlin: + prepare_moe_fp4_layer_for_marlin(layer) + del layer.g1_alphas + del layer.g2_alphas + del layer.w13_input_scale_quant + del layer.w2_input_scale_quant + del layer.w13_blockscale_swizzled + del layer.w2_blockscale_swizzled + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool = False, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, + activation: str = "silu", + ): + if self.use_marlin: + topk_weights, topk_ids = FusedMoE.select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias, + ) + + return torch.ops.vllm.fused_marlin_moe( + x, + layer.w13_weight, + layer.w2_weight, + layer.w13_weight_scale, + layer.w2_weight_scale, + router_logits, + topk_weights, + topk_ids, + global_scale1=layer.w13_weight_scale_2, + global_scale2=layer.w2_weight_scale_2, + quant_type_id=scalar_types.float4_e2m1f.id, + global_num_experts=global_num_experts, + expert_map=expert_map) + + assert activation == "silu", "Only SiLU activation is supported." + assert not apply_router_weight_on_input, ( + "Router weight on input is not " + "supported for ModelOptNvFp4FusedMoE.") + assert expert_map is None, ("Expert Parallelism / expert_map " + "is currently not supported for " + "ModelOptNvFp4FusedMoE.") + + topk_weights, topk_ids = FusedMoE.select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias) + + from vllm.model_executor.layers.fused_moe.cutlass_moe import ( + cutlass_moe_fp4) + + # Cutlass moe takes in activations in BF16/Half precision + # and fp4 quantized weights loaded from the checkpoint + return cutlass_moe_fp4(a=x, + w1_fp4=layer.w13_weight, + w1_blockscale=layer.w13_blockscale_swizzled, + w1_alphas=layer.g1_alphas, + w2_fp4=layer.w2_weight, + w2_blockscale=layer.w2_blockscale_swizzled, + w2_alphas=layer.g2_alphas, + topk_weights=topk_weights, + topk_ids=topk_ids, + m=x.shape[0], + n=layer.w2_weight.shape[2] * 2, + k=x.shape[1], + e=layer.w13_weight.shape[0], + a1_gscale=layer.w13_input_scale_quant, + a2_gscale=layer.w2_input_scale_quant, + device=x.device).to(x.dtype) + class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): def __init__( From eb69b46fd69a9373e08d8b6b7a2a5038d9e286ca Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Mon, 23 Jun 2025 19:53:54 +0000 Subject: [PATCH 2/9] fix all parameters Signed-off-by: Dipika Sikka --- vllm/model_executor/layers/fused_moe/layer.py | 17 +- .../compressed_tensors_moe.py | 215 +++++++++--------- 2 files changed, 120 insertions(+), 112 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 39324506022..7bd4311c86d 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -1176,7 +1176,7 @@ def weight_loader(self, param: torch.nn.Parameter, full_load = len(loaded_weight.shape) == 3 if full_load: shard_dim += 1 - """ + # Materialize GGUF UninitializedParameter if is_gguf_weight and isinstance(param, UninitializedParameter): final_shape = list(loaded_weight.shape) @@ -1186,6 +1186,7 @@ def weight_loader(self, param: torch.nn.Parameter, param.materialize(final_shape, dtype=loaded_weight.dtype) expert_data = param.data if full_load else param.data[expert_id] + # Case input scale: input_scale loading is only supported for fp8 if "input_scale" in weight_name: # this is needed for compressed-tensors only @@ -1212,10 +1213,11 @@ def weight_loader(self, param: torch.nn.Parameter, expert_data=expert_data, tp_rank=self.tp_rank) return - """ - if "ModelOpt" in quant_method_name or "compressed" in quant_method_name: - if ('weight_scale_2' in weight_name or 'input_scale' in weight_name - or "global" in weight_name): + + # TODO @dsikka: ModelOpt should follow the proper MoE loading pattern + if "ModelOpt" in quant_method_name: + if ('weight_scale_2' in weight_name + or 'input_scale' in weight_name): self._load_per_tensor_weight_scale(shard_id=shard_id, param=param, loaded_weight=loaded_weight, @@ -1228,8 +1230,8 @@ def weight_loader(self, param: torch.nn.Parameter, expert_data=expert_data, tp_rank=self.tp_rank) return - """ - # Case weight scales, zero_points and offset + + # Case weight scales, zero_points and offset, weight/input global scales if ("scale" in weight_name or "zero" in weight_name or "offset" in weight_name): # load the weight scales and zp based on the quantization scheme @@ -1283,7 +1285,6 @@ def weight_loader(self, param: torch.nn.Parameter, expert_data=expert_data, tp_rank=self.tp_rank) return - """ @staticmethod def select_experts(hidden_states: torch.Tensor, diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index c01c764e47b..e50dd1aca1f 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -10,8 +10,7 @@ from compressed_tensors import CompressionFormat from compressed_tensors.quantization import (ActivationOrdering, QuantizationStrategy) -from vllm.model_executor.parameter import (ModelWeightParameter, - PerTensorScaleParameter) + import vllm.envs as envs from vllm import _custom_ops as ops from vllm.logger import init_logger @@ -23,6 +22,8 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import ( check_moe_marlin_supports_layer, marlin_make_workspace_new, marlin_moe_permute_scales) +from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( + prepare_moe_fp4_layer_for_marlin) from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( prepare_moe_fp8_layer_for_marlin) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( @@ -30,9 +31,7 @@ from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform from vllm.scalar_type import scalar_types -from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( - apply_fp4_marlin_linear, is_fp4_marlin_supported, - prepare_fp4_layer_for_marlin, prepare_moe_fp4_layer_for_marlin) + has_pplx = importlib.util.find_spec("pplx_kernels") is not None if current_platform.is_cuda_alike(): @@ -51,12 +50,10 @@ class GPTQMarlinState(Enum): __all__ = [ - "CompressedTensorsMoEMethod", - "CompressedTensorsW8A8Fp8MoEMethod", + "CompressedTensorsMoEMethod", "CompressedTensorsW8A8Fp8MoEMethod", "CompressedTensorsW8A8Fp8MoECutlassMethod", "CompressedTensorsW8A8Int8MoEMethod", - "CompressedTensorsWNA16MarlinMoEMethod", - "CompressedTensorsWNA16MoEMethod", + "CompressedTensorsWNA16MarlinMoEMethod", "CompressedTensorsWNA16MoEMethod", "CompressedTensorsW4A4MoeMethod" ] @@ -104,157 +101,165 @@ def get_moe_method( class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod): + def __init__(self): self.use_marlin = True + self.group_size = 16 def create_weights(self, layer: torch.nn.Module, num_experts: int, hidden_size: int, intermediate_size_per_partition: int, params_dtype: torch.dtype, **extra_weight_attrs): - if not self.quant_config.is_checkpoint_nvfp4_serialized: - raise ValueError("NVFP4 quantization was selected, " - " dynamic quantization is not supported.") layer.num_experts = num_experts layer.params_dtype = params_dtype - layer.quant_config = self.quant_config - weight_dtype = torch.uint8 - weight_scale_dtype = torch.float8_e4m3fn - weight_loader = extra_weight_attrs.get("weight_loader") - breakpoint() - # GEMM 1 - w13_weight = ModelWeightParameter( - data=torch.empty( + + w13_weight = torch.nn.Parameter( + torch.empty( num_experts, 2 * intermediate_size_per_partition, # 2 fp4 items are packed in the input dimension hidden_size // 2, - dtype=weight_dtype), - input_dim=1, - output_dim=2, - weight_loader=weight_loader) - layer.register_parameter("w13_weight", w13_weight) + requires_grad=False, + dtype=torch.uint8), + requires_grad=False) + layer.register_parameter("w13_weight_packed", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) - # GEMM 2 - w2_weight = ModelWeightParameter( - data=torch.empty( + w2_weight = torch.nn.Parameter( + torch.empty( num_experts, hidden_size, # 2 fp4 items are packed in the input dimension intermediate_size_per_partition // 2, - dtype=weight_dtype), - input_dim=1, - output_dim=2, - weight_loader=weight_loader) - layer.register_parameter("w2_weight", w2_weight) + dtype=torch.uint8), + requires_grad=False) + layer.register_parameter("w2_weight_packed", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) - w13_weight_scale = ModelWeightParameter( - data=torch.empty( + # Weight Scales + w13_weight_scale = torch.nn.Parameter( + torch.empty( num_experts, 2 * intermediate_size_per_partition, # 2 fp4 items are packed in the input dimension - hidden_size // self.quant_config.group_size, - dtype=weight_scale_dtype), - input_dim=1, - output_dim=2, - weight_loader=weight_loader) + hidden_size // self.group_size, + dtype=torch.float8_e4m3fn), + requires_grad=False) layer.register_parameter("w13_weight_scale", w13_weight_scale) + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.GROUP.value}) + set_weight_attrs(w13_weight_scale, extra_weight_attrs) - w2_weight_scale = ModelWeightParameter( - data=torch.empty( + w2_weight_scale = torch.nn.Parameter( + torch.empty( num_experts, hidden_size, # 2 fp4 items are packed in the input dimension - intermediate_size_per_partition // - self.quant_config.group_size, - dtype=weight_scale_dtype), - input_dim=1, - output_dim=2, - weight_loader=weight_loader) + intermediate_size_per_partition // self.group_size, + dtype=torch.float8_e4m3fn), + requires_grad=False) layer.register_parameter("w2_weight_scale", w2_weight_scale) + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.GROUP.value}) + set_weight_attrs(w2_weight_scale, extra_weight_attrs) - w13_weight_scale_2 = PerTensorScaleParameter( - data=torch.empty(num_experts, 2, dtype=torch.float32), - weight_loader=weight_loader) + # Weight Global Scales + w13_weight_scale_2 = torch.nn.Parameter(torch.empty( + num_experts, 2, dtype=torch.float32), + requires_grad=False) layer.register_parameter("w13_weight_global_scale", w13_weight_scale_2) + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}) + set_weight_attrs(w13_weight_scale_2, extra_weight_attrs) - w2_weight_scale_2 = PerTensorScaleParameter( - data=torch.empty(num_experts, dtype=torch.float32), - weight_loader=weight_loader) + w2_weight_scale_2 = torch.nn.Parameter(torch.empty( + num_experts, dtype=torch.float32), + requires_grad=False) layer.register_parameter("w2_weight_global_scale", w2_weight_scale_2) + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}) + set_weight_attrs(w2_weight_scale_2, extra_weight_attrs) - w13_input_scale = PerTensorScaleParameter(data=torch.empty( - num_experts, 2, dtype=torch.float32), - weight_loader=weight_loader) + # Input Global Scales + w13_input_scale = torch.nn.Parameter(torch.empty(num_experts, + 2, + dtype=torch.float32), + requires_grad=False) layer.register_parameter("w13_input_global_scale", w13_input_scale) + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}) + set_weight_attrs(w13_input_scale, extra_weight_attrs) - w2_input_scale = PerTensorScaleParameter(data=torch.empty( - num_experts, dtype=torch.float32), - weight_loader=weight_loader) + w2_input_scale = torch.nn.Parameter(torch.empty(num_experts, + dtype=torch.float32), + requires_grad=False) layer.register_parameter("w2_input_global_scale", w2_input_scale) + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}) + set_weight_attrs(w2_input_scale, extra_weight_attrs) def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - # GEMM 1 - if not torch.allclose(layer.w13_weight_scale_2[:, 0], - layer.w13_weight_scale_2[:, 1]): + # From packed to weight + layer.w13_weight = torch.nn.Parameter(layer.w13_weight_packed.data, + requires_grad=False) + + layer.w2_weight = torch.nn.Parameter(layer.w2_weight_packed.data, + requires_grad=False) + + if not torch.allclose(layer.w13_weight_global_scale[:, 0], + layer.w13_weight_global_scale[:, 1]): logger.warning_once( - "w1_weight_scale_2 must match w3_weight_scale_2. " + "w1_weight_global_scale must match w3_weight_global_scale. " "Accuracy may be affected.") - w13_weight_scale_2 = layer.w13_weight_scale_2[:, 0] + # Take inverse of global scale saved to disk + w13_weight_scale_2 = 1 / layer.w13_weight_global_scale[:, 0] + layer.w13_weight_scale_2 = torch.nn.Parameter(w13_weight_scale_2, - requires_grad=False) + requires_grad=False) - w13_input_scale = layer.w13_input_scale.max(dim=1).values.to( - torch.float32) - layer.g1_alphas = torch.nn.Parameter( - (w13_input_scale * w13_weight_scale_2).to(torch.float32), - requires_grad=False) + layer.w2_weight_scale_2 = torch.nn.Parameter( + 1 / layer.w2_weight_global_scale.data, requires_grad=False) - assert (layer.w13_weight_scale.shape[2] % 16 == 0), ( - "Expected weight_scale.dim(1) to be divisible by 16") - assert (layer.w13_weight_scale.dtype == torch.float8_e4m3fn), ( - "Weight Blockscale must be represented as FP8-E4M3") - w13_blockscale_swizzled = self.swizzle_blockscale( - layer.w13_weight_scale) + #w13_input_scale = layer.w13_input_scale.max(dim=1).values.to( + # torch.float32) + #layer.g1_alphas = torch.nn.Parameter( + # (w13_input_scale * w13_weight_scale_2).to(torch.float32), + # requires_grad=False) - layer.w13_blockscale_swizzled = torch.nn.Parameter(w13_blockscale_swizzled, - requires_grad=False) + #w13_blockscale_swizzled = self.swizzle_blockscale( + # layer.w13_weight_scale) - # This is for quantization, so we need to invert it. - layer.w13_input_scale_quant = torch.nn.Parameter( - (1 / w13_input_scale).to(torch.float32), requires_grad=False) + #layer.w13_blockscale_swizzled = torch.nn.Parameter( + # w13_blockscale_swizzled, requires_grad=False) - layer.w13_weight = torch.nn.Parameter(layer.w13_weight.data, - requires_grad=False) + # This is for quantization, so we need to invert it. + #layer.w13_input_scale_quant = torch.nn.Parameter( + # (1 / w13_input_scale).to(torch.float32), requires_grad=False) # GEMM 2 - layer.g2_alphas = torch.nn.Parameter( - (layer.w2_input_scale * layer.w2_weight_scale_2).to(torch.float32), - requires_grad=False) + #layer.g2_alphas = torch.nn.Parameter( + # (layer.w2_input_scale * layer.w2_weight_scale_2).to(torch.float32), + # requires_grad=False) # This is for quantization, so we need to invert it. - layer.w2_input_scale_quant = torch.nn.Parameter( - (1 / layer.w2_input_scale).to(torch.float32), requires_grad=False) + #layer.w2_input_scale_quant = torch.nn.Parameter( + # (1 / layer.w2_input_scale).to(torch.float32), requires_grad=False) - assert (layer.w2_weight_scale.shape[2] % 16 == 0), ( - "Expected weight_scale.dim(1) to be divisible by 16") - assert (layer.w2_weight_scale.dtype == torch.float8_e4m3fn), ( - "Weight Blockscale must be represented as FP8-E4M3") - w2_blockscale_swizzled = self.swizzle_blockscale(layer.w2_weight_scale) + #w2_blockscale_swizzled = self.swizzle_blockscale(layer.w2_weight_scale) - layer.w2_blockscale_swizzled = torch.nn.Parameter(w2_blockscale_swizzled, - requires_grad=False) - layer.w2_weight = torch.nn.Parameter(layer.w2_weight.data, requires_grad=False) + #layer.w2_blockscale_swizzled = torch.nn.Parameter( + # w2_blockscale_swizzled, requires_grad=False) if self.use_marlin: prepare_moe_fp4_layer_for_marlin(layer) - del layer.g1_alphas - del layer.g2_alphas - del layer.w13_input_scale_quant - del layer.w2_input_scale_quant - del layer.w13_blockscale_swizzled - del layer.w2_blockscale_swizzled + #del layer.g1_alphas + #del layer.g2_alphas + #del layer.w13_input_scale_quant + #del layer.w2_input_scale_quant + #del layer.w13_blockscale_swizzled + #del layer.w2_blockscale_swizzled def apply( self, @@ -302,7 +307,7 @@ def apply( quant_type_id=scalar_types.float4_e2m1f.id, global_num_experts=global_num_experts, expert_map=expert_map) - + """ assert activation == "silu", "Only SiLU activation is supported." assert not apply_router_weight_on_input, ( "Router weight on input is not " @@ -344,6 +349,8 @@ def apply( a1_gscale=layer.w13_input_scale_quant, a2_gscale=layer.w2_input_scale_quant, device=x.device).to(x.dtype) + """ + class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): From 186727b9393f55ecf01480efb282490c2f13b256 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Mon, 23 Jun 2025 20:39:03 +0000 Subject: [PATCH 3/9] enable cutlass Signed-off-by: Dipika Sikka --- .../compressed_tensors_moe.py | 93 ++++++++----------- 1 file changed, 37 insertions(+), 56 deletions(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index e50dd1aca1f..dcf4cab9a15 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -222,44 +222,38 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: layer.w2_weight_scale_2 = torch.nn.Parameter( 1 / layer.w2_weight_global_scale.data, requires_grad=False) - #w13_input_scale = layer.w13_input_scale.max(dim=1).values.to( - # torch.float32) - #layer.g1_alphas = torch.nn.Parameter( - # (w13_input_scale * w13_weight_scale_2).to(torch.float32), - # requires_grad=False) + if not self.use_marlin: + # w13 + w13_input_scale = 1 / layer.w13_input_scale.max(dim=1).values.to( + torch.float32) - #w13_blockscale_swizzled = self.swizzle_blockscale( - # layer.w13_weight_scale) + layer.g1_alphas = torch.nn.Parameter( + (w13_input_scale * w13_weight_scale_2).to(torch.float32), + requires_grad=False) - #layer.w13_blockscale_swizzled = torch.nn.Parameter( - # w13_blockscale_swizzled, requires_grad=False) + layer.w13_blockscale_swizzled = torch.nn.Parameter( + self.swizzle_blockscale(layer.w13_weight_scale), + requires_grad=False) - # This is for quantization, so we need to invert it. - #layer.w13_input_scale_quant = torch.nn.Parameter( - # (1 / w13_input_scale).to(torch.float32), requires_grad=False) + # This is for quantization, so we need to invert it. + layer.w13_input_scale_quant = torch.nn.Parameter( + (1 / w13_input_scale).to(torch.float32), requires_grad=False) - # GEMM 2 - #layer.g2_alphas = torch.nn.Parameter( - # (layer.w2_input_scale * layer.w2_weight_scale_2).to(torch.float32), - # requires_grad=False) + # w2 + layer.g2_alphas = torch.nn.Parameter( + ((1 / layer.w2_input_scale) * layer.w2_weight_scale_2).to( + torch.float32), + requires_grad=False) - # This is for quantization, so we need to invert it. - #layer.w2_input_scale_quant = torch.nn.Parameter( - # (1 / layer.w2_input_scale).to(torch.float32), requires_grad=False) + layer.w2_input_scale_quant = torch.nn.Parameter( + (layer.w2_input_scale).to(torch.float32), requires_grad=False) - #w2_blockscale_swizzled = self.swizzle_blockscale(layer.w2_weight_scale) - - #layer.w2_blockscale_swizzled = torch.nn.Parameter( - # w2_blockscale_swizzled, requires_grad=False) + layer.w2_blockscale_swizzled = torch.nn.Parameter( + self.swizzle_blockscale(layer.w2_weight_scale), + requires_grad=False) if self.use_marlin: prepare_moe_fp4_layer_for_marlin(layer) - #del layer.g1_alphas - #del layer.g2_alphas - #del layer.w13_input_scale_quant - #del layer.w2_input_scale_quant - #del layer.w13_blockscale_swizzled - #del layer.w2_blockscale_swizzled def apply( self, @@ -279,20 +273,20 @@ def apply( apply_router_weight_on_input: bool = False, activation: str = "silu", ): - if self.use_marlin: - topk_weights, topk_ids = FusedMoE.select_experts( - hidden_states=x, - router_logits=router_logits, - use_grouped_topk=use_grouped_topk, - top_k=top_k, - renormalize=renormalize, - topk_group=topk_group, - num_expert_group=num_expert_group, - custom_routing_function=custom_routing_function, - scoring_func=scoring_func, - e_score_correction_bias=e_score_correction_bias, - ) + topk_weights, topk_ids = FusedMoE.select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias, + ) + if self.use_marlin: return torch.ops.vllm.fused_marlin_moe( x, layer.w13_weight, @@ -307,7 +301,7 @@ def apply( quant_type_id=scalar_types.float4_e2m1f.id, global_num_experts=global_num_experts, expert_map=expert_map) - """ + assert activation == "silu", "Only SiLU activation is supported." assert not apply_router_weight_on_input, ( "Router weight on input is not " @@ -316,18 +310,6 @@ def apply( "is currently not supported for " "ModelOptNvFp4FusedMoE.") - topk_weights, topk_ids = FusedMoE.select_experts( - hidden_states=x, - router_logits=router_logits, - use_grouped_topk=use_grouped_topk, - top_k=top_k, - renormalize=renormalize, - topk_group=topk_group, - num_expert_group=num_expert_group, - custom_routing_function=custom_routing_function, - scoring_func=scoring_func, - e_score_correction_bias=e_score_correction_bias) - from vllm.model_executor.layers.fused_moe.cutlass_moe import ( cutlass_moe_fp4) @@ -349,7 +331,6 @@ def apply( a1_gscale=layer.w13_input_scale_quant, a2_gscale=layer.w2_input_scale_quant, device=x.device).to(x.dtype) - """ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): From 94d54059d3068e95f5648a0e6ac04299ed2a2bbc Mon Sep 17 00:00:00 2001 From: Dipika Date: Mon, 23 Jun 2025 20:56:29 -0400 Subject: [PATCH 4/9] update Signed-off-by: Dipika --- .../compressed_tensors_moe.py | 66 ++++++++++++------- 1 file changed, 44 insertions(+), 22 deletions(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index dcf4cab9a15..a6d200db1bb 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -198,6 +198,29 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}) set_weight_attrs(w2_input_scale, extra_weight_attrs) + def swizzle_blockscale(self, scale: torch.tensor): + assert (scale.dtype == torch.float8_e4m3fn) + # Pad and blockwise interleave weight_scale + scale_ndim = scale.ndim + if scale.ndim == 2: + scale = scale.unsqueeze(0) + assert scale.ndim == 3 + B, M, K = scale.shape + round_up_multiple = lambda x, m: (x + m - 1) // m * m + M_padded = round_up_multiple(M, 128) + K_padded = round_up_multiple(K, 4) + padded_scale = torch.zeros((B, M_padded, K_padded), dtype=scale.dtype) + padded_scale[:B, :M, :K] = scale + batches, rows, cols = padded_scale.shape + assert rows % 128 == 0 + assert cols % 4 == 0 + padded_scale = padded_scale.reshape(batches, rows // 128, 4, 32, + cols // 4, 4) + swizzled_scale = padded_scale.permute((0, 1, 4, 3, 2, 5)) + swizzled_scale = swizzled_scale.contiguous().cuda() + return (swizzled_scale.reshape(M, K) + if scale_ndim == 2 else swizzled_scale.reshape(B, M, K)) + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # From packed to weight @@ -214,44 +237,43 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: "Accuracy may be affected.") # Take inverse of global scale saved to disk - w13_weight_scale_2 = 1 / layer.w13_weight_global_scale[:, 0] - - layer.w13_weight_scale_2 = torch.nn.Parameter(w13_weight_scale_2, - requires_grad=False) + layer.w13_weight_scale_2 = torch.nn.Parameter( + 1 / layer.w13_weight_global_scale[:, 0], requires_grad=False) layer.w2_weight_scale_2 = torch.nn.Parameter( 1 / layer.w2_weight_global_scale.data, requires_grad=False) if not self.use_marlin: - # w13 - w13_input_scale = 1 / layer.w13_input_scale.max(dim=1).values.to( - torch.float32) - - layer.g1_alphas = torch.nn.Parameter( - (w13_input_scale * w13_weight_scale_2).to(torch.float32), - requires_grad=False) - + # swizzle weight scales layer.w13_blockscale_swizzled = torch.nn.Parameter( self.swizzle_blockscale(layer.w13_weight_scale), requires_grad=False) - # This is for quantization, so we need to invert it. - layer.w13_input_scale_quant = torch.nn.Parameter( - (1 / w13_input_scale).to(torch.float32), requires_grad=False) + layer.w2_blockscale_swizzled = torch.nn.Parameter( + self.swizzle_blockscale(layer.w2_weight_scale), + requires_grad=False) + + # w13 + layer.g1_alphas = torch.nn.Parameter( + ((1 / (layer.w13_input_global_scale.max(dim=1).values.to( + torch.float32))) * layer.w13_weight_scale_2), + requires_grad=False) # w2 layer.g2_alphas = torch.nn.Parameter( - ((1 / layer.w2_input_scale) * layer.w2_weight_scale_2).to( - torch.float32), + ((1 / layer.w2_input_global_scale) * + layer.w2_weight_scale_2).to(torch.float32), requires_grad=False) - layer.w2_input_scale_quant = torch.nn.Parameter( - (layer.w2_input_scale).to(torch.float32), requires_grad=False) - - layer.w2_blockscale_swizzled = torch.nn.Parameter( - self.swizzle_blockscale(layer.w2_weight_scale), + # Inverse Input Quant Scales + layer.w13_input_scale_quant = torch.nn.Parameter( + (layer.w13_input_global_scale.max(dim=1).values.to( + torch.float32)), requires_grad=False) + layer.w2_input_scale_quant = torch.nn.Parameter( + (layer.w2_input_global_scale), requires_grad=False) + if self.use_marlin: prepare_moe_fp4_layer_for_marlin(layer) From 8e9069185f3ce75ff31ccddac12cba0a6e65bbec Mon Sep 17 00:00:00 2001 From: Dipika Date: Tue, 24 Jun 2025 14:04:39 -0400 Subject: [PATCH 5/9] temporarily enable sm100 build; fix condition for when to use marlin; move utils Signed-off-by: Dipika --- CMakeLists.txt | 2 +- .../compressed_tensors/compressed_tensors.py | 5 +- .../compressed_tensors_moe.py | 53 ++++++++++--------- .../schemes/compressed_tensors_w4a4_nvfp4.py | 13 +---- .../utils/nvfp4_emulation_utils.py | 15 +++++- 5 files changed, 47 insertions(+), 41 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 402131b7a1e..d63d831d082 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -545,7 +545,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") # The MoE kernel cutlass_moe_mm requires CUDA 12.3 or later (and ONLY works # on Hopper). get_cutlass_(pplx_)moe_mm_data should only be compiled # if it's possible to compile MoE kernels that use its output. - cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a" "${CUDA_ARCHS}") + cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a;10.0a" "${CUDA_ARCHS}") if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND SCALED_MM_ARCHS) set(SRCS "csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cu" "csrc/quantization/cutlass_w8a8/moe/moe_data.cu") diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index d21abb2741a..9bb239d3cde 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -33,6 +33,8 @@ find_matched_target, is_activation_quantization_format, should_ignore_layer) from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod +from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import ( # noqa: E501 + cutlass_fp4_supported) from vllm.platforms import current_platform logger = init_logger(__name__) @@ -375,8 +377,7 @@ def _get_scheme_from_parts( if is_activation_quantization_format(self.quant_format): if self._is_fp4a4_nvfp4(weight_quant, input_quant): - if CompressedTensorsW4A4Fp4.cutlass_fp4_supported( - ) or envs.VLLM_USE_NVFP4_CT_EMULATIONS: + if cutlass_fp4_supported() or envs.VLLM_USE_NVFP4_CT_EMULATIONS: return CompressedTensorsW4A4Fp4() else: logger.warning_once( diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index a6d200db1bb..9ba1feee382 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -26,6 +26,8 @@ prepare_moe_fp4_layer_for_marlin) from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( prepare_moe_fp8_layer_for_marlin) +from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import ( # noqa: E501 + cutlass_fp4_supported) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( all_close_1d, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize) from vllm.model_executor.utils import set_weight_attrs @@ -103,7 +105,7 @@ def get_moe_method( class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod): def __init__(self): - self.use_marlin = True + self.use_marlin = not cutlass_fp4_supported() self.group_size = 16 def create_weights(self, layer: torch.nn.Module, num_experts: int, @@ -180,23 +182,23 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}) set_weight_attrs(w2_weight_scale_2, extra_weight_attrs) - # Input Global Scales - w13_input_scale = torch.nn.Parameter(torch.empty(num_experts, - 2, - dtype=torch.float32), - requires_grad=False) - layer.register_parameter("w13_input_global_scale", w13_input_scale) - extra_weight_attrs.update( - {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}) - set_weight_attrs(w13_input_scale, extra_weight_attrs) + if not self.use_marlin: + # Input Global Scales + w13_input_scale = torch.nn.Parameter(torch.empty( + num_experts, 2, dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w13_input_global_scale", w13_input_scale) + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}) + set_weight_attrs(w13_input_scale, extra_weight_attrs) - w2_input_scale = torch.nn.Parameter(torch.empty(num_experts, - dtype=torch.float32), - requires_grad=False) - layer.register_parameter("w2_input_global_scale", w2_input_scale) - extra_weight_attrs.update( - {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}) - set_weight_attrs(w2_input_scale, extra_weight_attrs) + w2_input_scale = torch.nn.Parameter(torch.empty( + num_experts, dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w2_input_global_scale", w2_input_scale) + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}) + set_weight_attrs(w2_input_scale, extra_weight_attrs) def swizzle_blockscale(self, scale: torch.tensor): assert (scale.dtype == torch.float8_e4m3fn) @@ -254,23 +256,24 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: requires_grad=False) # w13 + w13_input_global_scale = layer.w13_input_global_scale.max( + dim=1).values.to(torch.float32) + layer.g1_alphas = torch.nn.Parameter( - ((1 / (layer.w13_input_global_scale.max(dim=1).values.to( - torch.float32))) * layer.w13_weight_scale_2), + ((1 / w13_input_global_scale) * layer.w13_weight_scale_2), requires_grad=False) + layer.w13_input_scale_quant = torch.nn.Parameter( + (w13_input_global_scale), requires_grad=False) + + del w13_input_global_scale + # w2 layer.g2_alphas = torch.nn.Parameter( ((1 / layer.w2_input_global_scale) * layer.w2_weight_scale_2).to(torch.float32), requires_grad=False) - # Inverse Input Quant Scales - layer.w13_input_scale_quant = torch.nn.Parameter( - (layer.w13_input_global_scale.max(dim=1).values.to( - torch.float32)), - requires_grad=False) - layer.w2_input_scale_quant = torch.nn.Parameter( (layer.w2_input_global_scale), requires_grad=False) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py index ec1d4a6c0ef..65cbc49d264 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py @@ -5,8 +5,7 @@ from torch.nn.parameter import Parameter import vllm.envs as envs -from vllm._custom_ops import (cutlass_scaled_fp4_mm, - cutlass_scaled_mm_supports_fp4, scaled_fp4_quant) +from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant from vllm.logger import init_logger from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( CompressedTensorsScheme) @@ -15,7 +14,6 @@ from vllm.model_executor.parameter import (GroupQuantScaleParameter, ModelWeightParameter, PerTensorScaleParameter) -from vllm.platforms import current_platform logger = init_logger(__name__) @@ -33,15 +31,6 @@ def get_min_capability(cls) -> int: return 80 return 100 - @classmethod - def cutlass_fp4_supported(cls) -> bool: - if not current_platform.is_cuda(): - return False - capability_tuple = current_platform.get_device_capability() - capability = -1 if capability_tuple is None else capability_tuple.to_int( # noqa: E501 - ) - return cutlass_scaled_mm_supports_fp4(capability) - def create_weights(self, layer: torch.nn.Module, output_partition_sizes: list[int], input_size_per_partition: int, diff --git a/vllm/model_executor/layers/quantization/utils/nvfp4_emulation_utils.py b/vllm/model_executor/layers/quantization/utils/nvfp4_emulation_utils.py index d5ce6d7ad75..fb3287d3b89 100644 --- a/vllm/model_executor/layers/quantization/utils/nvfp4_emulation_utils.py +++ b/vllm/model_executor/layers/quantization/utils/nvfp4_emulation_utils.py @@ -2,9 +2,14 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import torch +from vllm._custom_ops import cutlass_scaled_mm_supports_fp4 +from vllm.platforms import current_platform from vllm.scalar_type import scalar_types -__all__ = ["break_fp4_bytes", "dequantize_to_dtype", "ref_nvfp4_quant"] +__all__ = [ + "break_fp4_bytes", "dequantize_to_dtype", "ref_nvfp4_quant", + "cutlass_fp4_supported" +] FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max() @@ -12,6 +17,14 @@ dtype=torch.float32) +def cutlass_fp4_supported() -> bool: + if not current_platform.is_cuda(): + return False + capability_tuple = current_platform.get_device_capability() + capability = -1 if capability_tuple is None else capability_tuple.to_int() + return cutlass_scaled_mm_supports_fp4(capability) + + def break_fp4_bytes(a, dtype): assert a.dtype == torch.uint8 m, n = a.shape From 2bae08eec1ccc2620131a39ce415598f7c111810 Mon Sep 17 00:00:00 2001 From: Dipika Date: Tue, 24 Jun 2025 14:22:57 -0400 Subject: [PATCH 6/9] remove if condition Signed-off-by: Dipika --- .../compressed_tensors_moe.py | 32 +++++++++---------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 9ba1feee382..a6dd2af3a7e 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -182,23 +182,23 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}) set_weight_attrs(w2_weight_scale_2, extra_weight_attrs) - if not self.use_marlin: - # Input Global Scales - w13_input_scale = torch.nn.Parameter(torch.empty( - num_experts, 2, dtype=torch.float32), - requires_grad=False) - layer.register_parameter("w13_input_global_scale", w13_input_scale) - extra_weight_attrs.update( - {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}) - set_weight_attrs(w13_input_scale, extra_weight_attrs) + # Input Global Scales + w13_input_scale = torch.nn.Parameter(torch.empty(num_experts, + 2, + dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w13_input_global_scale", w13_input_scale) + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}) + set_weight_attrs(w13_input_scale, extra_weight_attrs) - w2_input_scale = torch.nn.Parameter(torch.empty( - num_experts, dtype=torch.float32), - requires_grad=False) - layer.register_parameter("w2_input_global_scale", w2_input_scale) - extra_weight_attrs.update( - {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}) - set_weight_attrs(w2_input_scale, extra_weight_attrs) + w2_input_scale = torch.nn.Parameter(torch.empty(num_experts, + dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w2_input_global_scale", w2_input_scale) + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}) + set_weight_attrs(w2_input_scale, extra_weight_attrs) def swizzle_blockscale(self, scale: torch.tensor): assert (scale.dtype == torch.float8_e4m3fn) From f75fe6d595aee4dffecd59f81cdefa5cf57fcc48 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Wed, 25 Jun 2025 09:24:26 -0400 Subject: [PATCH 7/9] Update CMakeLists.txt --- CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index d63d831d082..402131b7a1e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -545,7 +545,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") # The MoE kernel cutlass_moe_mm requires CUDA 12.3 or later (and ONLY works # on Hopper). get_cutlass_(pplx_)moe_mm_data should only be compiled # if it's possible to compile MoE kernels that use its output. - cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a;10.0a" "${CUDA_ARCHS}") + cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a" "${CUDA_ARCHS}") if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND SCALED_MM_ARCHS) set(SRCS "csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cu" "csrc/quantization/cutlass_w8a8/moe/moe_data.cu") From 766e2b4d5af4da83321eb1ae83f77165184d127b Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Thu, 26 Jun 2025 13:55:07 +0000 Subject: [PATCH 8/9] PR comments Signed-off-by: Dipika Sikka --- .../compressed_tensors_moe.py | 56 +++++++++---------- 1 file changed, 27 insertions(+), 29 deletions(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index a6dd2af3a7e..bd5957e15ea 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -245,40 +245,38 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: layer.w2_weight_scale_2 = torch.nn.Parameter( 1 / layer.w2_weight_global_scale.data, requires_grad=False) - if not self.use_marlin: - # swizzle weight scales - layer.w13_blockscale_swizzled = torch.nn.Parameter( - self.swizzle_blockscale(layer.w13_weight_scale), - requires_grad=False) - - layer.w2_blockscale_swizzled = torch.nn.Parameter( - self.swizzle_blockscale(layer.w2_weight_scale), - requires_grad=False) + if self.use_marlin: + prepare_moe_fp4_layer_for_marlin(layer) + return - # w13 - w13_input_global_scale = layer.w13_input_global_scale.max( - dim=1).values.to(torch.float32) + # swizzle weight scales + layer.w13_blockscale_swizzled = torch.nn.Parameter( + self.swizzle_blockscale(layer.w13_weight_scale), + requires_grad=False) - layer.g1_alphas = torch.nn.Parameter( - ((1 / w13_input_global_scale) * layer.w13_weight_scale_2), - requires_grad=False) + layer.w2_blockscale_swizzled = torch.nn.Parameter( + self.swizzle_blockscale(layer.w2_weight_scale), + requires_grad=False) - layer.w13_input_scale_quant = torch.nn.Parameter( - (w13_input_global_scale), requires_grad=False) + # w13 + w13_input_global_scale = layer.w13_input_global_scale.max( + dim=1).values.to(torch.float32) - del w13_input_global_scale + layer.g1_alphas = torch.nn.Parameter( + ((1 / w13_input_global_scale) * layer.w13_weight_scale_2), + requires_grad=False) - # w2 - layer.g2_alphas = torch.nn.Parameter( - ((1 / layer.w2_input_global_scale) * - layer.w2_weight_scale_2).to(torch.float32), - requires_grad=False) + layer.w13_input_scale_quant = torch.nn.Parameter( + (w13_input_global_scale), requires_grad=False) - layer.w2_input_scale_quant = torch.nn.Parameter( - (layer.w2_input_global_scale), requires_grad=False) + # w2 + layer.g2_alphas = torch.nn.Parameter( + ((1 / layer.w2_input_global_scale) * layer.w2_weight_scale_2).to( + torch.float32), + requires_grad=False) - if self.use_marlin: - prepare_moe_fp4_layer_for_marlin(layer) + layer.w2_input_scale_quant = torch.nn.Parameter( + (layer.w2_input_global_scale), requires_grad=False) def apply( self, @@ -330,10 +328,10 @@ def apply( assert activation == "silu", "Only SiLU activation is supported." assert not apply_router_weight_on_input, ( "Router weight on input is not " - "supported for ModelOptNvFp4FusedMoE.") + "supported for CompressedTensorsW4A4MoeMethod.") assert expert_map is None, ("Expert Parallelism / expert_map " "is currently not supported for " - "ModelOptNvFp4FusedMoE.") + "CompressedTensorsW4A4MoeMethod.") from vllm.model_executor.layers.fused_moe.cutlass_moe import ( cutlass_moe_fp4) From 7f6b39b828d3e53633d4fb7163cef00ac934d585 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Thu, 26 Jun 2025 14:15:11 +0000 Subject: [PATCH 9/9] format Signed-off-by: Dipika Sikka --- .../quantization/compressed_tensors/compressed_tensors.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index 9bb239d3cde..4f87b2a44f0 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -377,7 +377,8 @@ def _get_scheme_from_parts( if is_activation_quantization_format(self.quant_format): if self._is_fp4a4_nvfp4(weight_quant, input_quant): - if cutlass_fp4_supported() or envs.VLLM_USE_NVFP4_CT_EMULATIONS: + if cutlass_fp4_supported( + ) or envs.VLLM_USE_NVFP4_CT_EMULATIONS: return CompressedTensorsW4A4Fp4() else: logger.warning_once(