diff --git a/vllm_ascend/models/deepseek_v2.py b/vllm_ascend/models/deepseek_v2.py index 0ae1142a3..22ee05af5 100644 --- a/vllm_ascend/models/deepseek_v2.py +++ b/vllm_ascend/models/deepseek_v2.py @@ -37,7 +37,8 @@ from vllm.config import CacheConfig, ModelConfig, VllmConfig from vllm.distributed import (get_pp_group, get_tensor_model_parallel_world_size, - get_tp_group, tensor_model_parallel_all_reduce) + get_tp_group, tensor_model_parallel_all_reduce, + tensor_model_parallel_all_gather) from vllm.distributed.parallel_state import get_dp_group from vllm.forward_context import get_forward_context from vllm.model_executor.layers.activation import SiluAndMul @@ -76,6 +77,10 @@ VLLM_ENABLE_MC2: bool = envs_ascend.VLLM_ENABLE_MC2 +IS_FC3=False +import os +if os.getenv('IS_FC3') == "1": + IS_FC3=True class CustomDeepseekV2SiluAndMul(SiluAndMul): @@ -150,12 +155,14 @@ def __init__( prefix: str = "", ) -> None: super().__init__() + if not force_replicate: self.gate_up_proj = MergedColumnParallelLinear( hidden_size, [intermediate_size] * 2, bias=False, quant_config=quant_config, - prefix=f"{prefix}.gate_up_proj") + prefix=f"{prefix}.gate_up_proj", + is_fc3 = IS_FC3) self.down_proj = RowParallelLinear(intermediate_size, hidden_size, bias=False, @@ -316,7 +323,6 @@ def forward( old_hidden_states = hidden_states use_separated_shared_experts = (self.shared_experts is not None and not self.enable_multistream_moe) - if self.tp_size > 1: if (VLLM_ENABLE_MC2 and not is_prefill) or not (self.torchair_graph_enabled or @@ -330,6 +336,13 @@ def forward( hidden_states = chunk_hidden_states[self.tp_rank] # router_logits: (num_tokens, n_experts) + # if IS_FC3 and self.tp_size > 1: + # import torchair as tng + # with tng.scope.npu_stream_switch('21'): + # hidden_states = tng.scope.npu_wait_tensor(hidden_states,hidden_states) + # router_logits, _ = self.gate(hidden_states) + # else: + # router_logits, _ = self.gate(hidden_states) router_logits, _ = self.gate(hidden_states) experts_hidden_states = self.experts( @@ -340,6 +353,7 @@ def forward( enable_force_load_balance=enable_force_load_balance, shared_experts=(self.shared_experts if not use_separated_shared_experts else None), + is_fc3=IS_FC3 ) if not isinstance(experts_hidden_states, tuple): @@ -349,6 +363,8 @@ def forward( experts_hidden_states[0] * self.routed_scaling_factor + experts_hidden_states[1]) + if IS_FC3: + hidden_states = tensor_model_parallel_all_gather(hidden_states,0) if self.tp_size > 1: if (VLLM_ENABLE_MC2 and not is_prefill) or not (self.torchair_graph_enabled or @@ -362,10 +378,20 @@ def forward( hidden_states = tensor_model_parallel_all_reduce(hidden_states) if use_separated_shared_experts: - hidden_states = hidden_states + self.shared_experts( - old_hidden_states) - - return hidden_states.view(num_tokens, hidden_size) + # if IS_FC3 and self.tp_size > 1: + # import torchair as tng + # with tng.scope.npu_stream_switch('fc1'): + # hidden_states = tng.scope.npu_wait_tensor(hidden_states,old_hidden_states) + # old_hidden_states = self.shared_experts(old_hidden_states) + # else: + + old_hidden_states = self.shared_experts(old_hidden_states) + hidden_states = hidden_states + old_hidden_states + if IS_FC3: + output = hidden_states.view(num_tokens * self.tp_size, hidden_size) + else: + output = hidden_states.view(num_tokens, hidden_size) + return output class CustomDeepseekV2MLAAttention(DeepseekV2MLAAttention): @@ -446,7 +472,8 @@ def __init__( self.hidden_size, bias=False, quant_config=quant_config, - prefix=f"{prefix}.o_proj") + prefix=f"{prefix}.o_proj", + is_fc3=IS_FC3) if rope_scaling: rope_scaling["rope_type"] = 'deepseek_yarn' @@ -649,8 +676,11 @@ def forward( residual *= 1. / self.routed_scaling_factor # Fully Connected + hidden_states, residual = self.post_attention_layernorm( - hidden_states, residual) + hidden_states, residual,is_fc3=IS_FC3) + if IS_FC3 and get_tensor_model_parallel_world_size()>1: + residual = tensor_model_parallel_all_gather(residual, dim=0) if isinstance(self.mlp, CustomDeepseekV2MoE): hidden_states = self.mlp(hidden_states, attn_metadata) diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index d6115d35c..1452a4534 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -923,7 +923,7 @@ def apply( scoring_func=scoring_func, e_score_correction_bias=e_score_correction_bias, ) - + topk_weights = topk_weights.to(x.dtype) # this is a naive implementation for experts load balance so as # to avoid accumulating too much tokens on a single rank. @@ -1111,13 +1111,15 @@ def __init__( self.ep_group = get_ep_group() self.quant_method.create_weights(layer=self, **moe_quant_params) + def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor, is_prefill: bool, enable_force_load_balance: bool = False, top_k: Optional[int] = None, - shared_experts: Optional[Any] = None): + shared_experts: Optional[Any] = None, + is_fc3=False): assert self.quant_method is not None if top_k: @@ -1146,6 +1148,7 @@ def forward(self, hidden_states, router_logits) # Matrix multiply. + e_hidden_states = self.quant_method.apply( layer=self, x=hidden_states, @@ -1165,6 +1168,7 @@ def forward(self, log2phy=self.log2phy, global_redundant_expert_num=self.global_redundant_expert_num, shared_experts=shared_experts, + is_fc3=is_fc3 ) if shared_experts is not None: diff --git a/vllm_ascend/ops/layernorm.py b/vllm_ascend/ops/layernorm.py index 8ff4c559e..9cfbde480 100644 --- a/vllm_ascend/ops/layernorm.py +++ b/vllm_ascend/ops/layernorm.py @@ -19,17 +19,27 @@ import torch from vllm.model_executor.layers.layernorm import RMSNorm - +from vllm.distributed import (get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + split_tensor_along_last_dim, + split_tensor_along_first_dim) def forward_oot( self, x: torch.Tensor, residual: Optional[torch.Tensor] = None, + is_fc3=False ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: import torch_npu + if is_fc3: + tp_size = get_tensor_model_parallel_world_size() + tp_rank = get_tensor_model_parallel_rank() + if residual is not None: + residual_split = split_tensor_along_first_dim(residual, num_partitions=tp_size)[tp_rank].contiguous() + if residual is not None: - x, _, residual = torch_npu.npu_add_rms_norm(x, residual, self.weight, + x, _, residual = torch_npu.npu_add_rms_norm(x, residual_split if is_fc3 else residual, self.weight, self.variance_epsilon) return x, residual diff --git a/vllm_ascend/patch/worker/patch_common/__init__.py b/vllm_ascend/patch/worker/patch_common/__init__.py index 08a4b608e..e8e745c02 100644 --- a/vllm_ascend/patch/worker/patch_common/__init__.py +++ b/vllm_ascend/patch/worker/patch_common/__init__.py @@ -1,22 +1,3 @@ -# -# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. -# This file is a part of the vllm-ascend project. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -# patch_utils should be the first import, because it will be used by other -# patch files. import vllm_ascend.patch.worker.patch_common.patch_utils # noqa isort:skip import vllm_ascend.patch.worker.patch_common.patch_distributed # noqa import vllm_ascend.patch.worker.patch_common.patch_eagle # noqa @@ -24,3 +5,6 @@ import vllm_ascend.patch.worker.patch_common.patch_multi_step_worker # noqa import vllm_ascend.patch.worker.patch_common.patch_sampler # noqa import vllm_ascend.patch.worker.patch_common.patch_spec_decode_worker # noqa +import vllm_ascend.patch.worker.patch_common.patch_spec_decode_worker +#mny +import vllm_ascend.patch.worker.patch_common.patch_linear diff --git a/vllm_ascend/patch/worker/patch_common/patch_distributed.py b/vllm_ascend/patch/worker/patch_common/patch_distributed.py index 846d82cec..008e7f6c5 100644 --- a/vllm_ascend/patch/worker/patch_common/patch_distributed.py +++ b/vllm_ascend/patch/worker/patch_common/patch_distributed.py @@ -19,8 +19,9 @@ import torch import vllm +from vllm.distributed import divide from vllm.distributed.parallel_state import GroupCoordinator - +from typing import Any, Deque, Dict, Optional, Sequence, Tuple class GroupCoordinatorPatch(GroupCoordinator): @@ -46,4 +47,32 @@ def all_to_all(self, gather_sizes) -vllm.distributed.parallel_state.GroupCoordinator = GroupCoordinatorPatch # Note: check the GroupCoordinator with online serving \ No newline at end of file +def split_tensor_along_first_dim( + tensor: torch.Tensor, + num_partitions: int, + contiguous_split_chunks: bool = False, +) -> Sequence[torch.Tensor]: + """ Split a tensor along its last dimension. + + Arguments: + tensor: input tensor. + num_partitions: number of partitions to split the tensor + contiguous_split_chunks: If True, make each chunk contiguous + in memory. + + Returns: + A list of Tensors + """ + # Get the size and dimension. + first_dim = 0 + first_dim_size = divide(tensor.size()[first_dim], num_partitions) + # Split. + tensor_list = torch.split(tensor, first_dim_size, dim=first_dim) + # NOTE: torch.split does not create contiguous tensors by default. + if contiguous_split_chunks: + return tuple(chunk.contiguous() for chunk in tensor_list) + + return tensor_list + +vllm.distributed.parallel_state.GroupCoordinator = GroupCoordinatorPatch # Note: check the GroupCoordinator with online serving +vllm.distributed.split_tensor_along_first_dim = split_tensor_along_first_dim diff --git a/vllm_ascend/patch/worker/patch_common/patch_linear.py b/vllm_ascend/patch/worker/patch_common/patch_linear.py new file mode 100644 index 000000000..dad964978 --- /dev/null +++ b/vllm_ascend/patch/worker/patch_common/patch_linear.py @@ -0,0 +1,164 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import vllm +from vllm.model_executor.layers.linear import RowParallelLinear, ColumnParallelLinear, MergedColumnParallelLinear +import itertools +from abc import abstractmethod +from typing import Any, Literal, Optional, Union + +import torch +import torch.nn as nn +from torch.nn.parameter import Parameter, UninitializedParameter +from vllm.distributed import (divide, get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + split_tensor_along_last_dim, + tensor_model_parallel_all_gather, + tensor_model_parallel_all_reduce, + tensor_model_parallel_reduce_scatter) +from vllm.logger import init_logger +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig, QuantizeMethodBase) +from vllm.model_executor.layers.utils import dispatch_unquantized_gemm +# yapf: disable +from vllm.model_executor.parameter import (BasevLLMParameter, + BlockQuantScaleParameter, + PackedColumnParameter, + PackedvLLMParameter, + PerTensorScaleParameter, + RowvLLMParameter) +# yapf: enable +from vllm.model_executor.utils import set_weight_attrs +class RowParallelLinearPatch(RowParallelLinear): + def __init__( + self, + input_size: int, + output_size: int, + bias: bool = True, + input_is_parallel: bool = True, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + reduce_results: bool = True, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + *, + return_bias: bool = True, + + is_fc3:bool=False + ): + super().__init__( + input_size, + output_size, + bias, + input_is_parallel, + skip_bias_add, + params_dtype, + reduce_results, + quant_config, + prefix, + return_bias=return_bias) + + self.is_fc3 = is_fc3 + + def forward( + self, input_ + ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: + if self.input_is_parallel: + input_parallel = input_ + else: + tp_rank = get_tensor_model_parallel_rank() + splitted_input = split_tensor_along_last_dim( + input_, num_partitions=self.tp_size) + input_parallel = splitted_input[tp_rank].contiguous() + + # Matrix multiply. + assert self.quant_method is not None + # Only fuse bias add into GEMM for rank 0 (this ensures that + # bias will not get added more than once in TP>1 case) + bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias + output_parallel = self.quant_method.apply(self, + input_parallel, + bias=bias_) + + if self.is_fc3 and self.tp_size > 1: + output = tensor_model_parallel_reduce_scatter(output_parallel,0) + elif self.reduce_results and self.tp_size > 1: + output = tensor_model_parallel_all_reduce(output_parallel) + + # if self.reduce_results and self.tp_size > 1: + # output = tensor_model_parallel_all_reduce(output_parallel) + else: + output = output_parallel + + output_bias = self.bias if self.skip_bias_add else None + if not self.return_bias: + return output + return output, output_bias + + + + +class MergedColumnParallelLinearPatch(MergedColumnParallelLinear): + def __init__( + self, + input_size: int, + output_sizes: list[int], + bias: bool = True, + gather_output: bool = False, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + *, + return_bias: bool = True, + is_fc3:bool = False, + ): + self.output_sizes = output_sizes + tp_size = get_tensor_model_parallel_world_size() + assert all(output_size % tp_size == 0 for output_size in output_sizes) + super().__init__( + input_size, + output_sizes, + bias, + gather_output, + skip_bias_add, + params_dtype, + quant_config, + prefix, + return_bias = return_bias, + ) + self.is_fc3 = is_fc3 + + def forward( + self, input_ + ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: + bias = self.bias if not self.skip_bias_add else None + # Matrix multiply. + assert self.quant_method is not None + output_parallel = self.quant_method.apply(self, input_, bias, is_fc3=self.is_fc3) + if self.gather_output: + # All-gather across the partitions. + output = tensor_model_parallel_all_gather(output_parallel) + else: + output = output_parallel + output_bias = self.bias if self.skip_bias_add else None + if not self.return_bias: + return output + return output, output_bias + + +vllm.model_executor.layers.linear.RowParallelLinear = RowParallelLinearPatch +vllm.model_executor.layers.linear.MergedColumnParallelLinear = MergedColumnParallelLinearPatch diff --git a/vllm_ascend/quantization/func_wrapper.py b/vllm_ascend/quantization/func_wrapper.py index 77ecca2b1..1a48c9baf 100644 --- a/vllm_ascend/quantization/func_wrapper.py +++ b/vllm_ascend/quantization/func_wrapper.py @@ -43,6 +43,7 @@ def _rmsnorm_forward_oot( self, x: torch.Tensor, residual: Optional[torch.Tensor] = None, + is_fc3=False ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: if not self.ignore_anti: if residual is not None: @@ -66,10 +67,11 @@ def _rmsnorm_forward_oot( ) return out + if residual is not None: - x, residual = func(self, x, residual) - return x.add_(self.bias), residual + x, residual = func(self, x, residual,is_fc3=is_fc3) + return x.add_(self.bias), residual return func(self, x).add_(self.bias) return _rmsnorm_forward_oot diff --git a/vllm_ascend/quantization/quant_config.py b/vllm_ascend/quantization/quant_config.py index 3567dba35..f4d88dc71 100644 --- a/vllm_ascend/quantization/quant_config.py +++ b/vllm_ascend/quantization/quant_config.py @@ -23,12 +23,13 @@ from typing import Any, Callable, Dict, List, Mapping, Optional import torch -from vllm.distributed import get_tensor_model_parallel_rank +from vllm.distributed import get_tensor_model_parallel_rank,get_tensor_model_parallel_world_size from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase, - FusedMoeWeightScaleSupported) + FusedMoeWeightScaleSupported) from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, RowParallelLinear, - UnquantizedLinearMethod) + UnquantizedLinearMethod, + MergedColumnParallelLinear) from vllm.model_executor.layers.quantization import \ register_quantization_config from vllm.model_executor.layers.quantization.base_config import ( @@ -203,10 +204,14 @@ def apply( layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None, + is_fc3: bool=False ) -> torch.Tensor: if isinstance(layer, RowParallelLinear): tp_rank = get_tensor_model_parallel_rank() return self.quant_method.apply(layer, x, bias, tp_rank) + elif isinstance(layer, MergedColumnParallelLinear): + is_tp = True if get_tensor_model_parallel_world_size() > 1 else False + return self.quant_method.apply(layer, x, bias, is_fc3=(is_fc3 and is_tp)) return self.quant_method.apply(layer, x, bias) diff --git a/vllm_ascend/quantization/w8a8_dynamic.py b/vllm_ascend/quantization/w8a8_dynamic.py index 66a0a302c..8dfdb5ce5 100644 --- a/vllm_ascend/quantization/w8a8_dynamic.py +++ b/vllm_ascend/quantization/w8a8_dynamic.py @@ -20,8 +20,7 @@ import torch import torch.distributed as dist import torch_npu -from vllm.distributed import GroupCoordinator - +from vllm.distributed import GroupCoordinator,get_tensor_model_parallel_world_size,tensor_model_parallel_all_gather import vllm_ascend.envs as envs_ascend from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.distributed.parallel_state import get_ep_group @@ -31,7 +30,6 @@ VLLM_ENABLE_MC2: bool = envs_ascend.VLLM_ENABLE_MC2 - def apply_mlp(hidden_states: torch.Tensor, w1: torch.Tensor, w1_scale: torch.Tensor, @@ -459,11 +457,9 @@ def fused_experts(hidden_states: torch.Tensor, final_hidden_states = final_hidden_states.view(original_shape) return final_hidden_states - class AscendW8A8DynamicLinearMethod: """Linear method for Ascend W8A8_DYNAMIC. """ - def __init__(self): self.transpose_weight = True @@ -499,6 +495,7 @@ def apply( x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], bias: Optional[torch.Tensor] = None, tp_rank: Optional[int] = 0, + is_fc3:bool=False ) -> torch.Tensor: config = getattr(layer, "_ascend_quant_config", {}) if not isinstance(x, tuple): @@ -510,6 +507,10 @@ def apply( f"for pre-quantized input, got config [{config}]") output_dtype = config["output_dtype"] quantized_x, dynamic_scale = x + is_tp = True if get_tensor_model_parallel_world_size() > 1 else False + if is_fc3 and is_tp: + quantized_x = tensor_model_parallel_all_gather(quantized_x,0) + dynamic_scale = tensor_model_parallel_all_gather(dynamic_scale) pertoken_scale = (dynamic_scale if config.get("pertoken_scale", True) else None) @@ -652,6 +653,7 @@ def apply( e_score_correction_bias=e_score_correction_bias, ) + # this is a naive implementation for experts load balance so as # to avoid accumulating too much tokens on a single rank. # currently it is only activated when doing profile runs.