add latest optimization in starcoder2 (#11236)
This commit is contained in:
		
							parent
							
								
									ba27e750b1
								
							
						
					
					
						commit
						c4e5806e01
					
				
					 1 changed files with 21 additions and 24 deletions
				
			
		| 
						 | 
				
			
			@ -42,7 +42,7 @@ import warnings
 | 
			
		|||
 | 
			
		||||
from ipex_llm.transformers.models.utils import (
 | 
			
		||||
    use_quantize_kv_cache, restore_fp8_kv_cache,
 | 
			
		||||
    apply_rotary_pos_emb_no_cache_xpu
 | 
			
		||||
    should_use_fuse_rope, use_sdp, use_sdp_causal
 | 
			
		||||
)
 | 
			
		||||
from ipex_llm.transformers.kv import DynamicFp8Cache, DynamicNormalCache
 | 
			
		||||
from ipex_llm.utils.common.log4Error import invalidInputError
 | 
			
		||||
| 
						 | 
				
			
			@ -53,16 +53,6 @@ from transformers.models.starcoder2.modeling_starcoder2 import repeat_kv, apply_
 | 
			
		|||
from transformers.models.starcoder2.modeling_starcoder2 import Starcoder2Model, Starcoder2Attention
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def should_use_fuse_rope(self, hidden_states, position_ids):
 | 
			
		||||
    use_fuse_rope = (
 | 
			
		||||
        hidden_states.device.type == "xpu" and
 | 
			
		||||
        hidden_states.numel() == hidden_states.size(-1) and
 | 
			
		||||
        not (self.training and hidden_states.requires_grad) and
 | 
			
		||||
        position_ids is not None
 | 
			
		||||
    )
 | 
			
		||||
    return use_fuse_rope
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def merge_qkv(module: torch.nn.Module):
 | 
			
		||||
    if isinstance(module, Starcoder2Attention):
 | 
			
		||||
        new_weight = torch.cat([
 | 
			
		||||
| 
						 | 
				
			
			@ -115,12 +105,10 @@ def attention_forward(
 | 
			
		|||
        kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
 | 
			
		||||
 | 
			
		||||
    # IPEX-LLM OPT: fuse rope
 | 
			
		||||
    if should_use_fuse_rope(self, hidden_states, position_ids):
 | 
			
		||||
        query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states,
 | 
			
		||||
                                                                     key_states,
 | 
			
		||||
                                                                     position_ids,
 | 
			
		||||
                                                                     "mistral",
 | 
			
		||||
                                                                     self.rope_theta)
 | 
			
		||||
    if should_use_fuse_rope(hidden_states, position_ids, self.training):
 | 
			
		||||
        import xe_addons
 | 
			
		||||
        xe_addons.rotary_half_inplaced(self.rotary_emb.inv_freq, position_ids,
 | 
			
		||||
                                       query_states, key_states)
 | 
			
		||||
    else:
 | 
			
		||||
        cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
 | 
			
		||||
        query_states, key_states = apply_rotary_pos_emb(
 | 
			
		||||
| 
						 | 
				
			
			@ -129,21 +117,30 @@ def attention_forward(
 | 
			
		|||
    # IPEX-LLM OPT: kv cache and quantize kv cache
 | 
			
		||||
    invalidInputError(past_key_value is not None,
 | 
			
		||||
                      "`past_key_value` cannot be None")
 | 
			
		||||
    use_quantize_kv = use_quantize_kv_cache(self.o_proj, hidden_states)
 | 
			
		||||
 | 
			
		||||
    key_states, value_states = past_key_value.update(key_states, value_states,
 | 
			
		||||
                                                     self.layer_idx, None)
 | 
			
		||||
 | 
			
		||||
    if use_quantize_kv and q_len == 1:
 | 
			
		||||
    # IPEX-LLM OPT: sdp
 | 
			
		||||
    if use_sdp(q_len, kv_seq_len, self.head_dim, query_states):
 | 
			
		||||
        import xe_addons
 | 
			
		||||
        attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states,
 | 
			
		||||
        if isinstance(past_key_value, DynamicFp8Cache):
 | 
			
		||||
            attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states,
 | 
			
		||||
                                            attention_mask)
 | 
			
		||||
        else:
 | 
			
		||||
            attn_output = xe_addons.sdp(query_states, key_states, value_states,
 | 
			
		||||
                                        attention_mask)
 | 
			
		||||
        attn_weights = None
 | 
			
		||||
    elif use_sdp_causal(q_len, kv_seq_len, self.head_dim, query_states, self.training):
 | 
			
		||||
        import xe_addons
 | 
			
		||||
        if isinstance(past_key_value, DynamicFp8Cache):
 | 
			
		||||
            attn_output = xe_addons.sdp_fp8_causal(query_states, key_states,
 | 
			
		||||
                                                   value_states, attention_mask)
 | 
			
		||||
        else:
 | 
			
		||||
            attn_output = xe_addons.sdp_causal(query_states, key_states,
 | 
			
		||||
                                               value_states, attention_mask)
 | 
			
		||||
    else:
 | 
			
		||||
        if use_quantize_kv:
 | 
			
		||||
        if isinstance(past_key_value, DynamicFp8Cache):
 | 
			
		||||
            key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
 | 
			
		||||
                                                            query_states.dtype)
 | 
			
		||||
 | 
			
		||||
        # repeat k/v heads if n_kv_heads < n_heads
 | 
			
		||||
        key_states = repeat_kv(key_states, self.num_key_value_groups)
 | 
			
		||||
        value_states = repeat_kv(value_states, self.num_key_value_groups)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue