add rope theta argument (#10343)

This commit is contained in:
Yishuo Wang 2024-03-07 17:27:19 +08:00 committed by GitHub
parent 0c8d3c9830
commit 1ac193ba02

View file

@ -178,7 +178,7 @@ def apply_ipex_rotate_every_two(q, k, cos, sin):
torch.ops.torch_ipex.apply_rotary_embedding(k, sin, cos, k) torch.ops.torch_ipex.apply_rotary_embedding(k, sin, cos, k)
def apply_rotary_pos_emb_no_cache_xpu(q, k, position_ids, model_family): def apply_rotary_pos_emb_no_cache_xpu(q, k, position_ids, model_family, rope_theta=10000.0):
if q.device.type != "xpu": if q.device.type != "xpu":
invalidInputError(False, invalidInputError(False,
f"only xpu is supported in this function") f"only xpu is supported in this function")
@ -187,7 +187,8 @@ def apply_rotary_pos_emb_no_cache_xpu(q, k, position_ids, model_family):
k_embed = torch.empty(k.shape, dtype=k.dtype, device=k.device) k_embed = torch.empty(k.shape, dtype=k.dtype, device=k.device)
if model_family in ["llama", "baichuan", "internlm", "aquila", "gpt_neox", "mistral", if model_family in ["llama", "baichuan", "internlm", "aquila", "gpt_neox", "mistral",
"mixtral"]: "mixtral"]:
linear_q4_0.apply_rotary_embedding_half_q_and_k(q, k, position_ids, q_embed, k_embed) linear_q4_0.apply_rotary_embedding_half_q_and_k(q, k, position_ids,
q_embed, k_embed, rope_theta)
return q_embed, k_embed return q_embed, k_embed
else: else:
invalidInputError(False, invalidInputError(False,