From 8cae8976431563144c840cf9ad1e19fdc8b0b02a Mon Sep 17 00:00:00 2001 From: Yishuo Wang Date: Thu, 16 May 2024 15:12:35 +0800 Subject: [PATCH] use new rope in phi3 (#11047) --- .../llm/src/ipex_llm/transformers/convert.py | 11 +- .../src/ipex_llm/transformers/models/phi3.py | 116 +++++------------- 2 files changed, 36 insertions(+), 91 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/convert.py b/python/llm/src/ipex_llm/transformers/convert.py index e58da1a3..3baac65b 100644 --- a/python/llm/src/ipex_llm/transformers/convert.py +++ b/python/llm/src/ipex_llm/transformers/convert.py @@ -706,6 +706,8 @@ def _optimize_pre(model): from ipex_llm.transformers.models.phi import merge_qkv model.apply(merge_qkv) if model.config.model_type == "phi3": + from ipex_llm.transformers.models.phi3 import pre_compute_inv_freq + model.apply(pre_compute_inv_freq) from ipex_llm.transformers.models.phi3 import split_mlp model.apply(split_mlp) if model.config.model_type == "qwen": @@ -1525,8 +1527,6 @@ def _optimize_post(model, lightweight_bmm=False): # for phi-3 modeling_module_name = model.__class__.__module__ module = importlib.import_module(modeling_module_name) - from ipex_llm.transformers.models.phi3 import su_scaled_rope_forward - convert_forward(model, module.Phi3SuScaledRotaryEmbedding, su_scaled_rope_forward) from ipex_llm.transformers.models.phi3 import attention_forward convert_forward(model, module.Phi3Attention, attention_forward) from ipex_llm.transformers.models.phi3 import mlp_forward @@ -1534,13 +1534,8 @@ def _optimize_post(model, lightweight_bmm=False): from ipex_llm.transformers.models.phi3 import model_forward_wrapper model_forward = model_forward_wrapper(module.Phi3Model.forward) convert_forward(model, module.Phi3Model, model_forward) - from ipex_llm.transformers.models.phi3 import Phi3RotaryEmbeddingCached - replace_RotaryEmbed(model, module.Phi3RotaryEmbedding, Phi3RotaryEmbeddingCached) from ipex_llm.transformers.models.phi3 import phi3_rms_norm_forward - convert_forward( - model, - module.Phi3RMSNorm, - phi3_rms_norm_forward) + convert_forward(model, module.Phi3RMSNorm, phi3_rms_norm_forward) elif model.config.model_type == 'yuan': modeling_module_name = model.__class__.__module__ module = importlib.import_module(modeling_module_name) diff --git a/python/llm/src/ipex_llm/transformers/models/phi3.py b/python/llm/src/ipex_llm/transformers/models/phi3.py index 6593e81b..adc61373 100644 --- a/python/llm/src/ipex_llm/transformers/models/phi3.py +++ b/python/llm/src/ipex_llm/transformers/models/phi3.py @@ -38,7 +38,6 @@ from torch import nn from ipex_llm.transformers.models.utils import ( rotate_half, should_use_fuse_rope, - apply_rotary_pos_emb_cache_freq_xpu ) from ipex_llm.transformers.models.utils import mlp_fusion_check, SILU from ipex_llm.transformers.models.utils import use_sdp, use_sdp_causal @@ -58,46 +57,29 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): return q_embed, k_embed -def su_scaled_rope_forward(self, x: torch.Tensor, position_ids: torch.Tensor, seq_len=None): - if self.inv_freq is None: - short_ext_factors = torch.tensor(self.short_factor, dtype=torch.float32, device=x.device) - inv_freq_shape = torch.arange(0, self.dim, 2, - dtype=torch.int64, device=x.device).float() / self.dim - self.inv_freq = 1.0 / (short_ext_factors * self.base**inv_freq_shape) +def pre_compute_inv_freq(module: torch.nn.Module): + if module.__class__.__name__ == "Phi3RotaryEmbedding": + module.inv_freq = 1.0 / ( + module.base ** + (torch.arange(0, module.dim, 2, dtype=torch.int64).float() / module.dim) + ) + elif module.__class__.__name__ == "Phi3SuScaledRotaryEmbedding": + inv_freq_shape = torch.arange(0, module.dim, 2, dtype=torch.int64).float() / module.dim + short_ext_factors = torch.tensor(module.short_factor, dtype=torch.float32) + module.inv_freq = 1.0 / (short_ext_factors * module.base ** inv_freq_shape) - long_ext_factors = torch.tensor(self.long_factor, dtype=torch.float32, device=x.device) - self.register_buffer("long_inv_freq", None, persistent=False) - self.long_inv_freq = 1.0 / (long_ext_factors * self.base**inv_freq_shape) + long_ext_factors = torch.tensor(module.long_factor, dtype=torch.float32) + module.register_buffer("long_inv_freq", None, persistent=False) + module.long_inv_freq = 1.0 / (long_ext_factors * module.base ** inv_freq_shape) - seq_len = seq_len if seq_len is not None else torch.max(position_ids) + 1 - if seq_len > self.original_max_position_embeddings: - inv_freq = self.long_inv_freq - else: - inv_freq = self.inv_freq - - inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) - position_ids_expanded = position_ids[:, None, :].float() - - # Force float32 since bfloat16 loses precision on long contexts - # See https://github.com/huggingface/transformers/pull/29285 - device_type = x.device.type - device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): - freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) - emb = torch.cat((freqs, freqs), dim=-1) - - scale = self.max_position_embeddings / self.original_max_position_embeddings - if scale <= 1.0: - scaling_factor = 1.0 + if module.max_position_embeddings <= module.original_max_position_embeddings: + module.scaling_factor = 1.0 else: - scaling_factor = math.sqrt( - 1 + math.log(scale) / math.log(self.original_max_position_embeddings) + scale = module.max_position_embeddings / module.original_max_position_embeddings + module.scaling_factor = math.sqrt( + 1 + math.log(scale) / math.log(module.original_max_position_embeddings) ) - cos = emb.cos() * scaling_factor - sin = emb.sin() * scaling_factor - return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) - def attention_forward( self, @@ -124,12 +106,24 @@ def attention_forward( if past_key_value is not None: kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - cos, sin = self.rotary_emb(value_states, position_ids, seq_len=kv_seq_len) # IPEX-LLM OPT: fuse rope if should_use_fuse_rope(hidden_states, position_ids, self.training): - query_states, key_states = apply_rotary_pos_emb_cache_freq_xpu(query_states, key_states, - sin, cos, "phi3") + import linear_q4_0 + if self.rotary_emb.__class__.__name__ == "Phi3RotaryEmbedding": # 4k + linear_q4_0.rotary_half_inplaced(self.rotary_emb.inv_freq, position_ids, + query_states, key_states) + else: # 128k + if kv_seq_len > self.rotary_emb.original_max_position_embeddings: + linear_q4_0.rotary_half_inplaced(self.rotary_emb.long_inv_freq, position_ids, + query_states, key_states) + else: + linear_q4_0.rotary_half_inplaced(self.rotary_emb.inv_freq, position_ids, + query_states, key_states) + # todo: fuse scaling_factor + query_states *= self.rotary_emb.scaling_factor + key_states *= self.rotary_emb.scaling_factor else: + cos, sin = self.rotary_emb(value_states, position_ids, seq_len=kv_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) @@ -257,50 +251,6 @@ def model_forward_wrapper(origin_model_forward): return model_forward -class Phi3RotaryEmbeddingCached(nn.Module): - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): - super().__init__() - - self.dim = dim - self.max_position_embeddings = max_position_embeddings - self.base = base - inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, - dtype=torch.int64, - device=device).float() / self.dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - # self.gen_seq_len = None - - # Build here to make `torch.jit.trace` work. - self._set_cos_sin_cache( - seq_len=max_position_embeddings, - device=self.inv_freq.device, dtype=torch.get_default_dtype() - ) - - def _set_cos_sin_cache(self, seq_len, device, dtype): - self.max_seq_len_cached = seq_len - position_ids_expanded = torch.arange(self.max_seq_len_cached, - device=device, - dtype=self.inv_freq.dtype).reshape(1, 1, -1) - inv_freq_expanded = self.inv_freq[None, :, None].float().expand(1, -1, 1) - with torch.autocast(device_type=device.type, enabled=False): - freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) - # Different from paper, - # but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) - self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) - - def forward(self, x, position_ids, seq_len=None): - # x: [bs, num_attention_heads, seq_len, head_size] - if seq_len > self.max_seq_len_cached: - self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) - - return ( - self.cos_cached[:, seq_len-position_ids.shape[-1]:seq_len, :].to(dtype=x.dtype), - self.sin_cached[:, seq_len-position_ids.shape[-1]:seq_len, :].to(dtype=x.dtype), - ) - - def phi3_rms_norm_forward(self, hidden_states): if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad): import linear_q4_0