Skip to content

add qwen3 local chat #1280

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions ktransformers/local_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from ktransformers.optimize.optimize import optimize_and_load_gguf
from ktransformers.models.modeling_deepseek import DeepseekV2ForCausalLM
from ktransformers.models.modeling_qwen2_moe import Qwen2MoeForCausalLM
from ktransformers.models.modeling_qwen3_moe import Qwen3MoeForCausalLM
from ktransformers.models.modeling_deepseek_v3 import DeepseekV3ForCausalLM
from ktransformers.models.modeling_llama import LlamaForCausalLM
from ktransformers.models.modeling_mixtral import MixtralForCausalLM
Expand All @@ -37,6 +38,7 @@
"DeepseekV2ForCausalLM": DeepseekV2ForCausalLM,
"DeepseekV3ForCausalLM": DeepseekV3ForCausalLM,
"Qwen2MoeForCausalLM": Qwen2MoeForCausalLM,
"Qwen3MoeForCausalLM": Qwen3MoeForCausalLM,
"LlamaForCausalLM": LlamaForCausalLM,
"MixtralForCausalLM": MixtralForCausalLM,
}
Expand Down
6 changes: 4 additions & 2 deletions ktransformers/models/modeling_qwen3_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,9 +185,10 @@ def forward(
hidden_states: torch.Tensor,
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor],
position_ids: Optional[torch.LongTensor],
past_key_value: Optional[Cache] = None,
cache_position: Optional[torch.LongTensor] = None,
# **kwargs: Unpack[FlashAttentionKwargs],
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)
Expand All @@ -196,7 +197,8 @@ def forward(
key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)

cos, sin = position_embeddings
# cos, sin = position_embeddings
cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

if past_key_value is not None:
Expand Down
106 changes: 105 additions & 1 deletion ktransformers/operators/experts.py
Original file line number Diff line number Diff line change
Expand Up @@ -1494,4 +1494,108 @@ def moe_infer(self, x, topk_ids, topk_weight):
.sum(dim=1)
.type(new_x.dtype)
)
return final_out
return final_out

class KQwen3MoeSparseMoeBlock(BaseInjectedModule, Qwen3MoeSparseMoeBlock):
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
""" """
orig_shape = hidden_states.shape
batch_size, sequence_length, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
# router_logits: (batch * sequence_length, n_experts)
router_logits = self.gate(hidden_states)

routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
if self.norm_topk_prob:
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
# we cast back to the input dtype
routing_weights = routing_weights.to(hidden_states.dtype)

if sequence_length == 1 and hasattr(self.experts.generate_experts, "submit_for_one_decode"):
self.experts.generate_experts.submit_for_one_decode(hidden_states[0], selected_experts[0], routing_weights[0])
# shared_expert_output = self.shared_expert(hidden_states)
# shared_expert_output = F.sigmoid(self.shared_expert_gate(hidden_states)) * shared_expert_output
y = self.experts.generate_experts.sync_for_one_decode().unsqueeze(0)
# y += shared_expert_output
y.resize_(*orig_shape)
return y, router_logits

hidden_states_expert = hidden_states.to(self.experts.device) if isinstance(self.experts, KExpertsBase) else hidden_states.cpu()
selected_experts_expert = selected_experts.to(self.experts.device) if isinstance(self.experts, KExpertsBase) else selected_experts.cpu()
routing_weights_expert = routing_weights.to(self.experts.device) if isinstance(self.experts, KExpertsBase) else routing_weights.cpu()

# shared_expert_output = self.shared_expert(hidden_states)
# shared_expert_output = (
# F.sigmoid(self.shared_expert_gate(hidden_states)) * shared_expert_output
# )

if isinstance(self.experts, KExpertsBase):
y = (
self.moe_kexperts(
hidden_states_expert, selected_experts_expert, routing_weights_expert
)
.view(*orig_shape)
.to(device=hidden_states.device)
)
elif hidden_states_expert.size(0) > 10:
y = self.moe_infer(
hidden_states_expert, selected_experts_expert, routing_weights_expert, orig_shape
).to(device=hidden_states.device)
else:
y = self.moe_infer_simple(
hidden_states_expert, selected_experts_expert, routing_weights_expert
).to(device=hidden_states.device)
# y += shared_expert_output
y.resize_(*orig_shape)
return y, router_logits

@torch.no_grad()
def moe_kexperts(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor) -> torch.Tensor:
outs = self.experts(x, topk_ids, topk_weight)
return outs

@torch.no_grad()
# TODO may bugs here
def moe_infer_simple(self, hidden_states_cpu: torch.Tensor, selected_experts_cpu: torch.Tensor, routing_weights_cpu: torch.Tensor) -> torch.Tensor:
'''
hidden_states_cpu: [num_tokens, hidden_size]
topk_ids, topk_weight: [num_tokens, num_selected_experts]
'''
outs = torch.zeros_like(hidden_states_cpu)
for token_idx in range(selected_experts_cpu.size(0)):
for expert_idx in range(selected_experts_cpu.size(1)):
expert = self.experts[selected_experts_cpu[token_idx, expert_idx]]
outs[token_idx] += expert.forward(hidden_states_cpu[token_idx]) * routing_weights_cpu[token_idx, expert_idx]
return outs

@torch.no_grad()
# TODO may bugs here
def moe_infer(self, hidden_states_cpu: torch.Tensor, selected_experts_cpu: torch.Tensor, routing_weights_cpu: torch.Tensor, orig_shape: tuple) -> torch.Tensor:

batch_size, sequence_length, hidden_dim = orig_shape

final_hidden_states = torch.zeros(
(batch_size * sequence_length, hidden_dim), dtype=hidden_states_cpu.dtype, device=hidden_states_cpu.device
)

# One hot encode the selected experts to create an expert mask
# this will be used to easily index which expert is going to be sollicitated
expert_mask = torch.nn.functional.one_hot(selected_experts_cpu, num_classes=self.num_experts).permute(2, 1, 0)

# Loop over all available experts in the model and perform the computation on each expert
for expert_idx in range(self.num_experts):
expert_layer = self.experts[expert_idx]
idx, top_x = torch.where(expert_mask[expert_idx])

# Index the correct hidden states and compute the expert hidden state for
# the current expert. We need to make sure to multiply the output hidden
# states by `routing_weights` on the corresponding tokens (top-1 and top-2)
current_state = hidden_states_cpu[None, top_x].reshape(-1, hidden_dim)
current_hidden_states = expert_layer.forward(current_state) * routing_weights_cpu[top_x, idx, None]

# However `index_add_` only support torch tensors for indexing so we'll use
# the `top_x` tensor here.
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states_cpu.dtype))

return final_hidden_states
Loading