add rope theta argument (#10343)
This commit is contained in:
parent
0c8d3c9830
commit
1ac193ba02
1 changed files with 3 additions and 2 deletions
|
|
@ -178,7 +178,7 @@ def apply_ipex_rotate_every_two(q, k, cos, sin):
|
||||||
torch.ops.torch_ipex.apply_rotary_embedding(k, sin, cos, k)
|
torch.ops.torch_ipex.apply_rotary_embedding(k, sin, cos, k)
|
||||||
|
|
||||||
|
|
||||||
def apply_rotary_pos_emb_no_cache_xpu(q, k, position_ids, model_family):
|
def apply_rotary_pos_emb_no_cache_xpu(q, k, position_ids, model_family, rope_theta=10000.0):
|
||||||
if q.device.type != "xpu":
|
if q.device.type != "xpu":
|
||||||
invalidInputError(False,
|
invalidInputError(False,
|
||||||
f"only xpu is supported in this function")
|
f"only xpu is supported in this function")
|
||||||
|
|
@ -187,7 +187,8 @@ def apply_rotary_pos_emb_no_cache_xpu(q, k, position_ids, model_family):
|
||||||
k_embed = torch.empty(k.shape, dtype=k.dtype, device=k.device)
|
k_embed = torch.empty(k.shape, dtype=k.dtype, device=k.device)
|
||||||
if model_family in ["llama", "baichuan", "internlm", "aquila", "gpt_neox", "mistral",
|
if model_family in ["llama", "baichuan", "internlm", "aquila", "gpt_neox", "mistral",
|
||||||
"mixtral"]:
|
"mixtral"]:
|
||||||
linear_q4_0.apply_rotary_embedding_half_q_and_k(q, k, position_ids, q_embed, k_embed)
|
linear_q4_0.apply_rotary_embedding_half_q_and_k(q, k, position_ids,
|
||||||
|
q_embed, k_embed, rope_theta)
|
||||||
return q_embed, k_embed
|
return q_embed, k_embed
|
||||||
else:
|
else:
|
||||||
invalidInputError(False,
|
invalidInputError(False,
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue