fix qwen vl (#11090)

This commit is contained in:
Xin Qiu 2024-05-21 18:40:29 +08:00 committed by GitHub
parent f654f7e08c
commit 71bcd18f44
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -33,7 +33,6 @@ from transformers.utils import logging
from ipex_llm.transformers.models.utils import extend_kv_cache, init_kv_cache, append_kv_cache from ipex_llm.transformers.models.utils import extend_kv_cache, init_kv_cache, append_kv_cache
from ipex_llm.transformers.models.utils import rotate_half from ipex_llm.transformers.models.utils import rotate_half
from ipex_llm.transformers.models.utils import use_sdp from ipex_llm.transformers.models.utils import use_sdp
from ipex_llm.transformers.models.utils import use_decoding_fast_path
import os import os
@ -91,52 +90,12 @@ def qwen_attention_forward_vl(
device = hidden_states.device device = hidden_states.device
use_fuse_rope = should_use_fuse_rope(self, hidden_states) use_fuse_rope = should_use_fuse_rope(self, hidden_states)
decoding_fast_path = use_decoding_fast_path(self.q_proj, mixed_x_layer = self.c_attn(hidden_states)
use_fuse_rope, query, key, value = mixed_x_layer.split(self.split_size, dim=2)
True,
bsz * q_len)
if decoding_fast_path:
hidden_states = hidden_states.view(1, -1)
cache_k, cache_v = layer_past[0], layer_past[1]
cache_k = cache_k.transpose(1, 2)
cache_v = cache_v.transpose(1, 2)
kv_seq_len = cache_k.shape[-2] query = self._split_heads(query, self.num_heads, self.head_dim)
self.position_ids = self.position_ids.to(device) key = self._split_heads(key, self.num_heads, self.head_dim)
position_ids = self.position_ids[kv_seq_len] value = self._split_heads(value, self.num_heads, self.head_dim)
base = self.rope_base
if is_enough_kv_cache_room(layer_past, kv_seq_len):
new_cache_k, new_cache_v = extend_kv_cache(bsz,
self.num_heads,
self.head_dim,
cache_k.size(2),
kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH,
dtype=cache_k.dtype,
device=hidden_states.device)
new_cache_k[:] = cache_k
new_cache_v[:] = cache_v
cache_k = new_cache_k
cache_v = new_cache_v
args = [hidden_states, self.q_proj.weight.data, self.k_proj.weight.data,
self.v_proj.weight.data, self.q_proj.bias.data, self.k_proj.bias.data,
self.v_proj.bias.data, position_ids, cache_k, cache_v, self.q_proj.weight.qtype,
self.v_proj.weight.qtype, kv_seq_len, self.head_dim, base]
import linear_q4_0
query, key, value = linear_q4_0.forward_qkv_bias(*args)
kv_seq_len += 1
query_size, key_size = 1, 1
else:
query = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)
key = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)
value = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)
# TODO: speed up
# mixed_x_layer = self.c_attn(hidden_states)
# query, key, value = mixed_x_layer.split(self.split_size, dim=2)
# query = self._split_heads(query, self.num_heads, self.head_dim)
# key = self._split_heads(key, self.num_heads, self.head_dim)
# value = self._split_heads(value, self.num_heads, self.head_dim)
if rotary_pos_emb is not None: if rotary_pos_emb is not None:
cur_len = query.shape[1] cur_len = query.shape[1]
rotary_pos_emb = [i[:, -cur_len:, :, :] for i in rotary_pos_emb] rotary_pos_emb = [i[:, -cur_len:, :, :] for i in rotary_pos_emb]
@ -148,7 +107,6 @@ def qwen_attention_forward_vl(
query_size, key_size = query.size(1), key.size(1) query_size, key_size = query.size(1), key.size(1)
if layer_past is not None: if layer_past is not None:
if not decoding_fast_path:
kv_seq_len += layer_past[0].shape[1] kv_seq_len += layer_past[0].shape[1]
# past_key, past_value = layer_past[0], layer_past[1] # past_key, past_value = layer_past[0], layer_past[1]
# key = torch.cat((past_key, key), dim=1) # key = torch.cat((past_key, key), dim=1)
@ -192,10 +150,6 @@ def qwen_attention_forward_vl(
else: else:
present = None present = None
if decoding_fast_path:
# change to (bsz, q_len, num_heads, head_dim)
query = query.transpose(1, 2)
if self.use_logn_attn and not self.training: if self.use_logn_attn and not self.training:
if self.logn_tensor.device != query.device or self.logn_tensor.dtype != query.dtype: if self.logn_tensor.device != query.device or self.logn_tensor.dtype != query.dtype:
self.logn_tensor = self.logn_tensor.to(query.device).type_as(query) self.logn_tensor = self.logn_tensor.to(query.device).type_as(query)