Add phi3 cached RotaryEmbedding (#11013)
* phi3cachedrotaryembed * pep8
This commit is contained in:
parent
0b7e78b592
commit
0a732bebe7
2 changed files with 56 additions and 0 deletions
|
|
@ -848,6 +848,15 @@ def convert_forward(m, target_m, new_forward):
|
|||
convert_forward(sub_m, target_m, new_forward)
|
||||
|
||||
|
||||
def replace_RotaryEmbed(m, target_m, replace_embed):
|
||||
for attr_name, sub_m in m.named_children():
|
||||
if isinstance(sub_m, target_m):
|
||||
setattr(m, attr_name, replace_embed(sub_m.dim,
|
||||
sub_m.max_position_embeddings,
|
||||
sub_m.base))
|
||||
replace_RotaryEmbed(sub_m, target_m, replace_embed)
|
||||
|
||||
|
||||
def replace_func(m, target_m, func_name, new_func):
|
||||
for _, sub_m in m.named_children():
|
||||
if isinstance(sub_m, target_m):
|
||||
|
|
@ -1517,6 +1526,8 @@ def _optimize_post(model, lightweight_bmm=False):
|
|||
from ipex_llm.transformers.models.phi3 import model_forward_wrapper
|
||||
model_forward = model_forward_wrapper(module.Phi3Model.forward)
|
||||
convert_forward(model, module.Phi3Model, model_forward)
|
||||
from ipex_llm.transformers.models.phi3 import Phi3RotaryEmbeddingCached
|
||||
replace_RotaryEmbed(model, module.Phi3RotaryEmbedding, Phi3RotaryEmbeddingCached)
|
||||
from ipex_llm.transformers.models.phi3 import phi3_rms_norm_forward
|
||||
convert_forward(
|
||||
model,
|
||||
|
|
|
|||
|
|
@ -34,6 +34,7 @@
|
|||
import math
|
||||
import torch
|
||||
import warnings
|
||||
from torch import nn
|
||||
|
||||
from ipex_llm.transformers.models.utils import (
|
||||
rotate_half, should_use_fuse_rope,
|
||||
|
|
@ -255,6 +256,50 @@ def model_forward_wrapper(origin_model_forward):
|
|||
return model_forward
|
||||
|
||||
|
||||
class Phi3RotaryEmbeddingCached(nn.Module):
|
||||
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
||||
super().__init__()
|
||||
|
||||
self.dim = dim
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.base = base
|
||||
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2,
|
||||
dtype=torch.int64,
|
||||
device=device).float() / self.dim))
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
# self.gen_seq_len = None
|
||||
|
||||
# Build here to make `torch.jit.trace` work.
|
||||
self._set_cos_sin_cache(
|
||||
seq_len=max_position_embeddings,
|
||||
device=self.inv_freq.device, dtype=torch.get_default_dtype()
|
||||
)
|
||||
|
||||
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
||||
self.max_seq_len_cached = seq_len
|
||||
position_ids_expanded = torch.arange(self.max_seq_len_cached,
|
||||
device=device,
|
||||
dtype=self.inv_freq.dtype).reshape(1, 1, -1)
|
||||
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(1, -1, 1)
|
||||
with torch.autocast(device_type=device.type, enabled=False):
|
||||
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
||||
# Different from paper,
|
||||
# but it uses a different permutation in order to obtain the same calculation
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
|
||||
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
||||
|
||||
def forward(self, x, position_ids, seq_len=None):
|
||||
# x: [bs, num_attention_heads, seq_len, head_size]
|
||||
if seq_len > self.max_seq_len_cached:
|
||||
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
|
||||
|
||||
return (
|
||||
self.cos_cached[:, seq_len-position_ids.shape[-1]:seq_len, :].to(dtype=x.dtype),
|
||||
self.sin_cached[:, seq_len-position_ids.shape[-1]:seq_len, :].to(dtype=x.dtype),
|
||||
)
|
||||
|
||||
|
||||
def phi3_rms_norm_forward(self, hidden_states):
|
||||
if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad):
|
||||
import linear_q4_0
|
||||
|
|
|
|||
Loading…
Reference in a new issue