remove old rope usage (#12544)
This commit is contained in:
		
							parent
							
								
									5402fc65c8
								
							
						
					
					
						commit
						c090d167dc
					
				
					 2 changed files with 4 additions and 42 deletions
				
			
		| 
						 | 
				
			
			@ -51,8 +51,7 @@ import torch.nn.functional as F
 | 
			
		|||
from ipex_llm.ggml.quantize import ggml_tensor_qtype
 | 
			
		||||
from ipex_llm.utils.common import invalidInputError
 | 
			
		||||
from ipex_llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
 | 
			
		||||
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb,\
 | 
			
		||||
    apply_rotary_pos_emb_cache_freq_xpu, is_enough_kv_cache_room_4_36
 | 
			
		||||
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb, is_enough_kv_cache_room_4_36
 | 
			
		||||
from ipex_llm.transformers.models.mistral import should_use_fuse_rope
 | 
			
		||||
from ipex_llm.transformers.models.utils import use_decoding_fast_path
 | 
			
		||||
from ipex_llm.transformers.models.utils import use_flash_attention, use_sdp
 | 
			
		||||
| 
						 | 
				
			
			@ -258,16 +257,9 @@ def mixtral_attention_forward(
 | 
			
		|||
            kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
 | 
			
		||||
 | 
			
		||||
        if use_fuse_rope:
 | 
			
		||||
            cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
 | 
			
		||||
            cos = cos.squeeze(1).squeeze(0)  # [seq_len, dim]
 | 
			
		||||
            sin = sin.squeeze(1).squeeze(0)  # [seq_len, dim]
 | 
			
		||||
            cos = cos[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]
 | 
			
		||||
            sin = sin[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]
 | 
			
		||||
            query_states, key_states = apply_rotary_pos_emb_cache_freq_xpu(query_states,
 | 
			
		||||
                                                                           key_states,
 | 
			
		||||
                                                                           sin,
 | 
			
		||||
                                                                           cos,
 | 
			
		||||
                                                                           "mixtral")
 | 
			
		||||
            import xe_addons
 | 
			
		||||
            xe_addons.rotary_half_inplaced(self.rotary_emb.inv_freq, position_ids,
 | 
			
		||||
                                           query_states, key_states)
 | 
			
		||||
        else:
 | 
			
		||||
            cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
 | 
			
		||||
            query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -207,36 +207,6 @@ def apply_ipex_rotate_every_two(q, k, cos, sin):
 | 
			
		|||
        torch.ops.torch_ipex.apply_rotary_embedding(k, sin, cos, k)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def apply_rotary_pos_emb_cache_freq_xpu(q, k, sin, cos, model_family, position_ids=None):
 | 
			
		||||
    if q.device.type != "xpu":
 | 
			
		||||
        invalidInputError(False,
 | 
			
		||||
                          f"only xpu is supported in this function")
 | 
			
		||||
    import xe_addons
 | 
			
		||||
    q_embed = torch.empty(q.shape, dtype=q.dtype, device=q.device)
 | 
			
		||||
    k_embed = torch.empty(k.shape, dtype=k.dtype, device=k.device)
 | 
			
		||||
    if model_family in ["qwen", "mixtral"]:
 | 
			
		||||
        xe_addons.apply_rotary_embedding_half_q_and_k_cache_freq(q, k, sin, cos,
 | 
			
		||||
                                                                 q_embed, k_embed)
 | 
			
		||||
    elif model_family in ["qwen2", "yuan", "stablelm", "qwen2_moe", "internlm"]:
 | 
			
		||||
        cos = cos.to(q.dtype)
 | 
			
		||||
        sin = sin.to(q.dtype)
 | 
			
		||||
        cos = cos.squeeze(1).squeeze(0)  # [seq_len, dim]
 | 
			
		||||
        sin = sin.squeeze(1).squeeze(0)  # [seq_len, dim]
 | 
			
		||||
        cos = cos[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]
 | 
			
		||||
        sin = sin[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]
 | 
			
		||||
        xe_addons.apply_rotary_embedding_half_q_and_k_cache_freq(q, k, sin, cos,
 | 
			
		||||
                                                                 q_embed, k_embed)
 | 
			
		||||
    elif model_family in ["gemma", "phi3"]:
 | 
			
		||||
        cos = cos.unsqueeze(1)
 | 
			
		||||
        sin = sin.unsqueeze(1)
 | 
			
		||||
        xe_addons.apply_rotary_embedding_half_q_and_k_cache_freq(q, k, sin, cos,
 | 
			
		||||
                                                                 q_embed, k_embed)
 | 
			
		||||
    else:
 | 
			
		||||
        invalidInputError(False,
 | 
			
		||||
                          f"{model_family} is not supported.")
 | 
			
		||||
    return q_embed, k_embed
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def is_enough_kv_cache_room_4_36(past_key_value, idx, seq_len=1):
 | 
			
		||||
    # to determinate if is enough kv cache room in transformers==4.36
 | 
			
		||||
    # seq_len for current seq len
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue