diff --git a/python/llm/src/bigdl/llm/transformers/models/qwen.py b/python/llm/src/bigdl/llm/transformers/models/qwen.py index 3a50f9c6..0735b270 100644 --- a/python/llm/src/bigdl/llm/transformers/models/qwen.py +++ b/python/llm/src/bigdl/llm/transformers/models/qwen.py @@ -174,6 +174,9 @@ def qwen_attention_forward( context_layer = context_layer.flatten(2, 3).contiguous() else: + registered_causal_mask = torch.tril( + torch.ones((key.size(1), key.size(1)), dtype=torch.bool, device=key.device) + ).view(1, 1, key.size(1), key.size(1)) query = query.permute(0, 2, 1, 3) if not self.use_cache_quantization: key = key.permute(0, 2, 1, 3)