Fix Starcoder issue on CPU on transformers 4.36+ (#11190)

* fix starcoder for sdpa

* update

* style
This commit is contained in:
Jiao Wang 2024-06-04 10:05:40 -07:00 committed by GitHub
parent f93664147c
commit bb83bc23fd
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 112 additions and 3 deletions

View file

@ -855,7 +855,7 @@ def convert_bigdl_other_module(model, dtype):
def convert_forward(m, target_m, new_forward): def convert_forward(m, target_m, new_forward):
for _, sub_m in m.named_children(): for _, sub_m in m.named_children():
if isinstance(sub_m, target_m): if sub_m.__class__ == target_m:
bound_method = new_forward.__get__(sub_m, sub_m.__class__) bound_method = new_forward.__get__(sub_m, sub_m.__class__)
setattr(sub_m, "forward", bound_method) setattr(sub_m, "forward", bound_method)
convert_forward(sub_m, target_m, new_forward) convert_forward(sub_m, target_m, new_forward)
@ -872,7 +872,7 @@ def replace_RotaryEmbed(m, target_m, replace_embed):
def replace_func(m, target_m, func_name, new_func): def replace_func(m, target_m, func_name, new_func):
for _, sub_m in m.named_children(): for _, sub_m in m.named_children():
if isinstance(sub_m, target_m): if sub_m.__class__ == target_m:
bound_method = new_func.__get__(sub_m, sub_m.__class__) bound_method = new_func.__get__(sub_m, sub_m.__class__)
setattr(sub_m, func_name, bound_method) setattr(sub_m, func_name, bound_method)
replace_func(sub_m, target_m, func_name, new_func) replace_func(sub_m, target_m, func_name, new_func)
@ -1529,7 +1529,8 @@ def _optimize_post(model, lightweight_bmm=False):
modeling_module_name = model.__class__.__module__ modeling_module_name = model.__class__.__module__
module = importlib.import_module(modeling_module_name) module = importlib.import_module(modeling_module_name)
from ipex_llm.transformers.models.gptbigcode import _attn_wrapper from ipex_llm.transformers.models.gptbigcode import _attn_wrapper
from ipex_llm.transformers.models.gptbigcode import gptbigcode_attention_forward from ipex_llm.transformers.models.gptbigcode import gptbigcode_attention_forward, \
gptbigcode_sdpa_attention_forward
convert_forward(model, convert_forward(model,
module.GPTBigCodeAttention, module.GPTBigCodeAttention,
gptbigcode_attention_forward) gptbigcode_attention_forward)
@ -1538,6 +1539,18 @@ def _optimize_post(model, lightweight_bmm=False):
module.GPTBigCodeAttention, module.GPTBigCodeAttention,
"_attn", "_attn",
_attn) _attn)
try:
# for transformers 4.36+
convert_forward(model,
module.GPTBigCodeSdpaAttention,
gptbigcode_sdpa_attention_forward)
sdpa_attn = _attn_wrapper(module.GPTBigCodeSdpaAttention._attn)
replace_func(model,
module.GPTBigCodeSdpaAttention,
"_attn",
sdpa_attn)
except AttributeError:
pass
elif model.config.model_type == "starcoder2": elif model.config.model_type == "starcoder2":
# starcoder2 # starcoder2
modeling_module_name = model.__class__.__module__ modeling_module_name = model.__class__.__module__

View file

@ -99,3 +99,99 @@ def gptbigcode_attention_forward(
outputs += (attn_weights,) outputs += (attn_weights,)
return outputs return outputs
def gptbigcode_sdpa_attention_forward(
self,
hidden_states: torch.Tensor,
layer_past: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
) -> Union[
Tuple[torch.Tensor, Optional[torch.Tensor]],
Tuple[torch.Tensor, Optional[torch.Tensor], Tuple[torch.Tensor, ...]],
]:
if encoder_hidden_states is not None:
if not hasattr(self, "q_attn") or not self.is_cross_attention:
from ipex_llm.utils.common import invalidInputError
invalidInputError(
False,
"If class is used as cross attention," +
"the weights `q_attn` have to be defined. " +
"Please make sure to instantiate class with " +
"`GPTBigCodeAttention(..., is_cross_attention=True)`."
)
query = self.q_attn(hidden_states)
key_value = self.c_attn(encoder_hidden_states)
attention_mask = encoder_attention_mask
elif self.multi_query:
query, key_value = self.c_attn(hidden_states).split(
(self.embed_dim, 2 * self.kv_dim), dim=2)
else:
# Note: We split as (self.num_heads, 3, self.head_dim) instead of
# (3, self.num_heads, self.head_dim),
# i.e., the memory layout is not the same as GPT2.
# This makes the concatenation with past_key_value more efficient.
query, key_value = (
self.c_attn(hidden_states)
.view(*hidden_states.shape[:2], self.num_heads, 3 * self.head_dim)
.transpose(1, 2)
.split((self.head_dim, 2 * self.head_dim), dim=3)
)
if layer_past is not None:
if layer_past.shape[-2] == key_value.shape[-2]:
key_value = torch.cat((layer_past, key_value), dim=-2)
else:
fill_zeros = torch.zeros(layer_past.shape[0],
layer_past.shape[1],
key_value.shape[2] - layer_past.shape[2],
dtype=layer_past.dtype,
device=layer_past.device)
layer_past = torch.cat([layer_past, fill_zeros], dim=-1)
key_value = torch.cat((layer_past, key_value), dim=-2)
# key_value = torch.cat((layer_past, key_value), dim=-2)
present = key_value if use_cache else None
key, value = key_value.split((self.head_dim, self.head_dim), dim=-1)
if not output_attentions and head_mask is None:
# Difference with the original implementation: there is no need to
# transpose the key here,
# as SDPA expects seq_length to be at index -2 for the key as well
attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
else:
# TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"`
# once this is implemented.
logger.warning_once(
"GPTBigCodeModel is using GPTBigCodeSdpaAttention, "
"but `torch.nn.functional.scaled_dot_product_attention` does not support "
"`output_attentions=True` and `head_mask` not None."
' Falling back to the manual attention implementation, '
'but specifying the manual implementation will be required from '
'Transformers version v5.0.0 onwards. '
'This warning can be removed using the argument `attn_implementation="eager"` '
'when loading the model.'
)
attn_output, attn_weights = super()._attn(query, key.transpose(-1, -2),
value, attention_mask, head_mask)
if not self.multi_query:
attn_output = attn_output.transpose(1, 2).reshape(hidden_states.shape)
attn_output = self.c_proj(attn_output)
attn_output = self.resid_dropout(attn_output)
outputs = (attn_output, present)
if output_attentions:
if self.multi_query:
# Transpose to return weights in the usual format
# (batch_size, num_heads, query_length, key_length)
attn_weights = attn_weights.transpose(1, 2)
outputs += (attn_weights,)
return outputs