From 65934c9f4f89ed334540b835f78a176ea6395489 Mon Sep 17 00:00:00 2001 From: Ziteng Zhang <87107332+Jasonzzt@users.noreply.github.com> Date: Tue, 5 Dec 2023 15:15:54 +0800 Subject: [PATCH] [LLM] Fix Qwen causal_mask and attention_mask size mismatching (#9600) * Fix #9582 , caused by Qwen modified modeling_qwen.py https://huggingface.co/Qwen/Qwen-7B-Chat/commit/7f62181c94bfae7552652d35d380f219c44d8efd#d2h-049182 --- python/llm/src/bigdl/llm/transformers/models/qwen.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/python/llm/src/bigdl/llm/transformers/models/qwen.py b/python/llm/src/bigdl/llm/transformers/models/qwen.py index 0735b270..81feb4aa 100644 --- a/python/llm/src/bigdl/llm/transformers/models/qwen.py +++ b/python/llm/src/bigdl/llm/transformers/models/qwen.py @@ -151,7 +151,8 @@ def qwen_attention_forward( else: present = None - if self.use_logn_attn and not self.training: + key_size = key[0].size(2) if self.use_cache_quantization else key.size(1) + if key_size > self.seq_length and self.use_logn_attn and not self.training: if self.use_cache_quantization: seq_start = key[0].size(2) - query.size(1) seq_end = key[0].size(2) @@ -174,9 +175,10 @@ def qwen_attention_forward( context_layer = context_layer.flatten(2, 3).contiguous() else: - registered_causal_mask = torch.tril( - torch.ones((key.size(1), key.size(1)), dtype=torch.bool, device=key.device) - ).view(1, 1, key.size(1), key.size(1)) + if query.size(1) == key_size: + registered_causal_mask = torch.tril( + torch.ones((key_size, key_size), dtype=torch.bool, device=key.device) + ).view(1, 1, key_size, key_size) query = query.permute(0, 2, 1, 3) if not self.use_cache_quantization: key = key.permute(0, 2, 1, 3)