add fuse moe optimization for moonlight (#12898)
This commit is contained in:
parent
ad65e2b03a
commit
be1f073866
1 changed files with 32 additions and 10 deletions
|
|
@ -272,15 +272,37 @@ def deepseek_attention_forward(
|
|||
|
||||
|
||||
def moe_infer_decode(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor):
|
||||
idxs = topk_ids.flatten().tolist()
|
||||
outputs = []
|
||||
for i in idxs:
|
||||
expert = self.experts[i]
|
||||
expert_out = expert(x)
|
||||
outputs.append(expert_out)
|
||||
outs = torch.cat(outputs, dim=0)
|
||||
reshaped_topk_weight = topk_weight.squeeze(0).unsqueeze(-1).to(outs.dtype)
|
||||
final_out = (outs * reshaped_topk_weight).sum(dim=0, keepdim=True)
|
||||
if (
|
||||
x.device.type == "xpu"
|
||||
and x.dtype in [torch.float, torch.half]
|
||||
and self.experts[0].down_proj.qtype == 2
|
||||
):
|
||||
if getattr(self, "gates", None) is None:
|
||||
gate_addrs = [expert.gate_proj.weight.data_ptr() for expert in self.experts]
|
||||
up_addrs = [expert.up_proj.weight.data_ptr() for expert in self.experts]
|
||||
down_addrs = [expert.down_proj.weight.data_ptr() for expert in self.experts]
|
||||
gates = torch.tensor(gate_addrs, dtype=torch.uint64, device=x.device)
|
||||
ups = torch.tensor(up_addrs, dtype=torch.uint64, device=x.device)
|
||||
downs = torch.tensor(down_addrs, dtype=torch.uint64, device=x.device)
|
||||
self.register_buffer("gates", gates, persistent=False)
|
||||
self.register_buffer("ups", ups, persistent=False)
|
||||
self.register_buffer("downs", downs, persistent=False)
|
||||
|
||||
import xe_linear
|
||||
final_out = xe_linear.moe_forward_vec(
|
||||
x, topk_ids, topk_weight, self.gates, self.ups, self.downs,
|
||||
x.size(-1), self.experts[0].intermediate_size, 2
|
||||
)
|
||||
else:
|
||||
idxs = topk_ids.flatten().tolist()
|
||||
outputs = []
|
||||
for i in idxs:
|
||||
expert = self.experts[i]
|
||||
expert_out = expert(x)
|
||||
outputs.append(expert_out)
|
||||
outs = torch.cat(outputs, dim=0)
|
||||
reshaped_topk_weight = topk_weight.squeeze(0).unsqueeze(-1).to(outs.dtype)
|
||||
final_out = (outs * reshaped_topk_weight).sum(dim=0, keepdim=True)
|
||||
return final_out
|
||||
|
||||
|
||||
|
|
@ -292,7 +314,7 @@ def deepseek_moe_forward(self, hidden_states: torch.Tensor):
|
|||
flat_topk_idx = topk_idx.view(-1)
|
||||
if not self.training:
|
||||
# IPEX-LLM OPT start : add special moe_infer implementation for decoding
|
||||
if topk_idx.size(0) == 1:
|
||||
if topk_idx.size(0) == 1 and self.ep_size == 1:
|
||||
y = moe_infer_decode(self, hidden_states, topk_idx, topk_weight)
|
||||
else:
|
||||
y = self.moe_infer(hidden_states, topk_idx, topk_weight)
|
||||
|
|
|
|||
Loading…
Reference in a new issue