refactor phi-2 to reduce old fuse rope usage (#12214)
This commit is contained in:
parent
bb247e991b
commit
9104a168f6
1 changed files with 20 additions and 31 deletions
|
|
@ -35,7 +35,7 @@ import math
|
|||
import torch
|
||||
|
||||
from ipex_llm.transformers.models.common import merge_qkv_base, attention_softmax
|
||||
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_cache_freq_xpu
|
||||
from ipex_llm.transformers.models.utils import should_use_fuse_rope
|
||||
from ipex_llm.transformers.kv import DynamicNormalCache
|
||||
from ipex_llm.utils.common.log4Error import invalidInputError
|
||||
|
||||
|
|
@ -45,16 +45,6 @@ from transformers.models.phi.modeling_phi import repeat_kv, apply_rotary_pos_emb
|
|||
from transformers.models.phi.modeling_phi import PhiModel
|
||||
|
||||
|
||||
def should_use_fuse_rope(self, hidden_states, position_ids):
|
||||
use_fuse_rope = (
|
||||
hidden_states.device.type == "xpu" and
|
||||
hidden_states.numel() == hidden_states.size(-1) and
|
||||
not (self.training and hidden_states.requires_grad) and
|
||||
position_ids is not None
|
||||
)
|
||||
return use_fuse_rope
|
||||
|
||||
|
||||
def merge_qkv(module: torch.nn.Module):
|
||||
merge_qkv_base(module, "PhiAttention")
|
||||
|
||||
|
|
@ -82,31 +72,30 @@ def attention_forward(
|
|||
kv_seq_len = key_states.shape[-2]
|
||||
if past_key_value is not None:
|
||||
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
|
||||
# Partial rotary embedding
|
||||
query_rot, query_pass = (
|
||||
query_states[..., : self.rotary_emb.dim],
|
||||
query_states[..., self.rotary_emb.dim:],
|
||||
)
|
||||
key_rot, key_pass = (
|
||||
key_states[..., : self.rotary_emb.dim],
|
||||
key_states[..., self.rotary_emb.dim:],
|
||||
)
|
||||
|
||||
# IPEX-LLM OPT: fuse rope
|
||||
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
|
||||
|
||||
# [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
|
||||
if use_fuse_rope:
|
||||
query_rot, key_rot = apply_rotary_pos_emb_cache_freq_xpu(query_rot, key_rot, sin,
|
||||
cos, "stablelm", position_ids)
|
||||
if should_use_fuse_rope(hidden_states, position_ids, self.training):
|
||||
import xe_addons
|
||||
rot_dim = self.rotary_emb.dim
|
||||
xe_addons.rotary_half_inplaced(self.rotary_emb.inv_freq, position_ids,
|
||||
query_states[..., :rot_dim], key_states[..., :rot_dim])
|
||||
else:
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
# Partial rotary embedding
|
||||
query_rot, query_pass = (
|
||||
query_states[..., : self.rotary_emb.dim],
|
||||
query_states[..., self.rotary_emb.dim:],
|
||||
)
|
||||
key_rot, key_pass = (
|
||||
key_states[..., : self.rotary_emb.dim],
|
||||
key_states[..., self.rotary_emb.dim:],
|
||||
)
|
||||
query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, position_ids)
|
||||
|
||||
# [batch_size, seq_length, num_heads, head_dim]
|
||||
query_states = torch.cat((query_rot, query_pass), dim=-1)
|
||||
key_states = torch.cat((key_rot, key_pass), dim=-1)
|
||||
# [batch_size, seq_length, num_heads, head_dim]
|
||||
query_states = torch.cat((query_rot, query_pass), dim=-1)
|
||||
key_states = torch.cat((key_rot, key_pass), dim=-1)
|
||||
|
||||
invalidInputError(past_key_value is not None,
|
||||
"`past_key_value` cannot be None")
|
||||
|
|
@ -125,7 +114,7 @@ def attention_forward(
|
|||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
# upcast attention to fp32
|
||||
attn_weights = attention_softmax(attn_weights, self.training)
|
||||
attn_weights = attention_softmax(attn_weights, self.training).to(hidden_states.dtype)
|
||||
attn_weights = torch.nn.functional.dropout(attn_weights, p=self.attention_dropout,
|
||||
training=self.training)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue