LLM: Update for Qwen n tokens inputs (#9931)
* update for n tokens inputs * update style * update
This commit is contained in:
parent
5184f400f9
commit
1fc9dfa265
1 changed files with 7 additions and 6 deletions
|
|
@ -119,12 +119,13 @@ def qwen_attention_forward(
|
||||||
seq_end = kv_seq_len
|
seq_end = kv_seq_len
|
||||||
logn_tensor = self.logn_tensor[:, seq_start:seq_end, :, :].type_as(query)
|
logn_tensor = self.logn_tensor[:, seq_start:seq_end, :, :].type_as(query)
|
||||||
query = query * logn_tensor.expand_as(query)
|
query = query * logn_tensor.expand_as(query)
|
||||||
if key_size == kv_seq_len:
|
causal_mask = torch.tril(
|
||||||
causal_mask = torch.tril(
|
torch.ones((key_size, key_size), dtype=torch.bool, device=query.device)
|
||||||
torch.ones((key_size, key_size), dtype=torch.bool, device=query.device)
|
).view(1, 1, key_size, key_size)
|
||||||
).view(1, 1, key_size, key_size)
|
|
||||||
else:
|
causal_mask = causal_mask[
|
||||||
causal_mask = None
|
:, :, key.size(1) - query.size(1): key.size(1), :key.size(1)
|
||||||
|
]
|
||||||
|
|
||||||
if quantize_kv_cache(self.c_attn, hidden_states):
|
if quantize_kv_cache(self.c_attn, hidden_states):
|
||||||
query, key, value = query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2)
|
query, key, value = query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2)
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue