refactor phi-2 to reduce old fuse rope usage (#12214)

This commit is contained in:
Yishuo Wang 2024-10-16 17:08:14 +08:00 committed by GitHub
parent bb247e991b
commit 9104a168f6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -35,7 +35,7 @@ import math
import torch import torch
from ipex_llm.transformers.models.common import merge_qkv_base, attention_softmax from ipex_llm.transformers.models.common import merge_qkv_base, attention_softmax
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_cache_freq_xpu from ipex_llm.transformers.models.utils import should_use_fuse_rope
from ipex_llm.transformers.kv import DynamicNormalCache from ipex_llm.transformers.kv import DynamicNormalCache
from ipex_llm.utils.common.log4Error import invalidInputError from ipex_llm.utils.common.log4Error import invalidInputError
@ -45,16 +45,6 @@ from transformers.models.phi.modeling_phi import repeat_kv, apply_rotary_pos_emb
from transformers.models.phi.modeling_phi import PhiModel from transformers.models.phi.modeling_phi import PhiModel
def should_use_fuse_rope(self, hidden_states, position_ids):
use_fuse_rope = (
hidden_states.device.type == "xpu" and
hidden_states.numel() == hidden_states.size(-1) and
not (self.training and hidden_states.requires_grad) and
position_ids is not None
)
return use_fuse_rope
def merge_qkv(module: torch.nn.Module): def merge_qkv(module: torch.nn.Module):
merge_qkv_base(module, "PhiAttention") merge_qkv_base(module, "PhiAttention")
@ -82,31 +72,30 @@ def attention_forward(
kv_seq_len = key_states.shape[-2] kv_seq_len = key_states.shape[-2]
if past_key_value is not None: if past_key_value is not None:
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
# Partial rotary embedding
query_rot, query_pass = (
query_states[..., : self.rotary_emb.dim],
query_states[..., self.rotary_emb.dim:],
)
key_rot, key_pass = (
key_states[..., : self.rotary_emb.dim],
key_states[..., self.rotary_emb.dim:],
)
# IPEX-LLM OPT: fuse rope # IPEX-LLM OPT: fuse rope
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
# [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor] # [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
if use_fuse_rope: if should_use_fuse_rope(hidden_states, position_ids, self.training):
query_rot, key_rot = apply_rotary_pos_emb_cache_freq_xpu(query_rot, key_rot, sin, import xe_addons
cos, "stablelm", position_ids) rot_dim = self.rotary_emb.dim
xe_addons.rotary_half_inplaced(self.rotary_emb.inv_freq, position_ids,
query_states[..., :rot_dim], key_states[..., :rot_dim])
else: else:
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
# Partial rotary embedding
query_rot, query_pass = (
query_states[..., : self.rotary_emb.dim],
query_states[..., self.rotary_emb.dim:],
)
key_rot, key_pass = (
key_states[..., : self.rotary_emb.dim],
key_states[..., self.rotary_emb.dim:],
)
query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, position_ids) query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, position_ids)
# [batch_size, seq_length, num_heads, head_dim] # [batch_size, seq_length, num_heads, head_dim]
query_states = torch.cat((query_rot, query_pass), dim=-1) query_states = torch.cat((query_rot, query_pass), dim=-1)
key_states = torch.cat((key_rot, key_pass), dim=-1) key_states = torch.cat((key_rot, key_pass), dim=-1)
invalidInputError(past_key_value is not None, invalidInputError(past_key_value is not None,
"`past_key_value` cannot be None") "`past_key_value` cannot be None")
@ -125,7 +114,7 @@ def attention_forward(
attn_weights = attn_weights + attention_mask attn_weights = attn_weights + attention_mask
# upcast attention to fp32 # upcast attention to fp32
attn_weights = attention_softmax(attn_weights, self.training) attn_weights = attention_softmax(attn_weights, self.training).to(hidden_states.dtype)
attn_weights = torch.nn.functional.dropout(attn_weights, p=self.attention_dropout, attn_weights = torch.nn.functional.dropout(attn_weights, p=self.attention_dropout,
training=self.training) training=self.training)