From 5faba06409846d79c20d9ac1bf813fc6504aa66b Mon Sep 17 00:00:00 2001 From: Yishuo Wang Date: Tue, 25 Feb 2025 16:18:27 +0800 Subject: [PATCH] simple optimization for moonlight moe decoding forward (#12891) --- .../llm/src/ipex_llm/transformers/convert.py | 2 ++ .../ipex_llm/transformers/models/deepseek.py | 32 +++++++++++++++++++ 2 files changed, 34 insertions(+) diff --git a/python/llm/src/ipex_llm/transformers/convert.py b/python/llm/src/ipex_llm/transformers/convert.py index 662aa6cf..2ea563cb 100644 --- a/python/llm/src/ipex_llm/transformers/convert.py +++ b/python/llm/src/ipex_llm/transformers/convert.py @@ -2031,9 +2031,11 @@ def _optimize_post(model): from ipex_llm.transformers.models.common import rms_norm_forward from ipex_llm.transformers.models.deepseek import deepseek_model_forward from ipex_llm.transformers.models.deepseek import deepseek_attention_forward + from ipex_llm.transformers.models.deepseek import deepseek_moe_forward convert_forward(model, module.DeepseekV3RMSNorm, rms_norm_forward) convert_forward(model, module.DeepseekV3Model, deepseek_model_forward) convert_forward(model, module.DeepseekV3Attention, deepseek_attention_forward) + convert_forward(model, module.DeepseekV3MoE, deepseek_moe_forward) return model diff --git a/python/llm/src/ipex_llm/transformers/models/deepseek.py b/python/llm/src/ipex_llm/transformers/models/deepseek.py index 7cfa8cc6..0c876f52 100644 --- a/python/llm/src/ipex_llm/transformers/models/deepseek.py +++ b/python/llm/src/ipex_llm/transformers/models/deepseek.py @@ -269,3 +269,35 @@ def deepseek_attention_forward( attn_weights = None return attn_output, attn_weights, past_key_value + + +def moe_infer_decode(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor): + idxs = topk_ids.flatten().tolist() + outputs = [] + for i in idxs: + expert = self.experts[i] + expert_out = expert(x) + outputs.append(expert_out) + outs = torch.cat(outputs, dim=0) + reshaped_topk_weight = topk_weight.squeeze(0).unsqueeze(-1).to(outs.dtype) + final_out = (outs * reshaped_topk_weight).sum(dim=0, keepdim=True) + return final_out + + +def deepseek_moe_forward(self, hidden_states: torch.Tensor): + identity = hidden_states + orig_shape = hidden_states.shape + topk_idx, topk_weight = self.gate(hidden_states) + hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) + flat_topk_idx = topk_idx.view(-1) + if not self.training: + # IPEX-LLM OPT start : add special moe_infer implementation for decoding + if topk_idx.size(0) == 1: + y = moe_infer_decode(self, hidden_states, topk_idx, topk_weight) + else: + y = self.moe_infer(hidden_states, topk_idx, topk_weight) + y = y.view(*orig_shape) + # IPEX-LLM OPT end + if self.config.n_shared_experts is not None: + y = y + self.shared_experts(identity) + return y