add grouped topk optimization for moonlight (#12903)
This commit is contained in:
		
							parent
							
								
									e946127613
								
							
						
					
					
						commit
						39e360fe9d
					
				
					 1 changed files with 24 additions and 3 deletions
				
			
		| 
						 | 
					@ -271,6 +271,25 @@ def deepseek_attention_forward(
 | 
				
			||||||
    return attn_output, attn_weights, past_key_value
 | 
					    return attn_output, attn_weights, past_key_value
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def fuse_gate_forward(self, x: torch.Tensor):
 | 
				
			||||||
 | 
					    if x.device.type == "xpu" and x.dtype in [torch.float, torch.half]:
 | 
				
			||||||
 | 
					        x = x.view(-1, x.size(-1))
 | 
				
			||||||
 | 
					        logits = torch.nn.functional.linear(
 | 
				
			||||||
 | 
					            x.type(torch.float32), self.weight.type(torch.float32), None
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        scores = logits.sigmoid()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        import xe_addons
 | 
				
			||||||
 | 
					        topk_idx, topk_weight = xe_addons.moe_group_topk(
 | 
				
			||||||
 | 
					            scores, self.e_score_correction_bias,
 | 
				
			||||||
 | 
					            self.n_group, 2, self.topk_group, self.top_k,
 | 
				
			||||||
 | 
					            self.top_k > 1 and self.norm_topk_prob, 1e-20, self.routed_scaling_factor
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        topk_idx, topk_weight = self(x)
 | 
				
			||||||
 | 
					    return topk_idx, topk_weight.to(x.dtype)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def moe_infer_decode(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor):
 | 
					def moe_infer_decode(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor):
 | 
				
			||||||
    if (
 | 
					    if (
 | 
				
			||||||
        x.device.type == "xpu"
 | 
					        x.device.type == "xpu"
 | 
				
			||||||
| 
						 | 
					@ -301,7 +320,7 @@ def moe_infer_decode(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight:
 | 
				
			||||||
            expert_out = expert(x)
 | 
					            expert_out = expert(x)
 | 
				
			||||||
            outputs.append(expert_out)
 | 
					            outputs.append(expert_out)
 | 
				
			||||||
        outs = torch.cat(outputs, dim=0)
 | 
					        outs = torch.cat(outputs, dim=0)
 | 
				
			||||||
        reshaped_topk_weight = topk_weight.squeeze(0).unsqueeze(-1).to(outs.dtype)
 | 
					        reshaped_topk_weight = topk_weight.squeeze(0).unsqueeze(-1)
 | 
				
			||||||
        final_out = (outs * reshaped_topk_weight).sum(dim=0, keepdim=True)
 | 
					        final_out = (outs * reshaped_topk_weight).sum(dim=0, keepdim=True)
 | 
				
			||||||
    return final_out
 | 
					    return final_out
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -309,11 +328,13 @@ def moe_infer_decode(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight:
 | 
				
			||||||
def deepseek_moe_forward(self, hidden_states: torch.Tensor):
 | 
					def deepseek_moe_forward(self, hidden_states: torch.Tensor):
 | 
				
			||||||
    identity = hidden_states
 | 
					    identity = hidden_states
 | 
				
			||||||
    orig_shape = hidden_states.shape
 | 
					    orig_shape = hidden_states.shape
 | 
				
			||||||
    topk_idx, topk_weight = self.gate(hidden_states)
 | 
					    # IPEX-LLM OPT start: fuse grouped topk in gate forward
 | 
				
			||||||
 | 
					    topk_idx, topk_weight = fuse_gate_forward(self.gate, hidden_states)
 | 
				
			||||||
 | 
					    # IPEX-LLM OPT end
 | 
				
			||||||
    hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
 | 
					    hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
 | 
				
			||||||
    flat_topk_idx = topk_idx.view(-1)
 | 
					    flat_topk_idx = topk_idx.view(-1)
 | 
				
			||||||
    if not self.training:
 | 
					    if not self.training:
 | 
				
			||||||
        # IPEX-LLM OPT start : add special moe_infer implementation for decoding
 | 
					        # IPEX-LLM OPT start: add special moe_infer implementation for decoding
 | 
				
			||||||
        if topk_idx.size(0) == 1 and self.ep_size == 1:
 | 
					        if topk_idx.size(0) == 1 and self.ep_size == 1:
 | 
				
			||||||
            y = moe_infer_decode(self, hidden_states, topk_idx, topk_weight)
 | 
					            y = moe_infer_decode(self, hidden_states, topk_idx, topk_weight)
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue