refactor attention_softmax (#12295)
This commit is contained in:
parent
2b2cb9c693
commit
540eaeb12c
14 changed files with 18 additions and 18 deletions
|
|
@ -109,7 +109,7 @@ def aquila_attention_forward(
|
|||
)
|
||||
|
||||
# upcast attention to fp32
|
||||
attn_weights = attention_softmax(attn_weights, self.training)
|
||||
attn_weights = attention_softmax(attn_weights)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
|
||||
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
||||
|
|
|
|||
|
|
@ -79,8 +79,8 @@ def mlp_gelu_forward(self, x: torch.Tensor):
|
|||
return fuse_mlp_base(self, GELU, x)
|
||||
|
||||
|
||||
def attention_softmax(attn_weights: torch.Tensor, training: bool):
|
||||
if attn_weights.is_contiguous() and attn_weights.device.type == "xpu" and not training:
|
||||
def attention_softmax(attn_weights: torch.Tensor):
|
||||
if attn_weights.is_contiguous() and attn_weights.device.type == "xpu":
|
||||
import xe_addons
|
||||
xe_addons.attn_softmax_inplaced(attn_weights)
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -220,7 +220,7 @@ def gemma_attention_forward(
|
|||
attn_weights = attn_weights + causal_mask
|
||||
|
||||
# upcast attention to fp32
|
||||
attn_weights = attention_softmax(attn_weights, self.training)
|
||||
attn_weights = attention_softmax(attn_weights)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout,
|
||||
training=self.training)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
|
|
|
|||
|
|
@ -125,7 +125,7 @@ def internlm_attention_forward(
|
|||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
# upcast attention to fp32
|
||||
attn_weights = attention_softmax(attn_weights, self.training)
|
||||
attn_weights = attention_softmax(attn_weights)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2)
|
||||
|
|
|
|||
|
|
@ -237,7 +237,7 @@ def llama_attention_forward(
|
|||
attn_weights = attn_weights + causal_mask
|
||||
|
||||
# upcast attention to fp32
|
||||
attn_weights = attention_softmax(attn_weights, self.training)
|
||||
attn_weights = attention_softmax(attn_weights)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
|
|
|
|||
|
|
@ -56,7 +56,7 @@ def siglip_attention_forward(
|
|||
if attention_mask is not None:
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
attn_weights = attention_softmax(attn_weights, self.training)
|
||||
attn_weights = attention_softmax(attn_weights)
|
||||
|
||||
attn_weights = torch.nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
|
|
@ -161,7 +161,7 @@ def vision_transformer_attention_forward(self, x: torch.Tensor) -> torch.Tensor:
|
|||
query_states, key_states, value_states = qkv.chunk(3, dim=1)
|
||||
|
||||
attn_weights = torch.matmul(query_states * self.scale, key_states.transpose(2, 3))
|
||||
attn_weights = attention_softmax(attn_weights, self.training)
|
||||
attn_weights = attention_softmax(attn_weights)
|
||||
attn_weights = self.attn_drop(attn_weights)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
|
||||
|
|
|
|||
|
|
@ -85,7 +85,7 @@ def mllama_vision_attention_forward(
|
|||
|
||||
# upcast attention to fp32
|
||||
from ipex_llm.transformers.models.common import attention_softmax
|
||||
attn_weights = attention_softmax(attn_weights, False)
|
||||
attn_weights = attention_softmax(attn_weights)
|
||||
|
||||
attn_output = torch.matmul(attn_weights, value)
|
||||
|
||||
|
|
@ -311,7 +311,7 @@ def mllama_cross_attention_forward(
|
|||
attn_weights = attn_weights + causal_mask
|
||||
|
||||
# upcast attention to fp32
|
||||
attn_weights = attention_softmax(attn_weights, self.training)
|
||||
attn_weights = attention_softmax(attn_weights)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
|
|
|
|||
|
|
@ -114,7 +114,7 @@ def attention_forward(
|
|||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
# upcast attention to fp32
|
||||
attn_weights = attention_softmax(attn_weights, self.training).to(hidden_states.dtype)
|
||||
attn_weights = attention_softmax(attn_weights).to(hidden_states.dtype)
|
||||
attn_weights = torch.nn.functional.dropout(attn_weights, p=self.attention_dropout,
|
||||
training=self.training)
|
||||
|
||||
|
|
|
|||
|
|
@ -185,7 +185,7 @@ def attention_forward(
|
|||
attn_weights.div_(math.sqrt(self.head_dim))
|
||||
if attention_mask is not None:
|
||||
attn_weights.add_(attention_mask)
|
||||
attn_weights = attention_softmax(attn_weights, self.training)
|
||||
attn_weights = attention_softmax(attn_weights)
|
||||
|
||||
attn_weights = torch.nn.functional.dropout(attn_weights, p=self.attention_dropout,
|
||||
training=self.training)
|
||||
|
|
|
|||
|
|
@ -219,7 +219,7 @@ def qwen2_vision_attention_forward(
|
|||
attn_weights = torch.matmul(q, k.transpose(1, 2)) / math.sqrt(self.head_dim)
|
||||
if attention_mask is not None:
|
||||
attn_weights = attn_weights + attention_mask
|
||||
attn_weights = attention_softmax(attn_weights, False)
|
||||
attn_weights = attention_softmax(attn_weights)
|
||||
attn_output = torch.matmul(attn_weights, v)
|
||||
attn_output = attn_output.transpose(0, 1)
|
||||
attn_output = attn_output.reshape(seq_length, -1)
|
||||
|
|
@ -298,7 +298,7 @@ def qwen2_vl_attention_forward(
|
|||
attn_weights = attn_weights + causal_mask
|
||||
|
||||
# upcast attention to fp32
|
||||
attn_weights = attention_softmax(attn_weights, self.training)
|
||||
attn_weights = attention_softmax(attn_weights)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
|
|
|
|||
|
|
@ -114,7 +114,7 @@ class AttnProcessor2_0:
|
|||
attn_weights = torch.matmul(query * scale, key.transpose(-1, -2))
|
||||
if attention_mask is not None:
|
||||
attn_weights = attn_weights + attention_mask
|
||||
attn_weights = attention_softmax(attn_weights, False)
|
||||
attn_weights = attention_softmax(attn_weights)
|
||||
hidden_states = torch.matmul(attn_weights, value)
|
||||
# IPEX-LLM changes end
|
||||
|
||||
|
|
|
|||
|
|
@ -175,7 +175,7 @@ def stablelm_attention_forward(
|
|||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
# upcast attention to fp32
|
||||
attn_weights = attention_softmax(attn_weights, self.training)
|
||||
attn_weights = attention_softmax(attn_weights)
|
||||
attn_weights = self.attention_dropout(attn_weights)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
|
||||
|
|
|
|||
|
|
@ -134,7 +134,7 @@ def attention_forward(
|
|||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
# upcast attention to fp32
|
||||
attn_weights = attention_softmax(attn_weights, self.training)
|
||||
attn_weights = attention_softmax(attn_weights)
|
||||
attn_weights = torch.nn.functional.dropout(attn_weights, p=self.attention_dropout,
|
||||
training=self.training)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
|
|
|
|||
|
|
@ -240,7 +240,7 @@ def yuan_attention_forward(
|
|||
if attention_mask is not None:
|
||||
attn_weights = attn_weights + attention_mask
|
||||
# upcast attention to fp32
|
||||
attn_weights = attention_softmax(attn_weights, self.training)
|
||||
attn_weights = attention_softmax(attn_weights)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2)
|
||||
|
|
|
|||
Loading…
Reference in a new issue