parent
							
								
									cfdf8ad496
								
							
						
					
					
						commit
						021d77fd22
					
				
					 1 changed files with 6 additions and 9 deletions
				
			
		| 
						 | 
					@ -995,9 +995,8 @@ def llama_attention_forward_4_36_quantized(
 | 
				
			||||||
                )
 | 
					                )
 | 
				
			||||||
            attn_weights = attn_weights + attention_mask
 | 
					            attn_weights = attn_weights + attention_mask
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # upcast attention to fp32
 | 
					        # at inference time, for memory considerations, may not need to upcast attention to fp32
 | 
				
			||||||
        attn_weights = nn.functional.softmax(attn_weights, dim=-1,
 | 
					        attn_weights = nn.functional.softmax(attn_weights, dim=-1)
 | 
				
			||||||
                                             dtype=torch.float32).to(query_states.dtype)
 | 
					 | 
				
			||||||
        attn_output = torch.matmul(attn_weights, value_states)
 | 
					        attn_output = torch.matmul(attn_weights, value_states)
 | 
				
			||||||
        if use_cache:
 | 
					        if use_cache:
 | 
				
			||||||
            cache_kwargs = None
 | 
					            cache_kwargs = None
 | 
				
			||||||
| 
						 | 
					@ -1036,9 +1035,8 @@ def llama_attention_forward_4_36_quantized(
 | 
				
			||||||
                )
 | 
					                )
 | 
				
			||||||
            attn_weights = attn_weights + attention_mask
 | 
					            attn_weights = attn_weights + attention_mask
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # upcast attention to fp32
 | 
					        # at inference time, for memory considerations, may not need to upcast attention to fp32
 | 
				
			||||||
        attn_weights = nn.functional.softmax(attn_weights,
 | 
					        attn_weights = nn.functional.softmax(attn_weights, dim=-1)
 | 
				
			||||||
                                             dim=-1, dtype=torch.float32).to(query_states.dtype)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if query_states.size(2) != 1 or query_states.device.type != 'xpu':
 | 
					        if query_states.size(2) != 1 or query_states.device.type != 'xpu':
 | 
				
			||||||
            attn_output = torch.matmul(attn_weights, value_states)
 | 
					            attn_output = torch.matmul(attn_weights, value_states)
 | 
				
			||||||
| 
						 | 
					@ -1324,9 +1322,8 @@ def native_sdp(query, key, value, attention_mask,
 | 
				
			||||||
                              f"but is {attention_mask.size()}")
 | 
					                              f"but is {attention_mask.size()}")
 | 
				
			||||||
        attn_weights = attn_weights + attention_mask
 | 
					        attn_weights = attn_weights + attention_mask
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # upcast attention to fp32
 | 
					    # at inference time, for memory considerations, may not need to upcast attention to fp32
 | 
				
			||||||
    attn_weights = nn.functional.softmax(attn_weights, dim=-1,
 | 
					    attn_weights = nn.functional.softmax(attn_weights, dim=-1)
 | 
				
			||||||
                                         dtype=torch.float32).to(value.dtype)
 | 
					 | 
				
			||||||
    attn_output = torch.matmul(attn_weights, value)
 | 
					    attn_output = torch.matmul(attn_weights, value)
 | 
				
			||||||
    return attn_output, attn_weights
 | 
					    return attn_output, attn_weights
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue