From 1fc9dfa265352143a8754379a977d39428f6293f Mon Sep 17 00:00:00 2001 From: "Wang, Jian4" <61138589+hzjane@users.noreply.github.com> Date: Thu, 18 Jan 2024 15:56:29 +0800 Subject: [PATCH] LLM: Update for Qwen n tokens inputs (#9931) * update for n tokens inputs * update style * update --- .../llm/src/bigdl/llm/transformers/models/qwen.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/python/llm/src/bigdl/llm/transformers/models/qwen.py b/python/llm/src/bigdl/llm/transformers/models/qwen.py index e0198c68..0517e9d2 100644 --- a/python/llm/src/bigdl/llm/transformers/models/qwen.py +++ b/python/llm/src/bigdl/llm/transformers/models/qwen.py @@ -119,12 +119,13 @@ def qwen_attention_forward( seq_end = kv_seq_len logn_tensor = self.logn_tensor[:, seq_start:seq_end, :, :].type_as(query) query = query * logn_tensor.expand_as(query) - if key_size == kv_seq_len: - causal_mask = torch.tril( - torch.ones((key_size, key_size), dtype=torch.bool, device=query.device) - ).view(1, 1, key_size, key_size) - else: - causal_mask = None + causal_mask = torch.tril( + torch.ones((key_size, key_size), dtype=torch.bool, device=query.device) + ).view(1, 1, key_size, key_size) + + causal_mask = causal_mask[ + :, :, key.size(1) - query.size(1): key.size(1), :key.size(1) + ] if quantize_kv_cache(self.c_attn, hidden_states): query, key, value = query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2)