refactor sd 1.5 and qwen2-vl and fix (#12590)
This commit is contained in:
parent
b050368efc
commit
098eb335b2
4 changed files with 23 additions and 58 deletions
|
|
@ -75,7 +75,7 @@ def siglip_attention_forward(
|
||||||
|
|
||||||
attn_weights = None
|
attn_weights = None
|
||||||
attn_output = scaled_dot_product_attention(
|
attn_output = scaled_dot_product_attention(
|
||||||
query_states, key_states, value_states,
|
query_states, key_states.contiguous(), value_states.contiguous(),
|
||||||
attention_mask, False, 1 / math.sqrt(self.head_dim)
|
attention_mask, False, 1 / math.sqrt(self.head_dim)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -583,8 +583,7 @@ def qwen2_attention_forward(
|
||||||
self.layer_idx, None)
|
self.layer_idx, None)
|
||||||
|
|
||||||
attn_weights = None
|
attn_weights = None
|
||||||
if query_states.device.type == 'xpu' \
|
if use_flash_attention(query_states, key_states, attention_mask):
|
||||||
and use_flash_attention(query_states, key_states, attention_mask):
|
|
||||||
# repeat k/v heads if n_kv_heads < n_heads
|
# repeat k/v heads if n_kv_heads < n_heads
|
||||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||||
|
|
|
||||||
|
|
@ -43,8 +43,9 @@ from typing import Optional, Tuple, Union, List
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from ipex_llm.transformers.models.common import merge_qkv_base, attention_softmax
|
from ipex_llm.transformers.models.common import merge_qkv_base, attention_softmax
|
||||||
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache
|
from ipex_llm.transformers.models.common import scaled_dot_product_attention
|
||||||
from ipex_llm.transformers.models.utils import use_sdp, use_sdp_causal, should_use_fuse_rope
|
from ipex_llm.transformers.models.utils import use_quantize_kv_cache
|
||||||
|
from ipex_llm.transformers.models.utils import should_use_fuse_rope
|
||||||
from ipex_llm.transformers.models.utils import use_sdp_non_causal
|
from ipex_llm.transformers.models.utils import use_sdp_non_causal
|
||||||
from ipex_llm.transformers.kv import DynamicFp8Cache, DynamicNormalCache
|
from ipex_llm.transformers.kv import DynamicFp8Cache, DynamicNormalCache
|
||||||
from ipex_llm.utils.common import invalidInputError
|
from ipex_llm.utils.common import invalidInputError
|
||||||
|
|
@ -198,7 +199,6 @@ def qwen2_vision_attention_forward(
|
||||||
"unexpected input")
|
"unexpected input")
|
||||||
|
|
||||||
if use_sdp_non_causal(self.head_dim, q.device, q.dtype):
|
if use_sdp_non_causal(self.head_dim, q.device, q.dtype):
|
||||||
import xe_addons
|
|
||||||
image_num = len(seq_lens) - 1
|
image_num = len(seq_lens) - 1
|
||||||
image_size = seq_lens[1] - seq_lens[0]
|
image_size = seq_lens[1] - seq_lens[0]
|
||||||
guessed_seq_lens = torch.arange(0, (image_num + 1) * image_size, image_size,
|
guessed_seq_lens = torch.arange(0, (image_num + 1) * image_size, image_size,
|
||||||
|
|
@ -209,7 +209,10 @@ def qwen2_vision_attention_forward(
|
||||||
v = v.view(image_num, image_size, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
|
v = v.view(image_num, image_size, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
|
||||||
# q, k, v: [image_num, num_heads, image_size, head_dim]
|
# q, k, v: [image_num, num_heads, image_size, head_dim]
|
||||||
|
|
||||||
attn_output = xe_addons.sdp_non_causal(q, k.contiguous(), v.contiguous(), None)
|
attn_output = scaled_dot_product_attention(
|
||||||
|
q, k.contiguous(), v.contiguous(),
|
||||||
|
None, False
|
||||||
|
)
|
||||||
attn_output = attn_output.permute(0, 2, 1, 3).contiguous()
|
attn_output = attn_output.permute(0, 2, 1, 3).contiguous()
|
||||||
attn_output = attn_output.view(seq_length, self.num_heads, self.head_dim)
|
attn_output = attn_output.view(seq_length, self.num_heads, self.head_dim)
|
||||||
# attn_output: [seq_length, num_heads, head_dim]
|
# attn_output: [seq_length, num_heads, head_dim]
|
||||||
|
|
@ -226,7 +229,10 @@ def qwen2_vision_attention_forward(
|
||||||
tmp_q = q[:, :, start_idx:end_idx, :]
|
tmp_q = q[:, :, start_idx:end_idx, :]
|
||||||
tmp_k = k[:, :, start_idx:end_idx, :]
|
tmp_k = k[:, :, start_idx:end_idx, :]
|
||||||
tmp_v = v[:, :, start_idx:end_idx, :]
|
tmp_v = v[:, :, start_idx:end_idx, :]
|
||||||
attn_output = xe_addons.sdp_non_causal(tmp_q, tmp_k, tmp_v, None)
|
attn_output = scaled_dot_product_attention(
|
||||||
|
tmp_q, tmp_k, tmp_v,
|
||||||
|
None, False
|
||||||
|
)
|
||||||
attn_output = attn_output.permute(0, 2, 1, 3)
|
attn_output = attn_output.permute(0, 2, 1, 3)
|
||||||
# attn_output: [1, seq_length, num_heads, head_dim]
|
# attn_output: [1, seq_length, num_heads, head_dim]
|
||||||
attn_outputs.append(attn_output)
|
attn_outputs.append(attn_output)
|
||||||
|
|
@ -293,42 +299,11 @@ def qwen2_vl_attention_forward(
|
||||||
key_states, value_states = past_key_value.update(key_states, value_states,
|
key_states, value_states = past_key_value.update(key_states, value_states,
|
||||||
self.layer_idx, None)
|
self.layer_idx, None)
|
||||||
|
|
||||||
kv_seq_len = key_states.size(2)
|
|
||||||
if attention_mask is not None: # no matter the length, we just slice it
|
|
||||||
causal_mask = attention_mask[:, :, :, :kv_seq_len]
|
|
||||||
|
|
||||||
attn_weights = None
|
attn_weights = None
|
||||||
if use_sdp(q_len, kv_seq_len, self.head_dim, query_states):
|
attn_output = scaled_dot_product_attention(
|
||||||
import xe_addons
|
query_states, key_states, value_states,
|
||||||
if isinstance(past_key_value, DynamicFp8Cache):
|
attention_mask, q_len == key_states.size(2)
|
||||||
attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states, causal_mask)
|
)
|
||||||
else:
|
|
||||||
attn_output = xe_addons.sdp(query_states, key_states, value_states, causal_mask)
|
|
||||||
elif use_sdp_causal(q_len, kv_seq_len, self.head_dim, query_states, self.training):
|
|
||||||
import xe_addons
|
|
||||||
if isinstance(past_key_value, DynamicFp8Cache):
|
|
||||||
attn_output = xe_addons.sdp_fp8_causal(query_states, key_states,
|
|
||||||
value_states, causal_mask)
|
|
||||||
else:
|
|
||||||
attn_output = xe_addons.sdp_causal(query_states, key_states,
|
|
||||||
value_states, causal_mask)
|
|
||||||
else:
|
|
||||||
if isinstance(past_key_value, DynamicFp8Cache):
|
|
||||||
key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
|
|
||||||
query_states.dtype)
|
|
||||||
# repeat k/v heads if n_kv_heads < n_heads
|
|
||||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
|
||||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
|
||||||
|
|
||||||
attn_weights = torch.matmul(query_states,
|
|
||||||
key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
|
||||||
|
|
||||||
if causal_mask is not None:
|
|
||||||
attn_weights = attn_weights + causal_mask
|
|
||||||
|
|
||||||
# upcast attention to fp32
|
|
||||||
attn_weights = attention_softmax(attn_weights)
|
|
||||||
attn_output = torch.matmul(attn_weights, value_states)
|
|
||||||
|
|
||||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||||
attn_output = attn_output.reshape(bsz, q_len, -1)
|
attn_output = attn_output.reshape(bsz, q_len, -1)
|
||||||
|
|
|
||||||
|
|
@ -37,8 +37,8 @@ import torch
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from ipex_llm.transformers.utils import get_xpu_device_type
|
from ipex_llm.transformers.utils import get_xpu_device_type
|
||||||
from ipex_llm.transformers.models.common import padding_qkv_hd, attention_softmax
|
from ipex_llm.transformers.models.common import padding_qkv_hd
|
||||||
from ipex_llm.transformers.models.utils import use_sdp_non_causal
|
from ipex_llm.transformers.models.common import scaled_dot_product_attention
|
||||||
from diffusers.models.attention_processor import Attention
|
from diffusers.models.attention_processor import Attention
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -110,19 +110,10 @@ class AttnProcessor2_0:
|
||||||
if query.device.type == "xpu" and query.dtype in [torch.half, torch.float]:
|
if query.device.type == "xpu" and query.dtype in [torch.half, torch.float]:
|
||||||
# padding head_dim 40 to 64
|
# padding head_dim 40 to 64
|
||||||
query, key, value = padding_qkv_hd(query, key, value, 40, 64)
|
query, key, value = padding_qkv_hd(query, key, value, 40, 64)
|
||||||
|
hidden_states = scaled_dot_product_attention(
|
||||||
if use_sdp_non_causal(query.size(-1), query.device, query.dtype):
|
query, key.contiguous(), value.contiguous(),
|
||||||
import xe_addons
|
attention_mask, False, 1 / math.sqrt(head_dim)
|
||||||
hidden_states = xe_addons.sdp_non_causal(query, key.contiguous(),
|
)
|
||||||
value.contiguous(), attention_mask)
|
|
||||||
else:
|
|
||||||
scale = 1 / math.sqrt(head_dim)
|
|
||||||
attn_weights = torch.matmul(query * scale, key.transpose(-1, -2))
|
|
||||||
if attention_mask is not None:
|
|
||||||
attn_weights = attn_weights + attention_mask
|
|
||||||
attn_weights = attention_softmax(attn_weights)
|
|
||||||
hidden_states = torch.matmul(attn_weights, value)
|
|
||||||
|
|
||||||
hidden_states = hidden_states[:, :, :, :head_dim]
|
hidden_states = hidden_states[:, :, :, :head_dim]
|
||||||
else:
|
else:
|
||||||
hidden_states = torch.nn.functional.scaled_dot_product_attention(
|
hidden_states = torch.nn.functional.scaled_dot_product_attention(
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue