refactor sd 1.5 and qwen2-vl and fix (#12590)

This commit is contained in:
Yishuo Wang 2024-12-20 17:34:55 +08:00 committed by GitHub
parent b050368efc
commit 098eb335b2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 23 additions and 58 deletions

View file

@ -75,7 +75,7 @@ def siglip_attention_forward(
attn_weights = None attn_weights = None
attn_output = scaled_dot_product_attention( attn_output = scaled_dot_product_attention(
query_states, key_states, value_states, query_states, key_states.contiguous(), value_states.contiguous(),
attention_mask, False, 1 / math.sqrt(self.head_dim) attention_mask, False, 1 / math.sqrt(self.head_dim)
) )

View file

@ -583,8 +583,7 @@ def qwen2_attention_forward(
self.layer_idx, None) self.layer_idx, None)
attn_weights = None attn_weights = None
if query_states.device.type == 'xpu' \ if use_flash_attention(query_states, key_states, attention_mask):
and use_flash_attention(query_states, key_states, attention_mask):
# repeat k/v heads if n_kv_heads < n_heads # repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups) key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups)

View file

@ -43,8 +43,9 @@ from typing import Optional, Tuple, Union, List
import torch import torch
from ipex_llm.transformers.models.common import merge_qkv_base, attention_softmax from ipex_llm.transformers.models.common import merge_qkv_base, attention_softmax
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache from ipex_llm.transformers.models.common import scaled_dot_product_attention
from ipex_llm.transformers.models.utils import use_sdp, use_sdp_causal, should_use_fuse_rope from ipex_llm.transformers.models.utils import use_quantize_kv_cache
from ipex_llm.transformers.models.utils import should_use_fuse_rope
from ipex_llm.transformers.models.utils import use_sdp_non_causal from ipex_llm.transformers.models.utils import use_sdp_non_causal
from ipex_llm.transformers.kv import DynamicFp8Cache, DynamicNormalCache from ipex_llm.transformers.kv import DynamicFp8Cache, DynamicNormalCache
from ipex_llm.utils.common import invalidInputError from ipex_llm.utils.common import invalidInputError
@ -198,7 +199,6 @@ def qwen2_vision_attention_forward(
"unexpected input") "unexpected input")
if use_sdp_non_causal(self.head_dim, q.device, q.dtype): if use_sdp_non_causal(self.head_dim, q.device, q.dtype):
import xe_addons
image_num = len(seq_lens) - 1 image_num = len(seq_lens) - 1
image_size = seq_lens[1] - seq_lens[0] image_size = seq_lens[1] - seq_lens[0]
guessed_seq_lens = torch.arange(0, (image_num + 1) * image_size, image_size, guessed_seq_lens = torch.arange(0, (image_num + 1) * image_size, image_size,
@ -209,7 +209,10 @@ def qwen2_vision_attention_forward(
v = v.view(image_num, image_size, self.num_heads, self.head_dim).permute(0, 2, 1, 3) v = v.view(image_num, image_size, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
# q, k, v: [image_num, num_heads, image_size, head_dim] # q, k, v: [image_num, num_heads, image_size, head_dim]
attn_output = xe_addons.sdp_non_causal(q, k.contiguous(), v.contiguous(), None) attn_output = scaled_dot_product_attention(
q, k.contiguous(), v.contiguous(),
None, False
)
attn_output = attn_output.permute(0, 2, 1, 3).contiguous() attn_output = attn_output.permute(0, 2, 1, 3).contiguous()
attn_output = attn_output.view(seq_length, self.num_heads, self.head_dim) attn_output = attn_output.view(seq_length, self.num_heads, self.head_dim)
# attn_output: [seq_length, num_heads, head_dim] # attn_output: [seq_length, num_heads, head_dim]
@ -226,7 +229,10 @@ def qwen2_vision_attention_forward(
tmp_q = q[:, :, start_idx:end_idx, :] tmp_q = q[:, :, start_idx:end_idx, :]
tmp_k = k[:, :, start_idx:end_idx, :] tmp_k = k[:, :, start_idx:end_idx, :]
tmp_v = v[:, :, start_idx:end_idx, :] tmp_v = v[:, :, start_idx:end_idx, :]
attn_output = xe_addons.sdp_non_causal(tmp_q, tmp_k, tmp_v, None) attn_output = scaled_dot_product_attention(
tmp_q, tmp_k, tmp_v,
None, False
)
attn_output = attn_output.permute(0, 2, 1, 3) attn_output = attn_output.permute(0, 2, 1, 3)
# attn_output: [1, seq_length, num_heads, head_dim] # attn_output: [1, seq_length, num_heads, head_dim]
attn_outputs.append(attn_output) attn_outputs.append(attn_output)
@ -293,42 +299,11 @@ def qwen2_vl_attention_forward(
key_states, value_states = past_key_value.update(key_states, value_states, key_states, value_states = past_key_value.update(key_states, value_states,
self.layer_idx, None) self.layer_idx, None)
kv_seq_len = key_states.size(2)
if attention_mask is not None: # no matter the length, we just slice it
causal_mask = attention_mask[:, :, :, :kv_seq_len]
attn_weights = None attn_weights = None
if use_sdp(q_len, kv_seq_len, self.head_dim, query_states): attn_output = scaled_dot_product_attention(
import xe_addons query_states, key_states, value_states,
if isinstance(past_key_value, DynamicFp8Cache): attention_mask, q_len == key_states.size(2)
attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states, causal_mask) )
else:
attn_output = xe_addons.sdp(query_states, key_states, value_states, causal_mask)
elif use_sdp_causal(q_len, kv_seq_len, self.head_dim, query_states, self.training):
import xe_addons
if isinstance(past_key_value, DynamicFp8Cache):
attn_output = xe_addons.sdp_fp8_causal(query_states, key_states,
value_states, causal_mask)
else:
attn_output = xe_addons.sdp_causal(query_states, key_states,
value_states, causal_mask)
else:
if isinstance(past_key_value, DynamicFp8Cache):
key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
query_states.dtype)
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
attn_weights = torch.matmul(query_states,
key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if causal_mask is not None:
attn_weights = attn_weights + causal_mask
# upcast attention to fp32
attn_weights = attention_softmax(attn_weights)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, -1) attn_output = attn_output.reshape(bsz, q_len, -1)

View file

@ -37,8 +37,8 @@ import torch
from typing import Optional from typing import Optional
from ipex_llm.transformers.utils import get_xpu_device_type from ipex_llm.transformers.utils import get_xpu_device_type
from ipex_llm.transformers.models.common import padding_qkv_hd, attention_softmax from ipex_llm.transformers.models.common import padding_qkv_hd
from ipex_llm.transformers.models.utils import use_sdp_non_causal from ipex_llm.transformers.models.common import scaled_dot_product_attention
from diffusers.models.attention_processor import Attention from diffusers.models.attention_processor import Attention
@ -110,19 +110,10 @@ class AttnProcessor2_0:
if query.device.type == "xpu" and query.dtype in [torch.half, torch.float]: if query.device.type == "xpu" and query.dtype in [torch.half, torch.float]:
# padding head_dim 40 to 64 # padding head_dim 40 to 64
query, key, value = padding_qkv_hd(query, key, value, 40, 64) query, key, value = padding_qkv_hd(query, key, value, 40, 64)
hidden_states = scaled_dot_product_attention(
if use_sdp_non_causal(query.size(-1), query.device, query.dtype): query, key.contiguous(), value.contiguous(),
import xe_addons attention_mask, False, 1 / math.sqrt(head_dim)
hidden_states = xe_addons.sdp_non_causal(query, key.contiguous(), )
value.contiguous(), attention_mask)
else:
scale = 1 / math.sqrt(head_dim)
attn_weights = torch.matmul(query * scale, key.transpose(-1, -2))
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
attn_weights = attention_softmax(attn_weights)
hidden_states = torch.matmul(attn_weights, value)
hidden_states = hidden_states[:, :, :, :head_dim] hidden_states = hidden_states[:, :, :, :head_dim]
else: else:
hidden_states = torch.nn.functional.scaled_dot_product_attention( hidden_states = torch.nn.functional.scaled_dot_product_attention(