[LLM] Fix Qwen causal_mask and attention_mask size mismatching (#9600)

* Fix #9582 , caused by Qwen modified modeling_qwen.py 7f62181c94 (d2h-049182)
This commit is contained in:
Ziteng Zhang 2023-12-05 15:15:54 +08:00 committed by GitHub
parent b721138132
commit 65934c9f4f

View file

@ -151,7 +151,8 @@ def qwen_attention_forward(
else:
present = None
if self.use_logn_attn and not self.training:
key_size = key[0].size(2) if self.use_cache_quantization else key.size(1)
if key_size > self.seq_length and self.use_logn_attn and not self.training:
if self.use_cache_quantization:
seq_start = key[0].size(2) - query.size(1)
seq_end = key[0].size(2)
@ -174,9 +175,10 @@ def qwen_attention_forward(
context_layer = context_layer.flatten(2, 3).contiguous()
else:
if query.size(1) == key_size:
registered_causal_mask = torch.tril(
torch.ones((key.size(1), key.size(1)), dtype=torch.bool, device=key.device)
).view(1, 1, key.size(1), key.size(1))
torch.ones((key_size, key_size), dtype=torch.bool, device=key.device)
).view(1, 1, key_size, key_size)
query = query.permute(0, 2, 1, 3)
if not self.use_cache_quantization:
key = key.permute(0, 2, 1, 3)