remove old rope usage (#12544)

This commit is contained in:
Yishuo Wang 2024-12-13 16:54:58 +08:00 committed by GitHub
parent 5402fc65c8
commit c090d167dc
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 4 additions and 42 deletions

View file

@ -51,8 +51,7 @@ import torch.nn.functional as F
from ipex_llm.ggml.quantize import ggml_tensor_qtype from ipex_llm.ggml.quantize import ggml_tensor_qtype
from ipex_llm.utils.common import invalidInputError from ipex_llm.utils.common import invalidInputError
from ipex_llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache from ipex_llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb,\ from ipex_llm.transformers.models.utils import apply_rotary_pos_emb, is_enough_kv_cache_room_4_36
apply_rotary_pos_emb_cache_freq_xpu, is_enough_kv_cache_room_4_36
from ipex_llm.transformers.models.mistral import should_use_fuse_rope from ipex_llm.transformers.models.mistral import should_use_fuse_rope
from ipex_llm.transformers.models.utils import use_decoding_fast_path from ipex_llm.transformers.models.utils import use_decoding_fast_path
from ipex_llm.transformers.models.utils import use_flash_attention, use_sdp from ipex_llm.transformers.models.utils import use_flash_attention, use_sdp
@ -258,16 +257,9 @@ def mixtral_attention_forward(
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
if use_fuse_rope: if use_fuse_rope:
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) import xe_addons
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] xe_addons.rotary_half_inplaced(self.rotary_emb.inv_freq, position_ids,
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] query_states, key_states)
cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
query_states, key_states = apply_rotary_pos_emb_cache_freq_xpu(query_states,
key_states,
sin,
cos,
"mixtral")
else: else:
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, query_states, key_states = apply_rotary_pos_emb(query_states, key_states,

View file

@ -207,36 +207,6 @@ def apply_ipex_rotate_every_two(q, k, cos, sin):
torch.ops.torch_ipex.apply_rotary_embedding(k, sin, cos, k) torch.ops.torch_ipex.apply_rotary_embedding(k, sin, cos, k)
def apply_rotary_pos_emb_cache_freq_xpu(q, k, sin, cos, model_family, position_ids=None):
if q.device.type != "xpu":
invalidInputError(False,
f"only xpu is supported in this function")
import xe_addons
q_embed = torch.empty(q.shape, dtype=q.dtype, device=q.device)
k_embed = torch.empty(k.shape, dtype=k.dtype, device=k.device)
if model_family in ["qwen", "mixtral"]:
xe_addons.apply_rotary_embedding_half_q_and_k_cache_freq(q, k, sin, cos,
q_embed, k_embed)
elif model_family in ["qwen2", "yuan", "stablelm", "qwen2_moe", "internlm"]:
cos = cos.to(q.dtype)
sin = sin.to(q.dtype)
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
xe_addons.apply_rotary_embedding_half_q_and_k_cache_freq(q, k, sin, cos,
q_embed, k_embed)
elif model_family in ["gemma", "phi3"]:
cos = cos.unsqueeze(1)
sin = sin.unsqueeze(1)
xe_addons.apply_rotary_embedding_half_q_and_k_cache_freq(q, k, sin, cos,
q_embed, k_embed)
else:
invalidInputError(False,
f"{model_family} is not supported.")
return q_embed, k_embed
def is_enough_kv_cache_room_4_36(past_key_value, idx, seq_len=1): def is_enough_kv_cache_room_4_36(past_key_value, idx, seq_len=1):
# to determinate if is enough kv cache room in transformers==4.36 # to determinate if is enough kv cache room in transformers==4.36
# seq_len for current seq len # seq_len for current seq len