simple optimization for moonlight moe decoding forward (#12891)
This commit is contained in:
		
							parent
							
								
									ae9f5320da
								
							
						
					
					
						commit
						5faba06409
					
				
					 2 changed files with 34 additions and 0 deletions
				
			
		| 
						 | 
					@ -2031,9 +2031,11 @@ def _optimize_post(model):
 | 
				
			||||||
        from ipex_llm.transformers.models.common import rms_norm_forward
 | 
					        from ipex_llm.transformers.models.common import rms_norm_forward
 | 
				
			||||||
        from ipex_llm.transformers.models.deepseek import deepseek_model_forward
 | 
					        from ipex_llm.transformers.models.deepseek import deepseek_model_forward
 | 
				
			||||||
        from ipex_llm.transformers.models.deepseek import deepseek_attention_forward
 | 
					        from ipex_llm.transformers.models.deepseek import deepseek_attention_forward
 | 
				
			||||||
 | 
					        from ipex_llm.transformers.models.deepseek import deepseek_moe_forward
 | 
				
			||||||
        convert_forward(model, module.DeepseekV3RMSNorm, rms_norm_forward)
 | 
					        convert_forward(model, module.DeepseekV3RMSNorm, rms_norm_forward)
 | 
				
			||||||
        convert_forward(model, module.DeepseekV3Model, deepseek_model_forward)
 | 
					        convert_forward(model, module.DeepseekV3Model, deepseek_model_forward)
 | 
				
			||||||
        convert_forward(model, module.DeepseekV3Attention, deepseek_attention_forward)
 | 
					        convert_forward(model, module.DeepseekV3Attention, deepseek_attention_forward)
 | 
				
			||||||
 | 
					        convert_forward(model, module.DeepseekV3MoE, deepseek_moe_forward)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    return model
 | 
					    return model
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -269,3 +269,35 @@ def deepseek_attention_forward(
 | 
				
			||||||
        attn_weights = None
 | 
					        attn_weights = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    return attn_output, attn_weights, past_key_value
 | 
					    return attn_output, attn_weights, past_key_value
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def moe_infer_decode(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor):
 | 
				
			||||||
 | 
					    idxs = topk_ids.flatten().tolist()
 | 
				
			||||||
 | 
					    outputs = []
 | 
				
			||||||
 | 
					    for i in idxs:
 | 
				
			||||||
 | 
					        expert = self.experts[i]
 | 
				
			||||||
 | 
					        expert_out = expert(x)
 | 
				
			||||||
 | 
					        outputs.append(expert_out)
 | 
				
			||||||
 | 
					    outs = torch.cat(outputs, dim=0)
 | 
				
			||||||
 | 
					    reshaped_topk_weight = topk_weight.squeeze(0).unsqueeze(-1).to(outs.dtype)
 | 
				
			||||||
 | 
					    final_out = (outs * reshaped_topk_weight).sum(dim=0, keepdim=True)
 | 
				
			||||||
 | 
					    return final_out
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def deepseek_moe_forward(self, hidden_states: torch.Tensor):
 | 
				
			||||||
 | 
					    identity = hidden_states
 | 
				
			||||||
 | 
					    orig_shape = hidden_states.shape
 | 
				
			||||||
 | 
					    topk_idx, topk_weight = self.gate(hidden_states)
 | 
				
			||||||
 | 
					    hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
 | 
				
			||||||
 | 
					    flat_topk_idx = topk_idx.view(-1)
 | 
				
			||||||
 | 
					    if not self.training:
 | 
				
			||||||
 | 
					        # IPEX-LLM OPT start : add special moe_infer implementation for decoding
 | 
				
			||||||
 | 
					        if topk_idx.size(0) == 1:
 | 
				
			||||||
 | 
					            y = moe_infer_decode(self, hidden_states, topk_idx, topk_weight)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            y = self.moe_infer(hidden_states, topk_idx, topk_weight)
 | 
				
			||||||
 | 
					        y = y.view(*orig_shape)
 | 
				
			||||||
 | 
					        # IPEX-LLM OPT end
 | 
				
			||||||
 | 
					    if self.config.n_shared_experts is not None:
 | 
				
			||||||
 | 
					        y = y + self.shared_experts(identity)
 | 
				
			||||||
 | 
					    return y
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue