LLM: Update for Qwen n tokens inputs (#9931)

* update for n tokens inputs

* update style

* update
This commit is contained in:
Wang, Jian4 2024-01-18 15:56:29 +08:00 committed by GitHub
parent 5184f400f9
commit 1fc9dfa265

View file

@ -119,12 +119,13 @@ def qwen_attention_forward(
seq_end = kv_seq_len seq_end = kv_seq_len
logn_tensor = self.logn_tensor[:, seq_start:seq_end, :, :].type_as(query) logn_tensor = self.logn_tensor[:, seq_start:seq_end, :, :].type_as(query)
query = query * logn_tensor.expand_as(query) query = query * logn_tensor.expand_as(query)
if key_size == kv_seq_len:
causal_mask = torch.tril( causal_mask = torch.tril(
torch.ones((key_size, key_size), dtype=torch.bool, device=query.device) torch.ones((key_size, key_size), dtype=torch.bool, device=query.device)
).view(1, 1, key_size, key_size) ).view(1, 1, key_size, key_size)
else:
causal_mask = None causal_mask = causal_mask[
:, :, key.size(1) - query.size(1): key.size(1), :key.size(1)
]
if quantize_kv_cache(self.c_attn, hidden_states): if quantize_kv_cache(self.c_attn, hidden_states):
query, key, value = query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2) query, key, value = query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2)