Update optimize qwen (#9943)
* update for n tokens input * fix dtype * update
This commit is contained in:
parent
db8e90796a
commit
bcaeb05272
1 changed files with 15 additions and 9 deletions
|
|
@ -119,13 +119,12 @@ def qwen_attention_forward(
|
|||
seq_end = kv_seq_len
|
||||
logn_tensor = self.logn_tensor[:, seq_start:seq_end, :, :].type_as(query)
|
||||
query = query * logn_tensor.expand_as(query)
|
||||
causal_mask = torch.tril(
|
||||
torch.ones((key_size, key_size), dtype=torch.bool, device=query.device)
|
||||
).view(1, 1, key_size, key_size)
|
||||
|
||||
causal_mask = causal_mask[
|
||||
:, :, key.size(1) - query.size(1): key.size(1), :key.size(1)
|
||||
]
|
||||
if key_size == kv_seq_len:
|
||||
causal_mask = torch.tril(
|
||||
torch.ones((key_size, key_size), dtype=torch.bool, device=query.device)
|
||||
).view(1, 1, key_size, key_size)
|
||||
else:
|
||||
causal_mask = None
|
||||
|
||||
if quantize_kv_cache(self.c_attn, hidden_states):
|
||||
query, key, value = query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2)
|
||||
|
|
@ -207,9 +206,16 @@ def qwen_attention_forward(
|
|||
value = new_value_states
|
||||
|
||||
query = query.transpose(1, 2)
|
||||
|
||||
# skip first init and only works for n tokens input
|
||||
if causal_mask is None and query.size(2) > 1:
|
||||
causal_mask = torch.tril(
|
||||
torch.ones((key.size(2), key.size(2)), dtype=torch.bool, device=query.device)
|
||||
).view(1, 1, key.size(2), key.size(2))
|
||||
causal_mask = causal_mask[
|
||||
:, :, key.size(2) - query.size(2): key.size(2), :key.size(2)
|
||||
]
|
||||
attn_output, attn_weight = self._attn(
|
||||
query, key, value, causal_mask, attention_mask, head_mask
|
||||
query.to(key.dtype), key, value, causal_mask, attention_mask, head_mask
|
||||
)
|
||||
|
||||
context_layer = self._merge_heads(
|
||||
|
|
|
|||
Loading…
Reference in a new issue