From 2c8a9aaf0dfe231ba3e3824186b4ff84cd7d9301 Mon Sep 17 00:00:00 2001 From: Yishuo Wang Date: Tue, 23 Jan 2024 16:34:05 +0800 Subject: [PATCH] fix qwen causal mask when quantize_kv_cache=True (#9968) --- .../src/bigdl/llm/transformers/models/qwen.py | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/python/llm/src/bigdl/llm/transformers/models/qwen.py b/python/llm/src/bigdl/llm/transformers/models/qwen.py index cf0eafce..a080847b 100644 --- a/python/llm/src/bigdl/llm/transformers/models/qwen.py +++ b/python/llm/src/bigdl/llm/transformers/models/qwen.py @@ -135,10 +135,14 @@ def qwen_attention_forward( seq_end = kv_seq_len logn_tensor = self.logn_tensor[:, seq_start:seq_end, :, :].type_as(query) query = query * logn_tensor.expand_as(query) - if key_size == kv_seq_len: + + if query_size > 1: causal_mask = torch.tril( - torch.ones((key_size, key_size), dtype=torch.bool, device=query.device) - ).view(1, 1, key_size, key_size) + torch.ones((kv_seq_len, kv_seq_len), dtype=torch.bool, device=query.device) + ).view(1, 1, kv_seq_len, kv_seq_len) + causal_mask = causal_mask[ + :, :, kv_seq_len - query_size:kv_seq_len, :kv_seq_len + ] else: causal_mask = None @@ -222,14 +226,7 @@ def qwen_attention_forward( value = new_value_states query = query.transpose(1, 2) - # skip first init and only works for n tokens input - if causal_mask is None and query.size(2) > 1: - causal_mask = torch.tril( - torch.ones((key.size(2), key.size(2)), dtype=torch.bool, device=query.device) - ).view(1, 1, key.size(2), key.size(2)) - causal_mask = causal_mask[ - :, :, key.size(2) - query.size(2): key.size(2), :key.size(2) - ] + attn_output, attn_weight = self._attn( query.to(key.dtype), key, value, causal_mask, attention_mask, head_mask )