Add phi3 cached RotaryEmbedding (#11013)
* phi3cachedrotaryembed * pep8
This commit is contained in:
parent
0b7e78b592
commit
0a732bebe7
2 changed files with 56 additions and 0 deletions
|
|
@ -848,6 +848,15 @@ def convert_forward(m, target_m, new_forward):
|
||||||
convert_forward(sub_m, target_m, new_forward)
|
convert_forward(sub_m, target_m, new_forward)
|
||||||
|
|
||||||
|
|
||||||
|
def replace_RotaryEmbed(m, target_m, replace_embed):
|
||||||
|
for attr_name, sub_m in m.named_children():
|
||||||
|
if isinstance(sub_m, target_m):
|
||||||
|
setattr(m, attr_name, replace_embed(sub_m.dim,
|
||||||
|
sub_m.max_position_embeddings,
|
||||||
|
sub_m.base))
|
||||||
|
replace_RotaryEmbed(sub_m, target_m, replace_embed)
|
||||||
|
|
||||||
|
|
||||||
def replace_func(m, target_m, func_name, new_func):
|
def replace_func(m, target_m, func_name, new_func):
|
||||||
for _, sub_m in m.named_children():
|
for _, sub_m in m.named_children():
|
||||||
if isinstance(sub_m, target_m):
|
if isinstance(sub_m, target_m):
|
||||||
|
|
@ -1517,6 +1526,8 @@ def _optimize_post(model, lightweight_bmm=False):
|
||||||
from ipex_llm.transformers.models.phi3 import model_forward_wrapper
|
from ipex_llm.transformers.models.phi3 import model_forward_wrapper
|
||||||
model_forward = model_forward_wrapper(module.Phi3Model.forward)
|
model_forward = model_forward_wrapper(module.Phi3Model.forward)
|
||||||
convert_forward(model, module.Phi3Model, model_forward)
|
convert_forward(model, module.Phi3Model, model_forward)
|
||||||
|
from ipex_llm.transformers.models.phi3 import Phi3RotaryEmbeddingCached
|
||||||
|
replace_RotaryEmbed(model, module.Phi3RotaryEmbedding, Phi3RotaryEmbeddingCached)
|
||||||
from ipex_llm.transformers.models.phi3 import phi3_rms_norm_forward
|
from ipex_llm.transformers.models.phi3 import phi3_rms_norm_forward
|
||||||
convert_forward(
|
convert_forward(
|
||||||
model,
|
model,
|
||||||
|
|
|
||||||
|
|
@ -34,6 +34,7 @@
|
||||||
import math
|
import math
|
||||||
import torch
|
import torch
|
||||||
import warnings
|
import warnings
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
from ipex_llm.transformers.models.utils import (
|
from ipex_llm.transformers.models.utils import (
|
||||||
rotate_half, should_use_fuse_rope,
|
rotate_half, should_use_fuse_rope,
|
||||||
|
|
@ -255,6 +256,50 @@ def model_forward_wrapper(origin_model_forward):
|
||||||
return model_forward
|
return model_forward
|
||||||
|
|
||||||
|
|
||||||
|
class Phi3RotaryEmbeddingCached(nn.Module):
|
||||||
|
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.dim = dim
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
self.base = base
|
||||||
|
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2,
|
||||||
|
dtype=torch.int64,
|
||||||
|
device=device).float() / self.dim))
|
||||||
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||||
|
# self.gen_seq_len = None
|
||||||
|
|
||||||
|
# Build here to make `torch.jit.trace` work.
|
||||||
|
self._set_cos_sin_cache(
|
||||||
|
seq_len=max_position_embeddings,
|
||||||
|
device=self.inv_freq.device, dtype=torch.get_default_dtype()
|
||||||
|
)
|
||||||
|
|
||||||
|
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
||||||
|
self.max_seq_len_cached = seq_len
|
||||||
|
position_ids_expanded = torch.arange(self.max_seq_len_cached,
|
||||||
|
device=device,
|
||||||
|
dtype=self.inv_freq.dtype).reshape(1, 1, -1)
|
||||||
|
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(1, -1, 1)
|
||||||
|
with torch.autocast(device_type=device.type, enabled=False):
|
||||||
|
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
||||||
|
# Different from paper,
|
||||||
|
# but it uses a different permutation in order to obtain the same calculation
|
||||||
|
emb = torch.cat((freqs, freqs), dim=-1)
|
||||||
|
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
|
||||||
|
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
||||||
|
|
||||||
|
def forward(self, x, position_ids, seq_len=None):
|
||||||
|
# x: [bs, num_attention_heads, seq_len, head_size]
|
||||||
|
if seq_len > self.max_seq_len_cached:
|
||||||
|
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
|
||||||
|
|
||||||
|
return (
|
||||||
|
self.cos_cached[:, seq_len-position_ids.shape[-1]:seq_len, :].to(dtype=x.dtype),
|
||||||
|
self.sin_cached[:, seq_len-position_ids.shape[-1]:seq_len, :].to(dtype=x.dtype),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def phi3_rms_norm_forward(self, hidden_states):
|
def phi3_rms_norm_forward(self, hidden_states):
|
||||||
if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad):
|
if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad):
|
||||||
import linear_q4_0
|
import linear_q4_0
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue