parent
							
								
									439c834ed3
								
							
						
					
					
						commit
						754b0ffecf
					
				
					 1 changed files with 5 additions and 1 deletions
				
			
		| 
						 | 
				
			
			@ -1348,7 +1348,11 @@ def llama_attention_forward_4_36_original(
 | 
			
		|||
        key_states = repeat_kv(key_states, self.num_key_value_groups)
 | 
			
		||||
        value_states = repeat_kv(value_states, self.num_key_value_groups)
 | 
			
		||||
        # otherwise, use native attention
 | 
			
		||||
        if not output_attentions:
 | 
			
		||||
        if query_states.device.type == "xpu":
 | 
			
		||||
            dev_name = torch.xpu.get_device_name(query_states.device.index)
 | 
			
		||||
        else:
 | 
			
		||||
            dev_name = "CPU"
 | 
			
		||||
        if not output_attentions and not dev_name.startswith("Intel(R) Data Center GPU Max"):
 | 
			
		||||
            attn_output = torch.nn.functional.scaled_dot_product_attention(
 | 
			
		||||
                query_states,
 | 
			
		||||
                key_states,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue