diff --git a/python/llm/src/ipex_llm/transformers/models/deepseek.py b/python/llm/src/ipex_llm/transformers/models/deepseek.py index 0c876f52..6745f2dc 100644 --- a/python/llm/src/ipex_llm/transformers/models/deepseek.py +++ b/python/llm/src/ipex_llm/transformers/models/deepseek.py @@ -272,15 +272,37 @@ def deepseek_attention_forward( def moe_infer_decode(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor): - idxs = topk_ids.flatten().tolist() - outputs = [] - for i in idxs: - expert = self.experts[i] - expert_out = expert(x) - outputs.append(expert_out) - outs = torch.cat(outputs, dim=0) - reshaped_topk_weight = topk_weight.squeeze(0).unsqueeze(-1).to(outs.dtype) - final_out = (outs * reshaped_topk_weight).sum(dim=0, keepdim=True) + if ( + x.device.type == "xpu" + and x.dtype in [torch.float, torch.half] + and self.experts[0].down_proj.qtype == 2 + ): + if getattr(self, "gates", None) is None: + gate_addrs = [expert.gate_proj.weight.data_ptr() for expert in self.experts] + up_addrs = [expert.up_proj.weight.data_ptr() for expert in self.experts] + down_addrs = [expert.down_proj.weight.data_ptr() for expert in self.experts] + gates = torch.tensor(gate_addrs, dtype=torch.uint64, device=x.device) + ups = torch.tensor(up_addrs, dtype=torch.uint64, device=x.device) + downs = torch.tensor(down_addrs, dtype=torch.uint64, device=x.device) + self.register_buffer("gates", gates, persistent=False) + self.register_buffer("ups", ups, persistent=False) + self.register_buffer("downs", downs, persistent=False) + + import xe_linear + final_out = xe_linear.moe_forward_vec( + x, topk_ids, topk_weight, self.gates, self.ups, self.downs, + x.size(-1), self.experts[0].intermediate_size, 2 + ) + else: + idxs = topk_ids.flatten().tolist() + outputs = [] + for i in idxs: + expert = self.experts[i] + expert_out = expert(x) + outputs.append(expert_out) + outs = torch.cat(outputs, dim=0) + reshaped_topk_weight = topk_weight.squeeze(0).unsqueeze(-1).to(outs.dtype) + final_out = (outs * reshaped_topk_weight).sum(dim=0, keepdim=True) return final_out @@ -292,7 +314,7 @@ def deepseek_moe_forward(self, hidden_states: torch.Tensor): flat_topk_idx = topk_idx.view(-1) if not self.training: # IPEX-LLM OPT start : add special moe_infer implementation for decoding - if topk_idx.size(0) == 1: + if topk_idx.size(0) == 1 and self.ep_size == 1: y = moe_infer_decode(self, hidden_states, topk_idx, topk_weight) else: y = self.moe_infer(hidden_states, topk_idx, topk_weight)