Add NPU HF example (#11358)
This commit is contained in:
parent
1eb884a249
commit
ae452688c2
2 changed files with 124 additions and 0 deletions
|
|
@ -0,0 +1,63 @@
|
|||
# Run LLama2 on Intel NPU
|
||||
In this directory, you will find examples on how you could apply IPEX-LLM INT4 optimizations on Llama2 models on [Intel NPUs](../../../README.md). For illustration purposes, we utilize the [meta-llama/Llama-2-7b-chat-hf](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf) as reference Llama2 models.
|
||||
|
||||
## 0. Requirements
|
||||
To run these examples with IPEX-LLM on Intel NPUs, make sure to install the newest driver version of Intel NPU.
|
||||
Go to https://www.intel.com/content/www/us/en/download/794734/intel-npu-driver-windows.html to download and unzip the driver.
|
||||
Then go to **Device Manager**, find **Neural Processors** -> **Intel(R) AI Boost**.
|
||||
Right click and select **Update Driver**. And then manually select the folder unzipped from the driver.
|
||||
|
||||
## Example: Predict Tokens using `generate()` API
|
||||
In the example [generate.py](./generate.py), we show a basic use case for a Llama2 model to predict the next N tokens using `generate()` API, with IPEX-LLM INT4 optimizations on Intel NPUs.
|
||||
### 1. Install
|
||||
#### 1.1 Installation on Windows
|
||||
We suggest using conda to manage environment:
|
||||
```bash
|
||||
conda create -n llm python=3.10 libuv
|
||||
conda activate llm
|
||||
|
||||
# below command will install intel_extension_for_pytorch==2.1.10+xpu as default
|
||||
pip install --pre --upgrade ipex-llm[xpu] --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
|
||||
|
||||
# below command will install intel_npu_acceleration_library
|
||||
conda install cmake
|
||||
git clone https://github.com/intel/intel-npu-acceleration-library npu-library
|
||||
cd npu-library
|
||||
git checkout bcb1315
|
||||
python setup.py bdist_wheel
|
||||
pip install dist\intel_npu_acceleration_library-1.2.0-cp310-cp310-win_amd64.whl
|
||||
```
|
||||
|
||||
### 2. Runtime Configurations
|
||||
For optimal performance, it is recommended to set several environment variables. Please check out the suggestions based on your device.
|
||||
#### 2.1 Configurations for Windows
|
||||
<details>
|
||||
|
||||
```cmd
|
||||
set BIGDL_USE_NPU=1
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
### 3. Running examples
|
||||
|
||||
```
|
||||
python ./generate.py
|
||||
```
|
||||
|
||||
Arguments info:
|
||||
- `--repo-id-or-model-path REPO_ID_OR_MODEL_PATH`: argument defining the huggingface repo id for the Llama2 model (e.g. `meta-llama/Llama-2-7b-chat-hf` and `meta-llama/Llama-2-13b-chat-hf`) to be downloaded, or the path to the huggingface checkpoint folder. It is default to be `'meta-llama/Llama-2-7b-chat-hf'`.
|
||||
- `--prompt PROMPT`: argument defining the prompt to be infered (with integrated prompt format for chat). It is default to be `'Once upon a time, there existed a little girl who liked to have adventures. She wanted to go to places and meet new people, and have fun'`.
|
||||
- `--n-predict N_PREDICT`: argument defining the max number of tokens to predict. It is default to be `32`.
|
||||
- `--load_in_low_bit`: argument defining the load_in_low_bit format used. It is default to be `sym_int8`, `sym_int4` can also be used.
|
||||
|
||||
#### Sample Output
|
||||
#### [meta-llama/Llama-2-7b-chat-hf](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf)
|
||||
|
||||
```log
|
||||
Inference time: xxxx s
|
||||
-------------------- Output --------------------
|
||||
<s> Once upon a time, there existed a little girl who liked to have adventures. She wanted to go to places and meet new people, and have fun. But her parents were always telling her to stay at home and be careful. They were worried about her safety, and they didn't want her to
|
||||
--------------------------------------------------------------------------------
|
||||
done
|
||||
```
|
||||
|
|
@ -0,0 +1,61 @@
|
|||
#
|
||||
# Copyright 2016 The BigDL Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import torch
|
||||
import time
|
||||
import argparse
|
||||
|
||||
from ipex_llm.transformers.npu_model import AutoModelForCausalLM
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='Predict Tokens using `generate()` API for npu model')
|
||||
parser.add_argument('--repo-id-or-model-path', type=str, default="D:\llm-models\Llama-2-7b-chat-hf",
|
||||
help='The huggingface repo id for the Llama2 model to be downloaded'
|
||||
', or the path to the huggingface checkpoint folder')
|
||||
parser.add_argument('--prompt', type=str, default="Once upon a time, there existed a little girl who liked to have adventures. She wanted to go to places and meet new people, and have fun",
|
||||
help='Prompt to infer')
|
||||
parser.add_argument('--n-predict', type=int, default=32,
|
||||
help='Max tokens to predict')
|
||||
parser.add_argument('--load_in_low_bit', type=str, default="sym_int8",
|
||||
help='Load in low bit to use')
|
||||
|
||||
args = parser.parse_args()
|
||||
model_path = args.repo_id_or_model_path
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True,
|
||||
load_in_low_bit=args.load_in_low_bit)
|
||||
|
||||
print(model)
|
||||
|
||||
with torch.inference_mode():
|
||||
prompt = "Once upon a time, there existed a little girl who liked to have adventures. She wanted to go to places and meet new people, and have fun"
|
||||
input_ids = tokenizer.encode(prompt, return_tensors="pt")
|
||||
print("finish to load")
|
||||
print('input length:', len(input_ids[0]))
|
||||
st = time.time()
|
||||
output = model.generate(input_ids, num_beams=1, do_sample=False, max_new_tokens=args.n_predict)
|
||||
end = time.time()
|
||||
print(f'Inference time: {end-st} s')
|
||||
output_str = tokenizer.decode(output[0], skip_special_tokens=False)
|
||||
print('-'*20, 'Output', '-'*20)
|
||||
print(output_str)
|
||||
|
||||
print('-'*80)
|
||||
print('done')
|
||||
Loading…
Reference in a new issue