diff --git a/python/llm/src/bigdl/llm/transformers/models/utils.py b/python/llm/src/bigdl/llm/transformers/models/utils.py index 77117bf5..aa791f26 100644 --- a/python/llm/src/bigdl/llm/transformers/models/utils.py +++ b/python/llm/src/bigdl/llm/transformers/models/utils.py @@ -178,7 +178,7 @@ def apply_ipex_rotate_every_two(q, k, cos, sin): torch.ops.torch_ipex.apply_rotary_embedding(k, sin, cos, k) -def apply_rotary_pos_emb_no_cache_xpu(q, k, position_ids, model_family): +def apply_rotary_pos_emb_no_cache_xpu(q, k, position_ids, model_family, rope_theta=10000.0): if q.device.type != "xpu": invalidInputError(False, f"only xpu is supported in this function") @@ -187,7 +187,8 @@ def apply_rotary_pos_emb_no_cache_xpu(q, k, position_ids, model_family): k_embed = torch.empty(k.shape, dtype=k.dtype, device=k.device) if model_family in ["llama", "baichuan", "internlm", "aquila", "gpt_neox", "mistral", "mixtral"]: - linear_q4_0.apply_rotary_embedding_half_q_and_k(q, k, position_ids, q_embed, k_embed) + linear_q4_0.apply_rotary_embedding_half_q_and_k(q, k, position_ids, + q_embed, k_embed, rope_theta) return q_embed, k_embed else: invalidInputError(False,