Fix Starcoder issue on CPU on transformers 4.36+ (#11190)
* fix starcoder for sdpa * update * style
This commit is contained in:
parent
f93664147c
commit
bb83bc23fd
2 changed files with 112 additions and 3 deletions
|
|
@ -855,7 +855,7 @@ def convert_bigdl_other_module(model, dtype):
|
||||||
|
|
||||||
def convert_forward(m, target_m, new_forward):
|
def convert_forward(m, target_m, new_forward):
|
||||||
for _, sub_m in m.named_children():
|
for _, sub_m in m.named_children():
|
||||||
if isinstance(sub_m, target_m):
|
if sub_m.__class__ == target_m:
|
||||||
bound_method = new_forward.__get__(sub_m, sub_m.__class__)
|
bound_method = new_forward.__get__(sub_m, sub_m.__class__)
|
||||||
setattr(sub_m, "forward", bound_method)
|
setattr(sub_m, "forward", bound_method)
|
||||||
convert_forward(sub_m, target_m, new_forward)
|
convert_forward(sub_m, target_m, new_forward)
|
||||||
|
|
@ -872,7 +872,7 @@ def replace_RotaryEmbed(m, target_m, replace_embed):
|
||||||
|
|
||||||
def replace_func(m, target_m, func_name, new_func):
|
def replace_func(m, target_m, func_name, new_func):
|
||||||
for _, sub_m in m.named_children():
|
for _, sub_m in m.named_children():
|
||||||
if isinstance(sub_m, target_m):
|
if sub_m.__class__ == target_m:
|
||||||
bound_method = new_func.__get__(sub_m, sub_m.__class__)
|
bound_method = new_func.__get__(sub_m, sub_m.__class__)
|
||||||
setattr(sub_m, func_name, bound_method)
|
setattr(sub_m, func_name, bound_method)
|
||||||
replace_func(sub_m, target_m, func_name, new_func)
|
replace_func(sub_m, target_m, func_name, new_func)
|
||||||
|
|
@ -1529,7 +1529,8 @@ def _optimize_post(model, lightweight_bmm=False):
|
||||||
modeling_module_name = model.__class__.__module__
|
modeling_module_name = model.__class__.__module__
|
||||||
module = importlib.import_module(modeling_module_name)
|
module = importlib.import_module(modeling_module_name)
|
||||||
from ipex_llm.transformers.models.gptbigcode import _attn_wrapper
|
from ipex_llm.transformers.models.gptbigcode import _attn_wrapper
|
||||||
from ipex_llm.transformers.models.gptbigcode import gptbigcode_attention_forward
|
from ipex_llm.transformers.models.gptbigcode import gptbigcode_attention_forward, \
|
||||||
|
gptbigcode_sdpa_attention_forward
|
||||||
convert_forward(model,
|
convert_forward(model,
|
||||||
module.GPTBigCodeAttention,
|
module.GPTBigCodeAttention,
|
||||||
gptbigcode_attention_forward)
|
gptbigcode_attention_forward)
|
||||||
|
|
@ -1538,6 +1539,18 @@ def _optimize_post(model, lightweight_bmm=False):
|
||||||
module.GPTBigCodeAttention,
|
module.GPTBigCodeAttention,
|
||||||
"_attn",
|
"_attn",
|
||||||
_attn)
|
_attn)
|
||||||
|
try:
|
||||||
|
# for transformers 4.36+
|
||||||
|
convert_forward(model,
|
||||||
|
module.GPTBigCodeSdpaAttention,
|
||||||
|
gptbigcode_sdpa_attention_forward)
|
||||||
|
sdpa_attn = _attn_wrapper(module.GPTBigCodeSdpaAttention._attn)
|
||||||
|
replace_func(model,
|
||||||
|
module.GPTBigCodeSdpaAttention,
|
||||||
|
"_attn",
|
||||||
|
sdpa_attn)
|
||||||
|
except AttributeError:
|
||||||
|
pass
|
||||||
elif model.config.model_type == "starcoder2":
|
elif model.config.model_type == "starcoder2":
|
||||||
# starcoder2
|
# starcoder2
|
||||||
modeling_module_name = model.__class__.__module__
|
modeling_module_name = model.__class__.__module__
|
||||||
|
|
|
||||||
|
|
@ -99,3 +99,99 @@ def gptbigcode_attention_forward(
|
||||||
outputs += (attn_weights,)
|
outputs += (attn_weights,)
|
||||||
|
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
def gptbigcode_sdpa_attention_forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
layer_past: Optional[torch.Tensor] = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
head_mask: Optional[torch.Tensor] = None,
|
||||||
|
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||||
|
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
use_cache: Optional[bool] = False,
|
||||||
|
output_attentions: Optional[bool] = False,
|
||||||
|
) -> Union[
|
||||||
|
Tuple[torch.Tensor, Optional[torch.Tensor]],
|
||||||
|
Tuple[torch.Tensor, Optional[torch.Tensor], Tuple[torch.Tensor, ...]],
|
||||||
|
]:
|
||||||
|
if encoder_hidden_states is not None:
|
||||||
|
if not hasattr(self, "q_attn") or not self.is_cross_attention:
|
||||||
|
from ipex_llm.utils.common import invalidInputError
|
||||||
|
invalidInputError(
|
||||||
|
False,
|
||||||
|
"If class is used as cross attention," +
|
||||||
|
"the weights `q_attn` have to be defined. " +
|
||||||
|
"Please make sure to instantiate class with " +
|
||||||
|
"`GPTBigCodeAttention(..., is_cross_attention=True)`."
|
||||||
|
)
|
||||||
|
|
||||||
|
query = self.q_attn(hidden_states)
|
||||||
|
key_value = self.c_attn(encoder_hidden_states)
|
||||||
|
attention_mask = encoder_attention_mask
|
||||||
|
elif self.multi_query:
|
||||||
|
query, key_value = self.c_attn(hidden_states).split(
|
||||||
|
(self.embed_dim, 2 * self.kv_dim), dim=2)
|
||||||
|
else:
|
||||||
|
# Note: We split as (self.num_heads, 3, self.head_dim) instead of
|
||||||
|
# (3, self.num_heads, self.head_dim),
|
||||||
|
# i.e., the memory layout is not the same as GPT2.
|
||||||
|
# This makes the concatenation with past_key_value more efficient.
|
||||||
|
query, key_value = (
|
||||||
|
self.c_attn(hidden_states)
|
||||||
|
.view(*hidden_states.shape[:2], self.num_heads, 3 * self.head_dim)
|
||||||
|
.transpose(1, 2)
|
||||||
|
.split((self.head_dim, 2 * self.head_dim), dim=3)
|
||||||
|
)
|
||||||
|
|
||||||
|
if layer_past is not None:
|
||||||
|
if layer_past.shape[-2] == key_value.shape[-2]:
|
||||||
|
key_value = torch.cat((layer_past, key_value), dim=-2)
|
||||||
|
else:
|
||||||
|
fill_zeros = torch.zeros(layer_past.shape[0],
|
||||||
|
layer_past.shape[1],
|
||||||
|
key_value.shape[2] - layer_past.shape[2],
|
||||||
|
dtype=layer_past.dtype,
|
||||||
|
device=layer_past.device)
|
||||||
|
layer_past = torch.cat([layer_past, fill_zeros], dim=-1)
|
||||||
|
key_value = torch.cat((layer_past, key_value), dim=-2)
|
||||||
|
# key_value = torch.cat((layer_past, key_value), dim=-2)
|
||||||
|
present = key_value if use_cache else None
|
||||||
|
|
||||||
|
key, value = key_value.split((self.head_dim, self.head_dim), dim=-1)
|
||||||
|
|
||||||
|
if not output_attentions and head_mask is None:
|
||||||
|
# Difference with the original implementation: there is no need to
|
||||||
|
# transpose the key here,
|
||||||
|
# as SDPA expects seq_length to be at index -2 for the key as well
|
||||||
|
attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
|
||||||
|
else:
|
||||||
|
# TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"`
|
||||||
|
# once this is implemented.
|
||||||
|
logger.warning_once(
|
||||||
|
"GPTBigCodeModel is using GPTBigCodeSdpaAttention, "
|
||||||
|
"but `torch.nn.functional.scaled_dot_product_attention` does not support "
|
||||||
|
"`output_attentions=True` and `head_mask` not None."
|
||||||
|
' Falling back to the manual attention implementation, '
|
||||||
|
'but specifying the manual implementation will be required from '
|
||||||
|
'Transformers version v5.0.0 onwards. '
|
||||||
|
'This warning can be removed using the argument `attn_implementation="eager"` '
|
||||||
|
'when loading the model.'
|
||||||
|
)
|
||||||
|
attn_output, attn_weights = super()._attn(query, key.transpose(-1, -2),
|
||||||
|
value, attention_mask, head_mask)
|
||||||
|
|
||||||
|
if not self.multi_query:
|
||||||
|
attn_output = attn_output.transpose(1, 2).reshape(hidden_states.shape)
|
||||||
|
attn_output = self.c_proj(attn_output)
|
||||||
|
attn_output = self.resid_dropout(attn_output)
|
||||||
|
|
||||||
|
outputs = (attn_output, present)
|
||||||
|
if output_attentions:
|
||||||
|
if self.multi_query:
|
||||||
|
# Transpose to return weights in the usual format
|
||||||
|
# (batch_size, num_heads, query_length, key_length)
|
||||||
|
attn_weights = attn_weights.transpose(1, 2)
|
||||||
|
outputs += (attn_weights,)
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue