fix phi-2 and phi-3 convert (#11116)
This commit is contained in:
		
							parent
							
								
									37b98a531f
								
							
						
					
					
						commit
						797dbc48b8
					
				
					 1 changed files with 2 additions and 2 deletions
				
			
		| 
						 | 
				
			
			@ -1522,7 +1522,7 @@ def _optimize_post(model, lightweight_bmm=False):
 | 
			
		|||
        from ipex_llm.transformers.models.starcoder2 import model_forward
 | 
			
		||||
        convert_forward(model, module.Starcoder2Attention, attention_forward)
 | 
			
		||||
        convert_forward(model, module.Starcoder2Model, model_forward)
 | 
			
		||||
    elif model.config.model_type in ["phi3", "phi3_v"]:
 | 
			
		||||
    elif model.config.model_type == "phi":
 | 
			
		||||
        # for phi-2
 | 
			
		||||
        modeling_module_name = model.__class__.__module__
 | 
			
		||||
        module = importlib.import_module(modeling_module_name)
 | 
			
		||||
| 
						 | 
				
			
			@ -1530,7 +1530,7 @@ def _optimize_post(model, lightweight_bmm=False):
 | 
			
		|||
        from ipex_llm.transformers.models.phi import model_forward
 | 
			
		||||
        convert_forward(model, module.PhiAttention, attention_forward)
 | 
			
		||||
        convert_forward(model, module.PhiModel, model_forward)
 | 
			
		||||
    elif model.config.model_type == "phi3":
 | 
			
		||||
    elif model.config.model_type in ["phi3", "phi3_v"]:
 | 
			
		||||
        # for phi-3
 | 
			
		||||
        modeling_module_name = model.__class__.__module__
 | 
			
		||||
        module = importlib.import_module(modeling_module_name)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue