diff --git a/python/llm/src/ipex_llm/transformers/models/qwen_vl.py b/python/llm/src/ipex_llm/transformers/models/qwen_vl.py index 24ac0767..c518e1a6 100644 --- a/python/llm/src/ipex_llm/transformers/models/qwen_vl.py +++ b/python/llm/src/ipex_llm/transformers/models/qwen_vl.py @@ -33,7 +33,6 @@ from transformers.utils import logging from ipex_llm.transformers.models.utils import extend_kv_cache, init_kv_cache, append_kv_cache from ipex_llm.transformers.models.utils import rotate_half from ipex_llm.transformers.models.utils import use_sdp -from ipex_llm.transformers.models.utils import use_decoding_fast_path import os @@ -91,21 +90,31 @@ def qwen_attention_forward_vl( device = hidden_states.device use_fuse_rope = should_use_fuse_rope(self, hidden_states) - decoding_fast_path = use_decoding_fast_path(self.q_proj, - use_fuse_rope, - True, - bsz * q_len) - if decoding_fast_path: - hidden_states = hidden_states.view(1, -1) - cache_k, cache_v = layer_past[0], layer_past[1] - cache_k = cache_k.transpose(1, 2) - cache_v = cache_v.transpose(1, 2) + mixed_x_layer = self.c_attn(hidden_states) + query, key, value = mixed_x_layer.split(self.split_size, dim=2) - kv_seq_len = cache_k.shape[-2] - self.position_ids = self.position_ids.to(device) - position_ids = self.position_ids[kv_seq_len] - base = self.rope_base - if is_enough_kv_cache_room(layer_past, kv_seq_len): + query = self._split_heads(query, self.num_heads, self.head_dim) + key = self._split_heads(key, self.num_heads, self.head_dim) + value = self._split_heads(value, self.num_heads, self.head_dim) + if rotary_pos_emb is not None: + cur_len = query.shape[1] + rotary_pos_emb = [i[:, -cur_len:, :, :] for i in rotary_pos_emb] + rotary_pos_emb = (rotary_pos_emb,) * 2 + q_pos_emb, k_pos_emb = rotary_pos_emb + # Slice the pos emb for current inference + query = apply_rotary_pos_emb(query, q_pos_emb) + key = apply_rotary_pos_emb(key, k_pos_emb) + query_size, key_size = query.size(1), key.size(1) + + if layer_past is not None: + kv_seq_len += layer_past[0].shape[1] + # past_key, past_value = layer_past[0], layer_past[1] + # key = torch.cat((past_key, key), dim=1) + # value = torch.cat((past_value, value), dim=1) + cache_k = layer_past[0].transpose(1, 2) + cache_v = layer_past[1].transpose(1, 2) + if cache_k.stride()[1] < kv_seq_len * cache_k.size(3): + # allocate new new_cache_k, new_cache_v = extend_kv_cache(bsz, self.num_heads, self.head_dim, @@ -118,61 +127,10 @@ def qwen_attention_forward_vl( cache_k = new_cache_k cache_v = new_cache_v - args = [hidden_states, self.q_proj.weight.data, self.k_proj.weight.data, - self.v_proj.weight.data, self.q_proj.bias.data, self.k_proj.bias.data, - self.v_proj.bias.data, position_ids, cache_k, cache_v, self.q_proj.weight.qtype, - self.v_proj.weight.qtype, kv_seq_len, self.head_dim, base] - import linear_q4_0 - query, key, value = linear_q4_0.forward_qkv_bias(*args) - kv_seq_len += 1 - query_size, key_size = 1, 1 - else: - query = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim) - key = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim) - value = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim) - # TODO: speed up - # mixed_x_layer = self.c_attn(hidden_states) - # query, key, value = mixed_x_layer.split(self.split_size, dim=2) - - # query = self._split_heads(query, self.num_heads, self.head_dim) - # key = self._split_heads(key, self.num_heads, self.head_dim) - # value = self._split_heads(value, self.num_heads, self.head_dim) - if rotary_pos_emb is not None: - cur_len = query.shape[1] - rotary_pos_emb = [i[:, -cur_len:, :, :] for i in rotary_pos_emb] - rotary_pos_emb = (rotary_pos_emb,) * 2 - q_pos_emb, k_pos_emb = rotary_pos_emb - # Slice the pos emb for current inference - query = apply_rotary_pos_emb(query, q_pos_emb) - key = apply_rotary_pos_emb(key, k_pos_emb) - query_size, key_size = query.size(1), key.size(1) - - if layer_past is not None: - if not decoding_fast_path: - kv_seq_len += layer_past[0].shape[1] - # past_key, past_value = layer_past[0], layer_past[1] - # key = torch.cat((past_key, key), dim=1) - # value = torch.cat((past_value, value), dim=1) - cache_k = layer_past[0].transpose(1, 2) - cache_v = layer_past[1].transpose(1, 2) - if cache_k.stride()[1] < kv_seq_len * cache_k.size(3): - # allocate new - new_cache_k, new_cache_v = extend_kv_cache(bsz, - self.num_heads, - self.head_dim, - cache_k.size(2), - kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH, - dtype=cache_k.dtype, - device=hidden_states.device) - new_cache_k[:] = cache_k - new_cache_v[:] = cache_v - cache_k = new_cache_k - cache_v = new_cache_v - - key_states, value_states = append_kv_cache(cache_k, cache_v, - key.transpose(1, 2), value.transpose(1, 2)) - key = key_states - value = value_states + key_states, value_states = append_kv_cache(cache_k, cache_v, + key.transpose(1, 2), value.transpose(1, 2)) + key = key_states + value = value_states elif use_cache: max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH new_key_states, new_value_states = init_kv_cache(bsz, @@ -192,10 +150,6 @@ def qwen_attention_forward_vl( else: present = None - if decoding_fast_path: - # change to (bsz, q_len, num_heads, head_dim) - query = query.transpose(1, 2) - if self.use_logn_attn and not self.training: if self.logn_tensor.device != query.device or self.logn_tensor.dtype != query.dtype: self.logn_tensor = self.logn_tensor.to(query.device).type_as(query)