Skip to content

[WIP]FC3 #1377

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 39 additions & 9 deletions vllm_ascend/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@
from vllm.config import CacheConfig, ModelConfig, VllmConfig
from vllm.distributed import (get_pp_group,
get_tensor_model_parallel_world_size,
get_tp_group, tensor_model_parallel_all_reduce)
get_tp_group, tensor_model_parallel_all_reduce,
tensor_model_parallel_all_gather)
from vllm.distributed.parallel_state import get_dp_group
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.activation import SiluAndMul
Expand Down Expand Up @@ -76,6 +77,10 @@

VLLM_ENABLE_MC2: bool = envs_ascend.VLLM_ENABLE_MC2

IS_FC3=False
import os
if os.getenv('IS_FC3') == "1":
IS_FC3=True

class CustomDeepseekV2SiluAndMul(SiluAndMul):

Expand Down Expand Up @@ -150,12 +155,14 @@ def __init__(
prefix: str = "",
) -> None:
super().__init__()

if not force_replicate:
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj")
prefix=f"{prefix}.gate_up_proj",
is_fc3 = IS_FC3)
self.down_proj = RowParallelLinear(intermediate_size,
hidden_size,
bias=False,
Expand Down Expand Up @@ -316,7 +323,6 @@ def forward(
old_hidden_states = hidden_states
use_separated_shared_experts = (self.shared_experts is not None
and not self.enable_multistream_moe)

if self.tp_size > 1:
if (VLLM_ENABLE_MC2
and not is_prefill) or not (self.torchair_graph_enabled or
Expand All @@ -330,6 +336,13 @@ def forward(
hidden_states = chunk_hidden_states[self.tp_rank]

# router_logits: (num_tokens, n_experts)
# if IS_FC3 and self.tp_size > 1:
# import torchair as tng
# with tng.scope.npu_stream_switch('21'):
# hidden_states = tng.scope.npu_wait_tensor(hidden_states,hidden_states)
# router_logits, _ = self.gate(hidden_states)
# else:
# router_logits, _ = self.gate(hidden_states)
router_logits, _ = self.gate(hidden_states)

experts_hidden_states = self.experts(
Expand All @@ -340,6 +353,7 @@ def forward(
enable_force_load_balance=enable_force_load_balance,
shared_experts=(self.shared_experts
if not use_separated_shared_experts else None),
is_fc3=IS_FC3
)

if not isinstance(experts_hidden_states, tuple):
Expand All @@ -349,6 +363,8 @@ def forward(
experts_hidden_states[0] * self.routed_scaling_factor +
experts_hidden_states[1])

if IS_FC3:
hidden_states = tensor_model_parallel_all_gather(hidden_states,0)
if self.tp_size > 1:
if (VLLM_ENABLE_MC2
and not is_prefill) or not (self.torchair_graph_enabled or
Expand All @@ -362,10 +378,20 @@ def forward(
hidden_states = tensor_model_parallel_all_reduce(hidden_states)

if use_separated_shared_experts:
hidden_states = hidden_states + self.shared_experts(
old_hidden_states)

return hidden_states.view(num_tokens, hidden_size)
# if IS_FC3 and self.tp_size > 1:
# import torchair as tng
# with tng.scope.npu_stream_switch('fc1'):
# hidden_states = tng.scope.npu_wait_tensor(hidden_states,old_hidden_states)
# old_hidden_states = self.shared_experts(old_hidden_states)
# else:

old_hidden_states = self.shared_experts(old_hidden_states)
hidden_states = hidden_states + old_hidden_states
if IS_FC3:
output = hidden_states.view(num_tokens * self.tp_size, hidden_size)
else:
output = hidden_states.view(num_tokens, hidden_size)
return output


class CustomDeepseekV2MLAAttention(DeepseekV2MLAAttention):
Expand Down Expand Up @@ -446,7 +472,8 @@ def __init__(
self.hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.o_proj")
prefix=f"{prefix}.o_proj",
is_fc3=IS_FC3)

if rope_scaling:
rope_scaling["rope_type"] = 'deepseek_yarn'
Expand Down Expand Up @@ -649,8 +676,11 @@ def forward(
residual *= 1. / self.routed_scaling_factor

# Fully Connected

hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
hidden_states, residual,is_fc3=IS_FC3)
if IS_FC3 and get_tensor_model_parallel_world_size()>1:
residual = tensor_model_parallel_all_gather(residual, dim=0)

if isinstance(self.mlp, CustomDeepseekV2MoE):
hidden_states = self.mlp(hidden_states, attn_metadata)
Expand Down
8 changes: 6 additions & 2 deletions vllm_ascend/ops/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -923,7 +923,7 @@ def apply(
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
)

topk_weights = topk_weights.to(x.dtype)
# this is a naive implementation for experts load balance so as
# to avoid accumulating too much tokens on a single rank.
Expand Down Expand Up @@ -1111,13 +1111,15 @@ def __init__(
self.ep_group = get_ep_group()
self.quant_method.create_weights(layer=self, **moe_quant_params)


def forward(self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_prefill: bool,
enable_force_load_balance: bool = False,
top_k: Optional[int] = None,
shared_experts: Optional[Any] = None):
shared_experts: Optional[Any] = None,
is_fc3=False):
assert self.quant_method is not None

if top_k:
Expand Down Expand Up @@ -1146,6 +1148,7 @@ def forward(self,
hidden_states, router_logits)

# Matrix multiply.

e_hidden_states = self.quant_method.apply(
layer=self,
x=hidden_states,
Expand All @@ -1165,6 +1168,7 @@ def forward(self,
log2phy=self.log2phy,
global_redundant_expert_num=self.global_redundant_expert_num,
shared_experts=shared_experts,
is_fc3=is_fc3
)

if shared_experts is not None:
Expand Down
14 changes: 12 additions & 2 deletions vllm_ascend/ops/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,27 @@

import torch
from vllm.model_executor.layers.layernorm import RMSNorm

from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
split_tensor_along_last_dim,
split_tensor_along_first_dim)

def forward_oot(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
is_fc3=False
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
import torch_npu

if is_fc3:
tp_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()
if residual is not None:
residual_split = split_tensor_along_first_dim(residual, num_partitions=tp_size)[tp_rank].contiguous()

if residual is not None:
x, _, residual = torch_npu.npu_add_rms_norm(x, residual, self.weight,
x, _, residual = torch_npu.npu_add_rms_norm(x, residual_split if is_fc3 else residual, self.weight,
self.variance_epsilon)
return x, residual

Expand Down
22 changes: 3 additions & 19 deletions vllm_ascend/patch/worker/patch_common/__init__.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,10 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

# patch_utils should be the first import, because it will be used by other
# patch files.
import vllm_ascend.patch.worker.patch_common.patch_utils # noqa isort:skip
import vllm_ascend.patch.worker.patch_common.patch_distributed # noqa
import vllm_ascend.patch.worker.patch_common.patch_eagle # noqa
import vllm_ascend.patch.worker.patch_common.patch_minicpm # noqa
import vllm_ascend.patch.worker.patch_common.patch_multi_step_worker # noqa
import vllm_ascend.patch.worker.patch_common.patch_sampler # noqa
import vllm_ascend.patch.worker.patch_common.patch_spec_decode_worker # noqa
import vllm_ascend.patch.worker.patch_common.patch_spec_decode_worker
#mny
import vllm_ascend.patch.worker.patch_common.patch_linear
33 changes: 31 additions & 2 deletions vllm_ascend/patch/worker/patch_common/patch_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@

import torch
import vllm
from vllm.distributed import divide
from vllm.distributed.parallel_state import GroupCoordinator

from typing import Any, Deque, Dict, Optional, Sequence, Tuple

class GroupCoordinatorPatch(GroupCoordinator):

Expand All @@ -46,4 +47,32 @@ def all_to_all(self,
gather_sizes)


vllm.distributed.parallel_state.GroupCoordinator = GroupCoordinatorPatch # Note: check the GroupCoordinator with online serving
def split_tensor_along_first_dim(
tensor: torch.Tensor,
num_partitions: int,
contiguous_split_chunks: bool = False,
) -> Sequence[torch.Tensor]:
""" Split a tensor along its last dimension.

Arguments:
tensor: input tensor.
num_partitions: number of partitions to split the tensor
contiguous_split_chunks: If True, make each chunk contiguous
in memory.

Returns:
A list of Tensors
"""
# Get the size and dimension.
first_dim = 0
first_dim_size = divide(tensor.size()[first_dim], num_partitions)
# Split.
tensor_list = torch.split(tensor, first_dim_size, dim=first_dim)
# NOTE: torch.split does not create contiguous tensors by default.
if contiguous_split_chunks:
return tuple(chunk.contiguous() for chunk in tensor_list)

return tensor_list

vllm.distributed.parallel_state.GroupCoordinator = GroupCoordinatorPatch # Note: check the GroupCoordinator with online serving
vllm.distributed.split_tensor_along_first_dim = split_tensor_along_first_dim
Loading