Add phi3 cached RotaryEmbedding (#11013)

* phi3cachedrotaryembed

* pep8
This commit is contained in:
Zhao Changmin 2024-05-15 08:16:43 +08:00 committed by GitHub
parent 0b7e78b592
commit 0a732bebe7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 56 additions and 0 deletions

View file

@ -848,6 +848,15 @@ def convert_forward(m, target_m, new_forward):
convert_forward(sub_m, target_m, new_forward) convert_forward(sub_m, target_m, new_forward)
def replace_RotaryEmbed(m, target_m, replace_embed):
for attr_name, sub_m in m.named_children():
if isinstance(sub_m, target_m):
setattr(m, attr_name, replace_embed(sub_m.dim,
sub_m.max_position_embeddings,
sub_m.base))
replace_RotaryEmbed(sub_m, target_m, replace_embed)
def replace_func(m, target_m, func_name, new_func): def replace_func(m, target_m, func_name, new_func):
for _, sub_m in m.named_children(): for _, sub_m in m.named_children():
if isinstance(sub_m, target_m): if isinstance(sub_m, target_m):
@ -1517,6 +1526,8 @@ def _optimize_post(model, lightweight_bmm=False):
from ipex_llm.transformers.models.phi3 import model_forward_wrapper from ipex_llm.transformers.models.phi3 import model_forward_wrapper
model_forward = model_forward_wrapper(module.Phi3Model.forward) model_forward = model_forward_wrapper(module.Phi3Model.forward)
convert_forward(model, module.Phi3Model, model_forward) convert_forward(model, module.Phi3Model, model_forward)
from ipex_llm.transformers.models.phi3 import Phi3RotaryEmbeddingCached
replace_RotaryEmbed(model, module.Phi3RotaryEmbedding, Phi3RotaryEmbeddingCached)
from ipex_llm.transformers.models.phi3 import phi3_rms_norm_forward from ipex_llm.transformers.models.phi3 import phi3_rms_norm_forward
convert_forward( convert_forward(
model, model,

View file

@ -34,6 +34,7 @@
import math import math
import torch import torch
import warnings import warnings
from torch import nn
from ipex_llm.transformers.models.utils import ( from ipex_llm.transformers.models.utils import (
rotate_half, should_use_fuse_rope, rotate_half, should_use_fuse_rope,
@ -255,6 +256,50 @@ def model_forward_wrapper(origin_model_forward):
return model_forward return model_forward
class Phi3RotaryEmbeddingCached(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__()
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2,
dtype=torch.int64,
device=device).float() / self.dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
# self.gen_seq_len = None
# Build here to make `torch.jit.trace` work.
self._set_cos_sin_cache(
seq_len=max_position_embeddings,
device=self.inv_freq.device, dtype=torch.get_default_dtype()
)
def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len
position_ids_expanded = torch.arange(self.max_seq_len_cached,
device=device,
dtype=self.inv_freq.dtype).reshape(1, 1, -1)
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(1, -1, 1)
with torch.autocast(device_type=device.type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
# Different from paper,
# but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
def forward(self, x, position_ids, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
if seq_len > self.max_seq_len_cached:
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
return (
self.cos_cached[:, seq_len-position_ids.shape[-1]:seq_len, :].to(dtype=x.dtype),
self.sin_cached[:, seq_len-position_ids.shape[-1]:seq_len, :].to(dtype=x.dtype),
)
def phi3_rms_norm_forward(self, hidden_states): def phi3_rms_norm_forward(self, hidden_states):
if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad): if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad):
import linear_q4_0 import linear_q4_0