optimize glm4v vision attention (#12369)
This commit is contained in:
parent
2dfcc36825
commit
dc34e8c51f
2 changed files with 85 additions and 7 deletions
|
|
@ -51,6 +51,69 @@ def merge_qkv_base(module: torch.nn.Module, attention_class):
|
|||
del module.q_proj, module.k_proj, module.v_proj
|
||||
|
||||
|
||||
def padding_linear_hd(linear: torch.nn.Linear,
|
||||
old_head_dim: int, new_head_dim: int) -> torch.nn.Linear:
|
||||
in_features, out_features = linear.in_features, linear.out_features
|
||||
|
||||
weight = linear.weight.data
|
||||
weight = weight.view(-1, old_head_dim, in_features)
|
||||
new_weight = torch.empty([weight.size(0), new_head_dim, in_features],
|
||||
dtype=weight.dtype, device=weight.device)
|
||||
new_weight[:, :old_head_dim, :] = weight
|
||||
new_weight[:, old_head_dim:, :] = 0
|
||||
new_weight = new_weight.view(-1, in_features)
|
||||
if linear.bias is not None:
|
||||
bias = linear.bias.data
|
||||
bias = bias.view(-1, old_head_dim)
|
||||
new_bias = torch.empty([bias.size(0), new_head_dim],
|
||||
dtype=bias.dtype, device=bias.device)
|
||||
new_bias[:, :old_head_dim] = bias
|
||||
new_bias[:, old_head_dim:] = 0
|
||||
new_bias = new_bias.flatten()
|
||||
|
||||
new_linear = torch.nn.Linear(0, 0, bias=True)
|
||||
new_linear.bias = torch.nn.Parameter(new_bias, requires_grad=False)
|
||||
else:
|
||||
new_linear = torch.nn.Linear(0, 0, bias=False)
|
||||
new_linear.weight = torch.nn.Parameter(new_weight, requires_grad=False)
|
||||
new_linear.in_features = new_weight.size(1)
|
||||
new_linear.out_features = new_weight.size(0)
|
||||
return new_linear
|
||||
|
||||
|
||||
def padding_attention_hd_base(module: torch.nn.Module, attention_class,
|
||||
old_head_dim: int, new_head_dim: int):
|
||||
if (
|
||||
isinstance(attention_class, str) and module.__class__.__name__ == attention_class
|
||||
or not isinstance(attention_class, str) and isinstance(module, attention_class)
|
||||
) and module.head_dim == old_head_dim:
|
||||
module.q_proj = padding_linear_hd(module.q_proj, old_head_dim, new_head_dim)
|
||||
module.k_proj = padding_linear_hd(module.k_proj, old_head_dim, new_head_dim)
|
||||
module.v_proj = padding_linear_hd(module.v_proj, old_head_dim, new_head_dim)
|
||||
module.head_dim = new_head_dim
|
||||
module.old_head_dim = old_head_dim
|
||||
|
||||
|
||||
def padding_states_hd(states: torch.Tensor, old_head_dim: int, new_head_dim: int):
|
||||
bsz, num_heads, seq_len, head_dim = states.size()
|
||||
if head_dim == old_head_dim and old_head_dim < new_head_dim:
|
||||
new_states = torch.empty([bsz, num_heads, seq_len, new_head_dim],
|
||||
dtype=states.dtype, device=states.device)
|
||||
new_states[:, :, :, :old_head_dim] = states
|
||||
new_states[:, :, :, old_head_dim:] = 0
|
||||
return new_states
|
||||
return states
|
||||
|
||||
|
||||
def padding_qkv_hd(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
||||
old_head_dim: int, new_head_dim: int):
|
||||
return (
|
||||
padding_states_hd(q, old_head_dim, new_head_dim),
|
||||
padding_states_hd(k, old_head_dim, new_head_dim),
|
||||
padding_states_hd(v, old_head_dim, new_head_dim),
|
||||
)
|
||||
|
||||
|
||||
def fuse_mlp_base(module: torch.nn.Module, act: int, x: torch.Tensor):
|
||||
from ipex_llm.transformers.models.utils import mlp_fusion_check
|
||||
x_2d = x.view(-1, x.size(-1))
|
||||
|
|
|
|||
|
|
@ -26,8 +26,9 @@ import torch
|
|||
from threading import Thread
|
||||
from typing import Optional, List
|
||||
from torch.nn.functional import linear
|
||||
from ipex_llm.transformers.models.common import merge_qkv_base
|
||||
from ipex_llm.transformers.models.common import merge_qkv_base, padding_qkv_hd
|
||||
from ipex_llm.transformers.models.common import attention_softmax
|
||||
from ipex_llm.transformers.models.utils import use_sdp_non_causal
|
||||
from transformers import AutoProcessor, TextIteratorStreamer
|
||||
from transformers.generation.logits_process import RepetitionPenaltyLogitsProcessor
|
||||
|
||||
|
|
@ -52,14 +53,28 @@ def siglip_attention_forward(
|
|||
qkv = qkv.transpose(1, 2)
|
||||
query_states, key_states, value_states = qkv.chunk(3, dim=1)
|
||||
|
||||
attn_weights = torch.matmul(query_states * self.scale, key_states.transpose(2, 3))
|
||||
if attention_mask is not None:
|
||||
attn_weights = attn_weights + attention_mask
|
||||
query_states, key_states, value_states = padding_qkv_hd(
|
||||
query_states, key_states, value_states,
|
||||
72, 80
|
||||
)
|
||||
|
||||
attn_weights = attention_softmax(attn_weights)
|
||||
if use_sdp_non_causal(query_states.size(-1), query_states.device, query_states.dtype):
|
||||
import xe_addons
|
||||
attn_weights = None
|
||||
attn_output = xe_addons.sdp_non_causal(query_states, key_states.contiguous(),
|
||||
value_states.contiguous(), attention_mask)
|
||||
else:
|
||||
attn_weights = torch.matmul(query_states * self.scale, key_states.transpose(2, 3))
|
||||
if attention_mask is not None:
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
attn_weights = torch.nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
attn_weights = attention_softmax(attn_weights)
|
||||
|
||||
attn_weights = torch.nn.functional.dropout(attn_weights,
|
||||
p=self.dropout, training=self.training)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
|
||||
attn_output = attn_output[:, :, :, :self.head_dim]
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.embed_dim)
|
||||
|
|
|
|||
Loading…
Reference in a new issue