fix qwen2 vl (#12126)

This commit is contained in:
Yishuo Wang 2024-09-26 15:44:02 +08:00 committed by GitHub
parent 2ea13d502f
commit 66f419f8b7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -42,7 +42,7 @@ from typing import Optional, Tuple, Union, List
import torch import torch
from ipex_llm.transformers.models.common import merge_qkv_base from ipex_llm.transformers.models.common import merge_qkv_base, attention_softmax
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache
from ipex_llm.transformers.models.utils import use_sdp, use_sdp_causal, should_use_fuse_rope from ipex_llm.transformers.models.utils import use_sdp, use_sdp_causal, should_use_fuse_rope
from ipex_llm.transformers.kv import DynamicFp8Cache, DynamicNormalCache from ipex_llm.transformers.kv import DynamicFp8Cache, DynamicNormalCache
@ -207,29 +207,29 @@ def qwen2_vl_attention_forward(
query_states, key_states, cos, sin, self.rope_scaling["mrope_section"] query_states, key_states, cos, sin, self.rope_scaling["mrope_section"]
) )
kv_seq_len = key_states.shape[-2]
if past_key_value is not None: if past_key_value is not None:
key_states, value_states = past_key_value.update(key_states, value_states, key_states, value_states = past_key_value.update(key_states, value_states,
self.layer_idx, None) self.layer_idx, None)
kv_seq_len = key_states.shape[-2]
kv_seq_len = key_states.size(2)
if attention_mask is not None: # no matter the length, we just slice it
causal_mask = attention_mask[:, :, :, :kv_seq_len]
attn_weights = None attn_weights = None
if use_sdp(q_len, kv_seq_len, self.head_dim, query_states): if use_sdp(q_len, kv_seq_len, self.head_dim, query_states):
import xe_addons import xe_addons
if isinstance(past_key_value, DynamicFp8Cache): if isinstance(past_key_value, DynamicFp8Cache):
attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states, attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states, causal_mask)
attention_mask)
else: else:
attn_output = xe_addons.sdp(query_states, key_states, value_states, attn_output = xe_addons.sdp(query_states, key_states, value_states, causal_mask)
attention_mask)
elif use_sdp_causal(q_len, kv_seq_len, self.head_dim, query_states, self.training): elif use_sdp_causal(q_len, kv_seq_len, self.head_dim, query_states, self.training):
import xe_addons import xe_addons
if isinstance(past_key_value, DynamicFp8Cache): if isinstance(past_key_value, DynamicFp8Cache):
attn_output = xe_addons.sdp_fp8_causal(query_states, key_states, attn_output = xe_addons.sdp_fp8_causal(query_states, key_states,
value_states, attention_mask) value_states, causal_mask)
else: else:
attn_output = xe_addons.sdp_causal(query_states, key_states, attn_output = xe_addons.sdp_causal(query_states, key_states,
value_states, attention_mask) value_states, causal_mask)
else: else:
if isinstance(past_key_value, DynamicFp8Cache): if isinstance(past_key_value, DynamicFp8Cache):
key_states, value_states = restore_fp8_kv_cache(key_states, value_states, key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
@ -241,15 +241,11 @@ def qwen2_vl_attention_forward(
attn_weights = torch.matmul(query_states, attn_weights = torch.matmul(query_states,
key_states.transpose(2, 3)) / math.sqrt(self.head_dim) key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if attention_mask is not None: # no matter the length, we just slice it if causal_mask is not None:
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask attn_weights = attn_weights + causal_mask
# upcast attention to fp32 # upcast attention to fp32
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, attn_weights = attention_softmax(attn_weights, self.training)
dtype=torch.float32).to(query_states.dtype)
attn_weights = torch.nn.functional.dropout(attn_weights, p=self.attention_dropout,
training=self.training)
attn_output = torch.matmul(attn_weights, value_states) attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.transpose(1, 2).contiguous()