fix qwen vl (#11090)

This commit is contained in:
Xin Qiu 2024-05-21 18:40:29 +08:00 committed by GitHub
parent f654f7e08c
commit 71bcd18f44
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -33,7 +33,6 @@ from transformers.utils import logging
from ipex_llm.transformers.models.utils import extend_kv_cache, init_kv_cache, append_kv_cache from ipex_llm.transformers.models.utils import extend_kv_cache, init_kv_cache, append_kv_cache
from ipex_llm.transformers.models.utils import rotate_half from ipex_llm.transformers.models.utils import rotate_half
from ipex_llm.transformers.models.utils import use_sdp from ipex_llm.transformers.models.utils import use_sdp
from ipex_llm.transformers.models.utils import use_decoding_fast_path
import os import os
@ -91,21 +90,31 @@ def qwen_attention_forward_vl(
device = hidden_states.device device = hidden_states.device
use_fuse_rope = should_use_fuse_rope(self, hidden_states) use_fuse_rope = should_use_fuse_rope(self, hidden_states)
decoding_fast_path = use_decoding_fast_path(self.q_proj, mixed_x_layer = self.c_attn(hidden_states)
use_fuse_rope, query, key, value = mixed_x_layer.split(self.split_size, dim=2)
True,
bsz * q_len)
if decoding_fast_path:
hidden_states = hidden_states.view(1, -1)
cache_k, cache_v = layer_past[0], layer_past[1]
cache_k = cache_k.transpose(1, 2)
cache_v = cache_v.transpose(1, 2)
kv_seq_len = cache_k.shape[-2] query = self._split_heads(query, self.num_heads, self.head_dim)
self.position_ids = self.position_ids.to(device) key = self._split_heads(key, self.num_heads, self.head_dim)
position_ids = self.position_ids[kv_seq_len] value = self._split_heads(value, self.num_heads, self.head_dim)
base = self.rope_base if rotary_pos_emb is not None:
if is_enough_kv_cache_room(layer_past, kv_seq_len): cur_len = query.shape[1]
rotary_pos_emb = [i[:, -cur_len:, :, :] for i in rotary_pos_emb]
rotary_pos_emb = (rotary_pos_emb,) * 2
q_pos_emb, k_pos_emb = rotary_pos_emb
# Slice the pos emb for current inference
query = apply_rotary_pos_emb(query, q_pos_emb)
key = apply_rotary_pos_emb(key, k_pos_emb)
query_size, key_size = query.size(1), key.size(1)
if layer_past is not None:
kv_seq_len += layer_past[0].shape[1]
# past_key, past_value = layer_past[0], layer_past[1]
# key = torch.cat((past_key, key), dim=1)
# value = torch.cat((past_value, value), dim=1)
cache_k = layer_past[0].transpose(1, 2)
cache_v = layer_past[1].transpose(1, 2)
if cache_k.stride()[1] < kv_seq_len * cache_k.size(3):
# allocate new
new_cache_k, new_cache_v = extend_kv_cache(bsz, new_cache_k, new_cache_v = extend_kv_cache(bsz,
self.num_heads, self.num_heads,
self.head_dim, self.head_dim,
@ -118,61 +127,10 @@ def qwen_attention_forward_vl(
cache_k = new_cache_k cache_k = new_cache_k
cache_v = new_cache_v cache_v = new_cache_v
args = [hidden_states, self.q_proj.weight.data, self.k_proj.weight.data, key_states, value_states = append_kv_cache(cache_k, cache_v,
self.v_proj.weight.data, self.q_proj.bias.data, self.k_proj.bias.data, key.transpose(1, 2), value.transpose(1, 2))
self.v_proj.bias.data, position_ids, cache_k, cache_v, self.q_proj.weight.qtype, key = key_states
self.v_proj.weight.qtype, kv_seq_len, self.head_dim, base] value = value_states
import linear_q4_0
query, key, value = linear_q4_0.forward_qkv_bias(*args)
kv_seq_len += 1
query_size, key_size = 1, 1
else:
query = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)
key = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)
value = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)
# TODO: speed up
# mixed_x_layer = self.c_attn(hidden_states)
# query, key, value = mixed_x_layer.split(self.split_size, dim=2)
# query = self._split_heads(query, self.num_heads, self.head_dim)
# key = self._split_heads(key, self.num_heads, self.head_dim)
# value = self._split_heads(value, self.num_heads, self.head_dim)
if rotary_pos_emb is not None:
cur_len = query.shape[1]
rotary_pos_emb = [i[:, -cur_len:, :, :] for i in rotary_pos_emb]
rotary_pos_emb = (rotary_pos_emb,) * 2
q_pos_emb, k_pos_emb = rotary_pos_emb
# Slice the pos emb for current inference
query = apply_rotary_pos_emb(query, q_pos_emb)
key = apply_rotary_pos_emb(key, k_pos_emb)
query_size, key_size = query.size(1), key.size(1)
if layer_past is not None:
if not decoding_fast_path:
kv_seq_len += layer_past[0].shape[1]
# past_key, past_value = layer_past[0], layer_past[1]
# key = torch.cat((past_key, key), dim=1)
# value = torch.cat((past_value, value), dim=1)
cache_k = layer_past[0].transpose(1, 2)
cache_v = layer_past[1].transpose(1, 2)
if cache_k.stride()[1] < kv_seq_len * cache_k.size(3):
# allocate new
new_cache_k, new_cache_v = extend_kv_cache(bsz,
self.num_heads,
self.head_dim,
cache_k.size(2),
kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH,
dtype=cache_k.dtype,
device=hidden_states.device)
new_cache_k[:] = cache_k
new_cache_v[:] = cache_v
cache_k = new_cache_k
cache_v = new_cache_v
key_states, value_states = append_kv_cache(cache_k, cache_v,
key.transpose(1, 2), value.transpose(1, 2))
key = key_states
value = value_states
elif use_cache: elif use_cache:
max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH
new_key_states, new_value_states = init_kv_cache(bsz, new_key_states, new_value_states = init_kv_cache(bsz,
@ -192,10 +150,6 @@ def qwen_attention_forward_vl(
else: else:
present = None present = None
if decoding_fast_path:
# change to (bsz, q_len, num_heads, head_dim)
query = query.transpose(1, 2)
if self.use_logn_attn and not self.training: if self.use_logn_attn and not self.training:
if self.logn_tensor.device != query.device or self.logn_tensor.dtype != query.dtype: if self.logn_tensor.device != query.device or self.logn_tensor.dtype != query.dtype:
self.logn_tensor = self.logn_tensor.to(query.device).type_as(query) self.logn_tensor = self.logn_tensor.to(query.device).type_as(query)