[LLM] Fix Qwen registered_causal_mask is None (#9513)

* Add registered_causal_mask init based on 2abd8e5777.
This commit is contained in:
Qiyuan Gong 2023-11-23 09:28:04 +08:00 committed by GitHub
parent 11fa5a8a0e
commit 0f0c6bb631

View file

@ -174,6 +174,9 @@ def qwen_attention_forward(
context_layer = context_layer.flatten(2, 3).contiguous() context_layer = context_layer.flatten(2, 3).contiguous()
else: else:
registered_causal_mask = torch.tril(
torch.ones((key.size(1), key.size(1)), dtype=torch.bool, device=key.device)
).view(1, 1, key.size(1), key.size(1))
query = query.permute(0, 2, 1, 3) query = query.permute(0, 2, 1, 3)
if not self.use_cache_quantization: if not self.use_cache_quantization:
key = key.permute(0, 2, 1, 3) key = key.permute(0, 2, 1, 3)