optimize llama3.2 vison attention again (#12204)

This commit is contained in:
Yishuo Wang 2024-10-15 16:08:20 +08:00 committed by GitHub
parent 9b81236a2e
commit f6611f9d3a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -36,6 +36,7 @@ import math
import torch
from typing import Optional
from ipex_llm.transformers.models.utils import use_sdp_non_causal
def mllama_vision_attention_forward(
@ -55,15 +56,25 @@ def mllama_vision_attention_forward(
key = key.view(batch_size, kv_seq_len, self.num_heads, self.head_dim).transpose(1, 2)
value = value.view(batch_size, kv_seq_len, self.num_heads, self.head_dim).transpose(1, 2)
attn_weights = torch.matmul(query, key.transpose(2, 3)) / math.sqrt(self.head_dim)
if attention_mask is not None: # no matter the length, we just slice it
causal_mask = attention_mask[:, :, :, : key.shape[-2]]
else:
causal_mask = None
if use_sdp_non_causal(self.head_dim, query.device, query.dtype):
import xe_addons
attn_output = xe_addons.sdp_non_causal(query, key.contiguous(),
value.contiguous(), causal_mask)
attn_weights = None
else:
attn_weights = torch.matmul(query, key.transpose(2, 3)) / math.sqrt(self.head_dim)
if attention_mask is not None:
attn_weights = attn_weights + causal_mask
# upcast attention to fp32
from ipex_llm.transformers.models.common import attention_softmax
attn_weights = attention_softmax(attn_weights, self.training)
attn_weights = attention_softmax(attn_weights, False)
attn_output = torch.matmul(attn_weights, value)