fix qwen causal mask when quantize_kv_cache=True (#9968)
This commit is contained in:
parent
5aa4b32c1b
commit
2c8a9aaf0d
1 changed files with 8 additions and 11 deletions
|
|
@ -135,10 +135,14 @@ def qwen_attention_forward(
|
||||||
seq_end = kv_seq_len
|
seq_end = kv_seq_len
|
||||||
logn_tensor = self.logn_tensor[:, seq_start:seq_end, :, :].type_as(query)
|
logn_tensor = self.logn_tensor[:, seq_start:seq_end, :, :].type_as(query)
|
||||||
query = query * logn_tensor.expand_as(query)
|
query = query * logn_tensor.expand_as(query)
|
||||||
if key_size == kv_seq_len:
|
|
||||||
|
if query_size > 1:
|
||||||
causal_mask = torch.tril(
|
causal_mask = torch.tril(
|
||||||
torch.ones((key_size, key_size), dtype=torch.bool, device=query.device)
|
torch.ones((kv_seq_len, kv_seq_len), dtype=torch.bool, device=query.device)
|
||||||
).view(1, 1, key_size, key_size)
|
).view(1, 1, kv_seq_len, kv_seq_len)
|
||||||
|
causal_mask = causal_mask[
|
||||||
|
:, :, kv_seq_len - query_size:kv_seq_len, :kv_seq_len
|
||||||
|
]
|
||||||
else:
|
else:
|
||||||
causal_mask = None
|
causal_mask = None
|
||||||
|
|
||||||
|
|
@ -222,14 +226,7 @@ def qwen_attention_forward(
|
||||||
value = new_value_states
|
value = new_value_states
|
||||||
|
|
||||||
query = query.transpose(1, 2)
|
query = query.transpose(1, 2)
|
||||||
# skip first init and only works for n tokens input
|
|
||||||
if causal_mask is None and query.size(2) > 1:
|
|
||||||
causal_mask = torch.tril(
|
|
||||||
torch.ones((key.size(2), key.size(2)), dtype=torch.bool, device=query.device)
|
|
||||||
).view(1, 1, key.size(2), key.size(2))
|
|
||||||
causal_mask = causal_mask[
|
|
||||||
:, :, key.size(2) - query.size(2): key.size(2), :key.size(2)
|
|
||||||
]
|
|
||||||
attn_output, attn_weight = self._attn(
|
attn_output, attn_weight = self._attn(
|
||||||
query.to(key.dtype), key, value, causal_mask, attention_mask, head_mask
|
query.to(key.dtype), key, value, causal_mask, attention_mask, head_mask
|
||||||
)
|
)
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue