use new rope in phi3 (#11047)
This commit is contained in:
parent
00d4410746
commit
8cae897643
2 changed files with 36 additions and 91 deletions
|
|
@ -706,6 +706,8 @@ def _optimize_pre(model):
|
||||||
from ipex_llm.transformers.models.phi import merge_qkv
|
from ipex_llm.transformers.models.phi import merge_qkv
|
||||||
model.apply(merge_qkv)
|
model.apply(merge_qkv)
|
||||||
if model.config.model_type == "phi3":
|
if model.config.model_type == "phi3":
|
||||||
|
from ipex_llm.transformers.models.phi3 import pre_compute_inv_freq
|
||||||
|
model.apply(pre_compute_inv_freq)
|
||||||
from ipex_llm.transformers.models.phi3 import split_mlp
|
from ipex_llm.transformers.models.phi3 import split_mlp
|
||||||
model.apply(split_mlp)
|
model.apply(split_mlp)
|
||||||
if model.config.model_type == "qwen":
|
if model.config.model_type == "qwen":
|
||||||
|
|
@ -1525,8 +1527,6 @@ def _optimize_post(model, lightweight_bmm=False):
|
||||||
# for phi-3
|
# for phi-3
|
||||||
modeling_module_name = model.__class__.__module__
|
modeling_module_name = model.__class__.__module__
|
||||||
module = importlib.import_module(modeling_module_name)
|
module = importlib.import_module(modeling_module_name)
|
||||||
from ipex_llm.transformers.models.phi3 import su_scaled_rope_forward
|
|
||||||
convert_forward(model, module.Phi3SuScaledRotaryEmbedding, su_scaled_rope_forward)
|
|
||||||
from ipex_llm.transformers.models.phi3 import attention_forward
|
from ipex_llm.transformers.models.phi3 import attention_forward
|
||||||
convert_forward(model, module.Phi3Attention, attention_forward)
|
convert_forward(model, module.Phi3Attention, attention_forward)
|
||||||
from ipex_llm.transformers.models.phi3 import mlp_forward
|
from ipex_llm.transformers.models.phi3 import mlp_forward
|
||||||
|
|
@ -1534,13 +1534,8 @@ def _optimize_post(model, lightweight_bmm=False):
|
||||||
from ipex_llm.transformers.models.phi3 import model_forward_wrapper
|
from ipex_llm.transformers.models.phi3 import model_forward_wrapper
|
||||||
model_forward = model_forward_wrapper(module.Phi3Model.forward)
|
model_forward = model_forward_wrapper(module.Phi3Model.forward)
|
||||||
convert_forward(model, module.Phi3Model, model_forward)
|
convert_forward(model, module.Phi3Model, model_forward)
|
||||||
from ipex_llm.transformers.models.phi3 import Phi3RotaryEmbeddingCached
|
|
||||||
replace_RotaryEmbed(model, module.Phi3RotaryEmbedding, Phi3RotaryEmbeddingCached)
|
|
||||||
from ipex_llm.transformers.models.phi3 import phi3_rms_norm_forward
|
from ipex_llm.transformers.models.phi3 import phi3_rms_norm_forward
|
||||||
convert_forward(
|
convert_forward(model, module.Phi3RMSNorm, phi3_rms_norm_forward)
|
||||||
model,
|
|
||||||
module.Phi3RMSNorm,
|
|
||||||
phi3_rms_norm_forward)
|
|
||||||
elif model.config.model_type == 'yuan':
|
elif model.config.model_type == 'yuan':
|
||||||
modeling_module_name = model.__class__.__module__
|
modeling_module_name = model.__class__.__module__
|
||||||
module = importlib.import_module(modeling_module_name)
|
module = importlib.import_module(modeling_module_name)
|
||||||
|
|
|
||||||
|
|
@ -38,7 +38,6 @@ from torch import nn
|
||||||
|
|
||||||
from ipex_llm.transformers.models.utils import (
|
from ipex_llm.transformers.models.utils import (
|
||||||
rotate_half, should_use_fuse_rope,
|
rotate_half, should_use_fuse_rope,
|
||||||
apply_rotary_pos_emb_cache_freq_xpu
|
|
||||||
)
|
)
|
||||||
from ipex_llm.transformers.models.utils import mlp_fusion_check, SILU
|
from ipex_llm.transformers.models.utils import mlp_fusion_check, SILU
|
||||||
from ipex_llm.transformers.models.utils import use_sdp, use_sdp_causal
|
from ipex_llm.transformers.models.utils import use_sdp, use_sdp_causal
|
||||||
|
|
@ -58,45 +57,28 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
||||||
return q_embed, k_embed
|
return q_embed, k_embed
|
||||||
|
|
||||||
|
|
||||||
def su_scaled_rope_forward(self, x: torch.Tensor, position_ids: torch.Tensor, seq_len=None):
|
def pre_compute_inv_freq(module: torch.nn.Module):
|
||||||
if self.inv_freq is None:
|
if module.__class__.__name__ == "Phi3RotaryEmbedding":
|
||||||
short_ext_factors = torch.tensor(self.short_factor, dtype=torch.float32, device=x.device)
|
module.inv_freq = 1.0 / (
|
||||||
inv_freq_shape = torch.arange(0, self.dim, 2,
|
module.base **
|
||||||
dtype=torch.int64, device=x.device).float() / self.dim
|
(torch.arange(0, module.dim, 2, dtype=torch.int64).float() / module.dim)
|
||||||
self.inv_freq = 1.0 / (short_ext_factors * self.base**inv_freq_shape)
|
|
||||||
|
|
||||||
long_ext_factors = torch.tensor(self.long_factor, dtype=torch.float32, device=x.device)
|
|
||||||
self.register_buffer("long_inv_freq", None, persistent=False)
|
|
||||||
self.long_inv_freq = 1.0 / (long_ext_factors * self.base**inv_freq_shape)
|
|
||||||
|
|
||||||
seq_len = seq_len if seq_len is not None else torch.max(position_ids) + 1
|
|
||||||
if seq_len > self.original_max_position_embeddings:
|
|
||||||
inv_freq = self.long_inv_freq
|
|
||||||
else:
|
|
||||||
inv_freq = self.inv_freq
|
|
||||||
|
|
||||||
inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
|
|
||||||
position_ids_expanded = position_ids[:, None, :].float()
|
|
||||||
|
|
||||||
# Force float32 since bfloat16 loses precision on long contexts
|
|
||||||
# See https://github.com/huggingface/transformers/pull/29285
|
|
||||||
device_type = x.device.type
|
|
||||||
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
|
|
||||||
with torch.autocast(device_type=device_type, enabled=False):
|
|
||||||
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
|
||||||
emb = torch.cat((freqs, freqs), dim=-1)
|
|
||||||
|
|
||||||
scale = self.max_position_embeddings / self.original_max_position_embeddings
|
|
||||||
if scale <= 1.0:
|
|
||||||
scaling_factor = 1.0
|
|
||||||
else:
|
|
||||||
scaling_factor = math.sqrt(
|
|
||||||
1 + math.log(scale) / math.log(self.original_max_position_embeddings)
|
|
||||||
)
|
)
|
||||||
|
elif module.__class__.__name__ == "Phi3SuScaledRotaryEmbedding":
|
||||||
|
inv_freq_shape = torch.arange(0, module.dim, 2, dtype=torch.int64).float() / module.dim
|
||||||
|
short_ext_factors = torch.tensor(module.short_factor, dtype=torch.float32)
|
||||||
|
module.inv_freq = 1.0 / (short_ext_factors * module.base ** inv_freq_shape)
|
||||||
|
|
||||||
cos = emb.cos() * scaling_factor
|
long_ext_factors = torch.tensor(module.long_factor, dtype=torch.float32)
|
||||||
sin = emb.sin() * scaling_factor
|
module.register_buffer("long_inv_freq", None, persistent=False)
|
||||||
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
module.long_inv_freq = 1.0 / (long_ext_factors * module.base ** inv_freq_shape)
|
||||||
|
|
||||||
|
if module.max_position_embeddings <= module.original_max_position_embeddings:
|
||||||
|
module.scaling_factor = 1.0
|
||||||
|
else:
|
||||||
|
scale = module.max_position_embeddings / module.original_max_position_embeddings
|
||||||
|
module.scaling_factor = math.sqrt(
|
||||||
|
1 + math.log(scale) / math.log(module.original_max_position_embeddings)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def attention_forward(
|
def attention_forward(
|
||||||
|
|
@ -124,12 +106,24 @@ def attention_forward(
|
||||||
if past_key_value is not None:
|
if past_key_value is not None:
|
||||||
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||||||
|
|
||||||
cos, sin = self.rotary_emb(value_states, position_ids, seq_len=kv_seq_len)
|
|
||||||
# IPEX-LLM OPT: fuse rope
|
# IPEX-LLM OPT: fuse rope
|
||||||
if should_use_fuse_rope(hidden_states, position_ids, self.training):
|
if should_use_fuse_rope(hidden_states, position_ids, self.training):
|
||||||
query_states, key_states = apply_rotary_pos_emb_cache_freq_xpu(query_states, key_states,
|
import linear_q4_0
|
||||||
sin, cos, "phi3")
|
if self.rotary_emb.__class__.__name__ == "Phi3RotaryEmbedding": # 4k
|
||||||
|
linear_q4_0.rotary_half_inplaced(self.rotary_emb.inv_freq, position_ids,
|
||||||
|
query_states, key_states)
|
||||||
|
else: # 128k
|
||||||
|
if kv_seq_len > self.rotary_emb.original_max_position_embeddings:
|
||||||
|
linear_q4_0.rotary_half_inplaced(self.rotary_emb.long_inv_freq, position_ids,
|
||||||
|
query_states, key_states)
|
||||||
else:
|
else:
|
||||||
|
linear_q4_0.rotary_half_inplaced(self.rotary_emb.inv_freq, position_ids,
|
||||||
|
query_states, key_states)
|
||||||
|
# todo: fuse scaling_factor
|
||||||
|
query_states *= self.rotary_emb.scaling_factor
|
||||||
|
key_states *= self.rotary_emb.scaling_factor
|
||||||
|
else:
|
||||||
|
cos, sin = self.rotary_emb(value_states, position_ids, seq_len=kv_seq_len)
|
||||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
|
||||||
cos, sin, position_ids)
|
cos, sin, position_ids)
|
||||||
|
|
||||||
|
|
@ -257,50 +251,6 @@ def model_forward_wrapper(origin_model_forward):
|
||||||
return model_forward
|
return model_forward
|
||||||
|
|
||||||
|
|
||||||
class Phi3RotaryEmbeddingCached(nn.Module):
|
|
||||||
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.dim = dim
|
|
||||||
self.max_position_embeddings = max_position_embeddings
|
|
||||||
self.base = base
|
|
||||||
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2,
|
|
||||||
dtype=torch.int64,
|
|
||||||
device=device).float() / self.dim))
|
|
||||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
||||||
# self.gen_seq_len = None
|
|
||||||
|
|
||||||
# Build here to make `torch.jit.trace` work.
|
|
||||||
self._set_cos_sin_cache(
|
|
||||||
seq_len=max_position_embeddings,
|
|
||||||
device=self.inv_freq.device, dtype=torch.get_default_dtype()
|
|
||||||
)
|
|
||||||
|
|
||||||
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
|
||||||
self.max_seq_len_cached = seq_len
|
|
||||||
position_ids_expanded = torch.arange(self.max_seq_len_cached,
|
|
||||||
device=device,
|
|
||||||
dtype=self.inv_freq.dtype).reshape(1, 1, -1)
|
|
||||||
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(1, -1, 1)
|
|
||||||
with torch.autocast(device_type=device.type, enabled=False):
|
|
||||||
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
|
||||||
# Different from paper,
|
|
||||||
# but it uses a different permutation in order to obtain the same calculation
|
|
||||||
emb = torch.cat((freqs, freqs), dim=-1)
|
|
||||||
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
|
|
||||||
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
|
||||||
|
|
||||||
def forward(self, x, position_ids, seq_len=None):
|
|
||||||
# x: [bs, num_attention_heads, seq_len, head_size]
|
|
||||||
if seq_len > self.max_seq_len_cached:
|
|
||||||
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
|
|
||||||
|
|
||||||
return (
|
|
||||||
self.cos_cached[:, seq_len-position_ids.shape[-1]:seq_len, :].to(dtype=x.dtype),
|
|
||||||
self.sin_cached[:, seq_len-position_ids.shape[-1]:seq_len, :].to(dtype=x.dtype),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def phi3_rms_norm_forward(self, hidden_states):
|
def phi3_rms_norm_forward(self, hidden_states):
|
||||||
if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad):
|
if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad):
|
||||||
import linear_q4_0
|
import linear_q4_0
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue