From 0a732bebe72906c4a7dc85a4b65b363e09b8eaa0 Mon Sep 17 00:00:00 2001 From: Zhao Changmin Date: Wed, 15 May 2024 08:16:43 +0800 Subject: [PATCH] Add phi3 cached RotaryEmbedding (#11013) * phi3cachedrotaryembed * pep8 --- .../llm/src/ipex_llm/transformers/convert.py | 11 +++++ .../src/ipex_llm/transformers/models/phi3.py | 45 +++++++++++++++++++ 2 files changed, 56 insertions(+) diff --git a/python/llm/src/ipex_llm/transformers/convert.py b/python/llm/src/ipex_llm/transformers/convert.py index 37d1fba1..e60d5e1d 100644 --- a/python/llm/src/ipex_llm/transformers/convert.py +++ b/python/llm/src/ipex_llm/transformers/convert.py @@ -848,6 +848,15 @@ def convert_forward(m, target_m, new_forward): convert_forward(sub_m, target_m, new_forward) +def replace_RotaryEmbed(m, target_m, replace_embed): + for attr_name, sub_m in m.named_children(): + if isinstance(sub_m, target_m): + setattr(m, attr_name, replace_embed(sub_m.dim, + sub_m.max_position_embeddings, + sub_m.base)) + replace_RotaryEmbed(sub_m, target_m, replace_embed) + + def replace_func(m, target_m, func_name, new_func): for _, sub_m in m.named_children(): if isinstance(sub_m, target_m): @@ -1517,6 +1526,8 @@ def _optimize_post(model, lightweight_bmm=False): from ipex_llm.transformers.models.phi3 import model_forward_wrapper model_forward = model_forward_wrapper(module.Phi3Model.forward) convert_forward(model, module.Phi3Model, model_forward) + from ipex_llm.transformers.models.phi3 import Phi3RotaryEmbeddingCached + replace_RotaryEmbed(model, module.Phi3RotaryEmbedding, Phi3RotaryEmbeddingCached) from ipex_llm.transformers.models.phi3 import phi3_rms_norm_forward convert_forward( model, diff --git a/python/llm/src/ipex_llm/transformers/models/phi3.py b/python/llm/src/ipex_llm/transformers/models/phi3.py index d8715571..98fba35e 100644 --- a/python/llm/src/ipex_llm/transformers/models/phi3.py +++ b/python/llm/src/ipex_llm/transformers/models/phi3.py @@ -34,6 +34,7 @@ import math import torch import warnings +from torch import nn from ipex_llm.transformers.models.utils import ( rotate_half, should_use_fuse_rope, @@ -255,6 +256,50 @@ def model_forward_wrapper(origin_model_forward): return model_forward +class Phi3RotaryEmbeddingCached(nn.Module): + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): + super().__init__() + + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, + dtype=torch.int64, + device=device).float() / self.dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + # self.gen_seq_len = None + + # Build here to make `torch.jit.trace` work. + self._set_cos_sin_cache( + seq_len=max_position_embeddings, + device=self.inv_freq.device, dtype=torch.get_default_dtype() + ) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + position_ids_expanded = torch.arange(self.max_seq_len_cached, + device=device, + dtype=self.inv_freq.dtype).reshape(1, 1, -1) + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(1, -1, 1) + with torch.autocast(device_type=device.type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + # Different from paper, + # but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + + def forward(self, x, position_ids, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + if seq_len > self.max_seq_len_cached: + self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) + + return ( + self.cos_cached[:, seq_len-position_ids.shape[-1]:seq_len, :].to(dtype=x.dtype), + self.sin_cached[:, seq_len-position_ids.shape[-1]:seq_len, :].to(dtype=x.dtype), + ) + + def phi3_rms_norm_forward(self, hidden_states): if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad): import linear_q4_0