optimize llama3.2 vison attention again (#12204)
This commit is contained in:
parent
9b81236a2e
commit
f6611f9d3a
1 changed files with 18 additions and 7 deletions
|
|
@ -36,6 +36,7 @@ import math
|
|||
import torch
|
||||
|
||||
from typing import Optional
|
||||
from ipex_llm.transformers.models.utils import use_sdp_non_causal
|
||||
|
||||
|
||||
def mllama_vision_attention_forward(
|
||||
|
|
@ -55,17 +56,27 @@ def mllama_vision_attention_forward(
|
|||
key = key.view(batch_size, kv_seq_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
value = value.view(batch_size, kv_seq_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
attn_weights = torch.matmul(query, key.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||
|
||||
if attention_mask is not None: # no matter the length, we just slice it
|
||||
causal_mask = attention_mask[:, :, :, : key.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
else:
|
||||
causal_mask = None
|
||||
|
||||
# upcast attention to fp32
|
||||
from ipex_llm.transformers.models.common import attention_softmax
|
||||
attn_weights = attention_softmax(attn_weights, self.training)
|
||||
if use_sdp_non_causal(self.head_dim, query.device, query.dtype):
|
||||
import xe_addons
|
||||
attn_output = xe_addons.sdp_non_causal(query, key.contiguous(),
|
||||
value.contiguous(), causal_mask)
|
||||
attn_weights = None
|
||||
else:
|
||||
attn_weights = torch.matmul(query, key.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||
|
||||
attn_output = torch.matmul(attn_weights, value)
|
||||
if attention_mask is not None:
|
||||
attn_weights = attn_weights + causal_mask
|
||||
|
||||
# upcast attention to fp32
|
||||
from ipex_llm.transformers.models.common import attention_softmax
|
||||
attn_weights = attention_softmax(attn_weights, False)
|
||||
|
||||
attn_output = torch.matmul(attn_weights, value)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
attn_output = attn_output.reshape(batch_size, q_seq_len, -1)
|
||||
|
|
|
|||
Loading…
Reference in a new issue