Fix llama when rope scaling is not None (#9086)
* Fix llama when rope scaling is not None * fix style * fix style
This commit is contained in:
		
							parent
							
								
									fcb1c618a0
								
							
						
					
					
						commit
						36dd4afd61
					
				
					 1 changed files with 5 additions and 1 deletions
				
			
		| 
						 | 
				
			
			@ -118,7 +118,11 @@ def llama_attention_forward_4_31(
 | 
			
		|||
    if past_key_value is not None:
 | 
			
		||||
        kv_seq_len += past_key_value[0].shape[-2]
 | 
			
		||||
 | 
			
		||||
    if query_states.device.type == "xpu" and not (self.training and query_states.requires_grad):
 | 
			
		||||
    use_fuse_rope = query_states.device.type == "xpu"
 | 
			
		||||
    use_fuse_rope = use_fuse_rope and not (self.training and query_states.requires_grad)
 | 
			
		||||
    use_fuse_rope = use_fuse_rope and self.config.rope_scaling is None
 | 
			
		||||
 | 
			
		||||
    if use_fuse_rope:
 | 
			
		||||
        query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states,
 | 
			
		||||
                                                                     key_states,
 | 
			
		||||
                                                                     position_ids,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue