parent
8982ab73d5
commit
3c16c9f725
3 changed files with 65 additions and 3 deletions
|
|
@ -1,7 +1,7 @@
|
||||||
# Run Large Language Model on Intel NPU
|
# Run Large Language Model on Intel NPU
|
||||||
In this directory, you will find examples on how you could apply IPEX-LLM INT4 or INT8 optimizations on LLM models on [Intel NPUs](../../../README.md). For illustration purposes, we utilize the [meta-llama/Llama-2-7b-chat-hf](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf) as reference Llama2 models. In this directory, you will find examples on how you could apply IPEX-LLM INT4 or INT8 optimizations on LLM models on Intel NPUs. See the table blow for verified models.
|
In this directory, you will find examples on how you could apply IPEX-LLM INT4 or INT8 optimizations on LLM models on [Intel NPUs](../../../README.md). In this directory, you will find examples on how you could apply IPEX-LLM INT4 or INT8 optimizations on LLM models on Intel NPUs. See the table blow for verified models.
|
||||||
|
|
||||||
## Verification Models
|
## Verified Models
|
||||||
|
|
||||||
| Model | Model Link |
|
| Model | Model Link |
|
||||||
|------------|----------------------------------------------------------------|
|
|------------|----------------------------------------------------------------|
|
||||||
|
|
@ -12,6 +12,7 @@ In this directory, you will find examples on how you could apply IPEX-LLM INT4 o
|
||||||
| MiniCPM | [openbmb/MiniCPM-2B-sft-bf16](https://huggingface.co/openbmb/MiniCPM-2B-sft-bf16) |
|
| MiniCPM | [openbmb/MiniCPM-2B-sft-bf16](https://huggingface.co/openbmb/MiniCPM-2B-sft-bf16) |
|
||||||
| Phi-3 | [microsoft/Phi-3-mini-4k-instruct](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct) |
|
| Phi-3 | [microsoft/Phi-3-mini-4k-instruct](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct) |
|
||||||
| Stablelm | [stabilityai/stablelm-zephyr-3b](https://huggingface.co/stabilityai/stablelm-zephyr-3b) |
|
| Stablelm | [stabilityai/stablelm-zephyr-3b](https://huggingface.co/stabilityai/stablelm-zephyr-3b) |
|
||||||
|
| Baichuan2 | [baichuan-inc/Baichuan2-7B-Chat](https://huggingface.co/baichuan-inc/Baichuan-7B-Chat) |
|
||||||
|
|
||||||
## 0. Requirements
|
## 0. Requirements
|
||||||
To run these examples with IPEX-LLM on Intel NPUs, make sure to install the newest driver version of Intel NPU.
|
To run these examples with IPEX-LLM on Intel NPUs, make sure to install the newest driver version of Intel NPU.
|
||||||
|
|
@ -54,7 +55,7 @@ python ./generate.py
|
||||||
```
|
```
|
||||||
|
|
||||||
Arguments info:
|
Arguments info:
|
||||||
- `--repo-id-or-model-path REPO_ID_OR_MODEL_PATH`: argument defining the huggingface repo id for the Llama2 model (e.g. `meta-llama/Llama-2-7b-chat-hf` and `meta-llama/Llama-2-13b-chat-hf`) to be downloaded, or the path to the huggingface checkpoint folder. It is default to be `'meta-llama/Llama-2-7b-chat-hf'`, and more verified models please see the list in [Verification Models](#verification-models).
|
- `--repo-id-or-model-path REPO_ID_OR_MODEL_PATH`: argument defining the huggingface repo id for the Llama2 model (e.g. `meta-llama/Llama-2-7b-chat-hf` and `meta-llama/Llama-2-13b-chat-hf`) to be downloaded, or the path to the huggingface checkpoint folder. It is default to be `'meta-llama/Llama-2-7b-chat-hf'`, and more verified models please see the list in [Verified Models](#verified-models).
|
||||||
- `--prompt PROMPT`: argument defining the prompt to be infered (with integrated prompt format for chat). It is default to be `'Once upon a time, there existed a little girl who liked to have adventures. She wanted to go to places and meet new people, and have fun'`.
|
- `--prompt PROMPT`: argument defining the prompt to be infered (with integrated prompt format for chat). It is default to be `'Once upon a time, there existed a little girl who liked to have adventures. She wanted to go to places and meet new people, and have fun'`.
|
||||||
- `--n-predict N_PREDICT`: argument defining the max number of tokens to predict. It is default to be `32`.
|
- `--n-predict N_PREDICT`: argument defining the max number of tokens to predict. It is default to be `32`.
|
||||||
- `--load_in_low_bit`: argument defining the `load_in_low_bit` format used. It is default to be `sym_int8`, `sym_int4` can also be used.
|
- `--load_in_low_bit`: argument defining the `load_in_low_bit` format used. It is default to be `sym_int8`, `sym_int4` can also be used.
|
||||||
|
|
|
||||||
53
python/llm/src/ipex_llm/transformers/npu_models/baichuan.py
Normal file
53
python/llm/src/ipex_llm/transformers/npu_models/baichuan.py
Normal file
|
|
@ -0,0 +1,53 @@
|
||||||
|
#
|
||||||
|
# Copyright 2016 The BigDL Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
# Some parts of this file is adapted from
|
||||||
|
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llama/modeling_llama.py
|
||||||
|
# which is licensed under Apache License 2.0:
|
||||||
|
#
|
||||||
|
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from ipex_llm.transformers.npu_models.common import merge_linear
|
||||||
|
|
||||||
|
|
||||||
|
def merge_mlp(module: torch.nn.Module):
|
||||||
|
if type(module).__name__ == "MLP":
|
||||||
|
gate_up_proj = merge_linear([
|
||||||
|
module.gate_proj,
|
||||||
|
module.up_proj,
|
||||||
|
])
|
||||||
|
module.gate_up_proj = gate_up_proj
|
||||||
|
del module.gate_proj, module.up_proj
|
||||||
|
|
||||||
|
|
||||||
|
def baichuan_mlp_forward(self, x):
|
||||||
|
gate_up_proj = self.gate_up_proj(x)
|
||||||
|
gate_proj, up_proj = gate_up_proj.chunk(2, dim=-1)
|
||||||
|
down_proj = self.down_proj(self.act_fn(gate_proj) * up_proj)
|
||||||
|
return down_proj
|
||||||
|
|
@ -169,3 +169,11 @@ def optimize_llm(model: torch.nn.Module):
|
||||||
convert_forward(model, StableLmModel, stablelm_model_forward)
|
convert_forward(model, StableLmModel, stablelm_model_forward)
|
||||||
convert_forward(model, StableLmAttention, stablelm_attention_forward)
|
convert_forward(model, StableLmAttention, stablelm_attention_forward)
|
||||||
convert_forward(model, StableLmMLP, stablelm_mlp_forward)
|
convert_forward(model, StableLmMLP, stablelm_mlp_forward)
|
||||||
|
|
||||||
|
elif model.config.model_type == "baichuan":
|
||||||
|
modeling_module_name = model.__class__.__module__
|
||||||
|
module = importlib.import_module(modeling_module_name)
|
||||||
|
from ipex_llm.transformers.npu_models.baichuan import baichuan_mlp_forward, merge_mlp
|
||||||
|
model.apply(merge_mlp)
|
||||||
|
|
||||||
|
convert_forward(model, module.MLP, baichuan_mlp_forward)
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue