Add moonlight GPU example (#12929)

* Add moonlight GPU example and update table

* Small fix

* Fix based on comments

* Small fix
This commit is contained in:
Yuwen Hu 2025-03-05 11:31:14 +08:00 committed by GitHub
parent 33da3a3cb7
commit 68a770745b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 263 additions and 0 deletions

View file

@ -337,6 +337,7 @@ Over 70 models have been optimized/verified on `ipex-llm`, including *LLaMA/LLaM
| MiniCPM-V-2_6 | [link](python/llm/example/CPU/HF-Transformers-AutoModels/Model/minicpm-v-2_6) | [link](python/llm/example/GPU/HuggingFace/Multimodal/MiniCPM-V-2_6) | [Python link](python/llm/example/NPU/HF-Transformers-AutoModels/Multimodal) |
| MiniCPM-o-2_6 | | [link](python/llm/example/GPU/HuggingFace/Multimodal/MiniCPM-o-2_6/) |
| Janus-Pro | | [link](python/llm/example/GPU/HuggingFace/Multimodal/janus-pro/) |
| Moonlight | |[link](python/llm/example/GPU/HuggingFace/LLM/moonlight/) |
| StableDiffusion | | [link](python/llm/example/GPU/HuggingFace/Multimodal/StableDiffusion) |
| Bce-Embedding-Base-V1 | | | [Python link](python/llm/example/NPU/HF-Transformers-AutoModels/Embedding) |
| Speech_Paraformer-Large | | | [Python link](python/llm/example/NPU/HF-Transformers-AutoModels/Multimodal) |

View file

@ -337,6 +337,7 @@ See the demo of running [*Text-Generation-WebUI*](https://ipex-llm.readthedocs.i
| MiniCPM-V-2_6 | [link](python/llm/example/CPU/HF-Transformers-AutoModels/Model/minicpm-v-2_6) | [link](python/llm/example/GPU/HuggingFace/Multimodal/MiniCPM-V-2_6) | [Python link](python/llm/example/NPU/HF-Transformers-AutoModels/Multimodal) |
| MiniCPM-o-2_6 | | [link](python/llm/example/GPU/HuggingFace/Multimodal/MiniCPM-o-2_6/) |
| Janus-Pro | | [link](python/llm/example/GPU/HuggingFace/Multimodal/janus-pro/) |
| Moonlight | |[link](python/llm/example/GPU/HuggingFace/LLM/moonlight/) |
| StableDiffusion | | [link](python/llm/example/GPU/HuggingFace/Multimodal/StableDiffusion) |
| Bce-Embedding-Base-V1 | | | [Python link](python/llm/example/NPU/HF-Transformers-AutoModels/Embedding) |
| Speech_Paraformer-Large | | | [Python link](python/llm/example/NPU/HF-Transformers-AutoModels/Multimodal) |

View file

@ -0,0 +1,99 @@
# Moonlight
In this directory, you will find examples on how you could apply IPEX-LLM INT4 optimizations on Moonlight model on [Intel GPUs](../../../README.md). For illustration purposes, we utilize [moonshotai/Moonlight-16B-A3B-Instruct](https://huggingface.co/moonshotai/Moonlight-16B-A3B-Instruct) as reference Moonlight model.
## 0. Requirements & Installation
To run these examples with IPEX-LLM on Intel GPUs, we have some recommended requirements for your machine, please refer to [here](../../../README.md#requirements) for more information.
### 0.1 Installation
```bash
conda create -n llm python=3.11
conda activate llm
# install IPEX-LLM with PyTorch 2.6 supports
pip install --pre --upgrade ipex-llm[xpu_2.6] --extra-index-url https://download.pytorch.org/whl/xpu
pip install transformers==4.45.0
pip install accelerate==0.33.0
pip install "trl<0.12.0"
pip install tiktoken blobfile
```
### 0.2 Runtime Configuration
- For Windows users:
```cmd
set SYCL_CACHE_PERSISTENT=1
:: optional
set SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=1
```
- For Linux users:
```cmd
unset OCL_ICD_VENDOR
export SYCL_CACHE_PERSISTENT=1
# optional
export SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=1
```
> [!NOTE]
> The environment variable `SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS` determines the usage of immediate command lists for task submission to the GPU. Enabling this mode may improve performance, but sometimes this may also cause performance degradation. Please consider experimenting with and without this environment variable for best performance. For more details, you can refer to [this article](https://www.intel.com/content/www/us/en/developer/articles/guide/level-zero-immediate-command-lists.html)
## 1. Download & Convert Model
To run the Moonlight model with IPEX-LLM optimizations, we need to download and convert first it to make sure it could be successfully loaded by `transformers`.
### 1.1 Download Model
To download [moonshotai/Moonlight-16B-A3B-Instruct](https://huggingface.co/moonshotai/Moonlight-16B-A3B-Instruct) from Hugging Face, you could use [download.py](./download.py) through:
```bash
download.py --repo-id moonshotai/Moonlight-16B-A3B-Instruct --commit-id 95583251e616c46a80715897a705cd38659afc27
```
By default, Moonlight-16B-A3B-Instruct will be downloaded to the current folder. You could also define the download folder path by `--download-dir-path DOWNLOAD_DIR_PATH`.
> [!TIP]
> Refer to [here](https://huggingface.co/docs/hub/en/models-downloading) for althernative methods to download models from Hugging Face.
>
> For [moonshotai/Moonlight-16B-A3B-Instruct](https://huggingface.co/moonshotai/Moonlight-16B-A3B-Instruct), please make sure to use its revision/commit id `95583251e616c46a80715897a705cd38659afc27`.
### 1.2 Convert Model
Next, convert the downloaded model by [convert.py](./convert.py):
```bash
convert.py --model-path DOWNLOAD_DIR_PATH
```
The converted model will be saved at `<DOWNLOAD_DIR_PATH>-converted`.
## 2. Example: Predict Tokens using `generate()` API
In the example [generate.py](./generate.py), we show a basic use case for a Moonlight model to predict the next N tokens using `generate()` API, with IPEX-LLM INT4 optimizations on Intel GPUs.
### 2.1 Running example
```bash
python generate.py --converted-model-path `<DOWNLOAD_DIR_PATH>-converted` --prompt PROMPT --n-predict N_PREDICT
```
Arguments info:
- `--converted-model-path CONVERTED_MODEL_PATH`: argument defining the converted model path by [`convert.py`](./convert.py)
- `--prompt PROMPT`: argument defining the prompt to be infered (with integrated prompt format for chat). It is default to be `'What is AI?'`.
- `--n-predict N_PREDICT`: argument defining the max number of tokens to predict. It is default to be `32`.
### 2.2 Sample Outputs
#### [moonshotai/Moonlight-16B-A3B-Instruct](https://huggingface.co/moonshotai/Moonlight-16B-A3B-Instruct)
```log
Inference time: xxxx s
-------------------- Prompt --------------------
Is 123 a prime?
-------------------- Output --------------------
<|im_system|>system<|im_middle|>You are a helpful assistant provided by Moonshot-AI.<|im_end|><|im_user|>user<|im_middle|>Is 123 a prime?<|im_end|><|im_assistant|>assistant<|im_middle|>No, 123 is not a prime number. A prime number is a number greater than 1 that has no positive divisors other than 1 and itself
```

View file

@ -0,0 +1,45 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
import shutil
import argparse
from safetensors.torch import load_file, save_file
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Convert Moonlight model to be sucessfully loaded by transformers')
parser.add_argument('--model-path', type=str, required=True,
help='Path to the downloaded Moonlight model')
args = parser.parse_args()
model_path = args.model_path
converted_model_path = model_path + '-converted'
if os.path.exists(converted_model_path):
shutil.rmtree(converted_model_path)
os.makedirs(converted_model_path)
for f in os.listdir(model_path):
f_path = os.path.join(model_path, f)
f_dst_path = os.path.join(converted_model_path, f)
if f.endswith(".safetensors"):
save_file(load_file(f_path), f_dst_path, metadata={"format": "pt"})
elif not f.startswith(".") and os.path.isfile(f_path): # skip dir and file name started with .
shutil.copyfile(f_path, f_dst_path)
print(f"Converted model successfully saved to {converted_model_path}")

View file

@ -0,0 +1,40 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import argparse
from huggingface_hub import snapshot_download
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Download Moonlight model')
parser.add_argument('--repo-id', type=str, default='moonshotai/Moonlight-16B-A3B-Instruct',
help='Hugging Face repo id of the model to be downloaded')
parser.add_argument('--commit-id', type=str, required=True,
help='Revision of the model to be downloaded')
parser.add_argument('--download-dir-path', type=str,
help='Folder path where the model will be downloaded')
args = parser.parse_args()
repo_id = args.repo_id
download_dir_path = args.download_dir_path
if download_dir_path is None:
download_dir_path = repo_id.rsplit("/", 1)[-1]
snapshot_download(repo_id=repo_id,
revision=args.commit_id,
local_dir=download_dir_path)
print(f'{repo_id} has been downloaded to {download_dir_path}')

View file

@ -0,0 +1,77 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import time
import argparse
import torch
from ipex_llm.transformers import AutoModelForCausalLM
from transformers import AutoTokenizer
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Predict Tokens using `generate()` API for Moonlight model')
parser.add_argument('--converted-model-path', type=str, required=True,
help='Model path to the converted Moonlight model by convert.py')
parser.add_argument('--prompt', type=str, default="What is AI?",
help='Prompt to infer')
parser.add_argument('--n-predict', type=int, default=32,
help='Max tokens to predict')
args = parser.parse_args()
converted_model_path = args.converted_model_path
# Load model in 4 bit,
# which convert the relevant layers in the model into INT4 format
model = AutoModelForCausalLM.from_pretrained(converted_model_path,
load_in_4bit=True,
optimize_model=True,
trust_remote_code=True,
use_cache=True)
model = model.to('xpu')
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(converted_model_path, trust_remote_code=True)
# Generate predicted tokens
with torch.inference_mode():
# here the prompt tuning refers to
# https://huggingface.co/moonshotai/Moonlight-16B-A3B-Instruct#inference-with-hugging-face-transformers
messages = [
{"role": "system", "content": "You are a helpful assistant provided by Moonshot-AI."},
{"role": "user", "content": args.prompt}
]
input_ids = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
return_tensors="pt"
).to('xpu')
# ipex_llm model needs a warmup, then inference time can be accurate
output = model.generate(input_ids,
max_new_tokens=args.n_predict)
# start inference
st = time.time()
output = model.generate(input_ids,
max_new_tokens=args.n_predict)
torch.xpu.synchronize()
end = time.time()
output_str = tokenizer.decode(output[0], skip_special_tokens=False)
print(f'Inference time: {end-st} s')
print('-'*20, 'Prompt', '-'*20)
print(args.prompt)
print('-'*20, 'Output', '-'*20)
print(output_str)