optimize glm4v vision attention (#12369)
This commit is contained in:
parent
2dfcc36825
commit
dc34e8c51f
2 changed files with 85 additions and 7 deletions
|
|
@ -51,6 +51,69 @@ def merge_qkv_base(module: torch.nn.Module, attention_class):
|
||||||
del module.q_proj, module.k_proj, module.v_proj
|
del module.q_proj, module.k_proj, module.v_proj
|
||||||
|
|
||||||
|
|
||||||
|
def padding_linear_hd(linear: torch.nn.Linear,
|
||||||
|
old_head_dim: int, new_head_dim: int) -> torch.nn.Linear:
|
||||||
|
in_features, out_features = linear.in_features, linear.out_features
|
||||||
|
|
||||||
|
weight = linear.weight.data
|
||||||
|
weight = weight.view(-1, old_head_dim, in_features)
|
||||||
|
new_weight = torch.empty([weight.size(0), new_head_dim, in_features],
|
||||||
|
dtype=weight.dtype, device=weight.device)
|
||||||
|
new_weight[:, :old_head_dim, :] = weight
|
||||||
|
new_weight[:, old_head_dim:, :] = 0
|
||||||
|
new_weight = new_weight.view(-1, in_features)
|
||||||
|
if linear.bias is not None:
|
||||||
|
bias = linear.bias.data
|
||||||
|
bias = bias.view(-1, old_head_dim)
|
||||||
|
new_bias = torch.empty([bias.size(0), new_head_dim],
|
||||||
|
dtype=bias.dtype, device=bias.device)
|
||||||
|
new_bias[:, :old_head_dim] = bias
|
||||||
|
new_bias[:, old_head_dim:] = 0
|
||||||
|
new_bias = new_bias.flatten()
|
||||||
|
|
||||||
|
new_linear = torch.nn.Linear(0, 0, bias=True)
|
||||||
|
new_linear.bias = torch.nn.Parameter(new_bias, requires_grad=False)
|
||||||
|
else:
|
||||||
|
new_linear = torch.nn.Linear(0, 0, bias=False)
|
||||||
|
new_linear.weight = torch.nn.Parameter(new_weight, requires_grad=False)
|
||||||
|
new_linear.in_features = new_weight.size(1)
|
||||||
|
new_linear.out_features = new_weight.size(0)
|
||||||
|
return new_linear
|
||||||
|
|
||||||
|
|
||||||
|
def padding_attention_hd_base(module: torch.nn.Module, attention_class,
|
||||||
|
old_head_dim: int, new_head_dim: int):
|
||||||
|
if (
|
||||||
|
isinstance(attention_class, str) and module.__class__.__name__ == attention_class
|
||||||
|
or not isinstance(attention_class, str) and isinstance(module, attention_class)
|
||||||
|
) and module.head_dim == old_head_dim:
|
||||||
|
module.q_proj = padding_linear_hd(module.q_proj, old_head_dim, new_head_dim)
|
||||||
|
module.k_proj = padding_linear_hd(module.k_proj, old_head_dim, new_head_dim)
|
||||||
|
module.v_proj = padding_linear_hd(module.v_proj, old_head_dim, new_head_dim)
|
||||||
|
module.head_dim = new_head_dim
|
||||||
|
module.old_head_dim = old_head_dim
|
||||||
|
|
||||||
|
|
||||||
|
def padding_states_hd(states: torch.Tensor, old_head_dim: int, new_head_dim: int):
|
||||||
|
bsz, num_heads, seq_len, head_dim = states.size()
|
||||||
|
if head_dim == old_head_dim and old_head_dim < new_head_dim:
|
||||||
|
new_states = torch.empty([bsz, num_heads, seq_len, new_head_dim],
|
||||||
|
dtype=states.dtype, device=states.device)
|
||||||
|
new_states[:, :, :, :old_head_dim] = states
|
||||||
|
new_states[:, :, :, old_head_dim:] = 0
|
||||||
|
return new_states
|
||||||
|
return states
|
||||||
|
|
||||||
|
|
||||||
|
def padding_qkv_hd(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
||||||
|
old_head_dim: int, new_head_dim: int):
|
||||||
|
return (
|
||||||
|
padding_states_hd(q, old_head_dim, new_head_dim),
|
||||||
|
padding_states_hd(k, old_head_dim, new_head_dim),
|
||||||
|
padding_states_hd(v, old_head_dim, new_head_dim),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def fuse_mlp_base(module: torch.nn.Module, act: int, x: torch.Tensor):
|
def fuse_mlp_base(module: torch.nn.Module, act: int, x: torch.Tensor):
|
||||||
from ipex_llm.transformers.models.utils import mlp_fusion_check
|
from ipex_llm.transformers.models.utils import mlp_fusion_check
|
||||||
x_2d = x.view(-1, x.size(-1))
|
x_2d = x.view(-1, x.size(-1))
|
||||||
|
|
|
||||||
|
|
@ -26,8 +26,9 @@ import torch
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
from typing import Optional, List
|
from typing import Optional, List
|
||||||
from torch.nn.functional import linear
|
from torch.nn.functional import linear
|
||||||
from ipex_llm.transformers.models.common import merge_qkv_base
|
from ipex_llm.transformers.models.common import merge_qkv_base, padding_qkv_hd
|
||||||
from ipex_llm.transformers.models.common import attention_softmax
|
from ipex_llm.transformers.models.common import attention_softmax
|
||||||
|
from ipex_llm.transformers.models.utils import use_sdp_non_causal
|
||||||
from transformers import AutoProcessor, TextIteratorStreamer
|
from transformers import AutoProcessor, TextIteratorStreamer
|
||||||
from transformers.generation.logits_process import RepetitionPenaltyLogitsProcessor
|
from transformers.generation.logits_process import RepetitionPenaltyLogitsProcessor
|
||||||
|
|
||||||
|
|
@ -52,14 +53,28 @@ def siglip_attention_forward(
|
||||||
qkv = qkv.transpose(1, 2)
|
qkv = qkv.transpose(1, 2)
|
||||||
query_states, key_states, value_states = qkv.chunk(3, dim=1)
|
query_states, key_states, value_states = qkv.chunk(3, dim=1)
|
||||||
|
|
||||||
attn_weights = torch.matmul(query_states * self.scale, key_states.transpose(2, 3))
|
query_states, key_states, value_states = padding_qkv_hd(
|
||||||
if attention_mask is not None:
|
query_states, key_states, value_states,
|
||||||
attn_weights = attn_weights + attention_mask
|
72, 80
|
||||||
|
)
|
||||||
|
|
||||||
attn_weights = attention_softmax(attn_weights)
|
if use_sdp_non_causal(query_states.size(-1), query_states.device, query_states.dtype):
|
||||||
|
import xe_addons
|
||||||
|
attn_weights = None
|
||||||
|
attn_output = xe_addons.sdp_non_causal(query_states, key_states.contiguous(),
|
||||||
|
value_states.contiguous(), attention_mask)
|
||||||
|
else:
|
||||||
|
attn_weights = torch.matmul(query_states * self.scale, key_states.transpose(2, 3))
|
||||||
|
if attention_mask is not None:
|
||||||
|
attn_weights = attn_weights + attention_mask
|
||||||
|
|
||||||
attn_weights = torch.nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
|
attn_weights = attention_softmax(attn_weights)
|
||||||
attn_output = torch.matmul(attn_weights, value_states)
|
|
||||||
|
attn_weights = torch.nn.functional.dropout(attn_weights,
|
||||||
|
p=self.dropout, training=self.training)
|
||||||
|
attn_output = torch.matmul(attn_weights, value_states)
|
||||||
|
|
||||||
|
attn_output = attn_output[:, :, :, :self.head_dim]
|
||||||
|
|
||||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||||
attn_output = attn_output.reshape(bsz, q_len, self.embed_dim)
|
attn_output = attn_output.reshape(bsz, q_len, self.embed_dim)
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue