parent
							
								
									8cae897643
								
							
						
					
					
						commit
						6be70283b7
					
				
					 1 changed files with 23 additions and 6 deletions
				
			
		| 
						 | 
				
			
			@ -104,7 +104,7 @@ def attention_fn(
 | 
			
		|||
        present = None
 | 
			
		||||
 | 
			
		||||
    pytorch_major_version = int(torch.__version__.split('.')[0])
 | 
			
		||||
    if query_layer.size(0) > 1 and pytorch_major_version >= 2:
 | 
			
		||||
    if pytorch_major_version >= 2:
 | 
			
		||||
        query_layer = query_layer.permute(1, 2, 0, 3)
 | 
			
		||||
        if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]:
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -129,18 +129,35 @@ def attention_fn(
 | 
			
		|||
                                                                                 attention_mask,
 | 
			
		||||
                                                                                 is_causal=True)
 | 
			
		||||
        else:
 | 
			
		||||
            # attention_mask is not None only when past_key_value is not None and q_len > 1
 | 
			
		||||
            if attention_mask is not None:
 | 
			
		||||
                attn_bias = torch.zeros(attention_mask.shape, dtype=query_layer.dtype,
 | 
			
		||||
                                        device=query_layer.device)
 | 
			
		||||
                attention_mask = ~attention_mask
 | 
			
		||||
            attention_mask = attention_mask.masked_fill(~attention_mask, -float('inf'), )
 | 
			
		||||
                if attention_mask.dtype == torch.bool:
 | 
			
		||||
                    attn_bias.masked_fill_(attention_mask.logical_not(), float("-inf"))
 | 
			
		||||
                else:
 | 
			
		||||
                    attn_bias += attention_mask
 | 
			
		||||
            else:
 | 
			
		||||
                attn_bias = None
 | 
			
		||||
            if torch.is_autocast_cpu_enabled():
 | 
			
		||||
                query_layer = query_layer.to(torch.get_autocast_cpu_dtype())
 | 
			
		||||
                key_layer = key_layer.to(torch.get_autocast_cpu_dtype())
 | 
			
		||||
                value_layer = value_layer.to(torch.get_autocast_cpu_dtype())
 | 
			
		||||
                attention_mask = attention_mask.to(torch.get_autocast_cpu_dtype())
 | 
			
		||||
            context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer,
 | 
			
		||||
                                                                             key_layer,
 | 
			
		||||
                                                                             value_layer,
 | 
			
		||||
                                                                             attention_mask)
 | 
			
		||||
                context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer,
 | 
			
		||||
                                                                                 key_layer,
 | 
			
		||||
                                                                                 value_layer,
 | 
			
		||||
                                                                                 attention_mask)
 | 
			
		||||
            else:
 | 
			
		||||
                head_dim = query_layer.size(-1)
 | 
			
		||||
                attn = torch.matmul(query_layer.to(key_layer.dtype),
 | 
			
		||||
                                    key_layer.transpose(2, 3)) / math.sqrt(head_dim)
 | 
			
		||||
                if attn_bias is not None:
 | 
			
		||||
                    attn += attn_bias
 | 
			
		||||
                attn = F.softmax(attn, dim=-1,
 | 
			
		||||
                                 dtype=torch.float32).to(value_layer.dtype)
 | 
			
		||||
                context_layer = torch.matmul(attn, value_layer)
 | 
			
		||||
        context_layer = context_layer.permute(2, 0, 1, 3)
 | 
			
		||||
        new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
 | 
			
		||||
        context_layer = context_layer.reshape(*new_context_layer_shape)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue