refactor mllama, gpt2 and internvl (#12602)

This commit is contained in:
Yishuo Wang 2024-12-24 14:18:31 +08:00 committed by GitHub
parent 7aaf02f602
commit ad2dc965c5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 21 additions and 67 deletions

View file

@ -15,6 +15,7 @@
#
import torch
from ipex_llm.transformers.models.common import scaled_dot_product_attention
from ipex_llm.transformers.models.utils import use_sdp_non_causal
@ -44,10 +45,11 @@ def gpt2_attention_attn(
else:
attention_mask = attention_mask.expand(-1, -1, seq_len, seq_len)
import xe_addons
attn_weights = None
attn_output = xe_addons.sdp_non_causal(query, key.contiguous(),
value.contiguous(), attention_mask)
attn_output = scaled_dot_product_attention(
query, key.contiguous(), value.contiguous(),
attention_mask, False
)
return attn_output, attn_weights
# ipex-llm changes end

View file

@ -26,6 +26,7 @@
import torch
from ipex_llm.utils.common.log4Error import invalidInputError
from ipex_llm.transformers.models.common import scaled_dot_product_attention
from ipex_llm.transformers.models.utils import use_sdp_non_causal
@ -177,8 +178,10 @@ def intern_attention_forward(self, x: torch.Tensor) -> torch.Tensor:
k = self.k_norm(k.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2)
if use_sdp_non_causal(self.head_dim, q.device, q.dtype):
import xe_addons
x = xe_addons.sdp_non_causal(q, k.contiguous(), v.contiguous(), None)
x = scaled_dot_product_attention(
q, k.contiguous(), v.contiguous(),
None, False, self.scale
)
else:
attn = ((q * self.scale) @ k.transpose(-2, -1))
attn = attn.softmax(dim=-1)

View file

@ -32,7 +32,6 @@
# limitations under the License.
import math
import torch
from typing import Optional, Tuple, Union
@ -40,11 +39,10 @@ from transformers.cache_utils import Cache
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.models.mllama.modeling_mllama import MllamaVisionAttention
from transformers.models.mllama.modeling_mllama import MllamaTextSelfAttention
from transformers.models.mllama.modeling_mllama import repeat_kv
from ipex_llm.transformers.models.utils import use_sdp, use_sdp_causal, use_sdp_non_causal
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache
from ipex_llm.transformers.models.utils import use_quantize_kv_cache
from ipex_llm.transformers.models.utils import should_use_fuse_rope
from ipex_llm.transformers.models.common import merge_qkv_base, attention_softmax
from ipex_llm.transformers.models.common import scaled_dot_product_attention
from ipex_llm.transformers.kv import DynamicNormalCache, DynamicFp8Cache
from ipex_llm.transformers.utils import invalidInputError
@ -67,27 +65,11 @@ def mllama_vision_attention_forward(
qkv = qkv.transpose(1, 2)
query, key, value = qkv.chunk(3, dim=1)
if attention_mask is not None: # no matter the length, we just slice it
causal_mask = attention_mask[:, :, :, : key.shape[-2]]
else:
causal_mask = None
if use_sdp_non_causal(self.head_dim, query.device, query.dtype):
import xe_addons
attn_output = xe_addons.sdp_non_causal(query, key.contiguous(),
value.contiguous(), causal_mask)
attn_weights = None
else:
attn_weights = torch.matmul(query, key.transpose(2, 3)) / math.sqrt(self.head_dim)
if attention_mask is not None:
attn_weights = attn_weights + causal_mask
# upcast attention to fp32
from ipex_llm.transformers.models.common import attention_softmax
attn_weights = attention_softmax(attn_weights)
attn_output = torch.matmul(attn_weights, value)
attn_output = scaled_dot_product_attention(
query, key.contiguous(), value.contiguous(),
attention_softmax, False
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, -1)
@ -278,44 +260,11 @@ def mllama_cross_attention_forward(
past_key_value.value_cache[self.layer_idx],
)
kv_seq_len = key_states.size(2)
if attention_mask is not None: # no matter the length, we just slice it
causal_mask = attention_mask[:, :, :, :kv_seq_len]
else:
causal_mask = None
attn_weights = None
if use_sdp(q_len, kv_seq_len, self.head_dim, query_states):
import xe_addons
if isinstance(past_key_value, DynamicFp8Cache):
attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states, causal_mask)
else:
attn_output = xe_addons.sdp(query_states, key_states, value_states, causal_mask)
elif use_sdp_causal(q_len, kv_seq_len, self.head_dim, query_states, self.training):
import xe_addons
if isinstance(past_key_value, DynamicFp8Cache):
attn_output = xe_addons.sdp_fp8_causal(query_states, key_states,
value_states, causal_mask)
else:
attn_output = xe_addons.sdp_causal(query_states, key_states,
value_states, causal_mask)
else:
if isinstance(past_key_value, DynamicFp8Cache):
key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
query_states.dtype)
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
attn_weights = torch.matmul(query_states,
key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if causal_mask is not None:
attn_weights = attn_weights + causal_mask
# upcast attention to fp32
attn_weights = attention_softmax(attn_weights)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = scaled_dot_product_attention(
query_states, key_states, value_states,
attention_mask, q_len == key_states.size(2)
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, -1)