use apply_rotary_pos_emb_cache_freq_xpu in mixtral (#10060)

* use apply_rotary_pos_emb_cache_freq_xpu in mixtral

* fix style
This commit is contained in:
Xin Qiu 2024-02-01 15:40:49 +08:00 committed by GitHub
parent aae20d728e
commit 6e0f1a1e92
2 changed files with 12 additions and 6 deletions

View file

@ -47,7 +47,7 @@ from bigdl.llm.ggml.quantize import ggml_tensor_qtype
from bigdl.llm.utils.common import invalidInputError from bigdl.llm.utils.common import invalidInputError
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb,\ from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb,\
apply_rotary_pos_emb_no_cache_xpu, is_enough_kv_cache_room_4_36 apply_rotary_pos_emb_cache_freq_xpu, is_enough_kv_cache_room_4_36
from bigdl.llm.transformers.models.mistral import should_use_fuse_rope, use_decoding_fast_path from bigdl.llm.transformers.models.mistral import should_use_fuse_rope, use_decoding_fast_path
from bigdl.llm.transformers.models.utils import use_flash_attention from bigdl.llm.transformers.models.utils import use_flash_attention
from bigdl.llm.transformers.models.utils import mlp_fusion_check from bigdl.llm.transformers.models.utils import mlp_fusion_check
@ -198,9 +198,15 @@ def mixtral_attention_forward(
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
if use_fuse_rope: if use_fuse_rope:
query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states, cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
query_states, key_states = apply_rotary_pos_emb_cache_freq_xpu(query_states,
key_states, key_states,
position_ids, sin,
cos,
"mixtral") "mixtral")
else: else:
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)

View file

@ -186,7 +186,7 @@ def apply_rotary_pos_emb_cache_freq_xpu(q, k, sin, cos, model_family):
import linear_q4_0 import linear_q4_0
q_embed = torch.empty(q.shape, dtype=q.dtype, device=q.device) q_embed = torch.empty(q.shape, dtype=q.dtype, device=q.device)
k_embed = torch.empty(k.shape, dtype=k.dtype, device=k.device) k_embed = torch.empty(k.shape, dtype=k.dtype, device=k.device)
if model_family in ["qwen"]: if model_family in ["qwen", "mixtral"]:
linear_q4_0.apply_rotary_embedding_half_q_and_k_cache_freq(q, k, sin, cos, q_embed, k_embed) linear_q4_0.apply_rotary_embedding_half_q_and_k_cache_freq(q, k, sin, cos, q_embed, k_embed)
return q_embed, k_embed return q_embed, k_embed
else: else: