Fix Arc StarCoder wrong query_shape when input is long (#10268)

* Fix Arc StarCoder wrong query_shape when input is long

* Update gptbigcode.py
This commit is contained in:
Heyang Sun 2024-02-28 17:07:08 +08:00 committed by GitHub
parent a4de3095f3
commit 7244fd1ba5

View file

@ -42,13 +42,7 @@ def gptbigcode_attention_forward(
encoder_hidden_states: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = False, use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False):
**kwargs):
if "padding_mask" in kwargs:
logger.warning_once(
"Passing `padding_mask` is deprecated and will be removed in v4.37." +
"Please make sure use `attention_mask` instead.`"
)
if encoder_hidden_states is not None: if encoder_hidden_states is not None:
if not hasattr(self, "q_attn") or not self.is_cross_attention: if not hasattr(self, "q_attn") or not self.is_cross_attention:
@ -60,6 +54,7 @@ def gptbigcode_attention_forward(
"Please make sure to instantiate class with " + "Please make sure to instantiate class with " +
"`GPTBigCodeAttention(..., is_cross_attention=True)`." "`GPTBigCodeAttention(..., is_cross_attention=True)`."
) )
query = self.q_attn(hidden_states) query = self.q_attn(hidden_states)
key_value = self.c_attn(encoder_hidden_states) key_value = self.c_attn(encoder_hidden_states)
attention_mask = encoder_attention_mask attention_mask = encoder_attention_mask
@ -67,10 +62,6 @@ def gptbigcode_attention_forward(
query, key_value = self.c_attn(hidden_states).split( query, key_value = self.c_attn(hidden_states).split(
(self.embed_dim, 2 * self.kv_dim), dim=2) (self.embed_dim, 2 * self.kv_dim), dim=2)
else: else:
# Note: We split as (self.num_heads, 3, self.head_dim)
# instead of (3, self.num_heads, self.head_dim),
# i.e., the memory layout is not the same as GPT2.
# This makes the concatenation with past_key_value more efficient.
query, key_value = ( query, key_value = (
self.c_attn(hidden_states) self.c_attn(hidden_states)
.view(*hidden_states.shape[:2], self.num_heads, 3 * self.head_dim) .view(*hidden_states.shape[:2], self.num_heads, 3 * self.head_dim)
@ -89,16 +80,12 @@ def gptbigcode_attention_forward(
device=layer_past.device) device=layer_past.device)
layer_past = torch.cat([layer_past, fill_zeros], dim=-1) layer_past = torch.cat([layer_past, fill_zeros], dim=-1)
key_value = torch.cat((layer_past, key_value), dim=-2) key_value = torch.cat((layer_past, key_value), dim=-2)
present = key_value if use_cache else None present = key_value if use_cache else None
key, value = key_value.split((self.head_dim, self.head_dim), dim=-1) key, value = key_value.split((self.head_dim, self.head_dim), dim=-1)
attn_output, attn_weights = self._attn(query, attn_output, attn_weights = self._attn(query, key.transpose(-1, -2),
key.transpose(-1, -2), value, attention_mask, head_mask)
value,
attention_mask,
head_mask)
if not self.multi_query: if not self.multi_query:
attn_output = attn_output.transpose(1, 2).reshape(hidden_states.shape) attn_output = attn_output.transpose(1, 2).reshape(hidden_states.shape)