optimize glm4v vision attention (#12369)
This commit is contained in:
		
							parent
							
								
									2dfcc36825
								
							
						
					
					
						commit
						dc34e8c51f
					
				
					 2 changed files with 85 additions and 7 deletions
				
			
		| 
						 | 
				
			
			@ -51,6 +51,69 @@ def merge_qkv_base(module: torch.nn.Module, attention_class):
 | 
			
		|||
            del module.q_proj, module.k_proj, module.v_proj
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def padding_linear_hd(linear: torch.nn.Linear,
 | 
			
		||||
                      old_head_dim: int, new_head_dim: int) -> torch.nn.Linear:
 | 
			
		||||
    in_features, out_features = linear.in_features, linear.out_features
 | 
			
		||||
 | 
			
		||||
    weight = linear.weight.data
 | 
			
		||||
    weight = weight.view(-1, old_head_dim, in_features)
 | 
			
		||||
    new_weight = torch.empty([weight.size(0), new_head_dim, in_features],
 | 
			
		||||
                             dtype=weight.dtype, device=weight.device)
 | 
			
		||||
    new_weight[:, :old_head_dim, :] = weight
 | 
			
		||||
    new_weight[:, old_head_dim:, :] = 0
 | 
			
		||||
    new_weight = new_weight.view(-1, in_features)
 | 
			
		||||
    if linear.bias is not None:
 | 
			
		||||
        bias = linear.bias.data
 | 
			
		||||
        bias = bias.view(-1, old_head_dim)
 | 
			
		||||
        new_bias = torch.empty([bias.size(0), new_head_dim],
 | 
			
		||||
                               dtype=bias.dtype, device=bias.device)
 | 
			
		||||
        new_bias[:, :old_head_dim] = bias
 | 
			
		||||
        new_bias[:, old_head_dim:] = 0
 | 
			
		||||
        new_bias = new_bias.flatten()
 | 
			
		||||
 | 
			
		||||
        new_linear = torch.nn.Linear(0, 0, bias=True)
 | 
			
		||||
        new_linear.bias = torch.nn.Parameter(new_bias, requires_grad=False)
 | 
			
		||||
    else:
 | 
			
		||||
        new_linear = torch.nn.Linear(0, 0, bias=False)
 | 
			
		||||
    new_linear.weight = torch.nn.Parameter(new_weight, requires_grad=False)
 | 
			
		||||
    new_linear.in_features = new_weight.size(1)
 | 
			
		||||
    new_linear.out_features = new_weight.size(0)
 | 
			
		||||
    return new_linear
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def padding_attention_hd_base(module: torch.nn.Module, attention_class,
 | 
			
		||||
                              old_head_dim: int, new_head_dim: int):
 | 
			
		||||
    if (
 | 
			
		||||
        isinstance(attention_class, str) and module.__class__.__name__ == attention_class
 | 
			
		||||
        or not isinstance(attention_class, str) and isinstance(module, attention_class)
 | 
			
		||||
    ) and module.head_dim == old_head_dim:
 | 
			
		||||
        module.q_proj = padding_linear_hd(module.q_proj, old_head_dim, new_head_dim)
 | 
			
		||||
        module.k_proj = padding_linear_hd(module.k_proj, old_head_dim, new_head_dim)
 | 
			
		||||
        module.v_proj = padding_linear_hd(module.v_proj, old_head_dim, new_head_dim)
 | 
			
		||||
        module.head_dim = new_head_dim
 | 
			
		||||
        module.old_head_dim = old_head_dim
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def padding_states_hd(states: torch.Tensor, old_head_dim: int, new_head_dim: int):
 | 
			
		||||
    bsz, num_heads, seq_len, head_dim = states.size()
 | 
			
		||||
    if head_dim == old_head_dim and old_head_dim < new_head_dim:
 | 
			
		||||
        new_states = torch.empty([bsz, num_heads, seq_len, new_head_dim],
 | 
			
		||||
                                 dtype=states.dtype, device=states.device)
 | 
			
		||||
        new_states[:, :, :, :old_head_dim] = states
 | 
			
		||||
        new_states[:, :, :, old_head_dim:] = 0
 | 
			
		||||
        return new_states
 | 
			
		||||
    return states
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def padding_qkv_hd(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
 | 
			
		||||
                   old_head_dim: int, new_head_dim: int):
 | 
			
		||||
    return (
 | 
			
		||||
        padding_states_hd(q, old_head_dim, new_head_dim),
 | 
			
		||||
        padding_states_hd(k, old_head_dim, new_head_dim),
 | 
			
		||||
        padding_states_hd(v, old_head_dim, new_head_dim),
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def fuse_mlp_base(module: torch.nn.Module, act: int, x: torch.Tensor):
 | 
			
		||||
    from ipex_llm.transformers.models.utils import mlp_fusion_check
 | 
			
		||||
    x_2d = x.view(-1, x.size(-1))
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -26,8 +26,9 @@ import torch
 | 
			
		|||
from threading import Thread
 | 
			
		||||
from typing import Optional, List
 | 
			
		||||
from torch.nn.functional import linear
 | 
			
		||||
from ipex_llm.transformers.models.common import merge_qkv_base
 | 
			
		||||
from ipex_llm.transformers.models.common import merge_qkv_base, padding_qkv_hd
 | 
			
		||||
from ipex_llm.transformers.models.common import attention_softmax
 | 
			
		||||
from ipex_llm.transformers.models.utils import use_sdp_non_causal
 | 
			
		||||
from transformers import AutoProcessor, TextIteratorStreamer
 | 
			
		||||
from transformers.generation.logits_process import RepetitionPenaltyLogitsProcessor
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -52,14 +53,28 @@ def siglip_attention_forward(
 | 
			
		|||
    qkv = qkv.transpose(1, 2)
 | 
			
		||||
    query_states, key_states, value_states = qkv.chunk(3, dim=1)
 | 
			
		||||
 | 
			
		||||
    attn_weights = torch.matmul(query_states * self.scale, key_states.transpose(2, 3))
 | 
			
		||||
    if attention_mask is not None:
 | 
			
		||||
        attn_weights = attn_weights + attention_mask
 | 
			
		||||
    query_states, key_states, value_states = padding_qkv_hd(
 | 
			
		||||
        query_states, key_states, value_states,
 | 
			
		||||
        72, 80
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    attn_weights = attention_softmax(attn_weights)
 | 
			
		||||
    if use_sdp_non_causal(query_states.size(-1), query_states.device, query_states.dtype):
 | 
			
		||||
        import xe_addons
 | 
			
		||||
        attn_weights = None
 | 
			
		||||
        attn_output = xe_addons.sdp_non_causal(query_states, key_states.contiguous(),
 | 
			
		||||
                                               value_states.contiguous(), attention_mask)
 | 
			
		||||
    else:
 | 
			
		||||
        attn_weights = torch.matmul(query_states * self.scale, key_states.transpose(2, 3))
 | 
			
		||||
        if attention_mask is not None:
 | 
			
		||||
            attn_weights = attn_weights + attention_mask
 | 
			
		||||
 | 
			
		||||
    attn_weights = torch.nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
 | 
			
		||||
    attn_output = torch.matmul(attn_weights, value_states)
 | 
			
		||||
        attn_weights = attention_softmax(attn_weights)
 | 
			
		||||
 | 
			
		||||
        attn_weights = torch.nn.functional.dropout(attn_weights,
 | 
			
		||||
                                                   p=self.dropout, training=self.training)
 | 
			
		||||
        attn_output = torch.matmul(attn_weights, value_states)
 | 
			
		||||
 | 
			
		||||
    attn_output = attn_output[:, :, :, :self.head_dim]
 | 
			
		||||
 | 
			
		||||
    attn_output = attn_output.transpose(1, 2).contiguous()
 | 
			
		||||
    attn_output = attn_output.reshape(bsz, q_len, self.embed_dim)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue