use new rope in phi3 (#11047)
This commit is contained in:
parent
00d4410746
commit
8cae897643
2 changed files with 36 additions and 91 deletions
|
|
@ -706,6 +706,8 @@ def _optimize_pre(model):
|
|||
from ipex_llm.transformers.models.phi import merge_qkv
|
||||
model.apply(merge_qkv)
|
||||
if model.config.model_type == "phi3":
|
||||
from ipex_llm.transformers.models.phi3 import pre_compute_inv_freq
|
||||
model.apply(pre_compute_inv_freq)
|
||||
from ipex_llm.transformers.models.phi3 import split_mlp
|
||||
model.apply(split_mlp)
|
||||
if model.config.model_type == "qwen":
|
||||
|
|
@ -1525,8 +1527,6 @@ def _optimize_post(model, lightweight_bmm=False):
|
|||
# for phi-3
|
||||
modeling_module_name = model.__class__.__module__
|
||||
module = importlib.import_module(modeling_module_name)
|
||||
from ipex_llm.transformers.models.phi3 import su_scaled_rope_forward
|
||||
convert_forward(model, module.Phi3SuScaledRotaryEmbedding, su_scaled_rope_forward)
|
||||
from ipex_llm.transformers.models.phi3 import attention_forward
|
||||
convert_forward(model, module.Phi3Attention, attention_forward)
|
||||
from ipex_llm.transformers.models.phi3 import mlp_forward
|
||||
|
|
@ -1534,13 +1534,8 @@ def _optimize_post(model, lightweight_bmm=False):
|
|||
from ipex_llm.transformers.models.phi3 import model_forward_wrapper
|
||||
model_forward = model_forward_wrapper(module.Phi3Model.forward)
|
||||
convert_forward(model, module.Phi3Model, model_forward)
|
||||
from ipex_llm.transformers.models.phi3 import Phi3RotaryEmbeddingCached
|
||||
replace_RotaryEmbed(model, module.Phi3RotaryEmbedding, Phi3RotaryEmbeddingCached)
|
||||
from ipex_llm.transformers.models.phi3 import phi3_rms_norm_forward
|
||||
convert_forward(
|
||||
model,
|
||||
module.Phi3RMSNorm,
|
||||
phi3_rms_norm_forward)
|
||||
convert_forward(model, module.Phi3RMSNorm, phi3_rms_norm_forward)
|
||||
elif model.config.model_type == 'yuan':
|
||||
modeling_module_name = model.__class__.__module__
|
||||
module = importlib.import_module(modeling_module_name)
|
||||
|
|
|
|||
|
|
@ -38,7 +38,6 @@ from torch import nn
|
|||
|
||||
from ipex_llm.transformers.models.utils import (
|
||||
rotate_half, should_use_fuse_rope,
|
||||
apply_rotary_pos_emb_cache_freq_xpu
|
||||
)
|
||||
from ipex_llm.transformers.models.utils import mlp_fusion_check, SILU
|
||||
from ipex_llm.transformers.models.utils import use_sdp, use_sdp_causal
|
||||
|
|
@ -58,46 +57,29 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
|||
return q_embed, k_embed
|
||||
|
||||
|
||||
def su_scaled_rope_forward(self, x: torch.Tensor, position_ids: torch.Tensor, seq_len=None):
|
||||
if self.inv_freq is None:
|
||||
short_ext_factors = torch.tensor(self.short_factor, dtype=torch.float32, device=x.device)
|
||||
inv_freq_shape = torch.arange(0, self.dim, 2,
|
||||
dtype=torch.int64, device=x.device).float() / self.dim
|
||||
self.inv_freq = 1.0 / (short_ext_factors * self.base**inv_freq_shape)
|
||||
def pre_compute_inv_freq(module: torch.nn.Module):
|
||||
if module.__class__.__name__ == "Phi3RotaryEmbedding":
|
||||
module.inv_freq = 1.0 / (
|
||||
module.base **
|
||||
(torch.arange(0, module.dim, 2, dtype=torch.int64).float() / module.dim)
|
||||
)
|
||||
elif module.__class__.__name__ == "Phi3SuScaledRotaryEmbedding":
|
||||
inv_freq_shape = torch.arange(0, module.dim, 2, dtype=torch.int64).float() / module.dim
|
||||
short_ext_factors = torch.tensor(module.short_factor, dtype=torch.float32)
|
||||
module.inv_freq = 1.0 / (short_ext_factors * module.base ** inv_freq_shape)
|
||||
|
||||
long_ext_factors = torch.tensor(self.long_factor, dtype=torch.float32, device=x.device)
|
||||
self.register_buffer("long_inv_freq", None, persistent=False)
|
||||
self.long_inv_freq = 1.0 / (long_ext_factors * self.base**inv_freq_shape)
|
||||
long_ext_factors = torch.tensor(module.long_factor, dtype=torch.float32)
|
||||
module.register_buffer("long_inv_freq", None, persistent=False)
|
||||
module.long_inv_freq = 1.0 / (long_ext_factors * module.base ** inv_freq_shape)
|
||||
|
||||
seq_len = seq_len if seq_len is not None else torch.max(position_ids) + 1
|
||||
if seq_len > self.original_max_position_embeddings:
|
||||
inv_freq = self.long_inv_freq
|
||||
else:
|
||||
inv_freq = self.inv_freq
|
||||
|
||||
inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
|
||||
position_ids_expanded = position_ids[:, None, :].float()
|
||||
|
||||
# Force float32 since bfloat16 loses precision on long contexts
|
||||
# See https://github.com/huggingface/transformers/pull/29285
|
||||
device_type = x.device.type
|
||||
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
|
||||
with torch.autocast(device_type=device_type, enabled=False):
|
||||
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
|
||||
scale = self.max_position_embeddings / self.original_max_position_embeddings
|
||||
if scale <= 1.0:
|
||||
scaling_factor = 1.0
|
||||
if module.max_position_embeddings <= module.original_max_position_embeddings:
|
||||
module.scaling_factor = 1.0
|
||||
else:
|
||||
scaling_factor = math.sqrt(
|
||||
1 + math.log(scale) / math.log(self.original_max_position_embeddings)
|
||||
scale = module.max_position_embeddings / module.original_max_position_embeddings
|
||||
module.scaling_factor = math.sqrt(
|
||||
1 + math.log(scale) / math.log(module.original_max_position_embeddings)
|
||||
)
|
||||
|
||||
cos = emb.cos() * scaling_factor
|
||||
sin = emb.sin() * scaling_factor
|
||||
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
||||
|
||||
|
||||
def attention_forward(
|
||||
self,
|
||||
|
|
@ -124,12 +106,24 @@ def attention_forward(
|
|||
if past_key_value is not None:
|
||||
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||||
|
||||
cos, sin = self.rotary_emb(value_states, position_ids, seq_len=kv_seq_len)
|
||||
# IPEX-LLM OPT: fuse rope
|
||||
if should_use_fuse_rope(hidden_states, position_ids, self.training):
|
||||
query_states, key_states = apply_rotary_pos_emb_cache_freq_xpu(query_states, key_states,
|
||||
sin, cos, "phi3")
|
||||
import linear_q4_0
|
||||
if self.rotary_emb.__class__.__name__ == "Phi3RotaryEmbedding": # 4k
|
||||
linear_q4_0.rotary_half_inplaced(self.rotary_emb.inv_freq, position_ids,
|
||||
query_states, key_states)
|
||||
else: # 128k
|
||||
if kv_seq_len > self.rotary_emb.original_max_position_embeddings:
|
||||
linear_q4_0.rotary_half_inplaced(self.rotary_emb.long_inv_freq, position_ids,
|
||||
query_states, key_states)
|
||||
else:
|
||||
linear_q4_0.rotary_half_inplaced(self.rotary_emb.inv_freq, position_ids,
|
||||
query_states, key_states)
|
||||
# todo: fuse scaling_factor
|
||||
query_states *= self.rotary_emb.scaling_factor
|
||||
key_states *= self.rotary_emb.scaling_factor
|
||||
else:
|
||||
cos, sin = self.rotary_emb(value_states, position_ids, seq_len=kv_seq_len)
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
|
||||
cos, sin, position_ids)
|
||||
|
||||
|
|
@ -257,50 +251,6 @@ def model_forward_wrapper(origin_model_forward):
|
|||
return model_forward
|
||||
|
||||
|
||||
class Phi3RotaryEmbeddingCached(nn.Module):
|
||||
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
||||
super().__init__()
|
||||
|
||||
self.dim = dim
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.base = base
|
||||
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2,
|
||||
dtype=torch.int64,
|
||||
device=device).float() / self.dim))
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
# self.gen_seq_len = None
|
||||
|
||||
# Build here to make `torch.jit.trace` work.
|
||||
self._set_cos_sin_cache(
|
||||
seq_len=max_position_embeddings,
|
||||
device=self.inv_freq.device, dtype=torch.get_default_dtype()
|
||||
)
|
||||
|
||||
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
||||
self.max_seq_len_cached = seq_len
|
||||
position_ids_expanded = torch.arange(self.max_seq_len_cached,
|
||||
device=device,
|
||||
dtype=self.inv_freq.dtype).reshape(1, 1, -1)
|
||||
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(1, -1, 1)
|
||||
with torch.autocast(device_type=device.type, enabled=False):
|
||||
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
||||
# Different from paper,
|
||||
# but it uses a different permutation in order to obtain the same calculation
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
|
||||
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
||||
|
||||
def forward(self, x, position_ids, seq_len=None):
|
||||
# x: [bs, num_attention_heads, seq_len, head_size]
|
||||
if seq_len > self.max_seq_len_cached:
|
||||
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
|
||||
|
||||
return (
|
||||
self.cos_cached[:, seq_len-position_ids.shape[-1]:seq_len, :].to(dtype=x.dtype),
|
||||
self.sin_cached[:, seq_len-position_ids.shape[-1]:seq_len, :].to(dtype=x.dtype),
|
||||
)
|
||||
|
||||
|
||||
def phi3_rms_norm_forward(self, hidden_states):
|
||||
if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad):
|
||||
import linear_q4_0
|
||||
|
|
|
|||
Loading…
Reference in a new issue