diff --git a/python/llm/src/ipex_llm/transformers/models/qwen_vl.py b/python/llm/src/ipex_llm/transformers/models/qwen_vl.py index 4cdfc7ee..34f79052 100644 --- a/python/llm/src/ipex_llm/transformers/models/qwen_vl.py +++ b/python/llm/src/ipex_llm/transformers/models/qwen_vl.py @@ -126,7 +126,13 @@ def qwen_attention_forward_vl( query = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim) key = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim) value = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim) + # TODO: speed up + # mixed_x_layer = self.c_attn(hidden_states) + # query, key, value = mixed_x_layer.split(self.split_size, dim=2) + # query = self._split_heads(query, self.num_heads, self.head_dim) + # key = self._split_heads(key, self.num_heads, self.head_dim) + # value = self._split_heads(value, self.num_heads, self.head_dim) if rotary_pos_emb is not None: cur_len = query.shape[1] rotary_pos_emb = [i[:, -cur_len:, :, :] for i in rotary_pos_emb] @@ -183,7 +189,8 @@ def qwen_attention_forward_vl( present = None if decoding_fast_path: - query = query.transpose(1, 2) # change to (bsz, q_len, num_heads, head_dim) + # change to (bsz, q_len, num_heads, head_dim) + query = query.transpose(1, 2) if self.use_logn_attn and not self.training: if self.logn_tensor.device != query.device or self.logn_tensor.dtype != query.dtype: @@ -196,7 +203,7 @@ def qwen_attention_forward_vl( query = query.permute(0, 2, 1, 3) if not self.training and not hidden_states.requires_grad and \ - use_esimd_sdp(q_len, key.shape[2], self.head_dim, query): + use_esimd_sdp(q_len, key.shape[2], self.head_dim, query): import linear_fp16_esimd attn_output = linear_fp16_esimd.sdp_forward(query, key, @@ -209,7 +216,6 @@ def qwen_attention_forward_vl( query, key, value, registered_causal_mask, attention_mask, head_mask ) - context_layer = self._merge_heads( attn_output, self.num_heads, self.head_dim )