From bcaeb05272b2f4e2ce0c4ad5178adc8832198fbc Mon Sep 17 00:00:00 2001 From: "Wang, Jian4" <61138589+hzjane@users.noreply.github.com> Date: Fri, 19 Jan 2024 16:54:59 +0800 Subject: [PATCH] Update optimize qwen (#9943) * update for n tokens input * fix dtype * update --- .../src/bigdl/llm/transformers/models/qwen.py | 24 ++++++++++++------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/python/llm/src/bigdl/llm/transformers/models/qwen.py b/python/llm/src/bigdl/llm/transformers/models/qwen.py index 0517e9d2..4de49ae3 100644 --- a/python/llm/src/bigdl/llm/transformers/models/qwen.py +++ b/python/llm/src/bigdl/llm/transformers/models/qwen.py @@ -119,13 +119,12 @@ def qwen_attention_forward( seq_end = kv_seq_len logn_tensor = self.logn_tensor[:, seq_start:seq_end, :, :].type_as(query) query = query * logn_tensor.expand_as(query) - causal_mask = torch.tril( - torch.ones((key_size, key_size), dtype=torch.bool, device=query.device) - ).view(1, 1, key_size, key_size) - - causal_mask = causal_mask[ - :, :, key.size(1) - query.size(1): key.size(1), :key.size(1) - ] + if key_size == kv_seq_len: + causal_mask = torch.tril( + torch.ones((key_size, key_size), dtype=torch.bool, device=query.device) + ).view(1, 1, key_size, key_size) + else: + causal_mask = None if quantize_kv_cache(self.c_attn, hidden_states): query, key, value = query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2) @@ -207,9 +206,16 @@ def qwen_attention_forward( value = new_value_states query = query.transpose(1, 2) - + # skip first init and only works for n tokens input + if causal_mask is None and query.size(2) > 1: + causal_mask = torch.tril( + torch.ones((key.size(2), key.size(2)), dtype=torch.bool, device=query.device) + ).view(1, 1, key.size(2), key.size(2)) + causal_mask = causal_mask[ + :, :, key.size(2) - query.size(2): key.size(2), :key.size(2) + ] attn_output, attn_weight = self._attn( - query, key, value, causal_mask, attention_mask, head_mask + query.to(key.dtype), key, value, causal_mask, attention_mask, head_mask ) context_layer = self._merge_heads(