simple optimization for moonlight moe decoding forward (#12891)

This commit is contained in:
Yishuo Wang 2025-02-25 16:18:27 +08:00 committed by GitHub
parent ae9f5320da
commit 5faba06409
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 34 additions and 0 deletions

View file

@ -2031,9 +2031,11 @@ def _optimize_post(model):
from ipex_llm.transformers.models.common import rms_norm_forward from ipex_llm.transformers.models.common import rms_norm_forward
from ipex_llm.transformers.models.deepseek import deepseek_model_forward from ipex_llm.transformers.models.deepseek import deepseek_model_forward
from ipex_llm.transformers.models.deepseek import deepseek_attention_forward from ipex_llm.transformers.models.deepseek import deepseek_attention_forward
from ipex_llm.transformers.models.deepseek import deepseek_moe_forward
convert_forward(model, module.DeepseekV3RMSNorm, rms_norm_forward) convert_forward(model, module.DeepseekV3RMSNorm, rms_norm_forward)
convert_forward(model, module.DeepseekV3Model, deepseek_model_forward) convert_forward(model, module.DeepseekV3Model, deepseek_model_forward)
convert_forward(model, module.DeepseekV3Attention, deepseek_attention_forward) convert_forward(model, module.DeepseekV3Attention, deepseek_attention_forward)
convert_forward(model, module.DeepseekV3MoE, deepseek_moe_forward)
return model return model

View file

@ -269,3 +269,35 @@ def deepseek_attention_forward(
attn_weights = None attn_weights = None
return attn_output, attn_weights, past_key_value return attn_output, attn_weights, past_key_value
def moe_infer_decode(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor):
idxs = topk_ids.flatten().tolist()
outputs = []
for i in idxs:
expert = self.experts[i]
expert_out = expert(x)
outputs.append(expert_out)
outs = torch.cat(outputs, dim=0)
reshaped_topk_weight = topk_weight.squeeze(0).unsqueeze(-1).to(outs.dtype)
final_out = (outs * reshaped_topk_weight).sum(dim=0, keepdim=True)
return final_out
def deepseek_moe_forward(self, hidden_states: torch.Tensor):
identity = hidden_states
orig_shape = hidden_states.shape
topk_idx, topk_weight = self.gate(hidden_states)
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
flat_topk_idx = topk_idx.view(-1)
if not self.training:
# IPEX-LLM OPT start : add special moe_infer implementation for decoding
if topk_idx.size(0) == 1:
y = moe_infer_decode(self, hidden_states, topk_idx, topk_weight)
else:
y = self.moe_infer(hidden_states, topk_idx, topk_weight)
y = y.view(*orig_shape)
# IPEX-LLM OPT end
if self.config.n_shared_experts is not None:
y = y + self.shared_experts(identity)
return y