From 7244fd1ba5cc1ae7fb29df2f09117fc4bb6f477a Mon Sep 17 00:00:00 2001 From: Heyang Sun <60865256+Uxito-Ada@users.noreply.github.com> Date: Wed, 28 Feb 2024 17:07:08 +0800 Subject: [PATCH] Fix Arc StarCoder wrong query_shape when input is long (#10268) * Fix Arc StarCoder wrong query_shape when input is long * Update gptbigcode.py --- .../llm/transformers/models/gptbigcode.py | 43 +++++++------------ 1 file changed, 15 insertions(+), 28 deletions(-) diff --git a/python/llm/src/bigdl/llm/transformers/models/gptbigcode.py b/python/llm/src/bigdl/llm/transformers/models/gptbigcode.py index 6f9895b1..8a38f22e 100644 --- a/python/llm/src/bigdl/llm/transformers/models/gptbigcode.py +++ b/python/llm/src/bigdl/llm/transformers/models/gptbigcode.py @@ -42,13 +42,7 @@ def gptbigcode_attention_forward( encoder_hidden_states: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, use_cache: Optional[bool] = False, - output_attentions: Optional[bool] = False, - **kwargs): - if "padding_mask" in kwargs: - logger.warning_once( - "Passing `padding_mask` is deprecated and will be removed in v4.37." + - "Please make sure use `attention_mask` instead.`" - ) + output_attentions: Optional[bool] = False): if encoder_hidden_states is not None: if not hasattr(self, "q_attn") or not self.is_cross_attention: @@ -60,6 +54,7 @@ def gptbigcode_attention_forward( "Please make sure to instantiate class with " + "`GPTBigCodeAttention(..., is_cross_attention=True)`." ) + query = self.q_attn(hidden_states) key_value = self.c_attn(encoder_hidden_states) attention_mask = encoder_attention_mask @@ -67,10 +62,6 @@ def gptbigcode_attention_forward( query, key_value = self.c_attn(hidden_states).split( (self.embed_dim, 2 * self.kv_dim), dim=2) else: - # Note: We split as (self.num_heads, 3, self.head_dim) - # instead of (3, self.num_heads, self.head_dim), - # i.e., the memory layout is not the same as GPT2. - # This makes the concatenation with past_key_value more efficient. query, key_value = ( self.c_attn(hidden_states) .view(*hidden_states.shape[:2], self.num_heads, 3 * self.head_dim) @@ -78,27 +69,23 @@ def gptbigcode_attention_forward( .split((self.head_dim, 2 * self.head_dim), dim=3) ) - if layer_past is not None: - if layer_past.shape[-2] == key_value.shape[-2]: - key_value = torch.cat((layer_past, key_value), dim=-2) - else: - fill_zeros = torch.zeros(layer_past.shape[0], - layer_past.shape[1], - key_value.shape[2] - layer_past.shape[2], - dtype=layer_past.dtype, - device=layer_past.device) - layer_past = torch.cat([layer_past, fill_zeros], dim=-1) - key_value = torch.cat((layer_past, key_value), dim=-2) - + if layer_past is not None: + if layer_past.shape[-2] == key_value.shape[-2]: + key_value = torch.cat((layer_past, key_value), dim=-2) + else: + fill_zeros = torch.zeros(layer_past.shape[0], + layer_past.shape[1], + key_value.shape[2] - layer_past.shape[2], + dtype=layer_past.dtype, + device=layer_past.device) + layer_past = torch.cat([layer_past, fill_zeros], dim=-1) + key_value = torch.cat((layer_past, key_value), dim=-2) present = key_value if use_cache else None key, value = key_value.split((self.head_dim, self.head_dim), dim=-1) - attn_output, attn_weights = self._attn(query, - key.transpose(-1, -2), - value, - attention_mask, - head_mask) + attn_output, attn_weights = self._attn(query, key.transpose(-1, -2), + value, attention_mask, head_mask) if not self.multi_query: attn_output = attn_output.transpose(1, 2).reshape(hidden_states.shape)