Fix Starcoder issue on CPU on transformers 4.36+ (#11190)
* fix starcoder for sdpa * update * style
This commit is contained in:
		
							parent
							
								
									f93664147c
								
							
						
					
					
						commit
						bb83bc23fd
					
				
					 2 changed files with 112 additions and 3 deletions
				
			
		| 
						 | 
					@ -855,7 +855,7 @@ def convert_bigdl_other_module(model, dtype):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def convert_forward(m, target_m, new_forward):
 | 
					def convert_forward(m, target_m, new_forward):
 | 
				
			||||||
    for _, sub_m in m.named_children():
 | 
					    for _, sub_m in m.named_children():
 | 
				
			||||||
        if isinstance(sub_m, target_m):
 | 
					        if sub_m.__class__ == target_m:
 | 
				
			||||||
            bound_method = new_forward.__get__(sub_m, sub_m.__class__)
 | 
					            bound_method = new_forward.__get__(sub_m, sub_m.__class__)
 | 
				
			||||||
            setattr(sub_m, "forward", bound_method)
 | 
					            setattr(sub_m, "forward", bound_method)
 | 
				
			||||||
        convert_forward(sub_m, target_m, new_forward)
 | 
					        convert_forward(sub_m, target_m, new_forward)
 | 
				
			||||||
| 
						 | 
					@ -872,7 +872,7 @@ def replace_RotaryEmbed(m, target_m,  replace_embed):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def replace_func(m, target_m, func_name, new_func):
 | 
					def replace_func(m, target_m, func_name, new_func):
 | 
				
			||||||
    for _, sub_m in m.named_children():
 | 
					    for _, sub_m in m.named_children():
 | 
				
			||||||
        if isinstance(sub_m, target_m):
 | 
					        if sub_m.__class__ == target_m:
 | 
				
			||||||
            bound_method = new_func.__get__(sub_m, sub_m.__class__)
 | 
					            bound_method = new_func.__get__(sub_m, sub_m.__class__)
 | 
				
			||||||
            setattr(sub_m, func_name, bound_method)
 | 
					            setattr(sub_m, func_name, bound_method)
 | 
				
			||||||
        replace_func(sub_m, target_m, func_name, new_func)
 | 
					        replace_func(sub_m, target_m, func_name, new_func)
 | 
				
			||||||
| 
						 | 
					@ -1529,7 +1529,8 @@ def _optimize_post(model, lightweight_bmm=False):
 | 
				
			||||||
        modeling_module_name = model.__class__.__module__
 | 
					        modeling_module_name = model.__class__.__module__
 | 
				
			||||||
        module = importlib.import_module(modeling_module_name)
 | 
					        module = importlib.import_module(modeling_module_name)
 | 
				
			||||||
        from ipex_llm.transformers.models.gptbigcode import _attn_wrapper
 | 
					        from ipex_llm.transformers.models.gptbigcode import _attn_wrapper
 | 
				
			||||||
        from ipex_llm.transformers.models.gptbigcode import gptbigcode_attention_forward
 | 
					        from ipex_llm.transformers.models.gptbigcode import gptbigcode_attention_forward, \
 | 
				
			||||||
 | 
					            gptbigcode_sdpa_attention_forward
 | 
				
			||||||
        convert_forward(model,
 | 
					        convert_forward(model,
 | 
				
			||||||
                        module.GPTBigCodeAttention,
 | 
					                        module.GPTBigCodeAttention,
 | 
				
			||||||
                        gptbigcode_attention_forward)
 | 
					                        gptbigcode_attention_forward)
 | 
				
			||||||
| 
						 | 
					@ -1538,6 +1539,18 @@ def _optimize_post(model, lightweight_bmm=False):
 | 
				
			||||||
                     module.GPTBigCodeAttention,
 | 
					                     module.GPTBigCodeAttention,
 | 
				
			||||||
                     "_attn",
 | 
					                     "_attn",
 | 
				
			||||||
                     _attn)
 | 
					                     _attn)
 | 
				
			||||||
 | 
					        try:
 | 
				
			||||||
 | 
					            # for transformers 4.36+
 | 
				
			||||||
 | 
					            convert_forward(model,
 | 
				
			||||||
 | 
					                            module.GPTBigCodeSdpaAttention,
 | 
				
			||||||
 | 
					                            gptbigcode_sdpa_attention_forward)
 | 
				
			||||||
 | 
					            sdpa_attn = _attn_wrapper(module.GPTBigCodeSdpaAttention._attn)
 | 
				
			||||||
 | 
					            replace_func(model,
 | 
				
			||||||
 | 
					                         module.GPTBigCodeSdpaAttention,
 | 
				
			||||||
 | 
					                         "_attn",
 | 
				
			||||||
 | 
					                         sdpa_attn)
 | 
				
			||||||
 | 
					        except AttributeError:
 | 
				
			||||||
 | 
					            pass
 | 
				
			||||||
    elif model.config.model_type == "starcoder2":
 | 
					    elif model.config.model_type == "starcoder2":
 | 
				
			||||||
        # starcoder2
 | 
					        # starcoder2
 | 
				
			||||||
        modeling_module_name = model.__class__.__module__
 | 
					        modeling_module_name = model.__class__.__module__
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -99,3 +99,99 @@ def gptbigcode_attention_forward(
 | 
				
			||||||
            outputs += (attn_weights,)
 | 
					            outputs += (attn_weights,)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        return outputs
 | 
					        return outputs
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def gptbigcode_sdpa_attention_forward(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        hidden_states: torch.Tensor,
 | 
				
			||||||
 | 
					        layer_past: Optional[torch.Tensor] = None,
 | 
				
			||||||
 | 
					        attention_mask: Optional[torch.Tensor] = None,
 | 
				
			||||||
 | 
					        head_mask: Optional[torch.Tensor] = None,
 | 
				
			||||||
 | 
					        encoder_hidden_states: Optional[torch.Tensor] = None,
 | 
				
			||||||
 | 
					        encoder_attention_mask: Optional[torch.Tensor] = None,
 | 
				
			||||||
 | 
					        use_cache: Optional[bool] = False,
 | 
				
			||||||
 | 
					        output_attentions: Optional[bool] = False,
 | 
				
			||||||
 | 
					) -> Union[
 | 
				
			||||||
 | 
					    Tuple[torch.Tensor, Optional[torch.Tensor]],
 | 
				
			||||||
 | 
					    Tuple[torch.Tensor, Optional[torch.Tensor], Tuple[torch.Tensor, ...]],
 | 
				
			||||||
 | 
					]:
 | 
				
			||||||
 | 
					        if encoder_hidden_states is not None:
 | 
				
			||||||
 | 
					            if not hasattr(self, "q_attn") or not self.is_cross_attention:
 | 
				
			||||||
 | 
					                from ipex_llm.utils.common import invalidInputError
 | 
				
			||||||
 | 
					                invalidInputError(
 | 
				
			||||||
 | 
					                    False,
 | 
				
			||||||
 | 
					                    "If class is used as cross attention," +
 | 
				
			||||||
 | 
					                    "the weights `q_attn` have to be defined. " +
 | 
				
			||||||
 | 
					                    "Please make sure to instantiate class with " +
 | 
				
			||||||
 | 
					                    "`GPTBigCodeAttention(..., is_cross_attention=True)`."
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            query = self.q_attn(hidden_states)
 | 
				
			||||||
 | 
					            key_value = self.c_attn(encoder_hidden_states)
 | 
				
			||||||
 | 
					            attention_mask = encoder_attention_mask
 | 
				
			||||||
 | 
					        elif self.multi_query:
 | 
				
			||||||
 | 
					            query, key_value = self.c_attn(hidden_states).split(
 | 
				
			||||||
 | 
					                (self.embed_dim, 2 * self.kv_dim), dim=2)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            # Note: We split as (self.num_heads, 3, self.head_dim) instead of
 | 
				
			||||||
 | 
					            # (3, self.num_heads, self.head_dim),
 | 
				
			||||||
 | 
					            # i.e., the memory layout is not the same as GPT2.
 | 
				
			||||||
 | 
					            # This makes the concatenation with past_key_value more efficient.
 | 
				
			||||||
 | 
					            query, key_value = (
 | 
				
			||||||
 | 
					                self.c_attn(hidden_states)
 | 
				
			||||||
 | 
					                .view(*hidden_states.shape[:2], self.num_heads, 3 * self.head_dim)
 | 
				
			||||||
 | 
					                .transpose(1, 2)
 | 
				
			||||||
 | 
					                .split((self.head_dim, 2 * self.head_dim), dim=3)
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if layer_past is not None:
 | 
				
			||||||
 | 
					            if layer_past.shape[-2] == key_value.shape[-2]:
 | 
				
			||||||
 | 
					                key_value = torch.cat((layer_past, key_value), dim=-2)
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                fill_zeros = torch.zeros(layer_past.shape[0],
 | 
				
			||||||
 | 
					                                         layer_past.shape[1],
 | 
				
			||||||
 | 
					                                         key_value.shape[2] - layer_past.shape[2],
 | 
				
			||||||
 | 
					                                         dtype=layer_past.dtype,
 | 
				
			||||||
 | 
					                                         device=layer_past.device)
 | 
				
			||||||
 | 
					                layer_past = torch.cat([layer_past, fill_zeros], dim=-1)
 | 
				
			||||||
 | 
					                key_value = torch.cat((layer_past, key_value), dim=-2)
 | 
				
			||||||
 | 
					            # key_value = torch.cat((layer_past, key_value), dim=-2)
 | 
				
			||||||
 | 
					        present = key_value if use_cache else None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        key, value = key_value.split((self.head_dim, self.head_dim), dim=-1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if not output_attentions and head_mask is None:
 | 
				
			||||||
 | 
					            # Difference with the original implementation: there is no need to
 | 
				
			||||||
 | 
					            # transpose the key here,
 | 
				
			||||||
 | 
					            # as SDPA expects seq_length to be at index -2 for the key as well
 | 
				
			||||||
 | 
					            attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"`
 | 
				
			||||||
 | 
					            # once this is implemented.
 | 
				
			||||||
 | 
					            logger.warning_once(
 | 
				
			||||||
 | 
					                "GPTBigCodeModel is using GPTBigCodeSdpaAttention, "
 | 
				
			||||||
 | 
					                "but `torch.nn.functional.scaled_dot_product_attention` does not support "
 | 
				
			||||||
 | 
					                "`output_attentions=True` and `head_mask` not None."
 | 
				
			||||||
 | 
					                ' Falling back to the manual attention implementation, '
 | 
				
			||||||
 | 
					                'but specifying the manual implementation will be required from '
 | 
				
			||||||
 | 
					                'Transformers version v5.0.0 onwards. '
 | 
				
			||||||
 | 
					                'This warning can be removed using the argument `attn_implementation="eager"` '
 | 
				
			||||||
 | 
					                'when loading the model.'
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					            attn_output, attn_weights = super()._attn(query, key.transpose(-1, -2),
 | 
				
			||||||
 | 
					                                                      value, attention_mask, head_mask)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if not self.multi_query:
 | 
				
			||||||
 | 
					            attn_output = attn_output.transpose(1, 2).reshape(hidden_states.shape)
 | 
				
			||||||
 | 
					        attn_output = self.c_proj(attn_output)
 | 
				
			||||||
 | 
					        attn_output = self.resid_dropout(attn_output)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        outputs = (attn_output, present)
 | 
				
			||||||
 | 
					        if output_attentions:
 | 
				
			||||||
 | 
					            if self.multi_query:
 | 
				
			||||||
 | 
					                # Transpose to return weights in the usual format
 | 
				
			||||||
 | 
					                # (batch_size, num_heads, query_length, key_length)
 | 
				
			||||||
 | 
					                attn_weights = attn_weights.transpose(1, 2)
 | 
				
			||||||
 | 
					            outputs += (attn_weights,)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return outputs
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue