diff --git a/python/llm/src/ipex_llm/transformers/convert.py b/python/llm/src/ipex_llm/transformers/convert.py index 6ab8fa7a..192e1f95 100644 --- a/python/llm/src/ipex_llm/transformers/convert.py +++ b/python/llm/src/ipex_llm/transformers/convert.py @@ -855,7 +855,7 @@ def convert_bigdl_other_module(model, dtype): def convert_forward(m, target_m, new_forward): for _, sub_m in m.named_children(): - if isinstance(sub_m, target_m): + if sub_m.__class__ == target_m: bound_method = new_forward.__get__(sub_m, sub_m.__class__) setattr(sub_m, "forward", bound_method) convert_forward(sub_m, target_m, new_forward) @@ -872,7 +872,7 @@ def replace_RotaryEmbed(m, target_m, replace_embed): def replace_func(m, target_m, func_name, new_func): for _, sub_m in m.named_children(): - if isinstance(sub_m, target_m): + if sub_m.__class__ == target_m: bound_method = new_func.__get__(sub_m, sub_m.__class__) setattr(sub_m, func_name, bound_method) replace_func(sub_m, target_m, func_name, new_func) @@ -1529,7 +1529,8 @@ def _optimize_post(model, lightweight_bmm=False): modeling_module_name = model.__class__.__module__ module = importlib.import_module(modeling_module_name) from ipex_llm.transformers.models.gptbigcode import _attn_wrapper - from ipex_llm.transformers.models.gptbigcode import gptbigcode_attention_forward + from ipex_llm.transformers.models.gptbigcode import gptbigcode_attention_forward, \ + gptbigcode_sdpa_attention_forward convert_forward(model, module.GPTBigCodeAttention, gptbigcode_attention_forward) @@ -1538,6 +1539,18 @@ def _optimize_post(model, lightweight_bmm=False): module.GPTBigCodeAttention, "_attn", _attn) + try: + # for transformers 4.36+ + convert_forward(model, + module.GPTBigCodeSdpaAttention, + gptbigcode_sdpa_attention_forward) + sdpa_attn = _attn_wrapper(module.GPTBigCodeSdpaAttention._attn) + replace_func(model, + module.GPTBigCodeSdpaAttention, + "_attn", + sdpa_attn) + except AttributeError: + pass elif model.config.model_type == "starcoder2": # starcoder2 modeling_module_name = model.__class__.__module__ diff --git a/python/llm/src/ipex_llm/transformers/models/gptbigcode.py b/python/llm/src/ipex_llm/transformers/models/gptbigcode.py index 611b9fba..747cc26d 100644 --- a/python/llm/src/ipex_llm/transformers/models/gptbigcode.py +++ b/python/llm/src/ipex_llm/transformers/models/gptbigcode.py @@ -99,3 +99,99 @@ def gptbigcode_attention_forward( outputs += (attn_weights,) return outputs + + +def gptbigcode_sdpa_attention_forward( + self, + hidden_states: torch.Tensor, + layer_past: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, +) -> Union[ + Tuple[torch.Tensor, Optional[torch.Tensor]], + Tuple[torch.Tensor, Optional[torch.Tensor], Tuple[torch.Tensor, ...]], +]: + if encoder_hidden_states is not None: + if not hasattr(self, "q_attn") or not self.is_cross_attention: + from ipex_llm.utils.common import invalidInputError + invalidInputError( + False, + "If class is used as cross attention," + + "the weights `q_attn` have to be defined. " + + "Please make sure to instantiate class with " + + "`GPTBigCodeAttention(..., is_cross_attention=True)`." + ) + + query = self.q_attn(hidden_states) + key_value = self.c_attn(encoder_hidden_states) + attention_mask = encoder_attention_mask + elif self.multi_query: + query, key_value = self.c_attn(hidden_states).split( + (self.embed_dim, 2 * self.kv_dim), dim=2) + else: + # Note: We split as (self.num_heads, 3, self.head_dim) instead of + # (3, self.num_heads, self.head_dim), + # i.e., the memory layout is not the same as GPT2. + # This makes the concatenation with past_key_value more efficient. + query, key_value = ( + self.c_attn(hidden_states) + .view(*hidden_states.shape[:2], self.num_heads, 3 * self.head_dim) + .transpose(1, 2) + .split((self.head_dim, 2 * self.head_dim), dim=3) + ) + + if layer_past is not None: + if layer_past.shape[-2] == key_value.shape[-2]: + key_value = torch.cat((layer_past, key_value), dim=-2) + else: + fill_zeros = torch.zeros(layer_past.shape[0], + layer_past.shape[1], + key_value.shape[2] - layer_past.shape[2], + dtype=layer_past.dtype, + device=layer_past.device) + layer_past = torch.cat([layer_past, fill_zeros], dim=-1) + key_value = torch.cat((layer_past, key_value), dim=-2) + # key_value = torch.cat((layer_past, key_value), dim=-2) + present = key_value if use_cache else None + + key, value = key_value.split((self.head_dim, self.head_dim), dim=-1) + + if not output_attentions and head_mask is None: + # Difference with the original implementation: there is no need to + # transpose the key here, + # as SDPA expects seq_length to be at index -2 for the key as well + attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask) + else: + # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` + # once this is implemented. + logger.warning_once( + "GPTBigCodeModel is using GPTBigCodeSdpaAttention, " + "but `torch.nn.functional.scaled_dot_product_attention` does not support " + "`output_attentions=True` and `head_mask` not None." + ' Falling back to the manual attention implementation, ' + 'but specifying the manual implementation will be required from ' + 'Transformers version v5.0.0 onwards. ' + 'This warning can be removed using the argument `attn_implementation="eager"` ' + 'when loading the model.' + ) + attn_output, attn_weights = super()._attn(query, key.transpose(-1, -2), + value, attention_mask, head_mask) + + if not self.multi_query: + attn_output = attn_output.transpose(1, 2).reshape(hidden_states.shape) + attn_output = self.c_proj(attn_output) + attn_output = self.resid_dropout(attn_output) + + outputs = (attn_output, present) + if output_attentions: + if self.multi_query: + # Transpose to return weights in the usual format + # (batch_size, num_heads, query_length, key_length) + attn_weights = attn_weights.transpose(1, 2) + outputs += (attn_weights,) + + return outputs