simplify qwen attention (#9747)
This commit is contained in:
parent
984697afe2
commit
426660b88e
1 changed files with 28 additions and 69 deletions
|
|
@ -74,6 +74,9 @@ def qwen_attention_forward(
|
||||||
output_attentions: Optional[bool] = False,
|
output_attentions: Optional[bool] = False,
|
||||||
use_cache: Optional[bool] = False,
|
use_cache: Optional[bool] = False,
|
||||||
):
|
):
|
||||||
|
invalidInputError(not self.use_flash_attn and not self.use_cache_quantization,
|
||||||
|
"flash attn and kv_cache quantization are not supported")
|
||||||
|
|
||||||
mixed_x_layer = self.c_attn(hidden_states)
|
mixed_x_layer = self.c_attn(hidden_states)
|
||||||
query, key, value = mixed_x_layer.split(self.split_size, dim=2)
|
query, key, value = mixed_x_layer.split(self.split_size, dim=2)
|
||||||
|
|
||||||
|
|
@ -119,12 +122,10 @@ def qwen_attention_forward(
|
||||||
bsz, _, n_heads, head_dim = key.size()
|
bsz, _, n_heads, head_dim = key.size()
|
||||||
|
|
||||||
if layer_past is not None:
|
if layer_past is not None:
|
||||||
kv_seq_len += layer_past[0].shape[1]
|
cache_k, cache_v = layer_past[0], layer_past[1]
|
||||||
# past_key, past_value = layer_past[0], layer_past[1]
|
cache_k = cache_k.transpose(1, 2)
|
||||||
# key = torch.cat((past_key, key), dim=1)
|
cache_v = cache_v.transpose(1, 2)
|
||||||
# value = torch.cat((past_value, value), dim=1)
|
kv_seq_len += cache_k.shape[2]
|
||||||
cache_k = layer_past[0].transpose(1, 2)
|
|
||||||
cache_v = layer_past[1].transpose(1, 2)
|
|
||||||
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
|
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
|
||||||
# allocate new
|
# allocate new
|
||||||
new_cache_k, new_cache_v = extend_kv_cache(bsz,
|
new_cache_k, new_cache_v = extend_kv_cache(bsz,
|
||||||
|
|
@ -141,8 +142,8 @@ def qwen_attention_forward(
|
||||||
|
|
||||||
key_states, value_states = append_kv_cache(cache_k, cache_v,
|
key_states, value_states = append_kv_cache(cache_k, cache_v,
|
||||||
key.transpose(1, 2), value.transpose(1, 2))
|
key.transpose(1, 2), value.transpose(1, 2))
|
||||||
key = key_states.transpose(1, 2)
|
key = key_states
|
||||||
value = value_states.transpose(1, 2)
|
value = value_states
|
||||||
elif use_cache:
|
elif use_cache:
|
||||||
max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH
|
max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH
|
||||||
new_key_states, new_value_states = init_kv_cache(bsz,
|
new_key_states, new_value_states = init_kv_cache(bsz,
|
||||||
|
|
@ -154,80 +155,38 @@ def qwen_attention_forward(
|
||||||
device=hidden_states.device)
|
device=hidden_states.device)
|
||||||
new_key_states[:] = key.transpose(1, 2)
|
new_key_states[:] = key.transpose(1, 2)
|
||||||
new_value_states[:] = value.transpose(1, 2)
|
new_value_states[:] = value.transpose(1, 2)
|
||||||
key = new_key_states.transpose(1, 2)
|
key = new_key_states
|
||||||
value = new_value_states.transpose(1, 2)
|
value = new_value_states
|
||||||
|
|
||||||
if use_cache:
|
query_size, key_size = query.size(1), key.size(2)
|
||||||
present = (key, value)
|
|
||||||
else:
|
|
||||||
present = None
|
|
||||||
|
|
||||||
key_size = key[0].size(2) if self.use_cache_quantization else key.size(1)
|
|
||||||
if key_size > self.seq_length and self.use_logn_attn and not self.training:
|
if key_size > self.seq_length and self.use_logn_attn and not self.training:
|
||||||
if self.use_cache_quantization:
|
seq_start = key_size - query_size
|
||||||
seq_start = key[0].size(2) - query.size(1)
|
seq_end = key_size
|
||||||
seq_end = key[0].size(2)
|
|
||||||
else:
|
|
||||||
seq_start = key.size(1) - query.size(1)
|
|
||||||
seq_end = key.size(1)
|
|
||||||
logn_tensor = self.logn_tensor[:, seq_start:seq_end, :, :].type_as(query)
|
logn_tensor = self.logn_tensor[:, seq_start:seq_end, :, :].type_as(query)
|
||||||
query = query * logn_tensor.expand_as(query)
|
query = query * logn_tensor.expand_as(query)
|
||||||
|
if query_size == key_size:
|
||||||
if (
|
causal_mask = torch.tril(
|
||||||
self.use_flash_attn
|
torch.ones((key_size, key_size), dtype=torch.bool, device=query.device)
|
||||||
and flash_attn_unpadded_func is not None
|
).view(1, 1, key_size, key_size)
|
||||||
and not self.is_fp32
|
|
||||||
and query.is_cuda
|
|
||||||
):
|
|
||||||
q, k, v = query, key, value
|
|
||||||
attn_output = self.core_attention_flash(q, k, v, attention_mask=attention_mask)
|
|
||||||
else:
|
else:
|
||||||
key_size = key[0].size(2) if self.use_cache_quantization else key.size(1)
|
causal_mask = None
|
||||||
if query.size(1) == key_size:
|
query = query.transpose(1, 2)
|
||||||
causal_mask = torch.tril(
|
|
||||||
torch.ones((key_size, key_size), dtype=torch.bool, device=query.device)
|
|
||||||
).view(1, 1, key_size, key_size)
|
|
||||||
else:
|
|
||||||
causal_mask = None
|
|
||||||
query = query.permute(0, 2, 1, 3)
|
|
||||||
if not self.use_cache_quantization:
|
|
||||||
key = key.permute(0, 2, 1, 3)
|
|
||||||
value = value.permute(0, 2, 1, 3)
|
|
||||||
if (
|
|
||||||
causal_mask is None
|
|
||||||
and self.use_flash_attn
|
|
||||||
and flash_attn_unpadded_func is not None
|
|
||||||
and not self.is_fp32
|
|
||||||
and not query.is_cuda
|
|
||||||
):
|
|
||||||
invalidOperationError(False,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
Exception(_ERROR_INPUT_CPU_QUERY_WITH_FLASH_ATTN_ACTIVATED))
|
|
||||||
|
|
||||||
attn_output, attn_weight = self._attn(
|
attn_output, attn_weight = self._attn(
|
||||||
query, key, value, causal_mask, attention_mask, head_mask
|
query, key, value, causal_mask, attention_mask, head_mask
|
||||||
)
|
)
|
||||||
context_layer = self._merge_heads(
|
context_layer = self._merge_heads(
|
||||||
attn_output, self.num_heads, self.head_dim
|
attn_output, self.num_heads, self.head_dim
|
||||||
)
|
)
|
||||||
|
|
||||||
attn_output = self.c_proj(context_layer)
|
attn_output = self.c_proj(context_layer)
|
||||||
|
|
||||||
outputs = (attn_output, present)
|
if use_cache:
|
||||||
|
outputs = (attn_output, (key.transpose(1, 2), value.transpose(1, 2)))
|
||||||
|
else:
|
||||||
|
outputs = (attn_output, None)
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
if (
|
outputs += (attn_weight,)
|
||||||
self.use_flash_attn
|
|
||||||
and flash_attn_unpadded_func is not None
|
|
||||||
and not self.is_fp32
|
|
||||||
):
|
|
||||||
invalidInputError(False,
|
|
||||||
f"Cannot output attentions while using flash-attn")
|
|
||||||
elif not self.use_cache_quantization and SUPPORT_TORCH2:
|
|
||||||
invalidInputError(False,
|
|
||||||
f"Cannot output attentions while using scaled_dot_product_attention")
|
|
||||||
else:
|
|
||||||
outputs += (attn_weight,)
|
|
||||||
|
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue