Fix Starcoder issue on CPU on transformers 4.36+ (#11190)
* fix starcoder for sdpa * update * style
This commit is contained in:
		
							parent
							
								
									f93664147c
								
							
						
					
					
						commit
						bb83bc23fd
					
				
					 2 changed files with 112 additions and 3 deletions
				
			
		| 
						 | 
				
			
			@ -855,7 +855,7 @@ def convert_bigdl_other_module(model, dtype):
 | 
			
		|||
 | 
			
		||||
def convert_forward(m, target_m, new_forward):
 | 
			
		||||
    for _, sub_m in m.named_children():
 | 
			
		||||
        if isinstance(sub_m, target_m):
 | 
			
		||||
        if sub_m.__class__ == target_m:
 | 
			
		||||
            bound_method = new_forward.__get__(sub_m, sub_m.__class__)
 | 
			
		||||
            setattr(sub_m, "forward", bound_method)
 | 
			
		||||
        convert_forward(sub_m, target_m, new_forward)
 | 
			
		||||
| 
						 | 
				
			
			@ -872,7 +872,7 @@ def replace_RotaryEmbed(m, target_m,  replace_embed):
 | 
			
		|||
 | 
			
		||||
def replace_func(m, target_m, func_name, new_func):
 | 
			
		||||
    for _, sub_m in m.named_children():
 | 
			
		||||
        if isinstance(sub_m, target_m):
 | 
			
		||||
        if sub_m.__class__ == target_m:
 | 
			
		||||
            bound_method = new_func.__get__(sub_m, sub_m.__class__)
 | 
			
		||||
            setattr(sub_m, func_name, bound_method)
 | 
			
		||||
        replace_func(sub_m, target_m, func_name, new_func)
 | 
			
		||||
| 
						 | 
				
			
			@ -1529,7 +1529,8 @@ def _optimize_post(model, lightweight_bmm=False):
 | 
			
		|||
        modeling_module_name = model.__class__.__module__
 | 
			
		||||
        module = importlib.import_module(modeling_module_name)
 | 
			
		||||
        from ipex_llm.transformers.models.gptbigcode import _attn_wrapper
 | 
			
		||||
        from ipex_llm.transformers.models.gptbigcode import gptbigcode_attention_forward
 | 
			
		||||
        from ipex_llm.transformers.models.gptbigcode import gptbigcode_attention_forward, \
 | 
			
		||||
            gptbigcode_sdpa_attention_forward
 | 
			
		||||
        convert_forward(model,
 | 
			
		||||
                        module.GPTBigCodeAttention,
 | 
			
		||||
                        gptbigcode_attention_forward)
 | 
			
		||||
| 
						 | 
				
			
			@ -1538,6 +1539,18 @@ def _optimize_post(model, lightweight_bmm=False):
 | 
			
		|||
                     module.GPTBigCodeAttention,
 | 
			
		||||
                     "_attn",
 | 
			
		||||
                     _attn)
 | 
			
		||||
        try:
 | 
			
		||||
            # for transformers 4.36+
 | 
			
		||||
            convert_forward(model,
 | 
			
		||||
                            module.GPTBigCodeSdpaAttention,
 | 
			
		||||
                            gptbigcode_sdpa_attention_forward)
 | 
			
		||||
            sdpa_attn = _attn_wrapper(module.GPTBigCodeSdpaAttention._attn)
 | 
			
		||||
            replace_func(model,
 | 
			
		||||
                         module.GPTBigCodeSdpaAttention,
 | 
			
		||||
                         "_attn",
 | 
			
		||||
                         sdpa_attn)
 | 
			
		||||
        except AttributeError:
 | 
			
		||||
            pass
 | 
			
		||||
    elif model.config.model_type == "starcoder2":
 | 
			
		||||
        # starcoder2
 | 
			
		||||
        modeling_module_name = model.__class__.__module__
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -99,3 +99,99 @@ def gptbigcode_attention_forward(
 | 
			
		|||
            outputs += (attn_weights,)
 | 
			
		||||
 | 
			
		||||
        return outputs
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def gptbigcode_sdpa_attention_forward(
 | 
			
		||||
        self,
 | 
			
		||||
        hidden_states: torch.Tensor,
 | 
			
		||||
        layer_past: Optional[torch.Tensor] = None,
 | 
			
		||||
        attention_mask: Optional[torch.Tensor] = None,
 | 
			
		||||
        head_mask: Optional[torch.Tensor] = None,
 | 
			
		||||
        encoder_hidden_states: Optional[torch.Tensor] = None,
 | 
			
		||||
        encoder_attention_mask: Optional[torch.Tensor] = None,
 | 
			
		||||
        use_cache: Optional[bool] = False,
 | 
			
		||||
        output_attentions: Optional[bool] = False,
 | 
			
		||||
) -> Union[
 | 
			
		||||
    Tuple[torch.Tensor, Optional[torch.Tensor]],
 | 
			
		||||
    Tuple[torch.Tensor, Optional[torch.Tensor], Tuple[torch.Tensor, ...]],
 | 
			
		||||
]:
 | 
			
		||||
        if encoder_hidden_states is not None:
 | 
			
		||||
            if not hasattr(self, "q_attn") or not self.is_cross_attention:
 | 
			
		||||
                from ipex_llm.utils.common import invalidInputError
 | 
			
		||||
                invalidInputError(
 | 
			
		||||
                    False,
 | 
			
		||||
                    "If class is used as cross attention," +
 | 
			
		||||
                    "the weights `q_attn` have to be defined. " +
 | 
			
		||||
                    "Please make sure to instantiate class with " +
 | 
			
		||||
                    "`GPTBigCodeAttention(..., is_cross_attention=True)`."
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
            query = self.q_attn(hidden_states)
 | 
			
		||||
            key_value = self.c_attn(encoder_hidden_states)
 | 
			
		||||
            attention_mask = encoder_attention_mask
 | 
			
		||||
        elif self.multi_query:
 | 
			
		||||
            query, key_value = self.c_attn(hidden_states).split(
 | 
			
		||||
                (self.embed_dim, 2 * self.kv_dim), dim=2)
 | 
			
		||||
        else:
 | 
			
		||||
            # Note: We split as (self.num_heads, 3, self.head_dim) instead of
 | 
			
		||||
            # (3, self.num_heads, self.head_dim),
 | 
			
		||||
            # i.e., the memory layout is not the same as GPT2.
 | 
			
		||||
            # This makes the concatenation with past_key_value more efficient.
 | 
			
		||||
            query, key_value = (
 | 
			
		||||
                self.c_attn(hidden_states)
 | 
			
		||||
                .view(*hidden_states.shape[:2], self.num_heads, 3 * self.head_dim)
 | 
			
		||||
                .transpose(1, 2)
 | 
			
		||||
                .split((self.head_dim, 2 * self.head_dim), dim=3)
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        if layer_past is not None:
 | 
			
		||||
            if layer_past.shape[-2] == key_value.shape[-2]:
 | 
			
		||||
                key_value = torch.cat((layer_past, key_value), dim=-2)
 | 
			
		||||
            else:
 | 
			
		||||
                fill_zeros = torch.zeros(layer_past.shape[0],
 | 
			
		||||
                                         layer_past.shape[1],
 | 
			
		||||
                                         key_value.shape[2] - layer_past.shape[2],
 | 
			
		||||
                                         dtype=layer_past.dtype,
 | 
			
		||||
                                         device=layer_past.device)
 | 
			
		||||
                layer_past = torch.cat([layer_past, fill_zeros], dim=-1)
 | 
			
		||||
                key_value = torch.cat((layer_past, key_value), dim=-2)
 | 
			
		||||
            # key_value = torch.cat((layer_past, key_value), dim=-2)
 | 
			
		||||
        present = key_value if use_cache else None
 | 
			
		||||
 | 
			
		||||
        key, value = key_value.split((self.head_dim, self.head_dim), dim=-1)
 | 
			
		||||
 | 
			
		||||
        if not output_attentions and head_mask is None:
 | 
			
		||||
            # Difference with the original implementation: there is no need to
 | 
			
		||||
            # transpose the key here,
 | 
			
		||||
            # as SDPA expects seq_length to be at index -2 for the key as well
 | 
			
		||||
            attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
 | 
			
		||||
        else:
 | 
			
		||||
            # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"`
 | 
			
		||||
            # once this is implemented.
 | 
			
		||||
            logger.warning_once(
 | 
			
		||||
                "GPTBigCodeModel is using GPTBigCodeSdpaAttention, "
 | 
			
		||||
                "but `torch.nn.functional.scaled_dot_product_attention` does not support "
 | 
			
		||||
                "`output_attentions=True` and `head_mask` not None."
 | 
			
		||||
                ' Falling back to the manual attention implementation, '
 | 
			
		||||
                'but specifying the manual implementation will be required from '
 | 
			
		||||
                'Transformers version v5.0.0 onwards. '
 | 
			
		||||
                'This warning can be removed using the argument `attn_implementation="eager"` '
 | 
			
		||||
                'when loading the model.'
 | 
			
		||||
            )
 | 
			
		||||
            attn_output, attn_weights = super()._attn(query, key.transpose(-1, -2),
 | 
			
		||||
                                                      value, attention_mask, head_mask)
 | 
			
		||||
 | 
			
		||||
        if not self.multi_query:
 | 
			
		||||
            attn_output = attn_output.transpose(1, 2).reshape(hidden_states.shape)
 | 
			
		||||
        attn_output = self.c_proj(attn_output)
 | 
			
		||||
        attn_output = self.resid_dropout(attn_output)
 | 
			
		||||
 | 
			
		||||
        outputs = (attn_output, present)
 | 
			
		||||
        if output_attentions:
 | 
			
		||||
            if self.multi_query:
 | 
			
		||||
                # Transpose to return weights in the usual format
 | 
			
		||||
                # (batch_size, num_heads, query_length, key_length)
 | 
			
		||||
                attn_weights = attn_weights.transpose(1, 2)
 | 
			
		||||
            outputs += (attn_weights,)
 | 
			
		||||
 | 
			
		||||
        return outputs
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue