refactor phi-2 to reduce old fuse rope usage (#12214)

This commit is contained in:
Yishuo Wang 2024-10-16 17:08:14 +08:00 committed by GitHub
parent bb247e991b
commit 9104a168f6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -35,7 +35,7 @@ import math
import torch
from ipex_llm.transformers.models.common import merge_qkv_base, attention_softmax
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_cache_freq_xpu
from ipex_llm.transformers.models.utils import should_use_fuse_rope
from ipex_llm.transformers.kv import DynamicNormalCache
from ipex_llm.utils.common.log4Error import invalidInputError
@ -45,16 +45,6 @@ from transformers.models.phi.modeling_phi import repeat_kv, apply_rotary_pos_emb
from transformers.models.phi.modeling_phi import PhiModel
def should_use_fuse_rope(self, hidden_states, position_ids):
use_fuse_rope = (
hidden_states.device.type == "xpu" and
hidden_states.numel() == hidden_states.size(-1) and
not (self.training and hidden_states.requires_grad) and
position_ids is not None
)
return use_fuse_rope
def merge_qkv(module: torch.nn.Module):
merge_qkv_base(module, "PhiAttention")
@ -82,8 +72,16 @@ def attention_forward(
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
# IPEX-LLM OPT: fuse rope
# [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
if should_use_fuse_rope(hidden_states, position_ids, self.training):
import xe_addons
rot_dim = self.rotary_emb.dim
xe_addons.rotary_half_inplaced(self.rotary_emb.inv_freq, position_ids,
query_states[..., :rot_dim], key_states[..., :rot_dim])
else:
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
# Partial rotary embedding
query_rot, query_pass = (
query_states[..., : self.rotary_emb.dim],
@ -93,15 +91,6 @@ def attention_forward(
key_states[..., : self.rotary_emb.dim],
key_states[..., self.rotary_emb.dim:],
)
# IPEX-LLM OPT: fuse rope
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
# [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
if use_fuse_rope:
query_rot, key_rot = apply_rotary_pos_emb_cache_freq_xpu(query_rot, key_rot, sin,
cos, "stablelm", position_ids)
else:
query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, position_ids)
# [batch_size, seq_length, num_heads, head_dim]
@ -125,7 +114,7 @@ def attention_forward(
attn_weights = attn_weights + attention_mask
# upcast attention to fp32
attn_weights = attention_softmax(attn_weights, self.training)
attn_weights = attention_softmax(attn_weights, self.training).to(hidden_states.dtype)
attn_weights = torch.nn.functional.dropout(attn_weights, p=self.attention_dropout,
training=self.training)