optimize internvl2 4b performance (#11720)
This commit is contained in:
		
							parent
							
								
									f44b732aa8
								
							
						
					
					
						commit
						bbdff6edeb
					
				
					 1 changed files with 3 additions and 0 deletions
				
			
		| 
						 | 
				
			
			@ -739,6 +739,8 @@ def _optimize_pre(model, qtype=None):
 | 
			
		|||
    if model.config.model_type == "internlmxcomposer2":
 | 
			
		||||
        from ipex_llm.transformers.models.internlm import pre_process_attn_and_mlp
 | 
			
		||||
        model.apply(pre_process_attn_and_mlp)
 | 
			
		||||
    if model.config.model_type == "internvl_chat":
 | 
			
		||||
        _optimize_pre(model.language_model)
 | 
			
		||||
    if model.config.model_type == "gemma2":
 | 
			
		||||
        from ipex_llm.transformers.models.gemma2 import merge_qkv
 | 
			
		||||
        model.apply(merge_qkv)
 | 
			
		||||
| 
						 | 
				
			
			@ -1268,6 +1270,7 @@ def _optimize_post(model, lightweight_bmm=False):
 | 
			
		|||
            from ipex_llm.transformers.models.internvl import _get_pos_embed
 | 
			
		||||
            vision_embedding = model.vision_model.embeddings
 | 
			
		||||
            vision_embedding._get_pos_embed = MethodType(_get_pos_embed, vision_embedding)
 | 
			
		||||
        _optimize_post(model.language_model, lightweight_bmm=lightweight_bmm)
 | 
			
		||||
    elif model.config.model_type == "qwen":
 | 
			
		||||
        if hasattr(model.config, "visual"):
 | 
			
		||||
            # for Qwen-VL-Chat
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue