add fuse moe optimization for moonlight (#12898)
This commit is contained in:
parent
ad65e2b03a
commit
be1f073866
1 changed files with 32 additions and 10 deletions
|
|
@ -272,6 +272,28 @@ def deepseek_attention_forward(
|
||||||
|
|
||||||
|
|
||||||
def moe_infer_decode(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor):
|
def moe_infer_decode(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor):
|
||||||
|
if (
|
||||||
|
x.device.type == "xpu"
|
||||||
|
and x.dtype in [torch.float, torch.half]
|
||||||
|
and self.experts[0].down_proj.qtype == 2
|
||||||
|
):
|
||||||
|
if getattr(self, "gates", None) is None:
|
||||||
|
gate_addrs = [expert.gate_proj.weight.data_ptr() for expert in self.experts]
|
||||||
|
up_addrs = [expert.up_proj.weight.data_ptr() for expert in self.experts]
|
||||||
|
down_addrs = [expert.down_proj.weight.data_ptr() for expert in self.experts]
|
||||||
|
gates = torch.tensor(gate_addrs, dtype=torch.uint64, device=x.device)
|
||||||
|
ups = torch.tensor(up_addrs, dtype=torch.uint64, device=x.device)
|
||||||
|
downs = torch.tensor(down_addrs, dtype=torch.uint64, device=x.device)
|
||||||
|
self.register_buffer("gates", gates, persistent=False)
|
||||||
|
self.register_buffer("ups", ups, persistent=False)
|
||||||
|
self.register_buffer("downs", downs, persistent=False)
|
||||||
|
|
||||||
|
import xe_linear
|
||||||
|
final_out = xe_linear.moe_forward_vec(
|
||||||
|
x, topk_ids, topk_weight, self.gates, self.ups, self.downs,
|
||||||
|
x.size(-1), self.experts[0].intermediate_size, 2
|
||||||
|
)
|
||||||
|
else:
|
||||||
idxs = topk_ids.flatten().tolist()
|
idxs = topk_ids.flatten().tolist()
|
||||||
outputs = []
|
outputs = []
|
||||||
for i in idxs:
|
for i in idxs:
|
||||||
|
|
@ -292,7 +314,7 @@ def deepseek_moe_forward(self, hidden_states: torch.Tensor):
|
||||||
flat_topk_idx = topk_idx.view(-1)
|
flat_topk_idx = topk_idx.view(-1)
|
||||||
if not self.training:
|
if not self.training:
|
||||||
# IPEX-LLM OPT start : add special moe_infer implementation for decoding
|
# IPEX-LLM OPT start : add special moe_infer implementation for decoding
|
||||||
if topk_idx.size(0) == 1:
|
if topk_idx.size(0) == 1 and self.ep_size == 1:
|
||||||
y = moe_infer_decode(self, hidden_states, topk_idx, topk_weight)
|
y = moe_infer_decode(self, hidden_states, topk_idx, topk_weight)
|
||||||
else:
|
else:
|
||||||
y = self.moe_infer(hidden_states, topk_idx, topk_weight)
|
y = self.moe_infer(hidden_states, topk_idx, topk_weight)
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue