refactor attention_softmax (#12295)

This commit is contained in:
Yishuo Wang 2024-10-30 13:20:50 +08:00 committed by GitHub
parent 2b2cb9c693
commit 540eaeb12c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 18 additions and 18 deletions

View file

@ -109,7 +109,7 @@ def aquila_attention_forward(
) )
# upcast attention to fp32 # upcast attention to fp32
attn_weights = attention_softmax(attn_weights, self.training) attn_weights = attention_softmax(attn_weights)
attn_output = torch.matmul(attn_weights, value_states) attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):

View file

@ -79,8 +79,8 @@ def mlp_gelu_forward(self, x: torch.Tensor):
return fuse_mlp_base(self, GELU, x) return fuse_mlp_base(self, GELU, x)
def attention_softmax(attn_weights: torch.Tensor, training: bool): def attention_softmax(attn_weights: torch.Tensor):
if attn_weights.is_contiguous() and attn_weights.device.type == "xpu" and not training: if attn_weights.is_contiguous() and attn_weights.device.type == "xpu":
import xe_addons import xe_addons
xe_addons.attn_softmax_inplaced(attn_weights) xe_addons.attn_softmax_inplaced(attn_weights)
else: else:

View file

@ -220,7 +220,7 @@ def gemma_attention_forward(
attn_weights = attn_weights + causal_mask attn_weights = attn_weights + causal_mask
# upcast attention to fp32 # upcast attention to fp32
attn_weights = attention_softmax(attn_weights, self.training) attn_weights = attention_softmax(attn_weights)
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout,
training=self.training) training=self.training)
attn_output = torch.matmul(attn_weights, value_states) attn_output = torch.matmul(attn_weights, value_states)

View file

@ -125,7 +125,7 @@ def internlm_attention_forward(
attn_weights = attn_weights + attention_mask attn_weights = attn_weights + attention_mask
# upcast attention to fp32 # upcast attention to fp32
attn_weights = attention_softmax(attn_weights, self.training) attn_weights = attention_softmax(attn_weights)
attn_output = torch.matmul(attn_weights, value_states) attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2) attn_output = attn_output.transpose(1, 2)

View file

@ -237,7 +237,7 @@ def llama_attention_forward(
attn_weights = attn_weights + causal_mask attn_weights = attn_weights + causal_mask
# upcast attention to fp32 # upcast attention to fp32
attn_weights = attention_softmax(attn_weights, self.training) attn_weights = attention_softmax(attn_weights)
attn_output = torch.matmul(attn_weights, value_states) attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.transpose(1, 2).contiguous()

View file

@ -56,7 +56,7 @@ def siglip_attention_forward(
if attention_mask is not None: if attention_mask is not None:
attn_weights = attn_weights + attention_mask attn_weights = attn_weights + attention_mask
attn_weights = attention_softmax(attn_weights, self.training) attn_weights = attention_softmax(attn_weights)
attn_weights = torch.nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) attn_weights = torch.nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
attn_output = torch.matmul(attn_weights, value_states) attn_output = torch.matmul(attn_weights, value_states)
@ -161,7 +161,7 @@ def vision_transformer_attention_forward(self, x: torch.Tensor) -> torch.Tensor:
query_states, key_states, value_states = qkv.chunk(3, dim=1) query_states, key_states, value_states = qkv.chunk(3, dim=1)
attn_weights = torch.matmul(query_states * self.scale, key_states.transpose(2, 3)) attn_weights = torch.matmul(query_states * self.scale, key_states.transpose(2, 3))
attn_weights = attention_softmax(attn_weights, self.training) attn_weights = attention_softmax(attn_weights)
attn_weights = self.attn_drop(attn_weights) attn_weights = self.attn_drop(attn_weights)
attn_output = torch.matmul(attn_weights, value_states) attn_output = torch.matmul(attn_weights, value_states)

View file

@ -85,7 +85,7 @@ def mllama_vision_attention_forward(
# upcast attention to fp32 # upcast attention to fp32
from ipex_llm.transformers.models.common import attention_softmax from ipex_llm.transformers.models.common import attention_softmax
attn_weights = attention_softmax(attn_weights, False) attn_weights = attention_softmax(attn_weights)
attn_output = torch.matmul(attn_weights, value) attn_output = torch.matmul(attn_weights, value)
@ -311,7 +311,7 @@ def mllama_cross_attention_forward(
attn_weights = attn_weights + causal_mask attn_weights = attn_weights + causal_mask
# upcast attention to fp32 # upcast attention to fp32
attn_weights = attention_softmax(attn_weights, self.training) attn_weights = attention_softmax(attn_weights)
attn_output = torch.matmul(attn_weights, value_states) attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.transpose(1, 2).contiguous()

View file

@ -114,7 +114,7 @@ def attention_forward(
attn_weights = attn_weights + attention_mask attn_weights = attn_weights + attention_mask
# upcast attention to fp32 # upcast attention to fp32
attn_weights = attention_softmax(attn_weights, self.training).to(hidden_states.dtype) attn_weights = attention_softmax(attn_weights).to(hidden_states.dtype)
attn_weights = torch.nn.functional.dropout(attn_weights, p=self.attention_dropout, attn_weights = torch.nn.functional.dropout(attn_weights, p=self.attention_dropout,
training=self.training) training=self.training)

View file

@ -185,7 +185,7 @@ def attention_forward(
attn_weights.div_(math.sqrt(self.head_dim)) attn_weights.div_(math.sqrt(self.head_dim))
if attention_mask is not None: if attention_mask is not None:
attn_weights.add_(attention_mask) attn_weights.add_(attention_mask)
attn_weights = attention_softmax(attn_weights, self.training) attn_weights = attention_softmax(attn_weights)
attn_weights = torch.nn.functional.dropout(attn_weights, p=self.attention_dropout, attn_weights = torch.nn.functional.dropout(attn_weights, p=self.attention_dropout,
training=self.training) training=self.training)

View file

@ -219,7 +219,7 @@ def qwen2_vision_attention_forward(
attn_weights = torch.matmul(q, k.transpose(1, 2)) / math.sqrt(self.head_dim) attn_weights = torch.matmul(q, k.transpose(1, 2)) / math.sqrt(self.head_dim)
if attention_mask is not None: if attention_mask is not None:
attn_weights = attn_weights + attention_mask attn_weights = attn_weights + attention_mask
attn_weights = attention_softmax(attn_weights, False) attn_weights = attention_softmax(attn_weights)
attn_output = torch.matmul(attn_weights, v) attn_output = torch.matmul(attn_weights, v)
attn_output = attn_output.transpose(0, 1) attn_output = attn_output.transpose(0, 1)
attn_output = attn_output.reshape(seq_length, -1) attn_output = attn_output.reshape(seq_length, -1)
@ -298,7 +298,7 @@ def qwen2_vl_attention_forward(
attn_weights = attn_weights + causal_mask attn_weights = attn_weights + causal_mask
# upcast attention to fp32 # upcast attention to fp32
attn_weights = attention_softmax(attn_weights, self.training) attn_weights = attention_softmax(attn_weights)
attn_output = torch.matmul(attn_weights, value_states) attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.transpose(1, 2).contiguous()

View file

@ -114,7 +114,7 @@ class AttnProcessor2_0:
attn_weights = torch.matmul(query * scale, key.transpose(-1, -2)) attn_weights = torch.matmul(query * scale, key.transpose(-1, -2))
if attention_mask is not None: if attention_mask is not None:
attn_weights = attn_weights + attention_mask attn_weights = attn_weights + attention_mask
attn_weights = attention_softmax(attn_weights, False) attn_weights = attention_softmax(attn_weights)
hidden_states = torch.matmul(attn_weights, value) hidden_states = torch.matmul(attn_weights, value)
# IPEX-LLM changes end # IPEX-LLM changes end

View file

@ -175,7 +175,7 @@ def stablelm_attention_forward(
attn_weights = attn_weights + attention_mask attn_weights = attn_weights + attention_mask
# upcast attention to fp32 # upcast attention to fp32
attn_weights = attention_softmax(attn_weights, self.training) attn_weights = attention_softmax(attn_weights)
attn_weights = self.attention_dropout(attn_weights) attn_weights = self.attention_dropout(attn_weights)
attn_output = torch.matmul(attn_weights, value_states) attn_output = torch.matmul(attn_weights, value_states)

View file

@ -134,7 +134,7 @@ def attention_forward(
attn_weights = attn_weights + attention_mask attn_weights = attn_weights + attention_mask
# upcast attention to fp32 # upcast attention to fp32
attn_weights = attention_softmax(attn_weights, self.training) attn_weights = attention_softmax(attn_weights)
attn_weights = torch.nn.functional.dropout(attn_weights, p=self.attention_dropout, attn_weights = torch.nn.functional.dropout(attn_weights, p=self.attention_dropout,
training=self.training) training=self.training)
attn_output = torch.matmul(attn_weights, value_states) attn_output = torch.matmul(attn_weights, value_states)

View file

@ -240,7 +240,7 @@ def yuan_attention_forward(
if attention_mask is not None: if attention_mask is not None:
attn_weights = attn_weights + attention_mask attn_weights = attn_weights + attention_mask
# upcast attention to fp32 # upcast attention to fp32
attn_weights = attention_softmax(attn_weights, self.training) attn_weights = attention_softmax(attn_weights)
attn_output = torch.matmul(attn_weights, value_states) attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2) attn_output = attn_output.transpose(1, 2)