diff --git a/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/README.md b/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/README.md index a3150d0d..1033c9cf 100644 --- a/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/README.md +++ b/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/README.md @@ -13,6 +13,7 @@ In this directory, you will find examples on how you could apply IPEX-LLM INT4 o | Phi-3 | [microsoft/Phi-3-mini-4k-instruct](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct) | | Stablelm | [stabilityai/stablelm-zephyr-3b](https://huggingface.co/stabilityai/stablelm-zephyr-3b) | | Baichuan2 | [baichuan-inc/Baichuan2-7B-Chat](https://huggingface.co/baichuan-inc/Baichuan-7B-Chat) | +| Deepseek | [deepseek-ai/deepseek-coder-6.7b-instruct](https://huggingface.co/deepseek-ai/deepseek-coder-6.7b-instruct) | ## 0. Requirements To run these examples with IPEX-LLM on Intel NPUs, make sure to install the newest driver version of Intel NPU. diff --git a/python/llm/src/ipex_llm/transformers/npu_models/baichuan.py b/python/llm/src/ipex_llm/transformers/npu_models/baichuan.py index 091336c0..f76c6798 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/baichuan.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/baichuan.py @@ -33,6 +33,9 @@ import torch +from torch.nn import functional as F +import importlib +from typing import Optional, Tuple from ipex_llm.transformers.npu_models.common import merge_linear @@ -51,3 +54,64 @@ def baichuan_mlp_forward(self, x): gate_proj, up_proj = gate_up_proj.chunk(2, dim=-1) down_proj = self.down_proj(self.act_fn(gate_proj) * up_proj) return down_proj + + +def baichuan_attention_fwd( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + modeling_module_name = self.__class__.__module__ + module = importlib.import_module(modeling_module_name) + apply_rotary_pos_emb = module.apply_rotary_pos_emb + + bsz, q_len, _ = hidden_states.size() + + proj = self.W_pack(hidden_states) + proj = proj.unflatten(-1, (3, self.hidden_size)).unsqueeze(0).transpose(0, -2).squeeze(-2) + query_states = proj[0].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = proj[1].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + value_states = proj[2].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, + cos, sin, position_ids) + # [bsz, nh, t, hd] + + if past_key_value is not None: + # reuse k, v, self_attention + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + + past_key_value = (key_states, value_states) if use_cache else None + + if query_states.size(2) == key_states.size(2): + # first token + from intel_npu_acceleration_library.functional import scaled_dot_product_attention + attn_output = scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=attention_mask + ) + attn_weights = None + else: + with torch.backends.cuda.sdp_kernel(enable_flash=True, + enable_math=True, enable_mem_efficient=True): + attn_output = F.scaled_dot_product_attention(query_states, key_states, + value_states, attn_mask=attention_mask) + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value diff --git a/python/llm/src/ipex_llm/transformers/npu_models/convert.py b/python/llm/src/ipex_llm/transformers/npu_models/convert.py index 6d3c95ee..7129cc54 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/convert.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/convert.py @@ -174,9 +174,11 @@ def optimize_llm(model: torch.nn.Module): modeling_module_name = model.__class__.__module__ module = importlib.import_module(modeling_module_name) from ipex_llm.transformers.npu_models.baichuan import baichuan_mlp_forward, merge_mlp + from ipex_llm.transformers.npu_models.baichuan import baichuan_attention_fwd model.apply(merge_mlp) convert_forward(model, module.MLP, baichuan_mlp_forward) + convert_forward(model, module.Attention, baichuan_attention_fwd) elif model.config.model_type == "phi3_v": modeling_module_name = model.__class__.__module__ @@ -189,3 +191,13 @@ def optimize_llm(model: torch.nn.Module): from transformers.models.clip.modeling_clip import CLIPAttention convert_forward(model, CLIPAttention, phi3v_encoder_attention_forward) convert_forward(model, module.Phi3VModel, phi3v_model_forward) + + from ipex_llm.transformers.npu_models.phi3 import phi3_attention_forward + convert_forward(model, module.Phi3Attention, phi3_attention_forward) + + elif model.config.model_type == "phi3": + modeling_module_name = model.__class__.__module__ + module = importlib.import_module(modeling_module_name) + from ipex_llm.transformers.npu_models.phi3 import phi3_attention_forward + + convert_forward(model, module.Phi3Attention, phi3_attention_forward) diff --git a/python/llm/src/ipex_llm/transformers/npu_models/phi3.py b/python/llm/src/ipex_llm/transformers/npu_models/phi3.py new file mode 100644 index 00000000..6889c9ee --- /dev/null +++ b/python/llm/src/ipex_llm/transformers/npu_models/phi3.py @@ -0,0 +1,157 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Some parts of this file is adapted from +# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llama/modeling_llama.py +# which is licensed under Apache License 2.0: +# +# Copyright 2021 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Optional, Tuple, List +import torch +from torch import nn +import math +import importlib +from transformers.cache_utils import Cache +from ipex_llm.utils.common.log4Error import invalidInputError + + +def phi3_attention_forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + modeling_module_name = self.__class__.__module__ + module = importlib.import_module(modeling_module_name) + apply_rotary_pos_emb, repeat_kv = module.apply_rotary_pos_emb, module.repeat_kv + bsz, q_len, _ = hidden_states.size() + + qkv = self.qkv_proj(hidden_states) + query_pos = self.num_heads * self.head_dim + query_states = qkv[..., :query_pos] + key_states = qkv[..., query_pos:query_pos + self.num_key_value_heads * self.head_dim] + value_states = qkv[..., query_pos + self.num_key_value_heads * self.head_dim:] + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, + self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, + self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + if self.layer_idx is None: + invalidInputError( + False, + f"The cache structure has changed since version v4.36." + f"If you are using {self.__class__.__name__} " + "for auto-regressive decoding with k/v caching," + "please make sure to initialize the attention class " + "with a layer index." + ) + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + cos, sin = self.rotary_emb(value_states, position_ids, seq_len=kv_seq_len) + + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, + cos, sin, position_ids) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, + self.layer_idx, cache_kwargs) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + else: + causal_mask = None + + if query_states.size(2) == key_states.size(2): + # first token + from intel_npu_acceleration_library.functional import scaled_dot_product_attention + attn_output = scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=attention_mask, + is_causal=self.is_causal and causal_mask is None and q_len > 1, + ) + attn_weights = None + else: + + attn_weights = torch.matmul(query_states, + key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): + invalidInputError( + False, + f"Attention weights should be of" + f"size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + invalidInputError( + False, + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}," + " but is {attention_mask.size()}" + ) + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, + dtype=torch.float32).to(value_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, + p=self.attention_dropout, training=self.training) + + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + invalidInputError( + False, + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value