add rope theta argument (#10343)
This commit is contained in:
parent
0c8d3c9830
commit
1ac193ba02
1 changed files with 3 additions and 2 deletions
|
|
@ -178,7 +178,7 @@ def apply_ipex_rotate_every_two(q, k, cos, sin):
|
|||
torch.ops.torch_ipex.apply_rotary_embedding(k, sin, cos, k)
|
||||
|
||||
|
||||
def apply_rotary_pos_emb_no_cache_xpu(q, k, position_ids, model_family):
|
||||
def apply_rotary_pos_emb_no_cache_xpu(q, k, position_ids, model_family, rope_theta=10000.0):
|
||||
if q.device.type != "xpu":
|
||||
invalidInputError(False,
|
||||
f"only xpu is supported in this function")
|
||||
|
|
@ -187,7 +187,8 @@ def apply_rotary_pos_emb_no_cache_xpu(q, k, position_ids, model_family):
|
|||
k_embed = torch.empty(k.shape, dtype=k.dtype, device=k.device)
|
||||
if model_family in ["llama", "baichuan", "internlm", "aquila", "gpt_neox", "mistral",
|
||||
"mixtral"]:
|
||||
linear_q4_0.apply_rotary_embedding_half_q_and_k(q, k, position_ids, q_embed, k_embed)
|
||||
linear_q4_0.apply_rotary_embedding_half_q_and_k(q, k, position_ids,
|
||||
q_embed, k_embed, rope_theta)
|
||||
return q_embed, k_embed
|
||||
else:
|
||||
invalidInputError(False,
|
||||
|
|
|
|||
Loading…
Reference in a new issue