Fix Starcoder issue on CPU on transformers 4.36+ (#11190)
* fix starcoder for sdpa * update * style
This commit is contained in:
parent
f93664147c
commit
bb83bc23fd
2 changed files with 112 additions and 3 deletions
|
|
@ -855,7 +855,7 @@ def convert_bigdl_other_module(model, dtype):
|
|||
|
||||
def convert_forward(m, target_m, new_forward):
|
||||
for _, sub_m in m.named_children():
|
||||
if isinstance(sub_m, target_m):
|
||||
if sub_m.__class__ == target_m:
|
||||
bound_method = new_forward.__get__(sub_m, sub_m.__class__)
|
||||
setattr(sub_m, "forward", bound_method)
|
||||
convert_forward(sub_m, target_m, new_forward)
|
||||
|
|
@ -872,7 +872,7 @@ def replace_RotaryEmbed(m, target_m, replace_embed):
|
|||
|
||||
def replace_func(m, target_m, func_name, new_func):
|
||||
for _, sub_m in m.named_children():
|
||||
if isinstance(sub_m, target_m):
|
||||
if sub_m.__class__ == target_m:
|
||||
bound_method = new_func.__get__(sub_m, sub_m.__class__)
|
||||
setattr(sub_m, func_name, bound_method)
|
||||
replace_func(sub_m, target_m, func_name, new_func)
|
||||
|
|
@ -1529,7 +1529,8 @@ def _optimize_post(model, lightweight_bmm=False):
|
|||
modeling_module_name = model.__class__.__module__
|
||||
module = importlib.import_module(modeling_module_name)
|
||||
from ipex_llm.transformers.models.gptbigcode import _attn_wrapper
|
||||
from ipex_llm.transformers.models.gptbigcode import gptbigcode_attention_forward
|
||||
from ipex_llm.transformers.models.gptbigcode import gptbigcode_attention_forward, \
|
||||
gptbigcode_sdpa_attention_forward
|
||||
convert_forward(model,
|
||||
module.GPTBigCodeAttention,
|
||||
gptbigcode_attention_forward)
|
||||
|
|
@ -1538,6 +1539,18 @@ def _optimize_post(model, lightweight_bmm=False):
|
|||
module.GPTBigCodeAttention,
|
||||
"_attn",
|
||||
_attn)
|
||||
try:
|
||||
# for transformers 4.36+
|
||||
convert_forward(model,
|
||||
module.GPTBigCodeSdpaAttention,
|
||||
gptbigcode_sdpa_attention_forward)
|
||||
sdpa_attn = _attn_wrapper(module.GPTBigCodeSdpaAttention._attn)
|
||||
replace_func(model,
|
||||
module.GPTBigCodeSdpaAttention,
|
||||
"_attn",
|
||||
sdpa_attn)
|
||||
except AttributeError:
|
||||
pass
|
||||
elif model.config.model_type == "starcoder2":
|
||||
# starcoder2
|
||||
modeling_module_name = model.__class__.__module__
|
||||
|
|
|
|||
|
|
@ -99,3 +99,99 @@ def gptbigcode_attention_forward(
|
|||
outputs += (attn_weights,)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
def gptbigcode_sdpa_attention_forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
layer_past: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
use_cache: Optional[bool] = False,
|
||||
output_attentions: Optional[bool] = False,
|
||||
) -> Union[
|
||||
Tuple[torch.Tensor, Optional[torch.Tensor]],
|
||||
Tuple[torch.Tensor, Optional[torch.Tensor], Tuple[torch.Tensor, ...]],
|
||||
]:
|
||||
if encoder_hidden_states is not None:
|
||||
if not hasattr(self, "q_attn") or not self.is_cross_attention:
|
||||
from ipex_llm.utils.common import invalidInputError
|
||||
invalidInputError(
|
||||
False,
|
||||
"If class is used as cross attention," +
|
||||
"the weights `q_attn` have to be defined. " +
|
||||
"Please make sure to instantiate class with " +
|
||||
"`GPTBigCodeAttention(..., is_cross_attention=True)`."
|
||||
)
|
||||
|
||||
query = self.q_attn(hidden_states)
|
||||
key_value = self.c_attn(encoder_hidden_states)
|
||||
attention_mask = encoder_attention_mask
|
||||
elif self.multi_query:
|
||||
query, key_value = self.c_attn(hidden_states).split(
|
||||
(self.embed_dim, 2 * self.kv_dim), dim=2)
|
||||
else:
|
||||
# Note: We split as (self.num_heads, 3, self.head_dim) instead of
|
||||
# (3, self.num_heads, self.head_dim),
|
||||
# i.e., the memory layout is not the same as GPT2.
|
||||
# This makes the concatenation with past_key_value more efficient.
|
||||
query, key_value = (
|
||||
self.c_attn(hidden_states)
|
||||
.view(*hidden_states.shape[:2], self.num_heads, 3 * self.head_dim)
|
||||
.transpose(1, 2)
|
||||
.split((self.head_dim, 2 * self.head_dim), dim=3)
|
||||
)
|
||||
|
||||
if layer_past is not None:
|
||||
if layer_past.shape[-2] == key_value.shape[-2]:
|
||||
key_value = torch.cat((layer_past, key_value), dim=-2)
|
||||
else:
|
||||
fill_zeros = torch.zeros(layer_past.shape[0],
|
||||
layer_past.shape[1],
|
||||
key_value.shape[2] - layer_past.shape[2],
|
||||
dtype=layer_past.dtype,
|
||||
device=layer_past.device)
|
||||
layer_past = torch.cat([layer_past, fill_zeros], dim=-1)
|
||||
key_value = torch.cat((layer_past, key_value), dim=-2)
|
||||
# key_value = torch.cat((layer_past, key_value), dim=-2)
|
||||
present = key_value if use_cache else None
|
||||
|
||||
key, value = key_value.split((self.head_dim, self.head_dim), dim=-1)
|
||||
|
||||
if not output_attentions and head_mask is None:
|
||||
# Difference with the original implementation: there is no need to
|
||||
# transpose the key here,
|
||||
# as SDPA expects seq_length to be at index -2 for the key as well
|
||||
attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
|
||||
else:
|
||||
# TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"`
|
||||
# once this is implemented.
|
||||
logger.warning_once(
|
||||
"GPTBigCodeModel is using GPTBigCodeSdpaAttention, "
|
||||
"but `torch.nn.functional.scaled_dot_product_attention` does not support "
|
||||
"`output_attentions=True` and `head_mask` not None."
|
||||
' Falling back to the manual attention implementation, '
|
||||
'but specifying the manual implementation will be required from '
|
||||
'Transformers version v5.0.0 onwards. '
|
||||
'This warning can be removed using the argument `attn_implementation="eager"` '
|
||||
'when loading the model.'
|
||||
)
|
||||
attn_output, attn_weights = super()._attn(query, key.transpose(-1, -2),
|
||||
value, attention_mask, head_mask)
|
||||
|
||||
if not self.multi_query:
|
||||
attn_output = attn_output.transpose(1, 2).reshape(hidden_states.shape)
|
||||
attn_output = self.c_proj(attn_output)
|
||||
attn_output = self.resid_dropout(attn_output)
|
||||
|
||||
outputs = (attn_output, present)
|
||||
if output_attentions:
|
||||
if self.multi_query:
|
||||
# Transpose to return weights in the usual format
|
||||
# (batch_size, num_heads, query_length, key_length)
|
||||
attn_weights = attn_weights.transpose(1, 2)
|
||||
outputs += (attn_weights,)
|
||||
|
||||
return outputs
|
||||
|
|
|
|||
Loading…
Reference in a new issue