diff --git a/python/llm/src/ipex_llm/transformers/models/phi.py b/python/llm/src/ipex_llm/transformers/models/phi.py index 43365623..ca68700a 100644 --- a/python/llm/src/ipex_llm/transformers/models/phi.py +++ b/python/llm/src/ipex_llm/transformers/models/phi.py @@ -35,7 +35,7 @@ import math import torch from ipex_llm.transformers.models.common import merge_qkv_base, attention_softmax -from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_cache_freq_xpu +from ipex_llm.transformers.models.utils import should_use_fuse_rope from ipex_llm.transformers.kv import DynamicNormalCache from ipex_llm.utils.common.log4Error import invalidInputError @@ -45,16 +45,6 @@ from transformers.models.phi.modeling_phi import repeat_kv, apply_rotary_pos_emb from transformers.models.phi.modeling_phi import PhiModel -def should_use_fuse_rope(self, hidden_states, position_ids): - use_fuse_rope = ( - hidden_states.device.type == "xpu" and - hidden_states.numel() == hidden_states.size(-1) and - not (self.training and hidden_states.requires_grad) and - position_ids is not None - ) - return use_fuse_rope - - def merge_qkv(module: torch.nn.Module): merge_qkv_base(module, "PhiAttention") @@ -82,31 +72,30 @@ def attention_forward( kv_seq_len = key_states.shape[-2] if past_key_value is not None: kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - - # Partial rotary embedding - query_rot, query_pass = ( - query_states[..., : self.rotary_emb.dim], - query_states[..., self.rotary_emb.dim:], - ) - key_rot, key_pass = ( - key_states[..., : self.rotary_emb.dim], - key_states[..., self.rotary_emb.dim:], - ) # IPEX-LLM OPT: fuse rope - use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids) - # [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor] - if use_fuse_rope: - query_rot, key_rot = apply_rotary_pos_emb_cache_freq_xpu(query_rot, key_rot, sin, - cos, "stablelm", position_ids) + if should_use_fuse_rope(hidden_states, position_ids, self.training): + import xe_addons + rot_dim = self.rotary_emb.dim + xe_addons.rotary_half_inplaced(self.rotary_emb.inv_freq, position_ids, + query_states[..., :rot_dim], key_states[..., :rot_dim]) else: + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + # Partial rotary embedding + query_rot, query_pass = ( + query_states[..., : self.rotary_emb.dim], + query_states[..., self.rotary_emb.dim:], + ) + key_rot, key_pass = ( + key_states[..., : self.rotary_emb.dim], + key_states[..., self.rotary_emb.dim:], + ) query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, position_ids) - # [batch_size, seq_length, num_heads, head_dim] - query_states = torch.cat((query_rot, query_pass), dim=-1) - key_states = torch.cat((key_rot, key_pass), dim=-1) + # [batch_size, seq_length, num_heads, head_dim] + query_states = torch.cat((query_rot, query_pass), dim=-1) + key_states = torch.cat((key_rot, key_pass), dim=-1) invalidInputError(past_key_value is not None, "`past_key_value` cannot be None") @@ -125,7 +114,7 @@ def attention_forward( attn_weights = attn_weights + attention_mask # upcast attention to fp32 - attn_weights = attention_softmax(attn_weights, self.training) + attn_weights = attention_softmax(attn_weights, self.training).to(hidden_states.dtype) attn_weights = torch.nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)