add fuse moe optimization for moonlight (#12898)

This commit is contained in:
Yishuo Wang 2025-02-27 09:15:24 +08:00 committed by GitHub
parent ad65e2b03a
commit be1f073866
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -272,15 +272,37 @@ def deepseek_attention_forward(
def moe_infer_decode(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor):
idxs = topk_ids.flatten().tolist()
outputs = []
for i in idxs:
expert = self.experts[i]
expert_out = expert(x)
outputs.append(expert_out)
outs = torch.cat(outputs, dim=0)
reshaped_topk_weight = topk_weight.squeeze(0).unsqueeze(-1).to(outs.dtype)
final_out = (outs * reshaped_topk_weight).sum(dim=0, keepdim=True)
if (
x.device.type == "xpu"
and x.dtype in [torch.float, torch.half]
and self.experts[0].down_proj.qtype == 2
):
if getattr(self, "gates", None) is None:
gate_addrs = [expert.gate_proj.weight.data_ptr() for expert in self.experts]
up_addrs = [expert.up_proj.weight.data_ptr() for expert in self.experts]
down_addrs = [expert.down_proj.weight.data_ptr() for expert in self.experts]
gates = torch.tensor(gate_addrs, dtype=torch.uint64, device=x.device)
ups = torch.tensor(up_addrs, dtype=torch.uint64, device=x.device)
downs = torch.tensor(down_addrs, dtype=torch.uint64, device=x.device)
self.register_buffer("gates", gates, persistent=False)
self.register_buffer("ups", ups, persistent=False)
self.register_buffer("downs", downs, persistent=False)
import xe_linear
final_out = xe_linear.moe_forward_vec(
x, topk_ids, topk_weight, self.gates, self.ups, self.downs,
x.size(-1), self.experts[0].intermediate_size, 2
)
else:
idxs = topk_ids.flatten().tolist()
outputs = []
for i in idxs:
expert = self.experts[i]
expert_out = expert(x)
outputs.append(expert_out)
outs = torch.cat(outputs, dim=0)
reshaped_topk_weight = topk_weight.squeeze(0).unsqueeze(-1).to(outs.dtype)
final_out = (outs * reshaped_topk_weight).sum(dim=0, keepdim=True)
return final_out
@ -292,7 +314,7 @@ def deepseek_moe_forward(self, hidden_states: torch.Tensor):
flat_topk_idx = topk_idx.view(-1)
if not self.training:
# IPEX-LLM OPT start : add special moe_infer implementation for decoding
if topk_idx.size(0) == 1:
if topk_idx.size(0) == 1 and self.ep_size == 1:
y = moe_infer_decode(self, hidden_states, topk_idx, topk_weight)
else:
y = self.moe_infer(hidden_states, topk_idx, topk_weight)