refactor attention_softmax (#12295)
This commit is contained in:
		
							parent
							
								
									2b2cb9c693
								
							
						
					
					
						commit
						540eaeb12c
					
				
					 14 changed files with 18 additions and 18 deletions
				
			
		| 
						 | 
				
			
			@ -109,7 +109,7 @@ def aquila_attention_forward(
 | 
			
		|||
        )
 | 
			
		||||
 | 
			
		||||
    # upcast attention to fp32
 | 
			
		||||
    attn_weights = attention_softmax(attn_weights, self.training)
 | 
			
		||||
    attn_weights = attention_softmax(attn_weights)
 | 
			
		||||
    attn_output = torch.matmul(attn_weights, value_states)
 | 
			
		||||
 | 
			
		||||
    if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -79,8 +79,8 @@ def mlp_gelu_forward(self, x: torch.Tensor):
 | 
			
		|||
    return fuse_mlp_base(self, GELU, x)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def attention_softmax(attn_weights: torch.Tensor, training: bool):
 | 
			
		||||
    if attn_weights.is_contiguous() and attn_weights.device.type == "xpu" and not training:
 | 
			
		||||
def attention_softmax(attn_weights: torch.Tensor):
 | 
			
		||||
    if attn_weights.is_contiguous() and attn_weights.device.type == "xpu":
 | 
			
		||||
        import xe_addons
 | 
			
		||||
        xe_addons.attn_softmax_inplaced(attn_weights)
 | 
			
		||||
    else:
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -220,7 +220,7 @@ def gemma_attention_forward(
 | 
			
		|||
        attn_weights = attn_weights + causal_mask
 | 
			
		||||
 | 
			
		||||
    # upcast attention to fp32
 | 
			
		||||
    attn_weights = attention_softmax(attn_weights, self.training)
 | 
			
		||||
    attn_weights = attention_softmax(attn_weights)
 | 
			
		||||
    attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout,
 | 
			
		||||
                                         training=self.training)
 | 
			
		||||
    attn_output = torch.matmul(attn_weights, value_states)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -125,7 +125,7 @@ def internlm_attention_forward(
 | 
			
		|||
            attn_weights = attn_weights + attention_mask
 | 
			
		||||
 | 
			
		||||
        # upcast attention to fp32
 | 
			
		||||
        attn_weights = attention_softmax(attn_weights, self.training)
 | 
			
		||||
        attn_weights = attention_softmax(attn_weights)
 | 
			
		||||
        attn_output = torch.matmul(attn_weights, value_states)
 | 
			
		||||
 | 
			
		||||
    attn_output = attn_output.transpose(1, 2)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -237,7 +237,7 @@ def llama_attention_forward(
 | 
			
		|||
            attn_weights = attn_weights + causal_mask
 | 
			
		||||
 | 
			
		||||
        # upcast attention to fp32
 | 
			
		||||
        attn_weights = attention_softmax(attn_weights, self.training)
 | 
			
		||||
        attn_weights = attention_softmax(attn_weights)
 | 
			
		||||
        attn_output = torch.matmul(attn_weights, value_states)
 | 
			
		||||
 | 
			
		||||
    attn_output = attn_output.transpose(1, 2).contiguous()
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -56,7 +56,7 @@ def siglip_attention_forward(
 | 
			
		|||
    if attention_mask is not None:
 | 
			
		||||
        attn_weights = attn_weights + attention_mask
 | 
			
		||||
 | 
			
		||||
    attn_weights = attention_softmax(attn_weights, self.training)
 | 
			
		||||
    attn_weights = attention_softmax(attn_weights)
 | 
			
		||||
 | 
			
		||||
    attn_weights = torch.nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
 | 
			
		||||
    attn_output = torch.matmul(attn_weights, value_states)
 | 
			
		||||
| 
						 | 
				
			
			@ -161,7 +161,7 @@ def vision_transformer_attention_forward(self, x: torch.Tensor) -> torch.Tensor:
 | 
			
		|||
    query_states, key_states, value_states = qkv.chunk(3, dim=1)
 | 
			
		||||
 | 
			
		||||
    attn_weights = torch.matmul(query_states * self.scale, key_states.transpose(2, 3))
 | 
			
		||||
    attn_weights = attention_softmax(attn_weights, self.training)
 | 
			
		||||
    attn_weights = attention_softmax(attn_weights)
 | 
			
		||||
    attn_weights = self.attn_drop(attn_weights)
 | 
			
		||||
    attn_output = torch.matmul(attn_weights, value_states)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -85,7 +85,7 @@ def mllama_vision_attention_forward(
 | 
			
		|||
 | 
			
		||||
        # upcast attention to fp32
 | 
			
		||||
        from ipex_llm.transformers.models.common import attention_softmax
 | 
			
		||||
        attn_weights = attention_softmax(attn_weights, False)
 | 
			
		||||
        attn_weights = attention_softmax(attn_weights)
 | 
			
		||||
 | 
			
		||||
        attn_output = torch.matmul(attn_weights, value)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -311,7 +311,7 @@ def mllama_cross_attention_forward(
 | 
			
		|||
            attn_weights = attn_weights + causal_mask
 | 
			
		||||
 | 
			
		||||
        # upcast attention to fp32
 | 
			
		||||
        attn_weights = attention_softmax(attn_weights, self.training)
 | 
			
		||||
        attn_weights = attention_softmax(attn_weights)
 | 
			
		||||
        attn_output = torch.matmul(attn_weights, value_states)
 | 
			
		||||
 | 
			
		||||
    attn_output = attn_output.transpose(1, 2).contiguous()
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -114,7 +114,7 @@ def attention_forward(
 | 
			
		|||
        attn_weights = attn_weights + attention_mask
 | 
			
		||||
 | 
			
		||||
    # upcast attention to fp32
 | 
			
		||||
    attn_weights = attention_softmax(attn_weights, self.training).to(hidden_states.dtype)
 | 
			
		||||
    attn_weights = attention_softmax(attn_weights).to(hidden_states.dtype)
 | 
			
		||||
    attn_weights = torch.nn.functional.dropout(attn_weights, p=self.attention_dropout,
 | 
			
		||||
                                               training=self.training)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -185,7 +185,7 @@ def attention_forward(
 | 
			
		|||
        attn_weights.div_(math.sqrt(self.head_dim))
 | 
			
		||||
        if attention_mask is not None:
 | 
			
		||||
            attn_weights.add_(attention_mask)
 | 
			
		||||
        attn_weights = attention_softmax(attn_weights, self.training)
 | 
			
		||||
        attn_weights = attention_softmax(attn_weights)
 | 
			
		||||
 | 
			
		||||
        attn_weights = torch.nn.functional.dropout(attn_weights, p=self.attention_dropout,
 | 
			
		||||
                                                   training=self.training)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -219,7 +219,7 @@ def qwen2_vision_attention_forward(
 | 
			
		|||
        attn_weights = torch.matmul(q, k.transpose(1, 2)) / math.sqrt(self.head_dim)
 | 
			
		||||
        if attention_mask is not None:
 | 
			
		||||
            attn_weights = attn_weights + attention_mask
 | 
			
		||||
        attn_weights = attention_softmax(attn_weights, False)
 | 
			
		||||
        attn_weights = attention_softmax(attn_weights)
 | 
			
		||||
        attn_output = torch.matmul(attn_weights, v)
 | 
			
		||||
    attn_output = attn_output.transpose(0, 1)
 | 
			
		||||
    attn_output = attn_output.reshape(seq_length, -1)
 | 
			
		||||
| 
						 | 
				
			
			@ -298,7 +298,7 @@ def qwen2_vl_attention_forward(
 | 
			
		|||
            attn_weights = attn_weights + causal_mask
 | 
			
		||||
 | 
			
		||||
        # upcast attention to fp32
 | 
			
		||||
        attn_weights = attention_softmax(attn_weights, self.training)
 | 
			
		||||
        attn_weights = attention_softmax(attn_weights)
 | 
			
		||||
        attn_output = torch.matmul(attn_weights, value_states)
 | 
			
		||||
 | 
			
		||||
    attn_output = attn_output.transpose(1, 2).contiguous()
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -114,7 +114,7 @@ class AttnProcessor2_0:
 | 
			
		|||
            attn_weights = torch.matmul(query * scale, key.transpose(-1, -2))
 | 
			
		||||
            if attention_mask is not None:
 | 
			
		||||
                attn_weights = attn_weights + attention_mask
 | 
			
		||||
            attn_weights = attention_softmax(attn_weights, False)
 | 
			
		||||
            attn_weights = attention_softmax(attn_weights)
 | 
			
		||||
            hidden_states = torch.matmul(attn_weights, value)
 | 
			
		||||
        # IPEX-LLM changes end
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -175,7 +175,7 @@ def stablelm_attention_forward(
 | 
			
		|||
            attn_weights = attn_weights + attention_mask
 | 
			
		||||
 | 
			
		||||
        # upcast attention to fp32
 | 
			
		||||
        attn_weights = attention_softmax(attn_weights, self.training)
 | 
			
		||||
        attn_weights = attention_softmax(attn_weights)
 | 
			
		||||
        attn_weights = self.attention_dropout(attn_weights)
 | 
			
		||||
        attn_output = torch.matmul(attn_weights, value_states)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -134,7 +134,7 @@ def attention_forward(
 | 
			
		|||
            attn_weights = attn_weights + attention_mask
 | 
			
		||||
 | 
			
		||||
        # upcast attention to fp32
 | 
			
		||||
        attn_weights = attention_softmax(attn_weights, self.training)
 | 
			
		||||
        attn_weights = attention_softmax(attn_weights)
 | 
			
		||||
        attn_weights = torch.nn.functional.dropout(attn_weights, p=self.attention_dropout,
 | 
			
		||||
                                                   training=self.training)
 | 
			
		||||
        attn_output = torch.matmul(attn_weights, value_states)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -240,7 +240,7 @@ def yuan_attention_forward(
 | 
			
		|||
        if attention_mask is not None:
 | 
			
		||||
            attn_weights = attn_weights + attention_mask
 | 
			
		||||
        # upcast attention to fp32
 | 
			
		||||
        attn_weights = attention_softmax(attn_weights, self.training)
 | 
			
		||||
        attn_weights = attention_softmax(attn_weights)
 | 
			
		||||
        attn_output = torch.matmul(attn_weights, value_states)
 | 
			
		||||
 | 
			
		||||
    attn_output = attn_output.transpose(1, 2)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue