optimize new minicpm model (#12579)
This commit is contained in:
parent
4540424271
commit
80f2fdc37b
3 changed files with 15 additions and 62 deletions
|
|
@ -217,8 +217,8 @@ def prepare_mask(mask, bsz, n_heads, seq_length, kv_length, is_causal, dtype, de
|
|||
return mask
|
||||
|
||||
|
||||
def scaled_dot_product_attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor,
|
||||
mask: torch.Tensor = None,
|
||||
def scaled_dot_product_attention(query: torch.Tensor, key: torch.Tensor,
|
||||
value: torch.Tensor, mask: torch.Tensor = None,
|
||||
is_causal: bool = False, scale: float = None) -> torch.Tensor:
|
||||
bsz, n_heads, seq_length, head_dim = query.shape
|
||||
_, n_kv_heads, kv_length, _ = key.shape
|
||||
|
|
@ -268,7 +268,7 @@ def scaled_dot_product_attention(query: torch.Tensor, key: torch.Tensor, value:
|
|||
attn_output = xe_addons.sdp(query, key, value, mask)
|
||||
else:
|
||||
if key.dtype == torch.uint8:
|
||||
attn_output = xe_addons.sdp_fp8(query, key, value, mask)
|
||||
attn_output = xe_addons.sdp_fp8_non_causal(query, key, value, mask)
|
||||
else:
|
||||
attn_output = xe_addons.sdp_non_causal(query, key, value, mask)
|
||||
|
||||
|
|
@ -281,6 +281,8 @@ def scaled_dot_product_attention(query: torch.Tensor, key: torch.Tensor, value:
|
|||
key = repeat_kv(key, n_heads // n_kv_heads)
|
||||
value = repeat_kv(value, n_heads // n_kv_heads)
|
||||
|
||||
return torch.nn.functional.scaled_dot_product_attention(
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
query, key, value, mask, is_causal=is_causal, scale=scale
|
||||
)
|
||||
attn_output = attn_output.to(dtype) # workaround ipex 2.1's bug
|
||||
return attn_output
|
||||
|
|
|
|||
|
|
@ -127,49 +127,12 @@ def minicpm_attention_forward(
|
|||
key_states, value_states = past_key_value.update(key_states, value_states,
|
||||
self.layer_idx, None)
|
||||
|
||||
from ipex_llm.transformers.models.common import scaled_dot_product_attention
|
||||
attn_weights = None
|
||||
if use_sdp(q_len, kv_seq_len, self.head_dim, query_states):
|
||||
import xe_addons
|
||||
# [CompressKV]
|
||||
if use_compresskv:
|
||||
attention_mask = get_compresskv_attn_mask(key_states, attention_mask)
|
||||
|
||||
if use_quantizekv:
|
||||
attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states,
|
||||
attention_mask)
|
||||
else:
|
||||
attn_output = xe_addons.sdp(query_states, key_states, value_states,
|
||||
attention_mask)
|
||||
elif use_sdp_causal(q_len, kv_seq_len, self.head_dim, query_states, self.training):
|
||||
import xe_addons
|
||||
if use_quantizekv:
|
||||
attn_output = xe_addons.sdp_fp8_causal(query_states, key_states,
|
||||
value_states, attention_mask)
|
||||
else:
|
||||
attn_output = xe_addons.sdp_causal(query_states, key_states,
|
||||
value_states, attention_mask)
|
||||
else:
|
||||
if use_quantizekv:
|
||||
key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
|
||||
query_states.dtype)
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
attn_weights = torch.matmul(
|
||||
query_states, key_states.transpose(2, 3)
|
||||
) / math.sqrt(self.head_dim)
|
||||
|
||||
if attention_mask is not None:
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
# upcast attention to fp32
|
||||
attn_weights = nn.functional.softmax(
|
||||
attn_weights, dim=-1, dtype=torch.float32
|
||||
).to(query_states.dtype)
|
||||
attn_weights = nn.functional.dropout(
|
||||
attn_weights, p=self.attention_dropout, training=self.training
|
||||
attn_output = scaled_dot_product_attention(
|
||||
query_states, key_states, value_states,
|
||||
attention_mask, q_len == kv_seq_len, math.sqrt(self.head_dim)
|
||||
)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||
|
|
|
|||
|
|
@ -28,7 +28,6 @@ from typing import Optional, List
|
|||
from torch.nn.functional import linear
|
||||
from ipex_llm.transformers.models.common import merge_qkv_base, padding_qkv_hd
|
||||
from ipex_llm.transformers.models.common import attention_softmax
|
||||
from ipex_llm.transformers.models.utils import use_sdp_non_causal
|
||||
from transformers import AutoProcessor, TextIteratorStreamer
|
||||
from transformers.generation.logits_process import RepetitionPenaltyLogitsProcessor
|
||||
|
||||
|
|
@ -73,21 +72,10 @@ def siglip_attention_forward(
|
|||
72, 80
|
||||
)
|
||||
|
||||
if use_sdp_non_causal(query_states.size(-1), query_states.device, query_states.dtype):
|
||||
import xe_addons
|
||||
from ipex_llm.transformers.models.common import scaled_dot_product_attention
|
||||
attn_weights = None
|
||||
attn_output = xe_addons.sdp_non_causal(query_states, key_states.contiguous(),
|
||||
value_states.contiguous(), attention_mask)
|
||||
else:
|
||||
attn_weights = torch.matmul(query_states * self.scale, key_states.transpose(2, 3))
|
||||
if attention_mask is not None:
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
attn_weights = attention_softmax(attn_weights)
|
||||
|
||||
attn_weights = torch.nn.functional.dropout(attn_weights,
|
||||
p=self.dropout, training=self.training)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
attn_output = scaled_dot_product_attention(query_states, key_states, value_states,
|
||||
attention_mask, False, math.sqrt(self.head_dim))
|
||||
|
||||
attn_output = attn_output[:, :, :, :self.head_dim]
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue