diff --git a/python/llm/src/bigdl/llm/transformers/models/mixtral.py b/python/llm/src/bigdl/llm/transformers/models/mixtral.py index 9e6248a4..42cfbfc2 100644 --- a/python/llm/src/bigdl/llm/transformers/models/mixtral.py +++ b/python/llm/src/bigdl/llm/transformers/models/mixtral.py @@ -47,7 +47,7 @@ from bigdl.llm.ggml.quantize import ggml_tensor_qtype from bigdl.llm.utils.common import invalidInputError from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb,\ - apply_rotary_pos_emb_no_cache_xpu, is_enough_kv_cache_room_4_36 + apply_rotary_pos_emb_cache_freq_xpu, is_enough_kv_cache_room_4_36 from bigdl.llm.transformers.models.mistral import should_use_fuse_rope, use_decoding_fast_path from bigdl.llm.transformers.models.utils import use_flash_attention from bigdl.llm.transformers.models.utils import mlp_fusion_check @@ -198,10 +198,16 @@ def mixtral_attention_forward( kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) if use_fuse_rope: - query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states, - key_states, - position_ids, - "mixtral") + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] + sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] + cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + query_states, key_states = apply_rotary_pos_emb_cache_freq_xpu(query_states, + key_states, + sin, + cos, + "mixtral") else: cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, diff --git a/python/llm/src/bigdl/llm/transformers/models/utils.py b/python/llm/src/bigdl/llm/transformers/models/utils.py index 1dc1a36a..5c6cbc18 100644 --- a/python/llm/src/bigdl/llm/transformers/models/utils.py +++ b/python/llm/src/bigdl/llm/transformers/models/utils.py @@ -186,7 +186,7 @@ def apply_rotary_pos_emb_cache_freq_xpu(q, k, sin, cos, model_family): import linear_q4_0 q_embed = torch.empty(q.shape, dtype=q.dtype, device=q.device) k_embed = torch.empty(k.shape, dtype=k.dtype, device=k.device) - if model_family in ["qwen"]: + if model_family in ["qwen", "mixtral"]: linear_q4_0.apply_rotary_embedding_half_q_and_k_cache_freq(q, k, sin, cos, q_embed, k_embed) return q_embed, k_embed else: