From 2ebec0395cc2d9fedd8b44bf5a839d0f77fe2246 Mon Sep 17 00:00:00 2001 From: Yishuo Wang Date: Wed, 8 May 2024 16:33:17 +0800 Subject: [PATCH] optimize phi-3-mini-128 (#10959) --- .../llm/src/ipex_llm/transformers/convert.py | 2 + .../src/ipex_llm/transformers/models/phi3.py | 41 +++++++++++++++++++ 2 files changed, 43 insertions(+) diff --git a/python/llm/src/ipex_llm/transformers/convert.py b/python/llm/src/ipex_llm/transformers/convert.py index dcd4597a..fc0766ae 100644 --- a/python/llm/src/ipex_llm/transformers/convert.py +++ b/python/llm/src/ipex_llm/transformers/convert.py @@ -1508,6 +1508,8 @@ def _optimize_post(model, lightweight_bmm=False): # for phi-3 modeling_module_name = model.__class__.__module__ module = importlib.import_module(modeling_module_name) + from ipex_llm.transformers.models.phi3 import su_scaled_rope_forward + convert_forward(model, module.Phi3SuScaledRotaryEmbedding, su_scaled_rope_forward) from ipex_llm.transformers.models.phi3 import attention_forward convert_forward(model, module.Phi3Attention, attention_forward) from ipex_llm.transformers.models.phi3 import mlp_forward diff --git a/python/llm/src/ipex_llm/transformers/models/phi3.py b/python/llm/src/ipex_llm/transformers/models/phi3.py index 6092ceab..ac3b65c2 100644 --- a/python/llm/src/ipex_llm/transformers/models/phi3.py +++ b/python/llm/src/ipex_llm/transformers/models/phi3.py @@ -57,6 +57,47 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): return q_embed, k_embed +def su_scaled_rope_forward(self, x: torch.Tensor, position_ids: torch.Tensor, seq_len=None): + if self.inv_freq is None: + short_ext_factors = torch.tensor(self.short_factor, dtype=torch.float32, device=x.device) + inv_freq_shape = torch.arange(0, self.dim, 2, + dtype=torch.int64, device=x.device).float() / self.dim + self.inv_freq = 1.0 / (short_ext_factors * self.base**inv_freq_shape) + + long_ext_factors = torch.tensor(self.long_factor, dtype=torch.float32, device=x.device) + self.register_buffer("long_inv_freq", None, persistent=False) + self.long_inv_freq = 1.0 / (long_ext_factors * self.base**inv_freq_shape) + + seq_len = seq_len if seq_len is not None else torch.max(position_ids) + 1 + if seq_len > self.original_max_position_embeddings: + inv_freq = self.long_inv_freq + else: + inv_freq = self.inv_freq + + inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + + # Force float32 since bfloat16 loses precision on long contexts + # See https://github.com/huggingface/transformers/pull/29285 + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + + scale = self.max_position_embeddings / self.original_max_position_embeddings + if scale <= 1.0: + scaling_factor = 1.0 + else: + scaling_factor = math.sqrt( + 1 + math.log(scale) / math.log(self.original_max_position_embeddings) + ) + + cos = emb.cos() * scaling_factor + sin = emb.sin() * scaling_factor + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + def attention_forward( self, hidden_states: torch.Tensor,