optimzie qwen2-vl vision (#12203)
This commit is contained in:
parent
d5344587ab
commit
9b81236a2e
3 changed files with 39 additions and 14 deletions
|
|
@ -26,6 +26,7 @@
|
|||
|
||||
import torch
|
||||
from ipex_llm.utils.common.log4Error import invalidInputError
|
||||
from ipex_llm.transformers.models.utils import use_sdp_non_causal
|
||||
|
||||
|
||||
def _get_pos_embed(self, pos_embed, H, W):
|
||||
|
|
@ -175,9 +176,9 @@ def intern_attention_forward(self, x: torch.Tensor) -> torch.Tensor:
|
|||
q = self.q_norm(q.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2)
|
||||
k = self.k_norm(k.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2)
|
||||
|
||||
if x.device.type == "xpu":
|
||||
if use_sdp_non_causal(self.head_dim, q.device, q.dtype):
|
||||
import xe_addons
|
||||
x = xe_addons.sdp_non_causal(q.contiguous(), k.contiguous(), v.contiguous(), None)
|
||||
x = xe_addons.sdp_non_causal(q, k.contiguous(), v.contiguous(), None)
|
||||
else:
|
||||
attn = ((q * self.scale) @ k.transpose(-2, -1))
|
||||
attn = attn.softmax(dim=-1)
|
||||
|
|
|
|||
|
|
@ -45,6 +45,7 @@ import torch
|
|||
from ipex_llm.transformers.models.common import merge_qkv_base, attention_softmax
|
||||
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache
|
||||
from ipex_llm.transformers.models.utils import use_sdp, use_sdp_causal, should_use_fuse_rope
|
||||
from ipex_llm.transformers.models.utils import use_sdp_non_causal
|
||||
from ipex_llm.transformers.kv import DynamicFp8Cache, DynamicNormalCache
|
||||
from ipex_llm.utils.common import invalidInputError
|
||||
|
||||
|
|
@ -191,20 +192,35 @@ def qwen2_vision_attention_forward(
|
|||
q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0)
|
||||
k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0)
|
||||
|
||||
attention_mask = torch.full(
|
||||
[1, seq_length, seq_length], torch.finfo(q.dtype).min, device=q.device, dtype=q.dtype
|
||||
)
|
||||
for i in range(1, len(cu_seqlens)):
|
||||
attention_mask[..., cu_seqlens[i - 1]:cu_seqlens[i],
|
||||
cu_seqlens[i - 1]:cu_seqlens[i]] = 0
|
||||
|
||||
q = q.transpose(0, 1)
|
||||
k = k.transpose(0, 1)
|
||||
v = v.transpose(0, 1)
|
||||
attn_weights = torch.matmul(q, k.transpose(1, 2)) / math.sqrt(self.head_dim)
|
||||
attn_weights = attn_weights + attention_mask
|
||||
attn_weights = attention_softmax(attn_weights, False)
|
||||
attn_output = torch.matmul(attn_weights, v)
|
||||
|
||||
if len(cu_seqlens) == 2 and cu_seqlens.tolist() == [0, seq_length]:
|
||||
attention_mask = None
|
||||
else:
|
||||
attention_mask = torch.full(
|
||||
[1, seq_length, seq_length], torch.finfo(q.dtype).min, device=q.device, dtype=q.dtype
|
||||
)
|
||||
for i in range(1, len(cu_seqlens)):
|
||||
attention_mask[..., cu_seqlens[i - 1]:cu_seqlens[i],
|
||||
cu_seqlens[i - 1]:cu_seqlens[i]] = 0
|
||||
|
||||
if use_sdp_non_causal(self.head_dim, q.device, q.dtype):
|
||||
import xe_addons
|
||||
q = q.unsqueeze(0)
|
||||
k = k.unsqueeze(0)
|
||||
v = v.unsqueeze(0)
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask.unsqueeze(0)
|
||||
attn_output = xe_addons.sdp_non_causal(q, k.contiguous(), v.contiguous(), attention_mask)
|
||||
attn_output = attn_output.squeeze(0)
|
||||
else:
|
||||
attn_weights = torch.matmul(q, k.transpose(1, 2)) / math.sqrt(self.head_dim)
|
||||
if attention_mask is not None:
|
||||
attn_weights = attn_weights + attention_mask
|
||||
attn_weights = attention_softmax(attn_weights, False)
|
||||
attn_output = torch.matmul(attn_weights, v)
|
||||
attn_output = attn_output.transpose(0, 1)
|
||||
attn_output = attn_output.reshape(seq_length, -1)
|
||||
attn_output = self.proj(attn_output)
|
||||
|
|
|
|||
|
|
@ -338,12 +338,20 @@ def use_sdp_causal(q_len, kv_len, head_dim, query_states, training):
|
|||
return (
|
||||
q_len == kv_len # first token
|
||||
and head_dim in [-1, 64, 80, 96, 128] # for now
|
||||
and query_states.device.type == "xpu" # GPU
|
||||
and query_states.device.type == "xpu" # GPU
|
||||
and query_states.dtype in [torch.float, torch.half] # fp32/fp16
|
||||
and not query_states.requires_grad and not training # not training
|
||||
)
|
||||
|
||||
|
||||
def use_sdp_non_causal(head_dim, device, dtype):
|
||||
return (
|
||||
head_dim in [40, 64, 80]
|
||||
and device.type == "xpu" # GPU
|
||||
and dtype in [torch.float, torch.half] # fp32/fp16
|
||||
)
|
||||
|
||||
|
||||
def mlp_fusion_check(x, qtype, training):
|
||||
invalidInputError(x.dim() == 2,
|
||||
"Here input x's dim should be 2.")
|
||||
|
|
|
|||
Loading…
Reference in a new issue