Add moonlight GPU example (#12929)
* Add moonlight GPU example and update table * Small fix * Fix based on comments * Small fix
This commit is contained in:
parent
33da3a3cb7
commit
68a770745b
6 changed files with 263 additions and 0 deletions
|
|
@ -337,6 +337,7 @@ Over 70 models have been optimized/verified on `ipex-llm`, including *LLaMA/LLaM
|
|||
| MiniCPM-V-2_6 | [link](python/llm/example/CPU/HF-Transformers-AutoModels/Model/minicpm-v-2_6) | [link](python/llm/example/GPU/HuggingFace/Multimodal/MiniCPM-V-2_6) | [Python link](python/llm/example/NPU/HF-Transformers-AutoModels/Multimodal) |
|
||||
| MiniCPM-o-2_6 | | [link](python/llm/example/GPU/HuggingFace/Multimodal/MiniCPM-o-2_6/) |
|
||||
| Janus-Pro | | [link](python/llm/example/GPU/HuggingFace/Multimodal/janus-pro/) |
|
||||
| Moonlight | |[link](python/llm/example/GPU/HuggingFace/LLM/moonlight/) |
|
||||
| StableDiffusion | | [link](python/llm/example/GPU/HuggingFace/Multimodal/StableDiffusion) |
|
||||
| Bce-Embedding-Base-V1 | | | [Python link](python/llm/example/NPU/HF-Transformers-AutoModels/Embedding) |
|
||||
| Speech_Paraformer-Large | | | [Python link](python/llm/example/NPU/HF-Transformers-AutoModels/Multimodal) |
|
||||
|
|
|
|||
|
|
@ -337,6 +337,7 @@ See the demo of running [*Text-Generation-WebUI*](https://ipex-llm.readthedocs.i
|
|||
| MiniCPM-V-2_6 | [link](python/llm/example/CPU/HF-Transformers-AutoModels/Model/minicpm-v-2_6) | [link](python/llm/example/GPU/HuggingFace/Multimodal/MiniCPM-V-2_6) | [Python link](python/llm/example/NPU/HF-Transformers-AutoModels/Multimodal) |
|
||||
| MiniCPM-o-2_6 | | [link](python/llm/example/GPU/HuggingFace/Multimodal/MiniCPM-o-2_6/) |
|
||||
| Janus-Pro | | [link](python/llm/example/GPU/HuggingFace/Multimodal/janus-pro/) |
|
||||
| Moonlight | |[link](python/llm/example/GPU/HuggingFace/LLM/moonlight/) |
|
||||
| StableDiffusion | | [link](python/llm/example/GPU/HuggingFace/Multimodal/StableDiffusion) |
|
||||
| Bce-Embedding-Base-V1 | | | [Python link](python/llm/example/NPU/HF-Transformers-AutoModels/Embedding) |
|
||||
| Speech_Paraformer-Large | | | [Python link](python/llm/example/NPU/HF-Transformers-AutoModels/Multimodal) |
|
||||
|
|
|
|||
99
python/llm/example/GPU/HuggingFace/LLM/moonlight/README.md
Normal file
99
python/llm/example/GPU/HuggingFace/LLM/moonlight/README.md
Normal file
|
|
@ -0,0 +1,99 @@
|
|||
# Moonlight
|
||||
|
||||
In this directory, you will find examples on how you could apply IPEX-LLM INT4 optimizations on Moonlight model on [Intel GPUs](../../../README.md). For illustration purposes, we utilize [moonshotai/Moonlight-16B-A3B-Instruct](https://huggingface.co/moonshotai/Moonlight-16B-A3B-Instruct) as reference Moonlight model.
|
||||
|
||||
## 0. Requirements & Installation
|
||||
|
||||
To run these examples with IPEX-LLM on Intel GPUs, we have some recommended requirements for your machine, please refer to [here](../../../README.md#requirements) for more information.
|
||||
|
||||
### 0.1 Installation
|
||||
|
||||
```bash
|
||||
conda create -n llm python=3.11
|
||||
conda activate llm
|
||||
|
||||
# install IPEX-LLM with PyTorch 2.6 supports
|
||||
pip install --pre --upgrade ipex-llm[xpu_2.6] --extra-index-url https://download.pytorch.org/whl/xpu
|
||||
|
||||
pip install transformers==4.45.0
|
||||
pip install accelerate==0.33.0
|
||||
pip install "trl<0.12.0"
|
||||
|
||||
pip install tiktoken blobfile
|
||||
```
|
||||
|
||||
### 0.2 Runtime Configuration
|
||||
|
||||
- For Windows users:
|
||||
```cmd
|
||||
set SYCL_CACHE_PERSISTENT=1
|
||||
:: optional
|
||||
set SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=1
|
||||
```
|
||||
|
||||
- For Linux users:
|
||||
```cmd
|
||||
unset OCL_ICD_VENDOR
|
||||
export SYCL_CACHE_PERSISTENT=1
|
||||
# optional
|
||||
export SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=1
|
||||
```
|
||||
|
||||
> [!NOTE]
|
||||
> The environment variable `SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS` determines the usage of immediate command lists for task submission to the GPU. Enabling this mode may improve performance, but sometimes this may also cause performance degradation. Please consider experimenting with and without this environment variable for best performance. For more details, you can refer to [this article](https://www.intel.com/content/www/us/en/developer/articles/guide/level-zero-immediate-command-lists.html)
|
||||
|
||||
## 1. Download & Convert Model
|
||||
|
||||
To run the Moonlight model with IPEX-LLM optimizations, we need to download and convert first it to make sure it could be successfully loaded by `transformers`.
|
||||
|
||||
### 1.1 Download Model
|
||||
|
||||
To download [moonshotai/Moonlight-16B-A3B-Instruct](https://huggingface.co/moonshotai/Moonlight-16B-A3B-Instruct) from Hugging Face, you could use [download.py](./download.py) through:
|
||||
|
||||
```bash
|
||||
download.py --repo-id moonshotai/Moonlight-16B-A3B-Instruct --commit-id 95583251e616c46a80715897a705cd38659afc27
|
||||
```
|
||||
|
||||
By default, Moonlight-16B-A3B-Instruct will be downloaded to the current folder. You could also define the download folder path by `--download-dir-path DOWNLOAD_DIR_PATH`.
|
||||
|
||||
> [!TIP]
|
||||
> Refer to [here](https://huggingface.co/docs/hub/en/models-downloading) for althernative methods to download models from Hugging Face.
|
||||
>
|
||||
> For [moonshotai/Moonlight-16B-A3B-Instruct](https://huggingface.co/moonshotai/Moonlight-16B-A3B-Instruct), please make sure to use its revision/commit id `95583251e616c46a80715897a705cd38659afc27`.
|
||||
|
||||
### 1.2 Convert Model
|
||||
|
||||
Next, convert the downloaded model by [convert.py](./convert.py):
|
||||
|
||||
```bash
|
||||
convert.py --model-path DOWNLOAD_DIR_PATH
|
||||
```
|
||||
|
||||
The converted model will be saved at `<DOWNLOAD_DIR_PATH>-converted`.
|
||||
|
||||
## 2. Example: Predict Tokens using `generate()` API
|
||||
|
||||
In the example [generate.py](./generate.py), we show a basic use case for a Moonlight model to predict the next N tokens using `generate()` API, with IPEX-LLM INT4 optimizations on Intel GPUs.
|
||||
|
||||
### 2.1 Running example
|
||||
|
||||
```bash
|
||||
python generate.py --converted-model-path `<DOWNLOAD_DIR_PATH>-converted` --prompt PROMPT --n-predict N_PREDICT
|
||||
```
|
||||
|
||||
Arguments info:
|
||||
- `--converted-model-path CONVERTED_MODEL_PATH`: argument defining the converted model path by [`convert.py`](./convert.py)
|
||||
- `--prompt PROMPT`: argument defining the prompt to be infered (with integrated prompt format for chat). It is default to be `'What is AI?'`.
|
||||
- `--n-predict N_PREDICT`: argument defining the max number of tokens to predict. It is default to be `32`.
|
||||
|
||||
### 2.2 Sample Outputs
|
||||
|
||||
#### [moonshotai/Moonlight-16B-A3B-Instruct](https://huggingface.co/moonshotai/Moonlight-16B-A3B-Instruct)
|
||||
|
||||
```log
|
||||
Inference time: xxxx s
|
||||
-------------------- Prompt --------------------
|
||||
Is 123 a prime?
|
||||
-------------------- Output --------------------
|
||||
<|im_system|>system<|im_middle|>You are a helpful assistant provided by Moonshot-AI.<|im_end|><|im_user|>user<|im_middle|>Is 123 a prime?<|im_end|><|im_assistant|>assistant<|im_middle|>No, 123 is not a prime number. A prime number is a number greater than 1 that has no positive divisors other than 1 and itself
|
||||
```
|
||||
45
python/llm/example/GPU/HuggingFace/LLM/moonlight/convert.py
Normal file
45
python/llm/example/GPU/HuggingFace/LLM/moonlight/convert.py
Normal file
|
|
@ -0,0 +1,45 @@
|
|||
#
|
||||
# Copyright 2016 The BigDL Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import argparse
|
||||
from safetensors.torch import load_file, save_file
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='Convert Moonlight model to be sucessfully loaded by transformers')
|
||||
parser.add_argument('--model-path', type=str, required=True,
|
||||
help='Path to the downloaded Moonlight model')
|
||||
|
||||
args = parser.parse_args()
|
||||
model_path = args.model_path
|
||||
converted_model_path = model_path + '-converted'
|
||||
|
||||
if os.path.exists(converted_model_path):
|
||||
shutil.rmtree(converted_model_path)
|
||||
|
||||
os.makedirs(converted_model_path)
|
||||
|
||||
for f in os.listdir(model_path):
|
||||
f_path = os.path.join(model_path, f)
|
||||
f_dst_path = os.path.join(converted_model_path, f)
|
||||
|
||||
if f.endswith(".safetensors"):
|
||||
save_file(load_file(f_path), f_dst_path, metadata={"format": "pt"})
|
||||
elif not f.startswith(".") and os.path.isfile(f_path): # skip dir and file name started with .
|
||||
shutil.copyfile(f_path, f_dst_path)
|
||||
|
||||
print(f"Converted model successfully saved to {converted_model_path}")
|
||||
40
python/llm/example/GPU/HuggingFace/LLM/moonlight/download.py
Normal file
40
python/llm/example/GPU/HuggingFace/LLM/moonlight/download.py
Normal file
|
|
@ -0,0 +1,40 @@
|
|||
#
|
||||
# Copyright 2016 The BigDL Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import argparse
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='Download Moonlight model')
|
||||
parser.add_argument('--repo-id', type=str, default='moonshotai/Moonlight-16B-A3B-Instruct',
|
||||
help='Hugging Face repo id of the model to be downloaded')
|
||||
parser.add_argument('--commit-id', type=str, required=True,
|
||||
help='Revision of the model to be downloaded')
|
||||
parser.add_argument('--download-dir-path', type=str,
|
||||
help='Folder path where the model will be downloaded')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
repo_id = args.repo_id
|
||||
download_dir_path = args.download_dir_path
|
||||
if download_dir_path is None:
|
||||
download_dir_path = repo_id.rsplit("/", 1)[-1]
|
||||
|
||||
snapshot_download(repo_id=repo_id,
|
||||
revision=args.commit_id,
|
||||
local_dir=download_dir_path)
|
||||
|
||||
print(f'{repo_id} has been downloaded to {download_dir_path}')
|
||||
77
python/llm/example/GPU/HuggingFace/LLM/moonlight/generate.py
Normal file
77
python/llm/example/GPU/HuggingFace/LLM/moonlight/generate.py
Normal file
|
|
@ -0,0 +1,77 @@
|
|||
#
|
||||
# Copyright 2016 The BigDL Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import time
|
||||
import argparse
|
||||
import torch
|
||||
from ipex_llm.transformers import AutoModelForCausalLM
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='Predict Tokens using `generate()` API for Moonlight model')
|
||||
parser.add_argument('--converted-model-path', type=str, required=True,
|
||||
help='Model path to the converted Moonlight model by convert.py')
|
||||
parser.add_argument('--prompt', type=str, default="What is AI?",
|
||||
help='Prompt to infer')
|
||||
parser.add_argument('--n-predict', type=int, default=32,
|
||||
help='Max tokens to predict')
|
||||
|
||||
args = parser.parse_args()
|
||||
converted_model_path = args.converted_model_path
|
||||
|
||||
# Load model in 4 bit,
|
||||
# which convert the relevant layers in the model into INT4 format
|
||||
model = AutoModelForCausalLM.from_pretrained(converted_model_path,
|
||||
load_in_4bit=True,
|
||||
optimize_model=True,
|
||||
trust_remote_code=True,
|
||||
use_cache=True)
|
||||
model = model.to('xpu')
|
||||
|
||||
# Load tokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained(converted_model_path, trust_remote_code=True)
|
||||
|
||||
# Generate predicted tokens
|
||||
with torch.inference_mode():
|
||||
# here the prompt tuning refers to
|
||||
# https://huggingface.co/moonshotai/Moonlight-16B-A3B-Instruct#inference-with-hugging-face-transformers
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant provided by Moonshot-AI."},
|
||||
{"role": "user", "content": args.prompt}
|
||||
]
|
||||
input_ids = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt=True,
|
||||
return_tensors="pt"
|
||||
).to('xpu')
|
||||
|
||||
# ipex_llm model needs a warmup, then inference time can be accurate
|
||||
output = model.generate(input_ids,
|
||||
max_new_tokens=args.n_predict)
|
||||
|
||||
# start inference
|
||||
st = time.time()
|
||||
output = model.generate(input_ids,
|
||||
max_new_tokens=args.n_predict)
|
||||
torch.xpu.synchronize()
|
||||
end = time.time()
|
||||
|
||||
output_str = tokenizer.decode(output[0], skip_special_tokens=False)
|
||||
print(f'Inference time: {end-st} s')
|
||||
print('-'*20, 'Prompt', '-'*20)
|
||||
print(args.prompt)
|
||||
print('-'*20, 'Output', '-'*20)
|
||||
print(output_str)
|
||||
Loading…
Reference in a new issue