From e4ceb722b6f82673c2b704ae65d85c3aabfb960d Mon Sep 17 00:00:00 2001 From: Yishuo Wang Date: Mon, 10 Feb 2025 13:25:53 +0800 Subject: [PATCH] fix qwen2 vl (#12798) --- .../src/ipex_llm/transformers/models/qwen2_vl.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/models/qwen2_vl.py b/python/llm/src/ipex_llm/transformers/models/qwen2_vl.py index d885e23a..512e39f9 100644 --- a/python/llm/src/ipex_llm/transformers/models/qwen2_vl.py +++ b/python/llm/src/ipex_llm/transformers/models/qwen2_vl.py @@ -200,15 +200,16 @@ def qwen2_vision_attention_forward( invalidInputError(seq_lens[0] == 0 and seq_lens[-1] == seq_length, "unexpected input") - if use_sdp_non_causal(self.head_dim, q.device, q.dtype): + head_dim = q.size(-1) + if use_sdp_non_causal(head_dim, q.device, q.dtype): image_num = len(seq_lens) - 1 image_size = seq_lens[1] - seq_lens[0] guessed_seq_lens = torch.arange(0, (image_num + 1) * image_size, image_size, dtype=cu_seqlens.dtype, device=cu_seqlens.device) if (guessed_seq_lens == cu_seqlens).all(): - q = q.view(image_num, image_size, self.num_heads, self.head_dim).permute(0, 2, 1, 3) - k = k.view(image_num, image_size, self.num_heads, self.head_dim).permute(0, 2, 1, 3) - v = v.view(image_num, image_size, self.num_heads, self.head_dim).permute(0, 2, 1, 3) + q = q.view(image_num, image_size, self.num_heads, head_dim).permute(0, 2, 1, 3) + k = k.view(image_num, image_size, self.num_heads, head_dim).permute(0, 2, 1, 3) + v = v.view(image_num, image_size, self.num_heads, head_dim).permute(0, 2, 1, 3) # q, k, v: [image_num, num_heads, image_size, head_dim] attn_output = scaled_dot_product_attention( @@ -216,7 +217,7 @@ def qwen2_vision_attention_forward( None, False ) attn_output = attn_output.permute(0, 2, 1, 3).contiguous() - attn_output = attn_output.view(seq_length, self.num_heads, self.head_dim) + attn_output = attn_output.view(seq_length, self.num_heads, head_dim) # attn_output: [seq_length, num_heads, head_dim] else: q = q.transpose(0, 1).unsqueeze(0) @@ -252,7 +253,7 @@ def qwen2_vision_attention_forward( v = v.transpose(0, 1) # q, k, v: [num_heads, seq_length, head_dim] - attn_weights = torch.matmul(q, k.transpose(1, 2)) / math.sqrt(self.head_dim) + attn_weights = torch.matmul(q, k.transpose(1, 2)) / math.sqrt(head_dim) attn_weights = attn_weights + attention_mask attn_weights = attention_softmax(attn_weights) attn_output = torch.matmul(attn_weights, v)