From 39e360fe9da0201a3598e77a64d3b3ee3b40e853 Mon Sep 17 00:00:00 2001 From: Yishuo Wang Date: Fri, 28 Feb 2025 13:25:56 +0800 Subject: [PATCH] add grouped topk optimization for moonlight (#12903) --- .../ipex_llm/transformers/models/deepseek.py | 27 ++++++++++++++++--- 1 file changed, 24 insertions(+), 3 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/models/deepseek.py b/python/llm/src/ipex_llm/transformers/models/deepseek.py index 6745f2dc..c5edf60e 100644 --- a/python/llm/src/ipex_llm/transformers/models/deepseek.py +++ b/python/llm/src/ipex_llm/transformers/models/deepseek.py @@ -271,6 +271,25 @@ def deepseek_attention_forward( return attn_output, attn_weights, past_key_value +def fuse_gate_forward(self, x: torch.Tensor): + if x.device.type == "xpu" and x.dtype in [torch.float, torch.half]: + x = x.view(-1, x.size(-1)) + logits = torch.nn.functional.linear( + x.type(torch.float32), self.weight.type(torch.float32), None + ) + scores = logits.sigmoid() + + import xe_addons + topk_idx, topk_weight = xe_addons.moe_group_topk( + scores, self.e_score_correction_bias, + self.n_group, 2, self.topk_group, self.top_k, + self.top_k > 1 and self.norm_topk_prob, 1e-20, self.routed_scaling_factor + ) + else: + topk_idx, topk_weight = self(x) + return topk_idx, topk_weight.to(x.dtype) + + def moe_infer_decode(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor): if ( x.device.type == "xpu" @@ -301,7 +320,7 @@ def moe_infer_decode(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: expert_out = expert(x) outputs.append(expert_out) outs = torch.cat(outputs, dim=0) - reshaped_topk_weight = topk_weight.squeeze(0).unsqueeze(-1).to(outs.dtype) + reshaped_topk_weight = topk_weight.squeeze(0).unsqueeze(-1) final_out = (outs * reshaped_topk_weight).sum(dim=0, keepdim=True) return final_out @@ -309,11 +328,13 @@ def moe_infer_decode(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: def deepseek_moe_forward(self, hidden_states: torch.Tensor): identity = hidden_states orig_shape = hidden_states.shape - topk_idx, topk_weight = self.gate(hidden_states) + # IPEX-LLM OPT start: fuse grouped topk in gate forward + topk_idx, topk_weight = fuse_gate_forward(self.gate, hidden_states) + # IPEX-LLM OPT end hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) flat_topk_idx = topk_idx.view(-1) if not self.training: - # IPEX-LLM OPT start : add special moe_infer implementation for decoding + # IPEX-LLM OPT start: add special moe_infer implementation for decoding if topk_idx.size(0) == 1 and self.ep_size == 1: y = moe_infer_decode(self, hidden_states, topk_idx, topk_weight) else: