diff --git a/python/llm/src/ipex_llm/transformers/convert.py b/python/llm/src/ipex_llm/transformers/convert.py index 9d7f803d..9ddf7bae 100644 --- a/python/llm/src/ipex_llm/transformers/convert.py +++ b/python/llm/src/ipex_llm/transformers/convert.py @@ -653,6 +653,9 @@ def _optimize_pre(model): if model.config.model_type == "phi": from ipex_llm.transformers.models.phi import merge_qkv model.apply(merge_qkv) + if model.config.model_type == "phi3": + from ipex_llm.transformers.models.phi3 import split_mlp + model.apply(split_mlp) if model.config.model_type == "qwen": rope_base = model.config.rotary_emb_base from accelerate.big_modeling import init_empty_weights @@ -1426,6 +1429,17 @@ def _optimize_post(model, lightweight_bmm=False): from ipex_llm.transformers.models.phi import model_forward convert_forward(model, module.PhiAttention, attention_forward) convert_forward(model, module.PhiModel, model_forward) + elif model.config.model_type == "phi3": + # for phi-3 + modeling_module_name = model.__class__.__module__ + module = importlib.import_module(modeling_module_name) + from ipex_llm.transformers.models.phi3 import attention_forward + convert_forward(model, module.Phi3Attention, attention_forward) + from ipex_llm.transformers.models.phi3 import mlp_forward + convert_forward(model, module.Phi3MLP, mlp_forward) + from ipex_llm.transformers.models.phi3 import model_forward_wrapper + model_forward = model_forward_wrapper(module.Phi3Model.forward) + convert_forward(model, module.Phi3Model, model_forward) elif model.config.model_type == 'yuan': modeling_module_name = model.__class__.__module__ module = importlib.import_module(modeling_module_name) diff --git a/python/llm/src/ipex_llm/transformers/models/phi3.py b/python/llm/src/ipex_llm/transformers/models/phi3.py new file mode 100644 index 00000000..01c0c34c --- /dev/null +++ b/python/llm/src/ipex_llm/transformers/models/phi3.py @@ -0,0 +1,193 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Some parts of this file is adapted from +# https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/modeling_phi3.py +# which is licensed under Apache License 2.0: +# +# Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import torch +import warnings + +from ipex_llm.transformers.models.utils import ( + rotate_half, should_use_fuse_rope, + apply_rotary_pos_emb_cache_freq_xpu +) +from ipex_llm.transformers.models.utils import mlp_fusion_check, SILU +from ipex_llm.transformers.kv import DynamicNormalCache + +from typing import Optional, Tuple, List +from transformers.models.phi.modeling_phi import repeat_kv +from transformers.cache_utils import Cache + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def attention_forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + warnings.warn("You are not running the flash-attention implementation, " + "expect numerical differences.") + + bsz, q_len, _ = hidden_states.size() + + qkv = self.qkv_proj(hidden_states) + qkv = qkv.view(bsz, q_len, self.num_heads + 2 * self.num_key_value_heads, self.head_dim) + qkv = qkv.transpose(1, 2) + query_states, key_states, value_states = qkv.split([self.num_heads, + self.num_key_value_heads, + self.num_key_value_heads], dim=1) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + + cos, sin = self.rotary_emb(value_states, position_ids, seq_len=kv_seq_len) + # IPEX-LLM OPT: fuse rope + if should_use_fuse_rope(hidden_states, position_ids, self.training): + query_states, key_states = apply_rotary_pos_emb_cache_freq_xpu(query_states, key_states, + sin, cos, "phi3") + else: + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, + cos, sin, position_ids) + + if past_key_value is not None: + key_states, value_states = past_key_value.update(key_states, value_states, + self.layer_idx, None) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, + key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, + dtype=torch.float32).to(value_states.dtype) + attn_weights = torch.nn.functional.dropout(attn_weights, p=self.attention_dropout, + training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +def split_mlp(module: torch.nn.Module): + if module.__class__.__name__ == "Phi3MLP": + gate_weight, up_weight = module.gate_up_proj.weight.data.chunk(2, dim=0) + + gate_proj = torch.nn.Linear(0, 0, bias=False) + gate_proj.weight = torch.nn.Parameter(gate_weight, requires_grad=False) + gate_proj.in_features = gate_weight.size(1) + gate_proj.out_features = gate_weight.size(0) + + up_proj = torch.nn.Linear(0, 0, bias=False) + up_proj.weight = torch.nn.Parameter(up_weight, requires_grad=False) + up_proj.in_features = up_weight.size(1) + up_proj.out_features = up_weight.size(0) + + module.gate_proj = gate_proj + module.up_proj = up_proj + + del module.gate_up_proj + + +def mlp_forward( + self, + hidden_states: torch.FloatTensor +) -> torch.FloatTensor: + x_2d = hidden_states.view(-1, hidden_states.shape[-1]) + qtype = getattr(self.gate_proj, "qtype", None) + if mlp_fusion_check(x_2d, qtype, self.training): + x_2d = x_2d.contiguous() + import linear_q4_0 + return self.down_proj(linear_q4_0.mlp_forward_xpu( + x_2d, self.gate_proj.weight.data, self.up_proj.weight.data, + x_2d.shape[0], x_2d.shape[1], self.gate_proj.out_features, + SILU, qtype + )) + return self.down_proj( + self.activation_fn(self.gate_proj(hidden_states)) * self.up_proj(hidden_states) + ) + + +def model_forward_wrapper(origin_model_forward): + def model_forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ): + # IPEX-LLM OPT: kv cache but no sdp (its head_dim 96, cannot use sdp) + use_cache = use_cache if use_cache is not None else self.config.use_cache + if use_cache: + if not isinstance(past_key_values, DynamicNormalCache): + past_key_values = DynamicNormalCache.from_legacy_cache(past_key_values) + return origin_model_forward( + self=self, + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + return model_forward diff --git a/python/llm/src/ipex_llm/transformers/models/utils.py b/python/llm/src/ipex_llm/transformers/models/utils.py index e8bd3466..2be095a0 100644 --- a/python/llm/src/ipex_llm/transformers/models/utils.py +++ b/python/llm/src/ipex_llm/transformers/models/utils.py @@ -167,6 +167,14 @@ def rotate_every_two(x): return x.flatten(-2) # in einsum notation: rearrange(x, '... d j -> ... (d j)') +def should_use_fuse_rope(hidden_states, position_ids, training): + return ( + hidden_states.device.type == "xpu" + and not training and not hidden_states.requires_grad + and position_ids is not None + ) + + def apply_rotary_pos_emb(q, k, cos, sin, position_ids, model_family): if model_family in ["llama", "baichuan", "internlm", "aquila", "gpt_neox", "mistral", "mixtral", "qwen2", "yuan", "stablelm", "qwen2_moe"]: @@ -234,7 +242,7 @@ def apply_rotary_pos_emb_cache_freq_xpu(q, k, sin, cos, model_family, position_i cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] linear_q4_0.apply_rotary_embedding_half_q_and_k_cache_freq(q, k, sin, cos, q_embed, k_embed) - elif model_family in ["gemma"]: + elif model_family in ["gemma", "phi3"]: cos = cos.unsqueeze(1) sin = sin.unsqueeze(1) linear_q4_0.apply_rotary_embedding_half_q_and_k_cache_freq(q, k, sin, cos, q_embed, k_embed)