add gpu more data types example (#9592)

* add gpu more data types example

* add int8
This commit is contained in:
dingbaorong 2023-12-05 15:45:38 +08:00 committed by GitHub
parent 65934c9f4f
commit a66fbedd7e
3 changed files with 105 additions and 0 deletions

View file

@ -0,0 +1,45 @@
# BigDL-LLM Transformers Low-Bit Inference Pipeline for Large Language Model
In this example, we show a pipeline to apply BigDL-LLM low-bit optimizations (including FP8/INT8/MixedFP8/FP4/MixedFP4) to any Hugging Face Transformers model, and then run inference on the optimized low-bit model.
## Prepare Environment
We suggest using conda to manage environment:
```bash
conda create -n llm python=3.9
conda activate llm
# below command will install intel_extension_for_pytorch==2.0.110+xpu as default
# you can install specific ipex/torch version for your need
pip install --pre --upgrade bigdl-llm[xpu] -f https://developer.intel.com/ipex-whl-stable-xpu
```
## Run Example
```bash
python ./transformers_low_bit_pipeline.py --repo-id-or-model-path meta-llama/Llama-2-7b-chat-hf --low-bit fp4 --save-path ./llama-2-7b-fp4
```
arguments info:
- `--repo-id-or-model-path`: str value, argument defining the huggingface repo id for the large language model to be downloaded, or the path to the huggingface checkpoint folder, the value is `meta-llama/Llama-2-7b-chat-hf` by default.
- `--low-bit`: str value, options are fp8, sym_int8, fp4, mixed_fp8 or mixed_fp4. Relevant low bit optimizations will be applied to the model.
- `--save-path`: str value, the path to save the low-bit model. Then you can load the low-bit directly.
- `--load-path`: optional str value. The path to load low-bit model.
## Sample Output for Inference
### `meta-llama/Llama-2-7b-chat-hf` Model
```log
Prompt: Once upon a time, there existed a little girl who liked to have adventures. She wanted to go to places and meet new people, and have fun
Output: Once upon a time, there existed a little girl who liked to have adventures. She wanted to go to places and meet new people, and have fun. But her parents were always telling her to stay at home and be careful. They were worried about her safety and didn't want her to get hurt
Model and tokenizer are saved to ./llama-2-7b-fp4
```
### Load low-bit model
Command to run:
```bash
python ./transformers_low_bit_pipeline.py --load-path ./llama-2-7b-fp4
```
Output log:
```log
Prompt: Once upon a time, there existed a little girl who liked to have adventures. She wanted to go to places and meet new people, and have fun
Output: Once upon a time, there existed a little girl who liked to have adventures. She wanted to go to places and meet new people, and have fun. But her parents were always telling her to stay at home and be careful. They were worried about her safety and didn't want her to get hurt
```

View file

@ -0,0 +1,60 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import torch
import intel_extension_for_pytorch as ipex
import argparse
from bigdl.llm.transformers import AutoModelForCausalLM
from transformers import AutoTokenizer, TextGenerationPipeline
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Transformer save_load example')
parser.add_argument('--repo-id-or-model-path', type=str, default="meta-llama/Llama-2-7b-chat-hf",
help='The huggingface repo id for the large language model to be downloaded'
', or the path to the huggingface checkpoint folder')
parser.add_argument('--low-bit', type=str, default="fp4",
choices=['fp8', 'sym_int8', 'fp4', 'mixed_fp8', 'mixed_fp4'],
help='The quantization type the model will convert to.')
parser.add_argument('--save-path', type=str, default=None,
help='The path to save the low-bit model.')
parser.add_argument('--load-path', type=str, default=None,
help='The path to load the low-bit model.')
args = parser.parse_args()
model_path = args.repo_id_or_model_path
low_bit = args.low_bit
load_path = args.load_path
if load_path:
model = AutoModelForCausalLM.load_low_bit(load_path)
model = model.to('xpu')
tokenizer = AutoTokenizer.from_pretrained(load_path)
else:
# load_in_low_bit in bigdl.llm.transformers will convert
# the relevant layers in the model into corresponding int X format
model = AutoModelForCausalLM.from_pretrained(model_path, load_in_low_bit=low_bit, trust_remote_code=True)
model = model.to('xpu')
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
pipeline = TextGenerationPipeline(model=model, tokenizer=tokenizer, max_new_tokens=32, device="xpu")
input_str = "Once upon a time, there existed a little girl who liked to have adventures. She wanted to go to places and meet new people, and have fun"
output = pipeline(input_str)[0]["generated_text"]
print(f"Prompt: {input_str}")
print(f"Output: {output}")
save_path = args.save_path
if save_path:
model.save_low_bit(save_path)
tokenizer.save_pretrained(save_path)
print(f"Model and tokenizer are saved to {save_path}")