simplify qwen attention (#9747)

This commit is contained in:
Yishuo Wang 2023-12-21 17:53:29 +08:00 committed by GitHub
parent 984697afe2
commit 426660b88e

View file

@ -74,6 +74,9 @@ def qwen_attention_forward(
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False, use_cache: Optional[bool] = False,
): ):
invalidInputError(not self.use_flash_attn and not self.use_cache_quantization,
"flash attn and kv_cache quantization are not supported")
mixed_x_layer = self.c_attn(hidden_states) mixed_x_layer = self.c_attn(hidden_states)
query, key, value = mixed_x_layer.split(self.split_size, dim=2) query, key, value = mixed_x_layer.split(self.split_size, dim=2)
@ -119,12 +122,10 @@ def qwen_attention_forward(
bsz, _, n_heads, head_dim = key.size() bsz, _, n_heads, head_dim = key.size()
if layer_past is not None: if layer_past is not None:
kv_seq_len += layer_past[0].shape[1] cache_k, cache_v = layer_past[0], layer_past[1]
# past_key, past_value = layer_past[0], layer_past[1] cache_k = cache_k.transpose(1, 2)
# key = torch.cat((past_key, key), dim=1) cache_v = cache_v.transpose(1, 2)
# value = torch.cat((past_value, value), dim=1) kv_seq_len += cache_k.shape[2]
cache_k = layer_past[0].transpose(1, 2)
cache_v = layer_past[1].transpose(1, 2)
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3): if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
# allocate new # allocate new
new_cache_k, new_cache_v = extend_kv_cache(bsz, new_cache_k, new_cache_v = extend_kv_cache(bsz,
@ -141,8 +142,8 @@ def qwen_attention_forward(
key_states, value_states = append_kv_cache(cache_k, cache_v, key_states, value_states = append_kv_cache(cache_k, cache_v,
key.transpose(1, 2), value.transpose(1, 2)) key.transpose(1, 2), value.transpose(1, 2))
key = key_states.transpose(1, 2) key = key_states
value = value_states.transpose(1, 2) value = value_states
elif use_cache: elif use_cache:
max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH
new_key_states, new_value_states = init_kv_cache(bsz, new_key_states, new_value_states = init_kv_cache(bsz,
@ -154,80 +155,38 @@ def qwen_attention_forward(
device=hidden_states.device) device=hidden_states.device)
new_key_states[:] = key.transpose(1, 2) new_key_states[:] = key.transpose(1, 2)
new_value_states[:] = value.transpose(1, 2) new_value_states[:] = value.transpose(1, 2)
key = new_key_states.transpose(1, 2) key = new_key_states
value = new_value_states.transpose(1, 2) value = new_value_states
if use_cache: query_size, key_size = query.size(1), key.size(2)
present = (key, value)
else:
present = None
key_size = key[0].size(2) if self.use_cache_quantization else key.size(1)
if key_size > self.seq_length and self.use_logn_attn and not self.training: if key_size > self.seq_length and self.use_logn_attn and not self.training:
if self.use_cache_quantization: seq_start = key_size - query_size
seq_start = key[0].size(2) - query.size(1) seq_end = key_size
seq_end = key[0].size(2)
else:
seq_start = key.size(1) - query.size(1)
seq_end = key.size(1)
logn_tensor = self.logn_tensor[:, seq_start:seq_end, :, :].type_as(query) logn_tensor = self.logn_tensor[:, seq_start:seq_end, :, :].type_as(query)
query = query * logn_tensor.expand_as(query) query = query * logn_tensor.expand_as(query)
if query_size == key_size:
if ( causal_mask = torch.tril(
self.use_flash_attn torch.ones((key_size, key_size), dtype=torch.bool, device=query.device)
and flash_attn_unpadded_func is not None ).view(1, 1, key_size, key_size)
and not self.is_fp32
and query.is_cuda
):
q, k, v = query, key, value
attn_output = self.core_attention_flash(q, k, v, attention_mask=attention_mask)
else: else:
key_size = key[0].size(2) if self.use_cache_quantization else key.size(1) causal_mask = None
if query.size(1) == key_size: query = query.transpose(1, 2)
causal_mask = torch.tril(
torch.ones((key_size, key_size), dtype=torch.bool, device=query.device)
).view(1, 1, key_size, key_size)
else:
causal_mask = None
query = query.permute(0, 2, 1, 3)
if not self.use_cache_quantization:
key = key.permute(0, 2, 1, 3)
value = value.permute(0, 2, 1, 3)
if (
causal_mask is None
and self.use_flash_attn
and flash_attn_unpadded_func is not None
and not self.is_fp32
and not query.is_cuda
):
invalidOperationError(False,
None,
None,
Exception(_ERROR_INPUT_CPU_QUERY_WITH_FLASH_ATTN_ACTIVATED))
attn_output, attn_weight = self._attn( attn_output, attn_weight = self._attn(
query, key, value, causal_mask, attention_mask, head_mask query, key, value, causal_mask, attention_mask, head_mask
) )
context_layer = self._merge_heads( context_layer = self._merge_heads(
attn_output, self.num_heads, self.head_dim attn_output, self.num_heads, self.head_dim
) )
attn_output = self.c_proj(context_layer) attn_output = self.c_proj(context_layer)
outputs = (attn_output, present) if use_cache:
outputs = (attn_output, (key.transpose(1, 2), value.transpose(1, 2)))
else:
outputs = (attn_output, None)
if output_attentions: if output_attentions:
if ( outputs += (attn_weight,)
self.use_flash_attn
and flash_attn_unpadded_func is not None
and not self.is_fp32
):
invalidInputError(False,
f"Cannot output attentions while using flash-attn")
elif not self.use_cache_quantization and SUPPORT_TORCH2:
invalidInputError(False,
f"Cannot output attentions while using scaled_dot_product_attention")
else:
outputs += (attn_weight,)
return outputs return outputs