optimize phi-3-mini-128 (#10959)

This commit is contained in:
Yishuo Wang 2024-05-08 16:33:17 +08:00 committed by GitHub
parent dfa3147278
commit 2ebec0395c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 43 additions and 0 deletions

View file

@ -1508,6 +1508,8 @@ def _optimize_post(model, lightweight_bmm=False):
# for phi-3
modeling_module_name = model.__class__.__module__
module = importlib.import_module(modeling_module_name)
from ipex_llm.transformers.models.phi3 import su_scaled_rope_forward
convert_forward(model, module.Phi3SuScaledRotaryEmbedding, su_scaled_rope_forward)
from ipex_llm.transformers.models.phi3 import attention_forward
convert_forward(model, module.Phi3Attention, attention_forward)
from ipex_llm.transformers.models.phi3 import mlp_forward

View file

@ -57,6 +57,47 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
return q_embed, k_embed
def su_scaled_rope_forward(self, x: torch.Tensor, position_ids: torch.Tensor, seq_len=None):
if self.inv_freq is None:
short_ext_factors = torch.tensor(self.short_factor, dtype=torch.float32, device=x.device)
inv_freq_shape = torch.arange(0, self.dim, 2,
dtype=torch.int64, device=x.device).float() / self.dim
self.inv_freq = 1.0 / (short_ext_factors * self.base**inv_freq_shape)
long_ext_factors = torch.tensor(self.long_factor, dtype=torch.float32, device=x.device)
self.register_buffer("long_inv_freq", None, persistent=False)
self.long_inv_freq = 1.0 / (long_ext_factors * self.base**inv_freq_shape)
seq_len = seq_len if seq_len is not None else torch.max(position_ids) + 1
if seq_len > self.original_max_position_embeddings:
inv_freq = self.long_inv_freq
else:
inv_freq = self.inv_freq
inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float()
# Force float32 since bfloat16 loses precision on long contexts
# See https://github.com/huggingface/transformers/pull/29285
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
scale = self.max_position_embeddings / self.original_max_position_embeddings
if scale <= 1.0:
scaling_factor = 1.0
else:
scaling_factor = math.sqrt(
1 + math.log(scale) / math.log(self.original_max_position_embeddings)
)
cos = emb.cos() * scaling_factor
sin = emb.sin() * scaling_factor
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
def attention_forward(
self,
hidden_states: torch.Tensor,