add rope theta argument (#10343)
This commit is contained in:
		
							parent
							
								
									0c8d3c9830
								
							
						
					
					
						commit
						1ac193ba02
					
				
					 1 changed files with 3 additions and 2 deletions
				
			
		| 
						 | 
				
			
			@ -178,7 +178,7 @@ def apply_ipex_rotate_every_two(q, k, cos, sin):
 | 
			
		|||
        torch.ops.torch_ipex.apply_rotary_embedding(k, sin, cos, k)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def apply_rotary_pos_emb_no_cache_xpu(q, k, position_ids, model_family):
 | 
			
		||||
def apply_rotary_pos_emb_no_cache_xpu(q, k, position_ids, model_family, rope_theta=10000.0):
 | 
			
		||||
    if q.device.type != "xpu":
 | 
			
		||||
        invalidInputError(False,
 | 
			
		||||
                          f"only xpu is supported in this function")
 | 
			
		||||
| 
						 | 
				
			
			@ -187,7 +187,8 @@ def apply_rotary_pos_emb_no_cache_xpu(q, k, position_ids, model_family):
 | 
			
		|||
    k_embed = torch.empty(k.shape, dtype=k.dtype, device=k.device)
 | 
			
		||||
    if model_family in ["llama", "baichuan", "internlm", "aquila", "gpt_neox", "mistral",
 | 
			
		||||
                        "mixtral"]:
 | 
			
		||||
        linear_q4_0.apply_rotary_embedding_half_q_and_k(q, k, position_ids, q_embed, k_embed)
 | 
			
		||||
        linear_q4_0.apply_rotary_embedding_half_q_and_k(q, k, position_ids,
 | 
			
		||||
                                                        q_embed, k_embed, rope_theta)
 | 
			
		||||
        return q_embed, k_embed
 | 
			
		||||
    else:
 | 
			
		||||
        invalidInputError(False,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue