From 0f0c6bb6310cdafead2d6e054a4087715c5c9fd6 Mon Sep 17 00:00:00 2001 From: Qiyuan Gong Date: Thu, 23 Nov 2023 09:28:04 +0800 Subject: [PATCH] [LLM] Fix Qwen registered_causal_mask is None (#9513) * Add registered_causal_mask init based on https://huggingface.co/Qwen/Qwen-7B-Chat/commit/2abd8e5777bb4ce9c8ab4be7dbbd0fe4526db78d. --- python/llm/src/bigdl/llm/transformers/models/qwen.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/python/llm/src/bigdl/llm/transformers/models/qwen.py b/python/llm/src/bigdl/llm/transformers/models/qwen.py index 3a50f9c6..0735b270 100644 --- a/python/llm/src/bigdl/llm/transformers/models/qwen.py +++ b/python/llm/src/bigdl/llm/transformers/models/qwen.py @@ -174,6 +174,9 @@ def qwen_attention_forward( context_layer = context_layer.flatten(2, 3).contiguous() else: + registered_causal_mask = torch.tril( + torch.ones((key.size(1), key.size(1)), dtype=torch.bool, device=key.device) + ).view(1, 1, key.size(1), key.size(1)) query = query.permute(0, 2, 1, 3) if not self.use_cache_quantization: key = key.permute(0, 2, 1, 3)