Fix Arc StarCoder wrong query_shape when input is long (#10268)
* Fix Arc StarCoder wrong query_shape when input is long * Update gptbigcode.py
This commit is contained in:
parent
a4de3095f3
commit
7244fd1ba5
1 changed files with 15 additions and 28 deletions
|
|
@ -42,13 +42,7 @@ def gptbigcode_attention_forward(
|
||||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||||
use_cache: Optional[bool] = False,
|
use_cache: Optional[bool] = False,
|
||||||
output_attentions: Optional[bool] = False,
|
output_attentions: Optional[bool] = False):
|
||||||
**kwargs):
|
|
||||||
if "padding_mask" in kwargs:
|
|
||||||
logger.warning_once(
|
|
||||||
"Passing `padding_mask` is deprecated and will be removed in v4.37." +
|
|
||||||
"Please make sure use `attention_mask` instead.`"
|
|
||||||
)
|
|
||||||
|
|
||||||
if encoder_hidden_states is not None:
|
if encoder_hidden_states is not None:
|
||||||
if not hasattr(self, "q_attn") or not self.is_cross_attention:
|
if not hasattr(self, "q_attn") or not self.is_cross_attention:
|
||||||
|
|
@ -60,6 +54,7 @@ def gptbigcode_attention_forward(
|
||||||
"Please make sure to instantiate class with " +
|
"Please make sure to instantiate class with " +
|
||||||
"`GPTBigCodeAttention(..., is_cross_attention=True)`."
|
"`GPTBigCodeAttention(..., is_cross_attention=True)`."
|
||||||
)
|
)
|
||||||
|
|
||||||
query = self.q_attn(hidden_states)
|
query = self.q_attn(hidden_states)
|
||||||
key_value = self.c_attn(encoder_hidden_states)
|
key_value = self.c_attn(encoder_hidden_states)
|
||||||
attention_mask = encoder_attention_mask
|
attention_mask = encoder_attention_mask
|
||||||
|
|
@ -67,10 +62,6 @@ def gptbigcode_attention_forward(
|
||||||
query, key_value = self.c_attn(hidden_states).split(
|
query, key_value = self.c_attn(hidden_states).split(
|
||||||
(self.embed_dim, 2 * self.kv_dim), dim=2)
|
(self.embed_dim, 2 * self.kv_dim), dim=2)
|
||||||
else:
|
else:
|
||||||
# Note: We split as (self.num_heads, 3, self.head_dim)
|
|
||||||
# instead of (3, self.num_heads, self.head_dim),
|
|
||||||
# i.e., the memory layout is not the same as GPT2.
|
|
||||||
# This makes the concatenation with past_key_value more efficient.
|
|
||||||
query, key_value = (
|
query, key_value = (
|
||||||
self.c_attn(hidden_states)
|
self.c_attn(hidden_states)
|
||||||
.view(*hidden_states.shape[:2], self.num_heads, 3 * self.head_dim)
|
.view(*hidden_states.shape[:2], self.num_heads, 3 * self.head_dim)
|
||||||
|
|
@ -78,27 +69,23 @@ def gptbigcode_attention_forward(
|
||||||
.split((self.head_dim, 2 * self.head_dim), dim=3)
|
.split((self.head_dim, 2 * self.head_dim), dim=3)
|
||||||
)
|
)
|
||||||
|
|
||||||
if layer_past is not None:
|
if layer_past is not None:
|
||||||
if layer_past.shape[-2] == key_value.shape[-2]:
|
if layer_past.shape[-2] == key_value.shape[-2]:
|
||||||
key_value = torch.cat((layer_past, key_value), dim=-2)
|
key_value = torch.cat((layer_past, key_value), dim=-2)
|
||||||
else:
|
else:
|
||||||
fill_zeros = torch.zeros(layer_past.shape[0],
|
fill_zeros = torch.zeros(layer_past.shape[0],
|
||||||
layer_past.shape[1],
|
layer_past.shape[1],
|
||||||
key_value.shape[2] - layer_past.shape[2],
|
key_value.shape[2] - layer_past.shape[2],
|
||||||
dtype=layer_past.dtype,
|
dtype=layer_past.dtype,
|
||||||
device=layer_past.device)
|
device=layer_past.device)
|
||||||
layer_past = torch.cat([layer_past, fill_zeros], dim=-1)
|
layer_past = torch.cat([layer_past, fill_zeros], dim=-1)
|
||||||
key_value = torch.cat((layer_past, key_value), dim=-2)
|
key_value = torch.cat((layer_past, key_value), dim=-2)
|
||||||
|
|
||||||
present = key_value if use_cache else None
|
present = key_value if use_cache else None
|
||||||
|
|
||||||
key, value = key_value.split((self.head_dim, self.head_dim), dim=-1)
|
key, value = key_value.split((self.head_dim, self.head_dim), dim=-1)
|
||||||
|
|
||||||
attn_output, attn_weights = self._attn(query,
|
attn_output, attn_weights = self._attn(query, key.transpose(-1, -2),
|
||||||
key.transpose(-1, -2),
|
value, attention_mask, head_mask)
|
||||||
value,
|
|
||||||
attention_mask,
|
|
||||||
head_mask)
|
|
||||||
|
|
||||||
if not self.multi_query:
|
if not self.multi_query:
|
||||||
attn_output = attn_output.transpose(1, 2).reshape(hidden_states.shape)
|
attn_output = attn_output.transpose(1, 2).reshape(hidden_states.shape)
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue