parent
							
								
									68d78fb57e
								
							
						
					
					
						commit
						427f75000b
					
				
					 1 changed files with 9 additions and 13 deletions
				
			
		| 
						 | 
				
			
			@ -369,8 +369,7 @@ def core_attn_forward_8eb45c(self, query_layer, key_layer, value_layer, attentio
 | 
			
		|||
    if pytorch_major_version >= 2 and (query_layer.device.type == 'xpu' or query_layer.size(0) > 1):
 | 
			
		||||
        query_layer = query_layer.permute(1, 2, 0, 3)
 | 
			
		||||
        L, S = query_layer.shape[2], key_layer.shape[2]
 | 
			
		||||
        if attention_mask is None and (use_flash_attention(query_layer, key_layer) or
 | 
			
		||||
                                       L == S and query_layer.device.type == "cpu"):
 | 
			
		||||
        if attention_mask is None and L == S:
 | 
			
		||||
            context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer,
 | 
			
		||||
                                                                             key_layer,
 | 
			
		||||
                                                                             value_layer,
 | 
			
		||||
| 
						 | 
				
			
			@ -380,19 +379,16 @@ def core_attn_forward_8eb45c(self, query_layer, key_layer, value_layer, attentio
 | 
			
		|||
            attn = torch.matmul(query_layer,
 | 
			
		||||
                                key_layer.transpose(2, 3)) / math.sqrt(head_dim)
 | 
			
		||||
            if attention_mask is not None:
 | 
			
		||||
                attention_mask = ~attention_mask
 | 
			
		||||
                attention_mask.masked_fill_(attention_mask.logical_not(), float("-inf"))
 | 
			
		||||
                attn += attention_mask
 | 
			
		||||
            elif L == S:
 | 
			
		||||
                # first token, need attention mask
 | 
			
		||||
                attn_bias = torch.zeros(L, S, dtype=query_layer.dtype,
 | 
			
		||||
                attn_bias = torch.zeros(attention_mask.shape, dtype=query_layer.dtype,
 | 
			
		||||
                                        device=query_layer.device)
 | 
			
		||||
                temp_mask = torch.ones(L, S, dtype=torch.bool,
 | 
			
		||||
                                       device=query_layer.device).tril(diagonal=0)
 | 
			
		||||
                attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
 | 
			
		||||
                attn_bias.to(query_layer.dtype)
 | 
			
		||||
                attention_mask = ~attention_mask
 | 
			
		||||
                if attention_mask.dtype == torch.bool:
 | 
			
		||||
                    attn_bias.masked_fill_(attention_mask.logical_not(), float("-inf"))
 | 
			
		||||
                else:
 | 
			
		||||
                    attn_bias += attention_mask
 | 
			
		||||
                attn += attn_bias
 | 
			
		||||
            attn = torch.softmax(attn, -1)
 | 
			
		||||
            attn = F.softmax(attn, dim=-1,
 | 
			
		||||
                             dtype=torch.float32).to(value_layer.dtype)
 | 
			
		||||
            context_layer = torch.matmul(attn, value_layer)
 | 
			
		||||
        context_layer = context_layer.permute(2, 0, 1, 3)
 | 
			
		||||
        new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue