simple optimization for moonlight moe decoding forward (#12891)
This commit is contained in:
parent
ae9f5320da
commit
5faba06409
2 changed files with 34 additions and 0 deletions
|
|
@ -2031,9 +2031,11 @@ def _optimize_post(model):
|
||||||
from ipex_llm.transformers.models.common import rms_norm_forward
|
from ipex_llm.transformers.models.common import rms_norm_forward
|
||||||
from ipex_llm.transformers.models.deepseek import deepseek_model_forward
|
from ipex_llm.transformers.models.deepseek import deepseek_model_forward
|
||||||
from ipex_llm.transformers.models.deepseek import deepseek_attention_forward
|
from ipex_llm.transformers.models.deepseek import deepseek_attention_forward
|
||||||
|
from ipex_llm.transformers.models.deepseek import deepseek_moe_forward
|
||||||
convert_forward(model, module.DeepseekV3RMSNorm, rms_norm_forward)
|
convert_forward(model, module.DeepseekV3RMSNorm, rms_norm_forward)
|
||||||
convert_forward(model, module.DeepseekV3Model, deepseek_model_forward)
|
convert_forward(model, module.DeepseekV3Model, deepseek_model_forward)
|
||||||
convert_forward(model, module.DeepseekV3Attention, deepseek_attention_forward)
|
convert_forward(model, module.DeepseekV3Attention, deepseek_attention_forward)
|
||||||
|
convert_forward(model, module.DeepseekV3MoE, deepseek_moe_forward)
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -269,3 +269,35 @@ def deepseek_attention_forward(
|
||||||
attn_weights = None
|
attn_weights = None
|
||||||
|
|
||||||
return attn_output, attn_weights, past_key_value
|
return attn_output, attn_weights, past_key_value
|
||||||
|
|
||||||
|
|
||||||
|
def moe_infer_decode(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor):
|
||||||
|
idxs = topk_ids.flatten().tolist()
|
||||||
|
outputs = []
|
||||||
|
for i in idxs:
|
||||||
|
expert = self.experts[i]
|
||||||
|
expert_out = expert(x)
|
||||||
|
outputs.append(expert_out)
|
||||||
|
outs = torch.cat(outputs, dim=0)
|
||||||
|
reshaped_topk_weight = topk_weight.squeeze(0).unsqueeze(-1).to(outs.dtype)
|
||||||
|
final_out = (outs * reshaped_topk_weight).sum(dim=0, keepdim=True)
|
||||||
|
return final_out
|
||||||
|
|
||||||
|
|
||||||
|
def deepseek_moe_forward(self, hidden_states: torch.Tensor):
|
||||||
|
identity = hidden_states
|
||||||
|
orig_shape = hidden_states.shape
|
||||||
|
topk_idx, topk_weight = self.gate(hidden_states)
|
||||||
|
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
||||||
|
flat_topk_idx = topk_idx.view(-1)
|
||||||
|
if not self.training:
|
||||||
|
# IPEX-LLM OPT start : add special moe_infer implementation for decoding
|
||||||
|
if topk_idx.size(0) == 1:
|
||||||
|
y = moe_infer_decode(self, hidden_states, topk_idx, topk_weight)
|
||||||
|
else:
|
||||||
|
y = self.moe_infer(hidden_states, topk_idx, topk_weight)
|
||||||
|
y = y.view(*orig_shape)
|
||||||
|
# IPEX-LLM OPT end
|
||||||
|
if self.config.n_shared_experts is not None:
|
||||||
|
y = y + self.shared_experts(identity)
|
||||||
|
return y
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue