refactor qwen2 forward to enable XPU (#10409)
* refactor awen2 forward to enable XPU * Update qwen2.py
This commit is contained in:
		
							parent
							
								
									f36224aac4
								
							
						
					
					
						commit
						7d29765092
					
				
					 2 changed files with 6 additions and 10 deletions
				
			
		| 
						 | 
					@ -1075,6 +1075,7 @@ def _optimize_post(model, lightweight_bmm=False):
 | 
				
			||||||
        modeling_module_name = model.__class__.__module__
 | 
					        modeling_module_name = model.__class__.__module__
 | 
				
			||||||
        module = importlib.import_module(modeling_module_name)
 | 
					        module = importlib.import_module(modeling_module_name)
 | 
				
			||||||
        from bigdl.llm.transformers.models.qwen2 import qwen2_model_forward
 | 
					        from bigdl.llm.transformers.models.qwen2 import qwen2_model_forward
 | 
				
			||||||
 | 
					        from bigdl.llm.transformers.models.qwen2 import qwen2_attention_forward
 | 
				
			||||||
        convert_forward(model,
 | 
					        convert_forward(model,
 | 
				
			||||||
                        module.Qwen2Model,
 | 
					                        module.Qwen2Model,
 | 
				
			||||||
                        qwen2_model_forward)
 | 
					                        qwen2_model_forward)
 | 
				
			||||||
| 
						 | 
					@ -1084,16 +1085,9 @@ def _optimize_post(model, lightweight_bmm=False):
 | 
				
			||||||
        convert_forward(model,
 | 
					        convert_forward(model,
 | 
				
			||||||
                        module.Qwen2MLP,
 | 
					                        module.Qwen2MLP,
 | 
				
			||||||
                        llama_mlp_forward)
 | 
					                        llama_mlp_forward)
 | 
				
			||||||
        if model.device.type == 'cpu':
 | 
					        convert_forward(model,
 | 
				
			||||||
            from bigdl.llm.transformers.models.qwen2 import qwen2_sdpa_attention_forward
 | 
					                        module.Qwen2Attention,
 | 
				
			||||||
            convert_forward(model,
 | 
					                        qwen2_attention_forward)
 | 
				
			||||||
                            module.Qwen2SdpaAttention,
 | 
					 | 
				
			||||||
                            qwen2_sdpa_attention_forward)
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            from bigdl.llm.transformers.models.qwen2 import qwen2_attention_forward
 | 
					 | 
				
			||||||
            convert_forward(model,
 | 
					 | 
				
			||||||
                            module.Qwen2Attention,
 | 
					 | 
				
			||||||
                            qwen2_attention_forward)
 | 
					 | 
				
			||||||
    elif model.config.model_type == "aquila":
 | 
					    elif model.config.model_type == "aquila":
 | 
				
			||||||
        modeling_module_name = model.__class__.__module__
 | 
					        modeling_module_name = model.__class__.__module__
 | 
				
			||||||
        module = importlib.import_module(modeling_module_name)
 | 
					        module = importlib.import_module(modeling_module_name)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -106,6 +106,8 @@ def qwen2_attention_forward(
 | 
				
			||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
 | 
					) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
 | 
				
			||||||
    if use_quantize_kv_cache(self.q_proj, hidden_states):
 | 
					    if use_quantize_kv_cache(self.q_proj, hidden_states):
 | 
				
			||||||
        forward_function = qwen2_attention_forward_quantized
 | 
					        forward_function = qwen2_attention_forward_quantized
 | 
				
			||||||
 | 
					    elif hidden_states.device.type == "cpu":
 | 
				
			||||||
 | 
					        forward_function = qwen2_sdpa_attention_forward
 | 
				
			||||||
    else:
 | 
					    else:
 | 
				
			||||||
        forward_function = qwen2_attention_forward_origin
 | 
					        forward_function = qwen2_attention_forward_origin
 | 
				
			||||||
    return forward_function(
 | 
					    return forward_function(
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue