add fuse moe optimization for moonlight (#12898)
This commit is contained in:
		
							parent
							
								
									ad65e2b03a
								
							
						
					
					
						commit
						be1f073866
					
				
					 1 changed files with 32 additions and 10 deletions
				
			
		| 
						 | 
				
			
			@ -272,15 +272,37 @@ def deepseek_attention_forward(
 | 
			
		|||
 | 
			
		||||
 | 
			
		||||
def moe_infer_decode(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor):
 | 
			
		||||
    idxs = topk_ids.flatten().tolist()
 | 
			
		||||
    outputs = []
 | 
			
		||||
    for i in idxs:
 | 
			
		||||
        expert = self.experts[i]
 | 
			
		||||
        expert_out = expert(x)
 | 
			
		||||
        outputs.append(expert_out)
 | 
			
		||||
    outs = torch.cat(outputs, dim=0)
 | 
			
		||||
    reshaped_topk_weight = topk_weight.squeeze(0).unsqueeze(-1).to(outs.dtype)
 | 
			
		||||
    final_out = (outs * reshaped_topk_weight).sum(dim=0, keepdim=True)
 | 
			
		||||
    if (
 | 
			
		||||
        x.device.type == "xpu"
 | 
			
		||||
        and x.dtype in [torch.float, torch.half]
 | 
			
		||||
        and self.experts[0].down_proj.qtype == 2
 | 
			
		||||
    ):
 | 
			
		||||
        if getattr(self, "gates", None) is None:
 | 
			
		||||
            gate_addrs = [expert.gate_proj.weight.data_ptr() for expert in self.experts]
 | 
			
		||||
            up_addrs = [expert.up_proj.weight.data_ptr() for expert in self.experts]
 | 
			
		||||
            down_addrs = [expert.down_proj.weight.data_ptr() for expert in self.experts]
 | 
			
		||||
            gates = torch.tensor(gate_addrs, dtype=torch.uint64, device=x.device)
 | 
			
		||||
            ups = torch.tensor(up_addrs, dtype=torch.uint64, device=x.device)
 | 
			
		||||
            downs = torch.tensor(down_addrs, dtype=torch.uint64, device=x.device)
 | 
			
		||||
            self.register_buffer("gates", gates, persistent=False)
 | 
			
		||||
            self.register_buffer("ups", ups, persistent=False)
 | 
			
		||||
            self.register_buffer("downs", downs, persistent=False)
 | 
			
		||||
 | 
			
		||||
        import xe_linear
 | 
			
		||||
        final_out = xe_linear.moe_forward_vec(
 | 
			
		||||
            x, topk_ids, topk_weight, self.gates, self.ups, self.downs,
 | 
			
		||||
            x.size(-1), self.experts[0].intermediate_size, 2
 | 
			
		||||
        )
 | 
			
		||||
    else:
 | 
			
		||||
        idxs = topk_ids.flatten().tolist()
 | 
			
		||||
        outputs = []
 | 
			
		||||
        for i in idxs:
 | 
			
		||||
            expert = self.experts[i]
 | 
			
		||||
            expert_out = expert(x)
 | 
			
		||||
            outputs.append(expert_out)
 | 
			
		||||
        outs = torch.cat(outputs, dim=0)
 | 
			
		||||
        reshaped_topk_weight = topk_weight.squeeze(0).unsqueeze(-1).to(outs.dtype)
 | 
			
		||||
        final_out = (outs * reshaped_topk_weight).sum(dim=0, keepdim=True)
 | 
			
		||||
    return final_out
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -292,7 +314,7 @@ def deepseek_moe_forward(self, hidden_states: torch.Tensor):
 | 
			
		|||
    flat_topk_idx = topk_idx.view(-1)
 | 
			
		||||
    if not self.training:
 | 
			
		||||
        # IPEX-LLM OPT start : add special moe_infer implementation for decoding
 | 
			
		||||
        if topk_idx.size(0) == 1:
 | 
			
		||||
        if topk_idx.size(0) == 1 and self.ep_size == 1:
 | 
			
		||||
            y = moe_infer_decode(self, hidden_states, topk_idx, topk_weight)
 | 
			
		||||
        else:
 | 
			
		||||
            y = self.moe_infer(hidden_states, topk_idx, topk_weight)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue