add grouped topk optimization for moonlight (#12903)
This commit is contained in:
parent
e946127613
commit
39e360fe9d
1 changed files with 24 additions and 3 deletions
|
|
@ -271,6 +271,25 @@ def deepseek_attention_forward(
|
||||||
return attn_output, attn_weights, past_key_value
|
return attn_output, attn_weights, past_key_value
|
||||||
|
|
||||||
|
|
||||||
|
def fuse_gate_forward(self, x: torch.Tensor):
|
||||||
|
if x.device.type == "xpu" and x.dtype in [torch.float, torch.half]:
|
||||||
|
x = x.view(-1, x.size(-1))
|
||||||
|
logits = torch.nn.functional.linear(
|
||||||
|
x.type(torch.float32), self.weight.type(torch.float32), None
|
||||||
|
)
|
||||||
|
scores = logits.sigmoid()
|
||||||
|
|
||||||
|
import xe_addons
|
||||||
|
topk_idx, topk_weight = xe_addons.moe_group_topk(
|
||||||
|
scores, self.e_score_correction_bias,
|
||||||
|
self.n_group, 2, self.topk_group, self.top_k,
|
||||||
|
self.top_k > 1 and self.norm_topk_prob, 1e-20, self.routed_scaling_factor
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
topk_idx, topk_weight = self(x)
|
||||||
|
return topk_idx, topk_weight.to(x.dtype)
|
||||||
|
|
||||||
|
|
||||||
def moe_infer_decode(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor):
|
def moe_infer_decode(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor):
|
||||||
if (
|
if (
|
||||||
x.device.type == "xpu"
|
x.device.type == "xpu"
|
||||||
|
|
@ -301,7 +320,7 @@ def moe_infer_decode(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight:
|
||||||
expert_out = expert(x)
|
expert_out = expert(x)
|
||||||
outputs.append(expert_out)
|
outputs.append(expert_out)
|
||||||
outs = torch.cat(outputs, dim=0)
|
outs = torch.cat(outputs, dim=0)
|
||||||
reshaped_topk_weight = topk_weight.squeeze(0).unsqueeze(-1).to(outs.dtype)
|
reshaped_topk_weight = topk_weight.squeeze(0).unsqueeze(-1)
|
||||||
final_out = (outs * reshaped_topk_weight).sum(dim=0, keepdim=True)
|
final_out = (outs * reshaped_topk_weight).sum(dim=0, keepdim=True)
|
||||||
return final_out
|
return final_out
|
||||||
|
|
||||||
|
|
@ -309,11 +328,13 @@ def moe_infer_decode(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight:
|
||||||
def deepseek_moe_forward(self, hidden_states: torch.Tensor):
|
def deepseek_moe_forward(self, hidden_states: torch.Tensor):
|
||||||
identity = hidden_states
|
identity = hidden_states
|
||||||
orig_shape = hidden_states.shape
|
orig_shape = hidden_states.shape
|
||||||
topk_idx, topk_weight = self.gate(hidden_states)
|
# IPEX-LLM OPT start: fuse grouped topk in gate forward
|
||||||
|
topk_idx, topk_weight = fuse_gate_forward(self.gate, hidden_states)
|
||||||
|
# IPEX-LLM OPT end
|
||||||
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
||||||
flat_topk_idx = topk_idx.view(-1)
|
flat_topk_idx = topk_idx.view(-1)
|
||||||
if not self.training:
|
if not self.training:
|
||||||
# IPEX-LLM OPT start : add special moe_infer implementation for decoding
|
# IPEX-LLM OPT start: add special moe_infer implementation for decoding
|
||||||
if topk_idx.size(0) == 1 and self.ep_size == 1:
|
if topk_idx.size(0) == 1 and self.ep_size == 1:
|
||||||
y = moe_infer_decode(self, hidden_states, topk_idx, topk_weight)
|
y = moe_infer_decode(self, hidden_states, topk_idx, topk_weight)
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue