optimize internvl2 vision model's attention (#12198)
This commit is contained in:
		
							parent
							
								
									f8d1adc573
								
							
						
					
					
						commit
						d5344587ab
					
				
					 2 changed files with 29 additions and 1 deletions
				
			
		| 
						 | 
				
			
			@ -1580,8 +1580,12 @@ def _optimize_post(model, lightweight_bmm=False):
 | 
			
		|||
        model.batch_chat = MethodType(internvl_batch_chat, model)
 | 
			
		||||
        if model.vision_model.__class__.__name__ == "InternVisionModel":
 | 
			
		||||
            from ipex_llm.transformers.models.internvl import _get_pos_embed
 | 
			
		||||
            vision_embedding = model.vision_model.embeddings
 | 
			
		||||
            from ipex_llm.transformers.models.internvl import intern_attention_forward
 | 
			
		||||
            vision_model = model.vision_model
 | 
			
		||||
            vision_embedding = vision_model.embeddings
 | 
			
		||||
            vision_embedding._get_pos_embed = MethodType(_get_pos_embed, vision_embedding)
 | 
			
		||||
            vision_module = importlib.import_module(vision_model.__class__.__module__)
 | 
			
		||||
            convert_forward(vision_model, vision_module.InternAttention, intern_attention_forward)
 | 
			
		||||
        _optimize_post(model.language_model, lightweight_bmm=lightweight_bmm)
 | 
			
		||||
    elif model.config.model_type == "qwen":
 | 
			
		||||
        if hasattr(model.config, "visual"):
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -163,3 +163,27 @@ def internvl_batch_chat(self, tokenizer, pixel_values, questions, generation_con
 | 
			
		|||
    responses = tokenizer.batch_decode(generation_output, skip_special_tokens=True)
 | 
			
		||||
    responses = [response.split(template.sep)[0].strip() for response in responses]
 | 
			
		||||
    return responses
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def intern_attention_forward(self, x: torch.Tensor) -> torch.Tensor:
 | 
			
		||||
    B, N, C = x.shape
 | 
			
		||||
    qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
 | 
			
		||||
    q, k, v = qkv.unbind(0)  # make torchscript happy (cannot use tensor as tuple)
 | 
			
		||||
 | 
			
		||||
    if self.qk_normalization:
 | 
			
		||||
        B_, H_, N_, D_ = q.shape
 | 
			
		||||
        q = self.q_norm(q.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2)
 | 
			
		||||
        k = self.k_norm(k.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2)
 | 
			
		||||
 | 
			
		||||
    if x.device.type == "xpu":
 | 
			
		||||
        import xe_addons
 | 
			
		||||
        x = xe_addons.sdp_non_causal(q.contiguous(), k.contiguous(), v.contiguous(), None)
 | 
			
		||||
    else:
 | 
			
		||||
        attn = ((q * self.scale) @ k.transpose(-2, -1))
 | 
			
		||||
        attn = attn.softmax(dim=-1)
 | 
			
		||||
        attn = self.attn_drop(attn)
 | 
			
		||||
        x = attn @ v
 | 
			
		||||
    x = x.transpose(1, 2).reshape(B, N, C)
 | 
			
		||||
    x = self.proj(x)
 | 
			
		||||
    x = self.proj_drop(x)
 | 
			
		||||
    return x
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue