diff --git a/.github/workflows/llm_performance_tests.yml b/.github/workflows/llm_performance_tests.yml
index 273def7c..de991822 100644
--- a/.github/workflows/llm_performance_tests.yml
+++ b/.github/workflows/llm_performance_tests.yml
@@ -127,5 +127,8 @@ jobs:
cd python/llm/dev/benchmark/all-in-one
export http_proxy=${HTTP_PROXY}
export https_proxy=${HTTPS_PROXY}
- taskset -c 0-$((THREAD_NUM - 1)) python run.py
+ python run.py
curl -T ./*.csv ${LLM_FTP_URL}/llm/ggml-actions/perf/
+ cd ../../../test/benchmark
+ python csv_to_html.py -f ../../dev/benchmark/all-in-one
+ cp ./*.html /mnt/disk1/nightly_perf/
diff --git a/.github/workflows/llm_unit_tests.yml b/.github/workflows/llm_unit_tests.yml
index 47a90421..d74965d1 100644
--- a/.github/workflows/llm_unit_tests.yml
+++ b/.github/workflows/llm_unit_tests.yml
@@ -83,6 +83,7 @@ jobs:
echo "ORIGINAL_CHATGLM2_6B_PATH=${ORIGIN_DIR}/chatglm2-6b" >> "$GITHUB_ENV"
echo "ORIGINAL_REPLIT_CODE_PATH=${ORIGIN_DIR}/replit-code-v1-3b" >> "$GITHUB_ENV"
echo "ORIGINAL_WHISPER_TINY_PATH=${ORIGIN_DIR}/whisper-tiny" >> "$GITHUB_ENV"
+ echo "MISTRAL_ORIGIN_PATH=${ORIGIN_DIR}/Mistral-7B-v0.1" >> "$GITHUB_ENV"
echo "LLAMA_INT4_CKPT_PATH=${INT4_CKPT_DIR}/bigdl_llm_llama_7b_q4_0.bin" >> "$GITHUB_ENV"
echo "GPTNEOX_INT4_CKPT_PATH=${INT4_CKPT_DIR}/bigdl_llm_redpajama_7b_q4_0.bin" >> "$GITHUB_ENV"
@@ -146,6 +147,11 @@ jobs:
echo "wget -r -nH --no-verbose --cut-dirs=1 $LLM_FTP_URL/llm/whisper-tiny -P $ORIGIN_DIR"
wget -r -nH --no-verbose --cut-dirs=1 $LLM_FTP_URL/llm/whisper-tiny -P $ORIGIN_DIR
fi
+ if [ ! -d $MISTRAL_ORIGIN_PATH ]; then
+ echo "Directory $MISTRAL_ORIGIN_PATH not found. Downloading from FTP server..."
+ echo "wget -r -nH --no-verbose --cut-dirs=1 $LLM_FTP_URL/llm/Mistral-7B-v0.1 -P $ORIGIN_DIR"
+ wget -r -nH --no-verbose --cut-dirs=1 $LLM_FTP_URL/llm/Mistral-7B-v0.1 -P $ORIGIN_DIR
+ fi
if [ ! -d $LLAMA_ORIGIN_PATH ]; then
echo "Directory $LLAMA_ORIGIN_PATH not found. Downloading from FTP server..."
echo "wget --no-verbose $LLM_FTP_URL/llm/llama-7b-hf -P $ORIGIN_DIR"
@@ -186,3 +192,78 @@ jobs:
pip install -U pandas==2.0.3
pip install -U typing_extensions==4.5.0
bash python/llm/test/run-llm-langchain-tests.sh
+ llm-unit-test-on-arc:
+ needs: llm-cpp-build
+ strategy:
+ fail-fast: false
+ matrix:
+ python-version: ["3.9"]
+ runs-on: [self-hosted, llm, perf]
+ env:
+ OMP_NUM_THREADS: 16
+ THREAD_NUM: 16
+ ANALYTICS_ZOO_ROOT: ${{ github.workspace }}
+ steps:
+ - name: Set environment variables
+ shell: bash
+ run: |
+ echo "LLAMA2_7B_ORIGIN_PATH=${ORIGIN_DIR}/Llama-2-7b-chat-hf" >> "$GITHUB_ENV"
+ echo "CHATGLM2_6B_ORIGIN_PATH=${ORIGIN_DIR}/chatglm2-6b" >> "$GITHUB_ENV"
+ echo "FALCON_7B_ORIGIN_PATH=${ORIGIN_DIR}/falcon-7b-instruct-with-patch" >> "$GITHUB_ENV"
+ echo "MPT_7B_ORIGIN_PATH=${ORIGIN_DIR}/mpt-7b-chat" >> "$GITHUB_ENV"
+
+ - name: Checkout repo
+ uses: actions/checkout@v3
+
+ - name: Set up Python ${{ matrix.python-version }}
+ uses: actions/setup-python@v4
+ with:
+ python-version: ${{ matrix.python-version }}
+
+ - name: Install dependencies
+ shell: bash
+ run: |
+ python -m pip install --upgrade pip
+ python -m pip install --upgrade setuptools
+ python -m pip install --upgrade wheel
+
+ - name: Download llm binary
+ uses: ./.github/actions/llm/download-llm-binary
+
+ - name: Run LLM install (all) test
+ uses: ./.github/actions/llm/setup-llm-env
+ with:
+ extra-dependency: "xpu"
+
+ - name: Test installed xpu version
+ shell: bash
+ run: |
+ source /opt/intel/oneapi/setvars.sh
+ bash python/llm/test/run-llm-install-tests.sh
+
+ - name: Download LLMs
+ shell: bash
+ run: |
+ if [ ! -d $LLAMA2_7B_ORIGIN_PATH ]; then
+ echo "Directory $LLAMA2_7B_ORIGIN_PATH not found. Downloading from FTP server..."
+ wget -r -nH --no-verbose --cut-dirs=1 $LLM_FTP_URL/llm/Llama-2-7b-chat-hf -P $ORIGIN_DIR
+ fi
+ if [ ! -d $CHATGLM2_6B_ORIGIN_PATH ]; then
+ echo "Directory $CHATGLM2_6B_ORIGIN_PATH not found. Downloading from FTP server..."
+ wget -r -nH --no-verbose --cut-dirs=1 $LLM_FTP_URL/llm/chatglm2-6b -P $ORIGIN_DIR
+ fi
+ if [ ! -d $FALCON_7B_ORIGIN_PATH ]; then
+ echo "Directory $FALCON_7B_ORIGIN_PATH not found. Downloading from FTP server..."
+ wget -r -nH --no-verbose --cut-dirs=1 $LLM_FTP_URL/llm/falcon-7b-instruct-with-patch -P $ORIGIN_DIR
+ fi
+ if [ ! -d $MPT_7B_ORIGIN_PATH ]; then
+ echo "Directory $MPT_7B_ORIGIN_PATH not found. Downloading from FTP server..."
+ wget -r -nH --no-verbose --cut-dirs=1 $LLM_FTP_URL/llm/mpt-7b-chat -P $ORIGIN_DIR
+ fi
+
+ - name: Run LLM inference test
+ shell: bash
+ run: |
+ source /opt/intel/oneapi/setvars.sh
+ python -m pip install expecttest
+ bash python/llm/test/run-llm-inference-tests-gpu.sh
diff --git a/.github/workflows/manually_build.yml b/.github/workflows/manually_build.yml
index 0601b2d0..f4728c02 100644
--- a/.github/workflows/manually_build.yml
+++ b/.github/workflows/manually_build.yml
@@ -10,6 +10,7 @@ on:
type: choice
options:
- all
+ - bigdl-llm-finetune-xpu
- bigdl-llm-xpu
- bigdl-llm-cpu
- bigdl-llm-serving-xpu
@@ -58,6 +59,33 @@ permissions:
packages: write
jobs:
+ bigdl-llm-finetune-xpu:
+ if: ${{ github.event.inputs.artifact == 'bigdl-llm-finetune-xpu' || github.event.inputs.artifact == 'all' }}
+ runs-on: [self-hosted, Shire]
+
+ steps:
+ - uses: actions/checkout@v3
+ - name: docker login
+ run: |
+ docker login -u ${DOCKERHUB_USERNAME} -p ${DOCKERHUB_PASSWORD}
+ - name: bigdl-llm-finetune-xpu
+ run: |
+ echo "##############################################################"
+ echo "####### bigdl-llm-finetune-xpu ########"
+ echo "##############################################################"
+ export image=intelanalytics/bigdl-llm-finetune-xpu
+ cd docker/llm/finetune/qlora/xpu/docker
+ sudo docker build \
+ --no-cache=true \
+ --build-arg http_proxy=${HTTP_PROXY} \
+ --build-arg https_proxy=${HTTPS_PROXY} \
+ --build-arg no_proxy=${NO_PROXY} \
+ -t ${image}:${TAG} -f ./Dockerfile .
+ sudo docker push ${image}:${TAG}
+ sudo docker tag ${image}:${TAG} 10.239.45.10/arda/${image}:${TAG}
+ sudo docker push 10.239.45.10/arda/${image}:${TAG}
+ sudo docker rmi -f ${image}:${TAG} 10.239.45.10/arda/${image}:${TAG}
+
bigdl-llm-xpu:
if: ${{ github.event.inputs.artifact == 'bigdl-llm-xpu' || github.event.inputs.artifact == 'all' }}
runs-on: [self-hosted, Shire]
diff --git a/.github/workflows/manually_build_for_testing.yml b/.github/workflows/manually_build_for_testing.yml
index 3c4aa874..b9c8932b 100644
--- a/.github/workflows/manually_build_for_testing.yml
+++ b/.github/workflows/manually_build_for_testing.yml
@@ -14,6 +14,7 @@ on:
type: choice
options:
- all
+ - bigdl-llm-finetune-xpu
- bigdl-llm-xpu
- bigdl-llm-cpu
- bigdl-llm-serving-xpu
@@ -55,6 +56,35 @@ permissions:
packages: write
jobs:
+ bigdl-llm-finetune-xpu:
+ if: ${{ github.event.inputs.artifact == 'bigdl-llm-finetune-xpu' || github.event.inputs.artifact == 'all' }}
+ runs-on: [self-hosted, Shire]
+
+ steps:
+ - uses: actions/checkout@v3
+ with:
+ ref: ${{ github.event.inputs.sha }}
+ - name: docker login
+ run: |
+ docker login -u ${DOCKERHUB_USERNAME} -p ${DOCKERHUB_PASSWORD}
+ - name: bigdl-llm-finetune-xpu
+ run: |
+ echo "##############################################################"
+ echo "####### bigdl-llm-finetune-xpu ########"
+ echo "##############################################################"
+ export image=intelanalytics/bigdl-llm-finetune-xpu
+ cd docker/llm/finetune/qlora/xpu/docker
+ sudo docker build \
+ --no-cache=true \
+ --build-arg http_proxy=${HTTP_PROXY} \
+ --build-arg https_proxy=${HTTPS_PROXY} \
+ --build-arg no_proxy=${NO_PROXY} \
+ -t ${image}:${TAG} -f ./Dockerfile .
+ sudo docker push ${image}:${TAG}
+ sudo docker tag ${image}:${TAG} 10.239.45.10/arda/${image}:${TAG}
+ sudo docker push 10.239.45.10/arda/${image}:${TAG}
+ sudo docker rmi -f ${image}:${TAG} 10.239.45.10/arda/${image}:${TAG}
+
bigdl-llm-xpu:
if: ${{ github.event.inputs.artifact == 'bigdl-llm-xpu' || github.event.inputs.artifact == 'all' }}
runs-on: [self-hosted, Shire]
diff --git a/README.md b/README.md
index 33aa7e00..ff02fb83 100644
--- a/README.md
+++ b/README.md
@@ -151,6 +151,9 @@ Over 20 models have been optimized/verified on `bigdl-llm`, including *LLaMA/LLa
| Aquila | [link](python/llm/example/CPU/HF-Transformers-AutoModels/Model/aquila) | [link](python/llm/example/GPU/HF-Transformers-AutoModels/Model/aquila) |
| MOSS | [link](python/llm/example/CPU/HF-Transformers-AutoModels/Model/moss) | |
| Whisper | [link](python/llm/example/CPU/HF-Transformers-AutoModels/Model/whisper) | [link](python/llm/example/GPU/HF-Transformers-AutoModels/Model/whisper) |
+| Phi-1_5 | [link](python/llm/example/CPU/HF-Transformers-AutoModels/Model/phi-1_5) | [link](python/llm/example/GPU/HF-Transformers-AutoModels/Model/phi-1_5) |
+| Flan-t5 | [link](python/llm/example/CPU/HF-Transformers-AutoModels/Model/flan-t5) | [link](python/llm/example/GPU/HF-Transformers-AutoModels/Model/flan-t5) |
+| Qwen-VL | [link](python/llm/example/CPU/HF-Transformers-AutoModels/Model/qwen-vl) | |
***For more details, please refer to the `bigdl-llm` [Document](https://test-bigdl-llm.readthedocs.io/en/main/doc/LLM/index.html), [Readme](python/llm), [Tutorial](https://github.com/intel-analytics/bigdl-llm-tutorial) and [API Doc](https://bigdl.readthedocs.io/en/latest/doc/PythonAPI/LLM/index.html).***
diff --git a/python/llm/README.md b/python/llm/README.md
index 92ec5c11..6032a992 100644
--- a/python/llm/README.md
+++ b/python/llm/README.md
@@ -58,20 +58,31 @@ Over 20 models have been optimized/verified on `bigdl-llm`, including *LLaMA/LLa
| Aquila | [link](example/CPU/HF-Transformers-AutoModels/Model/aquila) | [link](example/GPU/HF-Transformers-AutoModels/Model/aquila) |
| MOSS | [link](example/CPU/HF-Transformers-AutoModels/Model/moss) | |
| Whisper | [link](example/CPU/HF-Transformers-AutoModels/Model/whisper) | [link](example/GPU/HF-Transformers-AutoModels/Model/whisper) |
-
+| Phi-1_5 | [link](example/CPU/HF-Transformers-AutoModels/Model/phi-1_5) | [link](example/GPU/HF-Transformers-AutoModels/Model/phi-1_5) |
+| Flan-t5 | [link](example/CPU/HF-Transformers-AutoModels/Model/flan-t5) | [link](example/GPU/HF-Transformers-AutoModels/Model/flan-t5) |
+| Qwen-VL | [link](example/CPU/HF-Transformers-AutoModels/Model/qwen-vl) | |
### Working with `bigdl-llm`
Table of Contents
-- [Install](#install)
-- [Run Model](#run-model)
- - [Hugging Face `transformers` API](#1-hugging-face-transformers-api)
- - [Native INT4 Model](#2-native-int4-model)
- - [LangChain API](#l3-angchain-api)
- - [CLI Tool](#4-cli-tool)
-- [`bigdl-llm` API Doc](#bigdl-llm-api-doc)
-- [`bigdl-llm` Dependency](#bigdl-llm-dependency)
+- [BigDL-LLM](#bigdl-llm)
+ - [Demos](#demos)
+ - [Verified models](#verified-models)
+ - [Working with `bigdl-llm`](#working-with-bigdl-llm)
+ - [Install](#install)
+ - [CPU](#cpu)
+ - [GPU](#gpu)
+ - [Run Model](#run-model)
+ - [1. Hugging Face `transformers` API](#1-hugging-face-transformers-api)
+ - [CPU INT4](#cpu-int4)
+ - [GPU INT4](#gpu-int4)
+ - [More Low-Bit Support](#more-low-bit-support)
+ - [2. Native INT4 model](#2-native-int4-model)
+ - [3. LangChain API](#3-langchain-api)
+ - [4. CLI Tool](#4-cli-tool)
+ - [`bigdl-llm` API Doc](#bigdl-llm-api-doc)
+ - [`bigdl-llm` Dependency](#bigdl-llm-dependency)
diff --git a/python/llm/dev/benchmark/all-in-one/config.yaml b/python/llm/dev/benchmark/all-in-one/config.yaml
index 615e8325..2e57873c 100644
--- a/python/llm/dev/benchmark/all-in-one/config.yaml
+++ b/python/llm/dev/benchmark/all-in-one/config.yaml
@@ -6,6 +6,7 @@ local_model_hub: 'path to your local model hub'
warm_up: 1
num_trials: 3
num_beams: 1 # default to greedy search
+low_bit: 'sym_int4' # default to use 'sym_int4' (i.e. symmetric int4)
in_out_pairs:
- '32-32'
- '1024-128'
diff --git a/python/llm/dev/benchmark/all-in-one/run.py b/python/llm/dev/benchmark/all-in-one/run.py
index b72effa4..3bf48e72 100644
--- a/python/llm/dev/benchmark/all-in-one/run.py
+++ b/python/llm/dev/benchmark/all-in-one/run.py
@@ -38,19 +38,19 @@ LLAMA_IDS = ['meta-llama/Llama-2-7b-chat-hf','meta-llama/Llama-2-13b-chat-hf',
results = []
-def run_model(repo_id, test_api, in_out_pairs, local_model_hub=None, warm_up=1, num_trials=3, num_beams=1):
+def run_model(repo_id, test_api, in_out_pairs, local_model_hub=None, warm_up=1, num_trials=3, num_beams=1, low_bit='sym_int4'):
# TODO: make a parameter
result= {}
if test_api == 'transformer_int4':
- result = run_transformer_int4(repo_id, local_model_hub, in_out_pairs, warm_up, num_trials, num_beams)
+ result = run_transformer_int4(repo_id, local_model_hub, in_out_pairs, warm_up, num_trials, num_beams, low_bit)
elif test_api == 'native_int4':
run_native_int4(repo_id, local_model_hub, in_out_pairs, warm_up, num_trials)
elif test_api == 'optimize_model':
- result = run_optimize_model(repo_id, local_model_hub, in_out_pairs, warm_up, num_trials, num_beams)
+ result = run_optimize_model(repo_id, local_model_hub, in_out_pairs, warm_up, num_trials, num_beams, low_bit)
elif test_api == 'transformer_int4_gpu':
- result = run_transformer_int4_gpu(repo_id, local_model_hub, in_out_pairs, warm_up, num_trials, num_beams)
+ result = run_transformer_int4_gpu(repo_id, local_model_hub, in_out_pairs, warm_up, num_trials, num_beams, low_bit)
elif test_api == 'optimize_model_gpu':
- result = run_optimize_model_gpu(repo_id, local_model_hub, in_out_pairs, warm_up, num_trials, num_beams)
+ result = run_optimize_model_gpu(repo_id, local_model_hub, in_out_pairs, warm_up, num_trials, num_beams, low_bit)
elif test_api == 'pytorch_autocast_bf16':
result = run_pytorch_autocast_bf16(repo_id, local_model_hub, in_out_pairs, warm_up, num_trials, num_beams)
elif test_api == 'ipex_fp16_gpu':
@@ -59,13 +59,14 @@ def run_model(repo_id, test_api, in_out_pairs, local_model_hub=None, warm_up=1,
for in_out_pair in in_out_pairs:
if result:
results.append([repo_id,
- np.mean(result[in_out_pair], axis=0)[0],
- np.mean(result[in_out_pair], axis=0)[1],
- np.mean(result[in_out_pair], axis=0)[2],
+ round(np.mean(result[in_out_pair], axis=0)[0]*1000.0, 2),
+ round(np.mean(result[in_out_pair], axis=0)[1]*1000.0, 2),
+ round(np.mean(result[in_out_pair], axis=0)[2]*1000.0, 2),
in_out_pair,
f'{int(np.mean(result[in_out_pair], axis=0)[3])}' +
f'-{int(np.mean(result[in_out_pair], axis=0)[4])}',
- num_beams])
+ num_beams,
+ low_bit])
def get_model_path(repo_id, local_model_hub):
@@ -123,7 +124,8 @@ def run_transformer_int4(repo_id,
in_out_pairs,
warm_up,
num_trials,
- num_beams):
+ num_beams,
+ low_bit):
from bigdl.llm.transformers import AutoModel, AutoModelForCausalLM
from transformers import AutoTokenizer, LlamaTokenizer
@@ -132,14 +134,14 @@ def run_transformer_int4(repo_id,
# which convert the relevant layers in the model into INT4 format
st = time.perf_counter()
if repo_id in ['THUDM/chatglm-6b', 'THUDM/chatglm2-6b']:
- model = AutoModel.from_pretrained(model_path, load_in_4bit=True, trust_remote_code=True, torch_dtype='auto')
+ model = AutoModel.from_pretrained(model_path, load_in_low_bit=low_bit, trust_remote_code=True, torch_dtype='auto')
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
elif repo_id in LLAMA_IDS:
- model = AutoModelForCausalLM.from_pretrained(model_path, load_in_4bit=True, trust_remote_code=True,
+ model = AutoModelForCausalLM.from_pretrained(model_path, load_in_low_bit=low_bit, trust_remote_code=True,
use_cache=True)
tokenizer = LlamaTokenizer.from_pretrained(model_path, trust_remote_code=True)
else:
- model = AutoModelForCausalLM.from_pretrained(model_path, load_in_4bit=True, trust_remote_code=True,
+ model = AutoModelForCausalLM.from_pretrained(model_path, load_in_low_bit=low_bit, trust_remote_code=True,
use_cache=True)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
end = time.perf_counter()
@@ -250,7 +252,8 @@ def run_optimize_model(repo_id,
in_out_pairs,
warm_up,
num_trials,
- num_beams):
+ num_beams,
+ low_bit):
from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer, LlamaTokenizer
from bigdl.llm import optimize_model
@@ -260,16 +263,16 @@ def run_optimize_model(repo_id,
st = time.perf_counter()
if repo_id in ['THUDM/chatglm-6b', 'THUDM/chatglm2-6b']:
model = AutoModel.from_pretrained(model_path, torch_dtype='auto', low_cpu_mem_usage=True, trust_remote_code=True)
- model = optimize_model(model)
+ model = optimize_model(model, low_bit=low_bit)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
elif repo_id in LLAMA_IDS:
model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True,
use_cache=True, low_cpu_mem_usage=True)
- model = optimize_model(model)
+ model = optimize_model(model, low_bit=low_bit)
tokenizer = LlamaTokenizer.from_pretrained(model_path, trust_remote_code=True)
else:
model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype='auto', low_cpu_mem_usage=True)
- model = optimize_model(model)
+ model = optimize_model(model, low_bit=low_bit)
tokenizer = AutoTokenizer.from_pretrained(model_path)
end = time.perf_counter()
print(">> loading of model costs {}s".format(end - st))
@@ -317,7 +320,8 @@ def run_transformer_int4_gpu(repo_id,
in_out_pairs,
warm_up,
num_trials,
- num_beams):
+ num_beams,
+ low_bit):
from bigdl.llm.transformers import AutoModel, AutoModelForCausalLM
from transformers import AutoTokenizer, GPTJForCausalLM, LlamaTokenizer
import intel_extension_for_pytorch as ipex
@@ -326,17 +330,17 @@ def run_transformer_int4_gpu(repo_id,
# which convert the relevant layers in the model into INT4 format
st = time.perf_counter()
if repo_id in ['THUDM/chatglm-6b', 'THUDM/chatglm2-6b']:
- model = AutoModel.from_pretrained(model_path, load_in_4bit=True, optimize_model=True,
+ model = AutoModel.from_pretrained(model_path, load_in_low_bit=low_bit, optimize_model=True,
trust_remote_code=True, use_cache=True)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model = model.to('xpu')
elif repo_id in LLAMA_IDS:
- model = AutoModelForCausalLM.from_pretrained(model_path, load_in_4bit=True, trust_remote_code=True,
+ model = AutoModelForCausalLM.from_pretrained(model_path, load_in_low_bit=low_bit, trust_remote_code=True,
use_cache=True)
tokenizer = LlamaTokenizer.from_pretrained(model_path, trust_remote_code=True)
model = model.to('xpu')
else:
- model = AutoModelForCausalLM.from_pretrained(model_path, optimize_model=True, load_in_4bit=True,
+ model = AutoModelForCausalLM.from_pretrained(model_path, optimize_model=True, load_in_low_bit=low_bit,
trust_remote_code=True, use_cache=True)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model = model.to('xpu')
@@ -392,7 +396,8 @@ def run_optimize_model_gpu(repo_id,
in_out_pairs,
warm_up,
num_trials,
- num_beams):
+ num_beams,
+ low_bit):
from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer, GPTJForCausalLM, LlamaTokenizer
from bigdl.llm import optimize_model
import intel_extension_for_pytorch as ipex
@@ -403,19 +408,19 @@ def run_optimize_model_gpu(repo_id,
if repo_id in ['THUDM/chatglm-6b', 'THUDM/chatglm2-6b']:
model = AutoModel.from_pretrained(model_path, torch_dtype='auto', low_cpu_mem_usage=True,
trust_remote_code=True, use_cache=True)
- model = optimize_model(model)
+ model = optimize_model(model, low_bit=low_bit)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model = model.to('xpu')
elif repo_id in LLAMA_IDS:
model = AutoModelForCausalLM.from_pretrained(model_path, load_in_4bit=True, trust_remote_code=True,
use_cache=True, low_cpu_mem_usage=True)
- model = optimize_model(model)
+ model = optimize_model(model, low_bit=low_bit)
tokenizer = LlamaTokenizer.from_pretrained(model_path, trust_remote_code=True)
model = model.to('xpu')
else:
model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype='auto', low_cpu_mem_usage=True,
trust_remote_code=True, use_cache=True)
- model = optimize_model(model)
+ model = optimize_model(model, low_bit=low_bit)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model = model.to('xpu')
if isinstance(model, GPTJForCausalLM):
@@ -544,8 +549,9 @@ if __name__ == '__main__':
import pandas as pd
for api in conf.test_api:
for model in conf.repo_id:
- run_model(model, api, conf['in_out_pairs'], conf['local_model_hub'], conf['warm_up'], conf['num_trials'], conf['num_beams'])
- df = pd.DataFrame(results, columns=['model', '1st token avg latency (s)', '2+ avg latency (s/token)', 'encoder time (s)',
- 'input/output tokens', 'actual input/output tokens', 'num_beams'])
+ run_model(model, api, conf['in_out_pairs'], conf['local_model_hub'], conf['warm_up'], conf['num_trials'], conf['num_beams'], conf['low_bit'])
+ df = pd.DataFrame(results, columns=['model', '1st token avg latency (ms)', '2+ avg latency (ms/token)', 'encoder time (ms)',
+ 'input/output tokens', 'actual input/output tokens', 'num_beams', 'low_bit'])
+
df.to_csv(f'{current_dir}/{api}-results-{today}.csv')
results = []
diff --git a/python/llm/example/CPU/HF-Transformers-AutoModels/Model/README.md b/python/llm/example/CPU/HF-Transformers-AutoModels/Model/README.md
index 55d4c41b..f7bd3b55 100644
--- a/python/llm/example/CPU/HF-Transformers-AutoModels/Model/README.md
+++ b/python/llm/example/CPU/HF-Transformers-AutoModels/Model/README.md
@@ -24,6 +24,9 @@ You can use BigDL-LLM to run any Huggingface Transformer models with INT4 optimi
| Aquila | [link](aquila) |
| Replit | [link](replit) |
| Mistral | [link](mistral) |
+| Flan-t5 | [link](flan-t5) |
+| Phi-1_5 | [link](phi-1_5) |
+| Qwen-VL | [link](qwen-vl) |
## Recommended Requirements
To run the examples, we recommend using Intel® Xeon® processors (server), or >= 12th Gen Intel® Core™ processor (client).
diff --git a/python/llm/example/CPU/HF-Transformers-AutoModels/Model/flan-t5/README.md b/python/llm/example/CPU/HF-Transformers-AutoModels/Model/flan-t5/README.md
new file mode 100644
index 00000000..0b6ab57a
--- /dev/null
+++ b/python/llm/example/CPU/HF-Transformers-AutoModels/Model/flan-t5/README.md
@@ -0,0 +1,66 @@
+# Flan-t5
+
+In this directory, you will find examples on how you could apply BigDL-LLM INT4 optimizations on Flan-t5 models. For illustration purposes, we utilize the [google/flan-t5-xxl](https://huggingface.co/google/flan-t5-xxl) as a reference Flan-t5 model.
+
+## 0. Requirements
+To run these examples with BigDL-LLM, we have some recommended requirements for your machine, please refer to [here](../README.md#recommended-requirements) for more information.
+
+## Example: Predict Tokens using `generate()` API
+In the example [generate.py](./generate.py), we show a basic use case for a Flan-t5 model to predict the next N tokens using `generate()` API, with BigDL-LLM INT4 optimizations.
+### 1. Install
+We suggest using conda to manage the Python environment. For more information about conda installation, please refer to [here](https://docs.conda.io/en/latest/miniconda.html#).
+
+After installing conda, create a Python environment for BigDL-LLM:
+```bash
+conda create -n llm python=3.9 # recommend to use Python 3.9
+conda activate llm
+
+pip install --pre --upgrade bigdl-llm[all] # install the latest bigdl-llm nightly build with 'all' option
+```
+
+### 2. Run
+After setting up the Python environment, you could run the example by following steps.
+
+> **Note**: When loading the model in 4-bit, BigDL-LLM converts linear layers in the model into INT4 format. In theory, a *X*B model saved in 16-bit will requires approximately 2*X* GB of memory for loading, and ~0.5*X* GB memory for further inference.
+>
+> Please select the appropriate size of the Flan-t5 model based on the capabilities of your machine.
+
+#### 2.1 Client
+On client Windows machines, it is recommended to run directly with full utilization of all cores:
+```powershell
+python ./generate.py --prompt 'Translate to German: My name is Arthur'
+```
+More information about arguments can be found in [Arguments Info](#23-arguments-info) section. The expected output can be found in [Sample Output](#24-sample-output) section.
+
+#### 2.2 Server
+For optimal performance on server, it is recommended to set several environment variables (refer to [here](../README.md#best-known-configuration-on-linux) for more information), and run the example with all the physical cores of a single socket.
+
+E.g. on Linux,
+```bash
+# set BigDL-Nano env variables
+source bigdl-nano-init
+
+# e.g. for a server with 48 cores per socket
+export OMP_NUM_THREADS=48
+numactl -C 0-47 -m 0 python ./generate.py --prompt 'Translate to German: My name is Arthur'
+```
+More information about arguments can be found in [Arguments Info](#23-arguments-info) section. The expected output can be found in [Sample Output](#24-sample-output) section.
+
+#### 2.3 Arguments Info
+In the example, several arguments can be passed to satisfy your requirements:
+
+- `--repo-id-or-model-path REPO_ID_OR_MODEL_PATH`: argument defining the huggingface repo id for the Flan-t5 model (e.g. `google/flan-t5-xxl`) to be downloaded, or the path to the huggingface checkpoint folder. It is default to be `'google/flan-t5-xxl'`.
+- `--prompt PROMPT`: argument defining the prompt to be infered (with integrated prompt format for chat). It is default to be `'Translate to German: My name is Arthur'`.
+- `--n-predict N_PREDICT`: argument defining the max number of tokens to predict. It is default to be `32`.
+
+
+#### 2.4 Sample Output
+#### [google/flan-t5-xxl](https://huggingface.co/google/flan-t5-xxl)
+
+```log
+Inference time: xxxx s
+-------------------- Prompt --------------------
+<|User|>:Translate to German: My name is Arthur
+-------------------- Output --------------------
+Ich bin Arthur.
+```
diff --git a/python/llm/example/CPU/HF-Transformers-AutoModels/Model/flan-t5/generate.py b/python/llm/example/CPU/HF-Transformers-AutoModels/Model/flan-t5/generate.py
new file mode 100644
index 00000000..58d5e446
--- /dev/null
+++ b/python/llm/example/CPU/HF-Transformers-AutoModels/Model/flan-t5/generate.py
@@ -0,0 +1,73 @@
+#
+# Copyright 2016 The BigDL Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import torch
+import time
+import argparse
+import numpy as np
+
+from bigdl.llm.transformers import AutoModelForSeq2SeqLM
+from transformers import AutoTokenizer
+
+# you could tune the prompt based on your own model,
+FLAN_T5_PROMPT_FORMAT = "<|User|>:{prompt}"
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser(description='Predict Tokens using `generate()` API for flan-t5 model')
+ parser.add_argument('--repo-id-or-model-path', type=str, default="google/flan-t5-xxl",
+ help='The huggingface repo id for the flan-t5 model to be downloaded'
+ ', or the path to the huggingface checkpoint folder')
+ parser.add_argument('--prompt', type=str, default="Translate to German: My name is Arthur",
+ help='Prompt to infer')
+ parser.add_argument('--n-predict', type=int, default=32,
+ help='Max tokens to predict')
+
+ args = parser.parse_args()
+ model_path = args.repo_id_or_model_path
+
+ # Load model in 4 bit,
+ # which convert the relevant layers in the model into INT4 format.
+ # "wo" module is not converted due to some issues of T5 model
+ # (https://github.com/huggingface/transformers/issues/20287),
+ # "lm_head" module is not converted to generate outputs with better quality
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_path,
+ load_in_4bit=True,
+ trust_remote_code=True,
+ modules_to_not_convert=["wo", "lm_head"])
+
+ # Load tokenizer
+ tokenizer = AutoTokenizer.from_pretrained(model_path,
+ trust_remote_code=True)
+
+ # Generate predicted tokens
+ with torch.inference_mode():
+ prompt = FLAN_T5_PROMPT_FORMAT.format(prompt=args.prompt)
+ input_ids = tokenizer.encode(prompt, return_tensors="pt")
+ st = time.time()
+ # if your selected model is capable of utilizing previous key/value attentions
+ # to enhance decoding speed, but has `"use_cache": false` in its model config,
+ # it is important to set `use_cache=True` explicitly in the `generate` function
+ # to obtain optimal performance with BigDL-LLM INT4 optimizations
+ output = model.generate(input_ids,
+ max_new_tokens=args.n_predict)
+ end = time.time()
+ output_str = tokenizer.decode(output[0], skip_special_tokens=True)
+ output_str = output_str.split("")[0]
+ print(f'Inference time: {end-st} s')
+ print('-'*20, 'Prompt', '-'*20)
+ print(prompt)
+ print('-'*20, 'Output', '-'*20)
+ print(output_str)
diff --git a/python/llm/example/CPU/HF-Transformers-AutoModels/Model/phi-1_5/README.md b/python/llm/example/CPU/HF-Transformers-AutoModels/Model/phi-1_5/README.md
new file mode 100644
index 00000000..013704c8
--- /dev/null
+++ b/python/llm/example/CPU/HF-Transformers-AutoModels/Model/phi-1_5/README.md
@@ -0,0 +1,74 @@
+# phi-1_5
+
+In this directory, you will find examples on how you could apply BigDL-LLM INT4 optimizations on phi-1_5 models. For illustration purposes, we utilize the [microsoft/phi-1_5](https://huggingface.co/microsoft/phi-1_5) as a reference phi-1_5 model.
+
+> **Note**: If you want to download the Hugging Face *Transformers* model, please refer to [here](https://huggingface.co/docs/hub/models-downloading#using-git).
+>
+> BigDL-LLM optimizes the *Transformers* model in INT4 precision at runtime, and thus no explicit conversion is needed.
+
+## Requirements
+To run these examples with BigDL-LLM, we have some recommended requirements for your machine, please refer to [here](../README.md#recommended-requirements) for more information.
+
+## Example: Predict Tokens using `generate()` API
+In the example [generate.py](./generate.py), we show a basic use case for a phi-1_5 model to predict the next N tokens using `generate()` API, with BigDL-LLM INT4 optimizations.
+### 1. Install
+We suggest using conda to manage the Python environment. For more information about conda installation, please refer to [here](https://docs.conda.io/en/latest/miniconda.html#).
+
+After installing conda, create a Python environment for BigDL-LLM:
+```bash
+conda create -n llm python=3.9 # recommend to use Python 3.9
+conda activate llm
+
+pip install --pre --upgrade bigdl-llm[all] # install the latest bigdl-llm nightly build with 'all' option
+pip install einops # additional package required for phi-1_5 to conduct generation
+```
+
+### 2. Run
+After setting up the Python environment, you could run the example by following steps.
+
+> **Note**: When loading the model in 4-bit, BigDL-LLM converts linear layers in the model into INT4 format. In theory, a *X*B model saved in 16-bit will requires approximately 2*X* GB of memory for loading, and ~0.5*X* GB memory for further inference.
+>
+> Please select the appropriate size of the phi-1_5 model based on the capabilities of your machine.
+
+#### 2.1 Client
+On client Windows machines, it is recommended to run directly with full utilization of all cores:
+```powershell
+python ./generate.py --prompt 'What is AI?'
+```
+More information about arguments can be found in [Arguments Info](#23-arguments-info) section. The expected output can be found in [Sample Output](#24-sample-output) section.
+
+#### 2.2 Server
+For optimal performance on server, it is recommended to set several environment variables (refer to [here](../README.md#best-known-configuration-on-linux) for more information), and run the example with all the physical cores of a single socket.
+
+E.g. on Linux,
+```bash
+# set BigDL-Nano env variables
+source bigdl-nano-init
+
+# e.g. for a server with 48 cores per socket
+export OMP_NUM_THREADS=48
+numactl -C 0-47 -m 0 python ./generate.py --prompt 'What is AI?'
+```
+More information about arguments can be found in [Arguments Info](#23-arguments-info) section. The expected output can be found in [Sample Output](#24-sample-output) section.
+
+#### 2.3 Arguments Info
+In the example, several arguments can be passed to satisfy your requirements:
+
+- `--repo-id-or-model-path`: str, argument defining the huggingface repo id for the phi-1_5 model to be downloaded, or the path to the huggingface checkpoint folder. It is default to be `'microsoft/phi-1_5'`.
+- `--prompt`: str, argument defining the prompt to be inferred (with integrated prompt format for chat). It is default to be `What is AI?`.
+- `--n-predict`: int, argument defining the max number of tokens to predict. It is default to be `32`.
+
+#### 2.4 Sample Output
+#### [microsoft/phi-1_5](https://huggingface.co/microsoft/phi-1_5)
+```log
+Inference time: xxxx s
+-------------------- Prompt --------------------
+Question: What is AI?
+
+ Answer:
+-------------------- Output --------------------
+Question: What is AI?
+
+ Answer: AI stands for Artificial Intelligence, which refers to the development of computer systems that can perform tasks that typically require human intelligence, such as visual perception, speech recognition,
+
+```
diff --git a/python/llm/example/CPU/HF-Transformers-AutoModels/Model/phi-1_5/generate.py b/python/llm/example/CPU/HF-Transformers-AutoModels/Model/phi-1_5/generate.py
new file mode 100644
index 00000000..e934d99b
--- /dev/null
+++ b/python/llm/example/CPU/HF-Transformers-AutoModels/Model/phi-1_5/generate.py
@@ -0,0 +1,72 @@
+#
+# Copyright 2016 The BigDL Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import torch
+import time
+import argparse
+import numpy as np
+
+from bigdl.llm.transformers import AutoModel,AutoModelForCausalLM
+from transformers import AutoTokenizer, GenerationConfig
+
+# you could tune the prompt based on your own model,
+# here the prompt tuning refers to # TODO: https://huggingface.co/microsoft/phi-1_5/blob/main/modeling_mixformer_sequential.py
+PHI1_5_PROMPT_FORMAT = " Question:{prompt}\n\n Answer:"
+generation_config = GenerationConfig(use_cache = True)
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser(description='Predict Tokens using `generate()` API for phi-1_5 model')
+ parser.add_argument('--repo-id-or-model-path', type=str, default="microsoft/phi-1_5",
+ help='The huggingface repo id for the phi-1_5 model to be downloaded'
+ ', or the path to the huggingface checkpoint folder')
+ parser.add_argument('--prompt', type=str, default="What is AI?",
+ help='Prompt to infer')
+ parser.add_argument('--n-predict', type=int, default=32,
+ help='Max tokens to predict')
+
+ args = parser.parse_args()
+ model_path = args.repo_id_or_model_path
+
+ # Load model in 4 bit,
+ # which convert the relevant layers in the model into INT4 format
+ model = AutoModelForCausalLM.from_pretrained(model_path,
+ load_in_4bit=True,
+ trust_remote_code=True)
+
+ # Load tokenizer
+ tokenizer = AutoTokenizer.from_pretrained(model_path,
+ trust_remote_code=True)
+
+ # Generate predicted tokens
+ with torch.inference_mode():
+ prompt = PHI1_5_PROMPT_FORMAT.format(prompt=args.prompt)
+ input_ids = tokenizer.encode(prompt, return_tensors="pt")
+ st = time.time()
+ # if your selected model is capable of utilizing previous key/value attentions
+ # to enhance decoding speed, but has `"use_cache": false` in its model config,
+ # it is important to set `use_cache=True` explicitly in the `generate` function
+ # to obtain optimal performance with BigDL-LLM INT4 optimizations
+
+ # Note that phi-1_5 uses GenerationConfig to enable 'use_cache'
+ output = model.generate(input_ids, do_sample=False, max_new_tokens=args.n_predict, generation_config = generation_config)
+
+ end = time.time()
+ output_str = tokenizer.decode(output[0], skip_special_tokens=True)
+ print(f'Inference time: {end-st} s')
+ print('-'*20, 'Prompt', '-'*20)
+ print(prompt)
+ print('-'*20, 'Output', '-'*20)
+ print(output_str)
diff --git a/python/llm/example/CPU/HF-Transformers-AutoModels/Model/qwen-vl/README.md b/python/llm/example/CPU/HF-Transformers-AutoModels/Model/qwen-vl/README.md
new file mode 100644
index 00000000..8c04bbb6
--- /dev/null
+++ b/python/llm/example/CPU/HF-Transformers-AutoModels/Model/qwen-vl/README.md
@@ -0,0 +1,91 @@
+# Qwen-VL
+In this directory, you will find examples on how you could apply BigDL-LLM INT4 optimizations on Qwen-VL models. For illustration purposes, we utilize the [Qwen/Qwen-VL-Chat](https://huggingface.co/Qwen/Qwen-VL-Chat) as a reference Qwen-VL model.
+
+## Requirements
+To run these examples with BigDL-LLM, we have some recommended requirements for your machine, please refer to [here](../README.md#recommended-requirements) for more information.
+
+## Example: Multimodal chat using `chat()` API
+In the example [chat.py](./chat.py), we show a basic use case for a Qwen-VL model to start a multimodal chat using `chat()` API, with BigDL-LLM INT4 optimizations.
+### 1. Install
+We suggest using conda to manage the Python environment. For more information about conda installation, please refer to [here](https://docs.conda.io/en/latest/miniconda.html#).
+
+After installing conda, create a Python environment for BigDL-LLM:
+```bash
+conda create -n llm python=3.9 # recommend to use Python 3.9
+conda activate llm
+
+pip install --pre --upgrade bigdl-llm[all] # install the latest bigdl-llm nightly build with 'all' option
+
+pip install accelerate tiktoken einops transformers_stream_generator==0.0.4 scipy torchvision pillow tensorboard matplotlib # additional package required for Qwen-VL-Chat to conduct generation
+
+```
+
+### 2. Run
+After setting up the Python environment, you could run the example by following steps.
+
+#### 2.1 Client
+On client Windows machines, it is recommended to run directly with full utilization of all cores:
+```powershell
+python ./chat.py
+```
+More information about arguments can be found in [Arguments Info](#23-arguments-info) section. The expected output can be found in [Sample Output](#24-sample-output) section.
+
+#### 2.2 Server
+For optimal performance on server, it is recommended to set several environment variables (refer to [here](../README.md#best-known-configuration-on-linux) for more information), and run the example with all the physical cores of a single socket.
+
+E.g. on Linux,
+```bash
+# set BigDL-Nano env variables
+source bigdl-nano-init
+
+# e.g. for a server with 48 cores per socket
+export OMP_NUM_THREADS=48
+numactl -C 0-47 -m 0 python ./chat.py
+```
+More information about arguments can be found in [Arguments Info](#23-arguments-info) section. The expected output can be found in [Sample Output](#24-sample-output) section.
+
+#### 2.3 Arguments Info
+In the example, several arguments can be passed to satisfy your requirements:
+
+- `--repo-id-or-model-path`: str, argument defining the huggingface repo id for the Qwen-VL model to be downloaded, or the path to the huggingface checkpoint folder. It is default to be `'Qwen/Qwen-VL-Chat'`.
+- `--n-predict`: int, argument defining the max number of tokens to predict. It is default to be `32`.
+
+In every session, image and text can be entered into cmd (user can skip the input by type **'Enter'**) ; please type **'exit'** anytime you want to quit the dialouge.
+
+Every image output will be named as the round of session and placed under the current directory.
+
+#### 2.4 Sample Chat
+#### [Qwen/Qwen-VL-Chat](https://huggingface.co/Qwen/Qwen-VL-Chat)
+
+```log
+-------------------- Session 1 --------------------
+ Please input a picture: https://images.unsplash.com/photo-1533738363-b7f9aef128ce?auto=format&fit=crop&q=60&w=500&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxzZWFyY2h8NHx8Y2F0fGVufDB8fDB8fHwy
+ Please enter the text: 这是什么
+---------- Response ----------
+图中是一只戴着墨镜的酷炫猫咪,正坐在窗边,看着窗外。
+
+-------------------- Session 2 --------------------
+ Please input a picture:
+ Please enter the text: 这只猫猫多大了?
+---------- Response ----------
+由于只猫猫戴着太阳镜,无法判断年龄,但可以猜测它应该是一只成年猫猫,已经成年。
+
+-------------------- Session 3 --------------------
+ Please input a picture:
+ Please enter the text: 在图中检测框出猫猫的墨镜
+---------- Response ----------
+[猫猫的墨镜](398,313),(994,506)
+
+-------------------- Session 4 --------------------
+ Please input a picture: exit
+```
+The sample input image in Session 1 is (which is fetched from [here](https://images.unsplash.com/photo-1533738363-b7f9aef128ce?auto=format&fit=crop&q=60&w=500&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxzZWFyY2h8NHx8Y2F0fGVufDB8fDB8fHwy)):
+
+
+
+The sample output image in Session 3 is:
+
+
+
+
+
diff --git a/python/llm/example/CPU/HF-Transformers-AutoModels/Model/qwen-vl/chat.py b/python/llm/example/CPU/HF-Transformers-AutoModels/Model/qwen-vl/chat.py
new file mode 100644
index 00000000..6c017755
--- /dev/null
+++ b/python/llm/example/CPU/HF-Transformers-AutoModels/Model/qwen-vl/chat.py
@@ -0,0 +1,85 @@
+#
+# Copyright 2016 The BigDL Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from bigdl.llm.transformers import AutoModel, AutoModelForCausalLM
+from transformers import AutoTokenizer, LlamaTokenizer
+from transformers.generation import GenerationConfig
+import torch
+import time
+import os
+import argparse
+from bigdl.llm import optimize_model
+torch.manual_seed(1234)
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser(description='Predict Tokens using `chat()` API for Qwen-VL model')
+ parser.add_argument('--repo-id-or-model-path', type=str, default="Qwen/Qwen-VL-Chat",
+ help='The huggingface repo id for the Qwen-VL model to be downloaded'
+ ', or the path to the huggingface checkpoint folder')
+ parser.add_argument('--n-predict', type=int, default=32, help='Max tokens to predict')
+
+ current_path = os.path.dirname(os.path.abspath(__file__))
+ args = parser.parse_args()
+ model_path = args.repo_id_or_model_path
+
+ # Load model
+ # For successful BigDL-LLM optimization on Qwen-VL-Chat, skip the 'c_fc' and 'out_proj' modules during optimization
+ model = AutoModelForCausalLM.from_pretrained(model_path,
+ load_in_4bit=True,
+ device_map="cpu",
+ trust_remote_code=True,
+ modules_to_not_convert=['c_fc', 'out_proj'] )
+
+ # Specify hyperparameters for generation (No need to do this if you are using transformers>=4.32.0)
+ model.generation_config = GenerationConfig.from_pretrained(model_path, trust_remote_code=True)
+
+ # Load tokenizer
+ tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
+
+ # Session ID
+ session_id = 1
+
+ while True:
+ print('-'*20, 'Session %d' % session_id, '-'*20)
+ image_input = input(f' Please input a picture: ')
+ if image_input.lower() == 'exit' : # type 'exit' to quit the dialouge
+ break
+
+ text_input = input(f' Please enter the text: ')
+ if text_input.lower() == 'exit' : # type 'exit' to quit the dialouge
+ break
+
+ if session_id == 1:
+ history = None
+
+ all_input = [{'image': image_input}, {'text': text_input}]
+ input_list = [_input for _input in all_input if list(_input.values())[0] != '']
+
+ if len(input_list) == 0:
+ print("Input list should not be empty. Please try again with valid input.")
+ continue
+
+ query = tokenizer.from_list_format(input_list)
+ response, history = model.chat(tokenizer, query = query, history = history)
+
+ print('-'*10, 'Response', '-'*10)
+ print(response, '\n')
+
+ image = tokenizer.draw_bbox_on_latest_picture(response, history)
+ if image is not None:
+ image.save(os.path.join(current_path, f'Session_{session_id}.png'), )
+
+ session_id += 1
diff --git a/python/llm/example/CPU/PyTorch-Models/Model/README.md b/python/llm/example/CPU/PyTorch-Models/Model/README.md
index 268534df..40288dd4 100644
--- a/python/llm/example/CPU/PyTorch-Models/Model/README.md
+++ b/python/llm/example/CPU/PyTorch-Models/Model/README.md
@@ -10,6 +10,9 @@ You can use `optimize_model` API to accelerate general PyTorch models on Intel s
| BERT | [link](bert) |
| Bark | [link](bark) |
| Mistral | [link](mistral) |
+| Flan-t5 | [link](flan-t5) |
+| Phi-1_5 | [link](phi-1_5) |
+| Qwen-VL | [link](qwen-vl) |
## Recommended Requirements
To run the examples, we recommend using Intel® Xeon® processors (server), or >= 12th Gen Intel® Core™ processor (client).
diff --git a/python/llm/example/CPU/PyTorch-Models/Model/flan-t5/README.md b/python/llm/example/CPU/PyTorch-Models/Model/flan-t5/README.md
new file mode 100644
index 00000000..0b6ab57a
--- /dev/null
+++ b/python/llm/example/CPU/PyTorch-Models/Model/flan-t5/README.md
@@ -0,0 +1,66 @@
+# Flan-t5
+
+In this directory, you will find examples on how you could apply BigDL-LLM INT4 optimizations on Flan-t5 models. For illustration purposes, we utilize the [google/flan-t5-xxl](https://huggingface.co/google/flan-t5-xxl) as a reference Flan-t5 model.
+
+## 0. Requirements
+To run these examples with BigDL-LLM, we have some recommended requirements for your machine, please refer to [here](../README.md#recommended-requirements) for more information.
+
+## Example: Predict Tokens using `generate()` API
+In the example [generate.py](./generate.py), we show a basic use case for a Flan-t5 model to predict the next N tokens using `generate()` API, with BigDL-LLM INT4 optimizations.
+### 1. Install
+We suggest using conda to manage the Python environment. For more information about conda installation, please refer to [here](https://docs.conda.io/en/latest/miniconda.html#).
+
+After installing conda, create a Python environment for BigDL-LLM:
+```bash
+conda create -n llm python=3.9 # recommend to use Python 3.9
+conda activate llm
+
+pip install --pre --upgrade bigdl-llm[all] # install the latest bigdl-llm nightly build with 'all' option
+```
+
+### 2. Run
+After setting up the Python environment, you could run the example by following steps.
+
+> **Note**: When loading the model in 4-bit, BigDL-LLM converts linear layers in the model into INT4 format. In theory, a *X*B model saved in 16-bit will requires approximately 2*X* GB of memory for loading, and ~0.5*X* GB memory for further inference.
+>
+> Please select the appropriate size of the Flan-t5 model based on the capabilities of your machine.
+
+#### 2.1 Client
+On client Windows machines, it is recommended to run directly with full utilization of all cores:
+```powershell
+python ./generate.py --prompt 'Translate to German: My name is Arthur'
+```
+More information about arguments can be found in [Arguments Info](#23-arguments-info) section. The expected output can be found in [Sample Output](#24-sample-output) section.
+
+#### 2.2 Server
+For optimal performance on server, it is recommended to set several environment variables (refer to [here](../README.md#best-known-configuration-on-linux) for more information), and run the example with all the physical cores of a single socket.
+
+E.g. on Linux,
+```bash
+# set BigDL-Nano env variables
+source bigdl-nano-init
+
+# e.g. for a server with 48 cores per socket
+export OMP_NUM_THREADS=48
+numactl -C 0-47 -m 0 python ./generate.py --prompt 'Translate to German: My name is Arthur'
+```
+More information about arguments can be found in [Arguments Info](#23-arguments-info) section. The expected output can be found in [Sample Output](#24-sample-output) section.
+
+#### 2.3 Arguments Info
+In the example, several arguments can be passed to satisfy your requirements:
+
+- `--repo-id-or-model-path REPO_ID_OR_MODEL_PATH`: argument defining the huggingface repo id for the Flan-t5 model (e.g. `google/flan-t5-xxl`) to be downloaded, or the path to the huggingface checkpoint folder. It is default to be `'google/flan-t5-xxl'`.
+- `--prompt PROMPT`: argument defining the prompt to be infered (with integrated prompt format for chat). It is default to be `'Translate to German: My name is Arthur'`.
+- `--n-predict N_PREDICT`: argument defining the max number of tokens to predict. It is default to be `32`.
+
+
+#### 2.4 Sample Output
+#### [google/flan-t5-xxl](https://huggingface.co/google/flan-t5-xxl)
+
+```log
+Inference time: xxxx s
+-------------------- Prompt --------------------
+<|User|>:Translate to German: My name is Arthur
+-------------------- Output --------------------
+Ich bin Arthur.
+```
diff --git a/python/llm/example/CPU/PyTorch-Models/Model/flan-t5/generate.py b/python/llm/example/CPU/PyTorch-Models/Model/flan-t5/generate.py
new file mode 100644
index 00000000..51ba2500
--- /dev/null
+++ b/python/llm/example/CPU/PyTorch-Models/Model/flan-t5/generate.py
@@ -0,0 +1,65 @@
+#
+# Copyright 2016 The BigDL Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import torch
+import time
+import argparse
+
+from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
+from bigdl.llm import optimize_model
+
+# you could tune the prompt based on your own model,
+FLAN_T5_PROMPT_FORMAT = "<|User|>:{prompt}"
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser(description='Predict Tokens using `generate()` API for flan-t5 model')
+ parser.add_argument('--repo-id-or-model-path', type=str, default="google/flan-t5-xxl",
+ help='The huggingface repo id for the flan-t5 model to be downloaded'
+ ', or the path to the huggingface checkpoint folder')
+ parser.add_argument('--prompt', type=str, default="Translate to German: My name is Arthur",
+ help='Prompt to infer')
+ parser.add_argument('--n-predict', type=int, default=32,
+ help='Max tokens to predict')
+
+ args = parser.parse_args()
+ model_path = args.repo_id_or_model_path
+
+ # Load model
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_path, trust_remote_code=True)
+
+ # With only one line to enable BigDL-LLM optimization on model
+ # "wo" module is not converted due to some issues of T5 model
+ # (https://github.com/huggingface/transformers/issues/20287),
+ # "lm_head" module is not converted to generate outputs with better quality
+ model = optimize_model(model, modules_to_not_convert=["wo", "lm_head"])
+
+ # Load tokenizer
+ tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
+
+ # Generate predicted tokens
+ with torch.inference_mode():
+ prompt = FLAN_T5_PROMPT_FORMAT.format(prompt=args.prompt)
+ input_ids = tokenizer.encode(prompt, return_tensors="pt")
+ st = time.time()
+ output = model.generate(input_ids,
+ max_new_tokens=args.n_predict)
+ end = time.time()
+ output_str = tokenizer.decode(output[0], skip_special_tokens=True)
+ print(f'Inference time: {end-st} s')
+ print('-'*20, 'Prompt', '-'*20)
+ print(prompt)
+ print('-'*20, 'Output', '-'*20)
+ print(output_str)
diff --git a/python/llm/example/CPU/PyTorch-Models/Model/phi-1_5/README.md b/python/llm/example/CPU/PyTorch-Models/Model/phi-1_5/README.md
new file mode 100644
index 00000000..f925d0d6
--- /dev/null
+++ b/python/llm/example/CPU/PyTorch-Models/Model/phi-1_5/README.md
@@ -0,0 +1,60 @@
+# phi-1_5
+In this directory, you will find examples on how you could use BigDL-LLM `optimize_model` API to accelerate phi-1_5 models. For illustration purposes, we utilize the [microsoft/phi-1_5](https://huggingface.co/microsoft/phi-1_5) as a reference phi-1_5 model.
+
+## Requirements
+To run these examples with BigDL-LLM, we have some recommended requirements for your machine, please refer to [here](../README.md#recommended-requirements) for more information.
+
+## Example: Predict Tokens using `generate()` API
+In the example [generate.py](./generate.py), we show a basic use case for a phi-1_5 model to predict the next N tokens using `generate()` API, with BigDL-LLM INT4 optimizations.
+### 1. Install
+We suggest using conda to manage the Python environment. For more information about conda installation, please refer to [here](https://docs.conda.io/en/latest/miniconda.html#).
+
+After installing conda, create a Python environment for BigDL-LLM:
+```bash
+conda create -n llm python=3.9 # recommend to use Python 3.9
+conda activate llm
+
+pip install --pre --upgrade bigdl-llm[all] # install the latest bigdl-llm nightly build with 'all' option
+pip install einops
+```
+
+### 2. Run
+After setting up the Python environment, you could run the example by following steps.
+
+#### 2.1 Client
+On client Windows machines, it is recommended to run directly with full utilization of all cores:
+```powershell
+python ./generate.py --prompt 'What is AI?'
+```
+More information about arguments can be found in [Arguments Info](#23-arguments-info) section. The expected output can be found in [Sample Output](#24-sample-output) section.
+
+#### 2.2 Server
+For optimal performance on server, it is recommended to set several environment variables (refer to [here](../README.md#best-known-configuration-on-linux) for more information), and run the example with all the physical cores of a single socket.
+
+E.g. on Linux,
+```bash
+# set BigDL-Nano env variables
+source bigdl-nano-init
+
+# e.g. for a server with 48 cores per socket
+export OMP_NUM_THREADS=48
+numactl -C 0-47 -m 0 python ./generate.py --prompt 'What is AI?'
+```
+More information about arguments can be found in [Arguments Info](#23-arguments-info) section. The expected output can be found in [Sample Output](#24-sample-output) section.
+
+#### 2.3 Arguments Info
+In the example, several arguments can be passed to satisfy your requirements:
+
+- `--repo-id-or-model-path`: str, argument defining the huggingface repo id for the phi-1_5 model to be downloaded, or the path to the huggingface checkpoint folder. It is default to be `'microsoft/phi-1_5'`.
+- `--prompt`: str, argument defining the prompt to be inferred (with integrated prompt format for chat). It is default to be `'What is AI?'`.
+- `--n-predict`: int, argument defining the max number of tokens to predict. It is default to be `32`.
+
+#### 2.4 Sample Output
+#### [microsoft/phi-1_5](https://huggingface.co/microsoft/phi-1_5)
+```log
+Inference time: xxxx s
+-------------------- Output --------------------
+Question: What is AI?
+
+Answer: AI stands for Artificial Intelligence, which refers to the development of computer systems that can perform tasks that typically require human intelligence, such as visual perception, speech recognition,
+```
diff --git a/python/llm/example/CPU/PyTorch-Models/Model/phi-1_5/generate.py b/python/llm/example/CPU/PyTorch-Models/Model/phi-1_5/generate.py
new file mode 100644
index 00000000..f819a47f
--- /dev/null
+++ b/python/llm/example/CPU/PyTorch-Models/Model/phi-1_5/generate.py
@@ -0,0 +1,61 @@
+#
+# Copyright 2016 The BigDL Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import torch
+import time
+import argparse
+
+from transformers import AutoModel, AutoTokenizer, AutoModelForCausalLM, GenerationConfig
+from bigdl.llm import optimize_model
+
+# you could tune the prompt based on your own model,
+# here the prompt tuning refers to https://huggingface.co/microsoft/phi-1_5/blob/main/modeling_mixformer_sequential.py
+PHI_1_5_V1_PROMPT_FORMAT = "Question: {prompt}\n\n Answer:"
+generation_config = GenerationConfig(use_cache = True)
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser(description='Predict Tokens using `generate()` API for phi-1_5 model')
+ parser.add_argument('--repo-id-or-model-path', type=str, default="microsoft/phi-1_5",
+ help='The huggingface repo id for the phi-1_5 model to be downloaded'
+ ', or the path to the huggingface checkpoint folder')
+ parser.add_argument('--prompt', type=str, default="What is AI?",
+ help='Prompt to infer')
+ parser.add_argument('--n-predict', type=int, default=32,
+ help='Max tokens to predict')
+
+ args = parser.parse_args()
+ model_path = args.repo_id_or_model_path
+
+ # Load model
+ model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True)
+
+ # With only one line to enable BigDL-LLM optimization on model
+ model = optimize_model(model)
+
+ # Load tokenizer
+ tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
+
+ # Generate predicted tokens
+ with torch.inference_mode():
+ prompt = PHI_1_5_V1_PROMPT_FORMAT.format(prompt=args.prompt)
+ input_ids = tokenizer.encode(prompt, return_tensors="pt")
+ st = time.time()
+ output = model.generate(input_ids, max_new_tokens=args.n_predict, generation_config = generation_config)
+ end = time.time()
+ output_str = tokenizer.decode(output[0], skip_special_tokens=True)
+ print(f'Inference time: {end-st} s')
+ print('-'*20, 'Output', '-'*20)
+ print(output_str)
diff --git a/python/llm/example/CPU/PyTorch-Models/Model/qwen-vl/README.md b/python/llm/example/CPU/PyTorch-Models/Model/qwen-vl/README.md
new file mode 100644
index 00000000..444929ff
--- /dev/null
+++ b/python/llm/example/CPU/PyTorch-Models/Model/qwen-vl/README.md
@@ -0,0 +1,90 @@
+# Qwen-VL
+In this directory, you will find examples on how you could use BigDL-LLM `optimize_model` API to accelerate Qwen-VL models. For illustration purposes, we utilize the [Qwen/Qwen-VL-Chat](https://huggingface.co/Qwen/Qwen-VL-Chat) as a reference Qwen-VL model.
+
+## Requirements
+To run these examples with BigDL-LLM, we have some recommended requirements for your machine, please refer to [here](../README.md#recommended-requirements) for more information.
+
+## Example: Multimodal chat using `chat()` API
+In the example [chat.py](./chat.py), we show a basic use case for a Qwen-VL model to start a multimodal chat using `chat()` API, with BigDL-LLM 'optimize_model' API.
+### 1. Install
+We suggest using conda to manage the Python environment. For more information about conda installation, please refer to [here](https://docs.conda.io/en/latest/miniconda.html#).
+
+After installing conda, create a Python environment for BigDL-LLM:
+```bash
+conda create -n llm python=3.9 # recommend to use Python 3.9
+conda activate llm
+
+pip install --pre --upgrade bigdl-llm[all] # install the latest bigdl-llm nightly build with 'all' option
+
+pip install accelerate tiktoken einops transformers_stream_generator==0.0.4 scipy torchvision pillow tensorboard matplotlib # additional package required for Qwen-VL-Chat to conduct generation
+
+```
+
+### 2. Run
+After setting up the Python environment, you could run the example by following steps.
+
+#### 2.1 Client
+On client Windows machines, it is recommended to run directly with full utilization of all cores:
+```powershell
+python ./chat.py
+```
+More information about arguments can be found in [Arguments Info](#23-arguments-info) section. The expected output can be found in [Sample Output](#24-sample-output) section.
+
+#### 2.2 Server
+For optimal performance on server, it is recommended to set several environment variables (refer to [here](../README.md#best-known-configuration-on-linux) for more information), and run the example with all the physical cores of a single socket.
+
+E.g. on Linux,
+```bash
+# set BigDL-Nano env variables
+source bigdl-nano-init
+
+# e.g. for a server with 48 cores per socket
+export OMP_NUM_THREADS=48
+numactl -C 0-47 -m 0 python ./chat.py
+```
+More information about arguments can be found in [Arguments Info](#23-arguments-info) section. The expected output can be found in [Sample Output](#24-sample-output) section.
+
+#### 2.3 Arguments Info
+In the example, several arguments can be passed to satisfy your requirements:
+
+- `--repo-id-or-model-path`: str, argument defining the huggingface repo id for the Qwen-VL model to be downloaded, or the path to the huggingface checkpoint folder. It is default to be `'Qwen/Qwen-VL-Chat'`.
+- `--n-predict`: int, argument defining the max number of tokens to predict. It is default to be `32`.
+
+In every session, image and text can be entered into cmd (user can skip the input by type **'Enter'**) ; please type **'exit'** anytime you want to quit the dialouge.
+
+Every image output will be named as the round of session and placed under the current directory.
+
+#### 2.4 Sample Chat
+#### [Qwen/Qwen-VL-Chat](https://huggingface.co/Qwen/Qwen-VL-Chat)
+
+```log
+-------------------- Session 1 --------------------
+ Please input a picture: https://images.unsplash.com/photo-1533738363-b7f9aef128ce?auto=format&fit=crop&q=60&w=500&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxzZWFyY2h8NHx8Y2F0fGVufDB8fDB8fHwy
+ Please enter the text: 这是什么
+---------- Response ----------
+图中是一只戴着墨镜的酷炫猫咪,正坐在窗边,看着窗外。
+
+-------------------- Session 2 --------------------
+ Please input a picture:
+ Please enter the text: 这只猫猫多大了?
+---------- Response ----------
+由于只猫猫戴着太阳镜,无法判断年龄,但可以猜测它应该是一只成年猫猫,已经成年。
+
+-------------------- Session 3 --------------------
+ Please input a picture:
+ Please enter the text: 在图中检测框出猫猫的墨镜
+---------- Response ----------
+[猫猫的墨镜](398,313),(994,506)
+
+-------------------- Session 4 --------------------
+ Please input a picture: exit
+```
+
+The sample input image in Session 1 is (which is fetched from [here](https://images.unsplash.com/photo-1533738363-b7f9aef128ce?auto=format&fit=crop&q=60&w=500&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxzZWFyY2h8NHx8Y2F0fGVufDB8fDB8fHwy)):
+
+
+
+The sample output image in Session 3 is:
+
+
+
diff --git a/python/llm/example/CPU/PyTorch-Models/Model/qwen-vl/chat.py b/python/llm/example/CPU/PyTorch-Models/Model/qwen-vl/chat.py
new file mode 100644
index 00000000..5502a697
--- /dev/null
+++ b/python/llm/example/CPU/PyTorch-Models/Model/qwen-vl/chat.py
@@ -0,0 +1,85 @@
+#
+# Copyright 2016 The BigDL Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from transformers import AutoModelForCausalLM, AutoTokenizer
+from transformers.generation import GenerationConfig
+import torch
+import time
+import os
+import argparse
+from bigdl.llm import optimize_model
+torch.manual_seed(1234)
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser(description='Predict Tokens using `chat()` API for Qwen-VL model')
+ parser.add_argument('--repo-id-or-model-path', type=str, default="Qwen/Qwen-VL-Chat",
+ help='The huggingface repo id for the Qwen-VL model to be downloaded'
+ ', or the path to the huggingface checkpoint folder')
+ parser.add_argument('--n-predict', type=int, default=32, help='Max tokens to predict')
+
+ current_path = os.path.dirname(os.path.abspath(__file__))
+ args = parser.parse_args()
+ model_path = args.repo_id_or_model_path
+
+ # Load model
+ model = AutoModelForCausalLM.from_pretrained(model_path, device_map="cpu", trust_remote_code=True)
+
+ # With only one line to enable BigDL-LLM optimization on model
+ # For successful BigDL-LLM optimization on Qwen-VL-Chat, skip the 'c_fc' and 'out_proj' modules during optimization
+ model = optimize_model(model,
+ low_bit='sym_int4',
+ modules_to_not_convert=['c_fc', 'out_proj'])
+
+ # Specify hyperparameters for generation (No need to do this if you are using transformers>=4.32.0)
+ model.generation_config = GenerationConfig.from_pretrained(model_path, trust_remote_code=True)
+
+ # Load tokenizer
+ tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
+
+ # Session ID
+ session_id = 1
+
+ while True:
+ print('-'*20, 'Session %d' % session_id, '-'*20)
+ image_input = input(f' Please input a picture: ')
+ if image_input.lower() == 'exit' : # type 'exit' to quit the dialouge
+ break
+
+ text_input = input(f' Please enter the text: ')
+ if text_input.lower() == 'exit' : # type 'exit' to quit the dialouge
+ break
+
+ if session_id == 1:
+ history = None
+
+ all_input = [{'image': image_input}, {'text': text_input}]
+ input_list = [_input for _input in all_input if list(_input.values())[0] != '']
+
+ if len(input_list) == 0:
+ print("Input list should not be empty. Please try again with valid input.")
+ continue
+
+ query = tokenizer.from_list_format(input_list)
+ response, history = model.chat(tokenizer, query = query, history = history)
+
+ print('-'*10, 'Response', '-'*10)
+ print(response, '\n')
+
+ image = tokenizer.draw_bbox_on_latest_picture(response, history)
+ if image is not None:
+ image.save(os.path.join(current_path, f'Session_{session_id}.png'), )
+
+ session_id += 1
diff --git a/python/llm/example/GPU/Deepspeed-AutoTP/deepspeed_autotp.py b/python/llm/example/GPU/Deepspeed-AutoTP/deepspeed_autotp.py
new file mode 100644
index 00000000..6b1309a7
--- /dev/null
+++ b/python/llm/example/GPU/Deepspeed-AutoTP/deepspeed_autotp.py
@@ -0,0 +1,103 @@
+#
+# Copyright 2016 The BigDL Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import os
+import torch
+import transformers
+import deepspeed
+
+local_rank = int(os.getenv("LOCAL_RANK", "0"))
+world_size = int(os.getenv("WORLD_SIZE", "1"))
+
+from bigdl.llm import optimize_model
+
+import torch
+import intel_extension_for_pytorch as ipex
+import time
+import argparse
+
+from transformers import AutoModelForCausalLM # export AutoModelForCausalLM from transformers so that deepspeed use it
+from transformers import LlamaTokenizer, AutoTokenizer
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser(description='Predict Tokens using `generate()` API for Llama2 model')
+ parser.add_argument('--repo-id-or-model-path', type=str, default="meta-llama/Llama-2-7b-chat-hf",
+ help='The huggingface repo id for the Llama2 (e.g. `meta-llama/Llama-2-7b-chat-hf` and `meta-llama/Llama-2-13b-chat-hf`) to be downloaded'
+ ', or the path to the huggingface checkpoint folder')
+ parser.add_argument('--prompt', type=str, default="Once upon a time, there existed a little girl who liked to have adventures. She wanted to go to places and meet new people, and have fun",
+ help='Prompt to infer')
+ parser.add_argument('--n-predict', type=int, default=32,
+ help='Max tokens to predict')
+
+ args = parser.parse_args()
+ model_path = args.repo_id_or_model_path
+
+ model = AutoModelForCausalLM.from_pretrained(args.repo_id_or_model_path,
+ low_cpu_mem_usage=True,
+ torch_dtype=torch.float16,
+ trust_remote_code=True,
+ use_cache=True)
+
+ model = deepspeed.init_inference(
+ model,
+ mp_size=world_size,
+ dtype=torch.float16,
+ replace_method="auto",
+ )
+
+ # move model to cpu and use bigdl-llm `optimize_model` to convert the
+ # model into optimized low bit format
+ # convert the rest of the model into float16 to reduce allreduce traffic
+ model = optimize_model(model.module.to(f'cpu'), low_bit='sym_int4').to(torch.float16)
+
+ # move model back to xpu
+ model = model.to(f'xpu:{local_rank}')
+
+ print(model)
+
+ # Load tokenizer
+ tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
+
+ # Generate predicted tokens
+ with torch.inference_mode():
+ # prompt = get_prompt(args.prompt, [], system_prompt=DEFAULT_SYSTEM_PROMPT)
+ prompt = args.prompt
+ # input_str = f"Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{prompt}\n\n### Response:\n"
+ input_ids = tokenizer.encode(prompt, return_tensors="pt").to(f'xpu:{local_rank}')
+ # ipex model needs a warmup, then inference time can be accurate
+ output = model.generate(input_ids,
+ max_new_tokens=args.n_predict,
+ use_cache=True)
+
+ # start inference
+ st = time.time()
+ # if your selected model is capable of utilizing previous key/value attentions
+ # to enhance decoding speed, but has `"use_cache": false` in its model config,
+ # it is important to set `use_cache=True` explicitly in the `generate` function
+ # to obtain optimal performance with BigDL-LLM INT4 optimizations
+ output = model.generate(input_ids,
+ do_sample=False,
+ max_new_tokens=args.n_predict)
+ torch.xpu.synchronize()
+ end = time.time()
+ if local_rank == 0:
+ output = output.cpu()
+ output_str = tokenizer.decode(output[0], skip_special_tokens=True)
+ print(f'Inference time: {end-st} s')
+ print('-'*20, 'Prompt', '-'*20)
+ print(prompt)
+ print('-'*20, 'Output', '-'*20)
+ print(output_str)
diff --git a/python/llm/example/GPU/Deepspeed-AutoTP/run.sh b/python/llm/example/GPU/Deepspeed-AutoTP/run.sh
new file mode 100644
index 00000000..972e8c9d
--- /dev/null
+++ b/python/llm/example/GPU/Deepspeed-AutoTP/run.sh
@@ -0,0 +1,12 @@
+source bigdl-llm-init -t -g
+export MASTER_ADDR=127.0.0.1
+export CCL_ZE_IPC_EXCHANGE=sockets
+if [[ -n $OMP_NUM_THREADS ]]; then
+ export OMP_NUM_THREADS=$(($OMP_NUM_THREADS / 4))
+else
+ export OMP_NUM_THREADS=$(($(nproc) / 4))
+fi
+torchrun --standalone \
+ --nnodes=1 \
+ --nproc-per-node 4 \
+ deepspeed_autotp.py --repo-id-or-model-path "meta-llama/Llama-2-7b-hf"
diff --git a/python/llm/example/GPU/HF-Transformers-AutoModels/Model/README.md b/python/llm/example/GPU/HF-Transformers-AutoModels/Model/README.md
index 6f8fc22b..ec5db97f 100644
--- a/python/llm/example/GPU/HF-Transformers-AutoModels/Model/README.md
+++ b/python/llm/example/GPU/HF-Transformers-AutoModels/Model/README.md
@@ -23,6 +23,7 @@ You can use BigDL-LLM to run almost every Huggingface Transformer models with IN
| Vicuna | [link](vicuna) |
| Whisper | [link](whisper) |
| Replit | [link](replit) |
+| Flan-t5 | [link](flan-t5) |
## Verified Hardware Platforms
diff --git a/python/llm/example/GPU/HF-Transformers-AutoModels/Model/flan-t5/README.md b/python/llm/example/GPU/HF-Transformers-AutoModels/Model/flan-t5/README.md
new file mode 100644
index 00000000..a7d58fe5
--- /dev/null
+++ b/python/llm/example/GPU/HF-Transformers-AutoModels/Model/flan-t5/README.md
@@ -0,0 +1,55 @@
+# Flan-t5
+In this directory, you will find examples on how you could apply BigDL-LLM INT4 optimizations on Flan-t5 models on [Intel GPUs](../README.md). For illustration purposes, we utilize the [google/flan-t5-xxl](https://huggingface.co/google/flan-t5-xxl) as a reference Flan-t5 model.
+
+## 0. Requirements
+To run these examples with BigDL-LLM on Intel GPUs, we have some recommended requirements for your machine, please refer to [here](../README.md#recommended-requirements) for more information.
+
+## Example: Predict Tokens using `generate()` API
+In the example [generate.py](./generate.py), we show a basic use case for a Flan-t5 model to predict the next N tokens using `generate()` API, with BigDL-LLM INT4 optimizations on Intel GPUs.
+### 1. Install
+We suggest using conda to manage the Python environment. For more information about conda installation, please refer to [here](https://docs.conda.io/en/latest/miniconda.html#).
+
+After installing conda, create a Python environment for BigDL-LLM:
+```bash
+conda create -n llm python=3.9 # recommend to use Python 3.9
+conda activate llm
+
+# below command will install intel_extension_for_pytorch==2.0.110+xpu as default
+# you can install specific ipex/torch version for your need
+pip install --pre --upgrade bigdl-llm[xpu] -f https://developer.intel.com/ipex-whl-stable-xpu
+```
+
+### 2. Configures OneAPI environment variables
+```bash
+source /opt/intel/oneapi/setvars.sh
+```
+
+### 3. Run
+
+For optimal performance on Arc, it is recommended to set several environment variables.
+
+```bash
+export USE_XETLA=OFF
+export SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=1
+```
+
+```bash
+python ./generate.py --prompt 'Translate to German: My name is Arthur'
+```
+
+In the example, several arguments can be passed to satisfy your requirements:
+
+- `--repo-id-or-model-path REPO_ID_OR_MODEL_PATH`: argument defining the huggingface repo id for the Flan-t5 model (e.g. `google/flan-t5-xxl` to be downloaded, or the path to the huggingface checkpoint folder. It is default to be `'google/flan-t5-xxl'`.
+- `--prompt PROMPT`: argument defining the prompt to be infered (with integrated prompt format for chat). It is default to be `'Translate to German: My name is Arthur'`.
+- `--n-predict N_PREDICT`: argument defining the max number of tokens to predict. It is default to be `32`.
+
+#### Sample Output
+#### [google/flan-t5-xxl](https://huggingface.co/google/flan-t5-xxl)
+
+```log
+Inference time: xxxx s
+-------------------- Prompt --------------------
+<|User|>:Translate to German: My name is Arthur
+-------------------- Output --------------------
+Ich bin Arthur.
+```
diff --git a/python/llm/example/GPU/HF-Transformers-AutoModels/Model/flan-t5/generate.py b/python/llm/example/GPU/HF-Transformers-AutoModels/Model/flan-t5/generate.py
new file mode 100644
index 00000000..8d6ec148
--- /dev/null
+++ b/python/llm/example/GPU/HF-Transformers-AutoModels/Model/flan-t5/generate.py
@@ -0,0 +1,83 @@
+#
+# Copyright 2016 The BigDL Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import torch
+import intel_extension_for_pytorch as ipex
+import time
+import argparse
+
+from bigdl.llm.transformers import AutoModelForSeq2SeqLM
+from transformers import AutoTokenizer
+
+# you could tune the prompt based on your own model,
+FLAN_T5_PROMPT_FORMAT = "<|User|>:{prompt}"
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser(description='Predict Tokens using `generate()` API for flan-t5 model')
+ parser.add_argument('--repo-id-or-model-path', type=str, default="google/flan-t5-xxl",
+ help='The huggingface repo id for the flan-t5 model to be downloaded'
+ ', or the path to the huggingface checkpoint folder')
+ parser.add_argument('--prompt', type=str, default="Translate to German: My name is Arthur",
+ help='Prompt to infer')
+ parser.add_argument('--n-predict', type=int, default=32,
+ help='Max tokens to predict')
+
+ args = parser.parse_args()
+ model_path = args.repo_id_or_model_path
+
+ # Load model in 4 bit,
+ # which convert the relevant layers in the model into INT4 format.
+ # "wo" module is not converted due to some issues of T5 model
+ # (https://github.com/huggingface/transformers/issues/20287),
+ # "lm_head" module is not converted to generate outputs with better quality
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_path,
+ load_in_4bit=True,
+ optimize_model=False,
+ trust_remote_code=True,
+ use_cache=True,
+ modules_to_not_convert=["wo", "lm_head"])
+ model = model.to('xpu')
+
+ # Load tokenizer
+ tokenizer = AutoTokenizer.from_pretrained(model_path,
+ trust_remote_code=True)
+
+ # Generate predicted tokens
+ with torch.inference_mode():
+ prompt = FLAN_T5_PROMPT_FORMAT.format(prompt=args.prompt)
+ input_ids = tokenizer.encode(prompt, return_tensors="pt").to('xpu')
+ # ipex model needs a warmup, then inference time can be accurate
+ output = model.generate(input_ids,
+ max_new_tokens=args.n_predict)
+
+ # start inference
+ st = time.time()
+ # if your selected model is capable of utilizing previous key/value attentions
+ # to enhance decoding speed, but has `"use_cache": false` in its model config,
+ # it is important to set `use_cache=True` explicitly in the `generate` function
+ # to obtain optimal performance with BigDL-LLM INT4 optimizations
+ output = model.generate(input_ids,
+ max_new_tokens=args.n_predict)
+ torch.xpu.synchronize()
+ end = time.time()
+ output = output.cpu()
+ output_str = tokenizer.decode(output[0], skip_special_tokens=True)
+ output_str = output_str.split("")[0]
+ print(f'Inference time: {end-st} s')
+ print('-'*20, 'Prompt', '-'*20)
+ print(prompt)
+ print('-'*20, 'Output', '-'*20)
+ print(output_str)
diff --git a/python/llm/example/GPU/HF-Transformers-AutoModels/Model/phi-1_5/README.md b/python/llm/example/GPU/HF-Transformers-AutoModels/Model/phi-1_5/README.md
new file mode 100644
index 00000000..09d07df7
--- /dev/null
+++ b/python/llm/example/GPU/HF-Transformers-AutoModels/Model/phi-1_5/README.md
@@ -0,0 +1,56 @@
+# phi-1_5
+In this directory, you will find examples on how you could apply BigDL-LLM INT4 optimizations on phi-1_5 models on [Intel GPUs](../README.md). For illustration purposes, we utilize the [microsoft/phi-1_5](https://huggingface.co/microsoft/phi-1_5) as a reference phi-1_5 model.
+
+## 0. Requirements
+To run these examples with BigDL-LLM on Intel GPUs, we have some recommended requirements for your machine, please refer to [here](../README.md#recommended-requirements) for more information.
+
+## Example: Predict Tokens using `generate()` API
+In the example [generate.py](./generate.py), we show a basic use case for a phi-1_5 model to predict the next N tokens using `generate()` API, with BigDL-LLM INT4 optimizations on Intel GPUs.
+### 1. Install
+We suggest using conda to manage environment:
+```bash
+conda create -n llm python=3.9
+conda activate llm
+# below command will install intel_extension_for_pytorch==2.0.110+xpu as default
+# you can install specific ipex/torch version for your need
+pip install --pre --upgrade bigdl-llm[xpu] -f https://developer.intel.com/ipex-whl-stable-xpu
+pip install einops # additional package required for phi-1_5 to conduct generation
+```
+
+### 2. Configures OneAPI environment variables
+```bash
+source /opt/intel/oneapi/setvars.sh
+```
+
+### 3. Run
+
+For optimal performance on Arc, it is recommended to set several environment variables.
+
+```bash
+export USE_XETLA=OFF
+export SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=1
+```
+
+```
+python ./generate.py --prompt 'What is AI?'
+```
+
+Arguments info:
+- `--repo-id-or-model-path REPO_ID_OR_MODEL_PATH`: argument defining the huggingface repo id for the phi-1_5 model (e.g. `microsoft/phi-1_5`) to be downloaded, or the path to the huggingface checkpoint folder. It is default to be `'microsoft/phi-1_5'`.
+- `--prompt PROMPT`: argument defining the prompt to be infered (with integrated prompt format for chat). It is default to be `'What is AI?'`.
+- `--n-predict N_PREDICT`: argument defining the max number of tokens to predict. It is default to be `32`.
+
+#### Sample Output
+#### [microsoft/phi-1_5](https://huggingface.co/microsoft/phi-1_5)
+
+```log
+Inference time: xxxx s
+-------------------- Prompt --------------------
+Question: What is AI?
+
+ Answer:
+-------------------- Output --------------------
+Question: What is AI?
+
+ Answer: AI stands for Artificial Intelligence, which refers to the development of computer systems that can perform tasks that typically require human intelligence, such as visual perception, speech recognition,
+```
diff --git a/python/llm/example/GPU/HF-Transformers-AutoModels/Model/phi-1_5/generate.py b/python/llm/example/GPU/HF-Transformers-AutoModels/Model/phi-1_5/generate.py
new file mode 100644
index 00000000..2108654d
--- /dev/null
+++ b/python/llm/example/GPU/HF-Transformers-AutoModels/Model/phi-1_5/generate.py
@@ -0,0 +1,82 @@
+#
+# Copyright 2016 The BigDL Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import torch
+import intel_extension_for_pytorch as ipex
+import time
+import argparse
+import numpy as np
+
+from bigdl.llm.transformers import AutoModel,AutoModelForCausalLM
+from transformers import AutoTokenizer, GenerationConfig
+
+# you could tune the prompt based on your own model,
+# here the prompt tuning refers to # TODO: https://huggingface.co/microsoft/phi-1_5/blob/main/modeling_mixformer_sequential.py
+PHI1_5_PROMPT_FORMAT = " Question:{prompt}\n\n Answer:"
+generation_config = GenerationConfig(use_cache = True)
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser(description='Predict Tokens using `generate()` API for phi-1_5 model')
+ parser.add_argument('--repo-id-or-model-path', type=str, default="microsoft/phi-1_5",
+ help='The huggingface repo id for the phi-1_5 model to be downloaded'
+ ', or the path to the huggingface checkpoint folder')
+ parser.add_argument('--prompt', type=str, default="What is AI?",
+ help='Prompt to infer')
+ parser.add_argument('--n-predict', type=int, default=32,
+ help='Max tokens to predict')
+
+ args = parser.parse_args()
+ model_path = args.repo_id_or_model_path
+
+ # Load model in 4 bit,
+ # which convert the relevant layers in the model into INT4 format
+ model = AutoModelForCausalLM.from_pretrained(model_path,
+ load_in_4bit=True,
+ trust_remote_code=True)
+
+ model = model.to('xpu')
+
+ # Load tokenizer
+ tokenizer = AutoTokenizer.from_pretrained(model_path,
+ trust_remote_code=True)
+
+ # Generate predicted tokens
+ with torch.inference_mode():
+ prompt = PHI1_5_PROMPT_FORMAT.format(prompt=args.prompt)
+ input_ids = tokenizer.encode(prompt, return_tensors="pt").to('xpu')
+
+ # ipex model needs a warmup, then inference time can be accurate
+ output = model.generate(input_ids,
+ max_new_tokens=args.n_predict,
+ generation_config = generation_config)
+ # start inference
+ st = time.time()
+ # if your selected model is capable of utilizing previous key/value attentions
+ # to enhance decoding speed, but has `"use_cache": false` in its model config,
+ # it is important to set `use_cache=True` explicitly in the `generate` function
+ # to obtain optimal performance with BigDL-LLM INT4 optimizations
+
+ # Note that phi-1_5 uses GenerationConfig to enable 'use_cache'
+ output = model.generate(input_ids, do_sample=False, max_new_tokens=args.n_predict, generation_config = generation_config)
+ torch.xpu.synchronize()
+ end = time.time()
+ output = output.cpu()
+ output_str = tokenizer.decode(output[0], skip_special_tokens=True)
+ print(f'Inference time: {end-st} s')
+ print('-'*20, 'Prompt', '-'*20)
+ print(prompt)
+ print('-'*20, 'Output', '-'*20)
+ print(output_str)
diff --git a/python/llm/example/GPU/PyTorch-Models/Model/README.md b/python/llm/example/GPU/PyTorch-Models/Model/README.md
index ad74542c..41beaaee 100644
--- a/python/llm/example/GPU/PyTorch-Models/Model/README.md
+++ b/python/llm/example/GPU/PyTorch-Models/Model/README.md
@@ -13,6 +13,7 @@ You can use `optimize_model` API to accelerate general PyTorch models on Intel G
| StarCoder | [link](starcoder) |
| Dolly v1 | [link](dolly-v1) |
| Dolly v2 | [link](dolly-v2) |
+| Flan-t5 | [link](flan-t5) |
## Verified Hardware Platforms
diff --git a/python/llm/example/GPU/PyTorch-Models/Model/flan-t5/README.md b/python/llm/example/GPU/PyTorch-Models/Model/flan-t5/README.md
new file mode 100644
index 00000000..a7d58fe5
--- /dev/null
+++ b/python/llm/example/GPU/PyTorch-Models/Model/flan-t5/README.md
@@ -0,0 +1,55 @@
+# Flan-t5
+In this directory, you will find examples on how you could apply BigDL-LLM INT4 optimizations on Flan-t5 models on [Intel GPUs](../README.md). For illustration purposes, we utilize the [google/flan-t5-xxl](https://huggingface.co/google/flan-t5-xxl) as a reference Flan-t5 model.
+
+## 0. Requirements
+To run these examples with BigDL-LLM on Intel GPUs, we have some recommended requirements for your machine, please refer to [here](../README.md#recommended-requirements) for more information.
+
+## Example: Predict Tokens using `generate()` API
+In the example [generate.py](./generate.py), we show a basic use case for a Flan-t5 model to predict the next N tokens using `generate()` API, with BigDL-LLM INT4 optimizations on Intel GPUs.
+### 1. Install
+We suggest using conda to manage the Python environment. For more information about conda installation, please refer to [here](https://docs.conda.io/en/latest/miniconda.html#).
+
+After installing conda, create a Python environment for BigDL-LLM:
+```bash
+conda create -n llm python=3.9 # recommend to use Python 3.9
+conda activate llm
+
+# below command will install intel_extension_for_pytorch==2.0.110+xpu as default
+# you can install specific ipex/torch version for your need
+pip install --pre --upgrade bigdl-llm[xpu] -f https://developer.intel.com/ipex-whl-stable-xpu
+```
+
+### 2. Configures OneAPI environment variables
+```bash
+source /opt/intel/oneapi/setvars.sh
+```
+
+### 3. Run
+
+For optimal performance on Arc, it is recommended to set several environment variables.
+
+```bash
+export USE_XETLA=OFF
+export SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=1
+```
+
+```bash
+python ./generate.py --prompt 'Translate to German: My name is Arthur'
+```
+
+In the example, several arguments can be passed to satisfy your requirements:
+
+- `--repo-id-or-model-path REPO_ID_OR_MODEL_PATH`: argument defining the huggingface repo id for the Flan-t5 model (e.g. `google/flan-t5-xxl` to be downloaded, or the path to the huggingface checkpoint folder. It is default to be `'google/flan-t5-xxl'`.
+- `--prompt PROMPT`: argument defining the prompt to be infered (with integrated prompt format for chat). It is default to be `'Translate to German: My name is Arthur'`.
+- `--n-predict N_PREDICT`: argument defining the max number of tokens to predict. It is default to be `32`.
+
+#### Sample Output
+#### [google/flan-t5-xxl](https://huggingface.co/google/flan-t5-xxl)
+
+```log
+Inference time: xxxx s
+-------------------- Prompt --------------------
+<|User|>:Translate to German: My name is Arthur
+-------------------- Output --------------------
+Ich bin Arthur.
+```
diff --git a/python/llm/example/GPU/PyTorch-Models/Model/flan-t5/generate.py b/python/llm/example/GPU/PyTorch-Models/Model/flan-t5/generate.py
new file mode 100644
index 00000000..9dfbabc8
--- /dev/null
+++ b/python/llm/example/GPU/PyTorch-Models/Model/flan-t5/generate.py
@@ -0,0 +1,78 @@
+#
+# Copyright 2016 The BigDL Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import torch
+import intel_extension_for_pytorch as ipex
+import time
+import argparse
+
+from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
+from bigdl.llm import optimize_model
+
+# you could tune the prompt based on your own model,
+FLAN_T5_PROMPT_FORMAT = "<|User|>:{prompt}"
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser(description='Predict Tokens using `generate()` API for flan-t5 model')
+ parser.add_argument('--repo-id-or-model-path', type=str, default="google/flan-t5-xxl",
+ help='The huggingface repo id for the flan-t5 model to be downloaded'
+ ', or the path to the huggingface checkpoint folder')
+ parser.add_argument('--prompt', type=str, default="Translate to German: My name is Arthur",
+ help='Prompt to infer')
+ parser.add_argument('--n-predict', type=int, default=32,
+ help='Max tokens to predict')
+
+ args = parser.parse_args()
+ model_path = args.repo_id_or_model_path
+
+ # Load model
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_path,
+ trust_remote_code=True,
+ torch_dtype='auto',
+ low_cpu_mem_usage=True)
+
+ # With only one line to enable BigDL-LLM optimization on model
+ # "wo" module is not converted due to some issues of T5 model
+ # (https://github.com/huggingface/transformers/issues/20287),
+ # "lm_head" module is not converted to generate outputs with better quality
+ model = optimize_model(model, modules_to_not_convert=["wo", "lm_head"])
+
+ model = model.to('xpu')
+
+ # Load tokenizer
+ tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
+
+ # Generate predicted tokens
+ with torch.inference_mode():
+ prompt = FLAN_T5_PROMPT_FORMAT.format(prompt=args.prompt)
+ input_ids = tokenizer.encode(prompt, return_tensors="pt").to('xpu')
+ # ipex model needs a warmup, then inference time can be accurate
+ output = model.generate(input_ids,
+ max_new_tokens=args.n_predict)
+
+ # start inference
+ st = time.time()
+ output = model.generate(input_ids,
+ max_new_tokens=args.n_predict)
+ torch.xpu.synchronize()
+ end = time.time()
+ output = output.cpu()
+ output_str = tokenizer.decode(output[0], skip_special_tokens=True)
+ print(f'Inference time: {end-st} s')
+ print('-'*20, 'Prompt', '-'*20)
+ print(prompt)
+ print('-'*20, 'Output', '-'*20)
+ print(output_str)
diff --git a/python/llm/example/GPU/PyTorch-Models/Model/phi-1_5/README.md b/python/llm/example/GPU/PyTorch-Models/Model/phi-1_5/README.md
new file mode 100644
index 00000000..0c1d0a76
--- /dev/null
+++ b/python/llm/example/GPU/PyTorch-Models/Model/phi-1_5/README.md
@@ -0,0 +1,53 @@
+# phi-1_5
+In this directory, you will find examples on how you could use BigDL-LLM `optimize_model` API to accelerate phi-1_5 models on [Intel GPUs](../README.md). For illustration purposes, we utilize the [microsoft/phi-1_5](https://huggingface.co/microsoft/phi-1_5) as a reference phi-1_5 model.
+
+## Requirements
+To run these examples with BigDL-LLM, we have some recommended requirements for your machine, please refer to [here](../README.md#recommended-requirements) for more information.
+
+## Example: Predict Tokens using `generate()` API
+In the example [generate.py](./generate.py), we show a basic use case for a phi-1_5 model to predict the next N tokens using `generate()` API, with BigDL-LLM INT4 optimizations on Intel GPUs.
+### 1. Install
+We suggest using conda to manage the Python environment. For more information about conda installation, please refer to [here](https://docs.conda.io/en/latest/miniconda.html#).
+
+After installing conda, create a Python environment for BigDL-LLM:
+```bash
+conda create -n llm python=3.9 # recommend to use Python 3.9
+conda activate llm
+
+pip install --pre --upgrade bigdl-llm[xpu] -f https://developer.intel.com/ipex-whl-stable-xpu
+pip install einops # additional package required for phi-1_5 to conduct generation
+```
+
+### 2. Configures OneAPI environment variables
+```bash
+source /opt/intel/oneapi/setvars.sh
+```
+
+### 3. Run
+
+For optimal performance on Arc, it is recommended to set several environment variables.
+
+```bash
+export USE_XETLA=OFF
+export SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=1
+```
+
+```
+python ./generate.py --prompt 'What is AI?'
+```
+
+Arguments info:
+- `--repo-id-or-model-path REPO_ID_OR_MODEL_PATH`: argument defining the huggingface repo id for the phi-1_5 model (e.g. `microsoft/phi-1_5`) to be downloaded, or the path to the huggingface checkpoint folder. It is default to be `'microsoft/phi-1_5'`.
+- `--prompt PROMPT`: argument defining the prompt to be infered (with integrated prompt format for chat). It is default to be `'What is AI?'`.
+- `--n-predict N_PREDICT`: argument defining the max number of tokens to predict. It is default to be `32`.
+
+#### Sample Output
+#### [microsoft/phi-1_5](https://huggingface.co/microsoft/phi-1_5)
+
+```log
+Inference time: xxxx s
+-------------------- Output --------------------
+Question: What is AI?
+
+Answer: AI stands for Artificial Intelligence, which refers to the development of computer systems that can perform tasks that typically require human intelligence, such as visual perception, speech recognition,
+```
\ No newline at end of file
diff --git a/python/llm/example/GPU/PyTorch-Models/Model/phi-1_5/generate.py b/python/llm/example/GPU/PyTorch-Models/Model/phi-1_5/generate.py
new file mode 100644
index 00000000..f5f2bae0
--- /dev/null
+++ b/python/llm/example/GPU/PyTorch-Models/Model/phi-1_5/generate.py
@@ -0,0 +1,72 @@
+#
+# Copyright 2016 The BigDL Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import torch
+import intel_extension_for_pytorch as ipex
+import time
+import argparse
+
+from transformers import AutoModel, AutoTokenizer, AutoModelForCausalLM, GenerationConfig
+from bigdl.llm import optimize_model
+
+# you could tune the prompt based on your own model,
+# here the prompt tuning refers to https://huggingface.co/microsoft/phi-1_5/blob/main/modeling_mixformer_sequential.py
+PHI_1_5_V1_PROMPT_FORMAT = "Question: {prompt}\n\n Answer:"
+generation_config = GenerationConfig(use_cache = True)
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser(description='Predict Tokens using `generate()` API for phi-1_5 model')
+ parser.add_argument('--repo-id-or-model-path', type=str, default="microsoft/phi-1_5",
+ help='The huggingface repo id for the phi-1_5 model to be downloaded'
+ ', or the path to the huggingface checkpoint folder')
+ parser.add_argument('--prompt', type=str, default="What is AI?",
+ help='Prompt to infer')
+ parser.add_argument('--n-predict', type=int, default=32,
+ help='Max tokens to predict')
+
+ args = parser.parse_args()
+ model_path = args.repo_id_or_model_path
+
+ # Load model
+ model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True)
+
+ # With only one line to enable BigDL-LLM optimization on model
+ model = optimize_model(model)
+ model = model.to('xpu')
+
+ # Load tokenizer
+ tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
+
+ # Generate predicted tokens
+ with torch.inference_mode():
+ prompt = PHI_1_5_V1_PROMPT_FORMAT.format(prompt=args.prompt)
+ input_ids = tokenizer.encode(prompt, return_tensors="pt").to('xpu')
+
+ # ipex model needs a warmup, then inference time can be accurate
+ output = model.generate(input_ids, do_sample=False, max_new_tokens=args.n_predict, generation_config = generation_config)
+ # start inference
+ st = time.time()
+ # Note that phi-1_5 uses GenerationConfig to enable 'use_cache'
+ output = model.generate(input_ids, do_sample=False, max_new_tokens=args.n_predict, generation_config = generation_config)
+ torch.xpu.synchronize()
+ end = time.time()
+ output = output.cpu()
+ output_str = tokenizer.decode(output[0], skip_special_tokens=True)
+ print(f'Inference time: {end-st} s')
+ print('-'*20, 'Prompt', '-'*20)
+ print(prompt)
+ print('-'*20, 'Output', '-'*20)
+ print(output_str)
diff --git a/python/llm/src/bigdl/llm/ggml/quantize.py b/python/llm/src/bigdl/llm/ggml/quantize.py
index c162c694..4f8891e9 100644
--- a/python/llm/src/bigdl/llm/ggml/quantize.py
+++ b/python/llm/src/bigdl/llm/ggml/quantize.py
@@ -33,7 +33,8 @@ ggml_tensor_qtype = {"sym_int4": 2, # q4_0 in ggml
"nf4": 10,
"nf3": 11,
"fp16": 12,
- "fp8": 15}
+ "fp8": 15,
+ "fp4": 16}
_llama_quantize_type = {"q4_0": 2,
"q4_1": 3,
diff --git a/python/llm/src/bigdl/llm/serving/bigdl_llm_model.py b/python/llm/src/bigdl/llm/serving/bigdl_llm_model.py
index 6c8cc780..c4716b5e 100644
--- a/python/llm/src/bigdl/llm/serving/bigdl_llm_model.py
+++ b/python/llm/src/bigdl/llm/serving/bigdl_llm_model.py
@@ -104,7 +104,7 @@ def load_model(
device, load_8bit, cpu_offloading
)
if device == "cpu":
- kwargs = {"torch_dtype": torch.float32}
+ kwargs = {"torch_dtype": "auto"}
if CPU_ISA in ["avx512_bf16", "amx"]:
try:
import intel_extension_for_pytorch as ipex
diff --git a/python/llm/src/bigdl/llm/transformers/convert.py b/python/llm/src/bigdl/llm/transformers/convert.py
index c1902d35..2acab799 100644
--- a/python/llm/src/bigdl/llm/transformers/convert.py
+++ b/python/llm/src/bigdl/llm/transformers/convert.py
@@ -45,6 +45,42 @@ from bigdl.llm.ggml.quantize import ggml_tensor_qtype
from .utils import logger
+def is_deepspeed_available():
+ return importlib.util.find_spec("deepspeed") is not None
+
+
+def is_linear_module(module):
+
+ in_features = None
+ out_features = None
+ mp_group = None
+
+ if isinstance(module, nn.Linear):
+ in_features = module.in_features
+ out_features = module.out_features
+ mp_group = None
+ result = True
+ else:
+ if is_deepspeed_available():
+ from deepspeed.module_inject.layers import LinearLayer, LinearAllreduce
+ if isinstance(module, LinearLayer):
+ in_features = module.weight.shape[1]
+ out_features = module.weight.shape[0]
+ mp_group = None
+ result = True
+ elif isinstance(module, LinearAllreduce):
+ in_features = module.weight.shape[1]
+ out_features = module.weight.shape[0]
+ mp_group = module.mp_group
+ result = True
+ else:
+ result = False
+ else:
+ result = False
+
+ return result, (in_features, out_features, mp_group)
+
+
def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
current_key_name=None, convert_shape_only=False):
from bigdl.llm.transformers.low_bit_linear import LowBitLinear, FP4Params, FP16Linear
@@ -54,17 +90,20 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
if current_key_name is None:
current_key_name = []
- if isinstance(module, nn.Linear) and name not in modules_to_not_convert:
+ is_linear, linear_args = is_linear_module(module)
+ if is_linear and name not in modules_to_not_convert:
# Check if the current key is not in the `modules_to_not_convert`
if not any(key in ".".join(current_key_name) for key in modules_to_not_convert):
+ in_features, out_features, mp_group = linear_args
with init_empty_weights():
new_linear = None
if qtype != ggml_tensor_qtype["fp16"]:
new_linear = LowBitLinear(
- module.in_features,
- module.out_features,
+ in_features,
+ out_features,
qtype,
module.bias is not None,
+ mp_group=mp_group,
)
device_type = module.weight.data.device.type
@@ -82,10 +121,11 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
if module.in_features in [4096, 11008]:
# esimd fp16 path
new_linear = FP16Linear(
- module.in_features,
- module.out_features,
+ in_features,
+ out_features,
qtype,
module.bias is not None,
+ mp_group=mp_group,
)
device_type = module.weight.data.device.type
@@ -226,6 +266,7 @@ def _optimize_post(model):
module = importlib.import_module(modeling_module_name)
from bigdl.llm.transformers.models.chatglm2 import chatglm2_attention_forward_8eb45c
from bigdl.llm.transformers.models.chatglm2 import core_attn_forward_8eb45c
+ from bigdl.llm.transformers.models.chatglm2 import chatglm_rms_norm_forward
convert_forward(model,
module.SelfAttention,
chatglm2_attention_forward_8eb45c
@@ -233,6 +274,9 @@ def _optimize_post(model):
convert_forward(model,
module.CoreAttention,
core_attn_forward_8eb45c)
+ convert_forward(model,
+ module.RMSNorm,
+ chatglm_rms_norm_forward)
elif hasattr(model.config, 'vocab_size') and model.config.vocab_size == 130528:
# chatglm-6b
modeling_module_name = model.__class__.__module__
diff --git a/python/llm/src/bigdl/llm/transformers/low_bit_linear.py b/python/llm/src/bigdl/llm/transformers/low_bit_linear.py
index b1026ec7..f9bac244 100644
--- a/python/llm/src/bigdl/llm/transformers/low_bit_linear.py
+++ b/python/llm/src/bigdl/llm/transformers/low_bit_linear.py
@@ -65,6 +65,7 @@ SYM_INT8 = ggml_tensor_qtype["sym_int8"]
NF4 = ggml_tensor_qtype["nf4"]
NF3 = ggml_tensor_qtype["nf3"]
FP8 = ggml_tensor_qtype["fp8"]
+FP4 = ggml_tensor_qtype["fp4"]
def ggml_convert_qtype(tensor: torch.Tensor, qtype: int,
@@ -108,7 +109,7 @@ def ggml_q_format_convet_cpu2xpu(tensor: torch.Tensor, num_elem: int, qtype: int
src = ctypes.c_void_p(tensor.data.data_ptr())
- if qtype in [SYM_INT4, SYM_INT8, NF4, NF3]:
+ if qtype in [SYM_INT4, SYM_INT8, NF4, NF3, FP4]:
dst_tensor = torch.empty_like(tensor)
elif qtype == ggml_tensor_qtype["sym_int5"]:
QK = ggml.ggml_qk_size(qtype)
@@ -133,7 +134,7 @@ def ggml_q_format_convet_xpu2cpu(tensor: torch.Tensor, num_elem: int, qtype: int
src = ctypes.c_void_p(tensor.data.data_ptr())
- if qtype in [SYM_INT4, SYM_INT8, NF4, NF3]:
+ if qtype in [SYM_INT4, SYM_INT8, NF4, NF3, FP4]:
dst_tensor = torch.empty_like(tensor)
elif qtype == ggml_tensor_qtype["sym_int5"]:
QK = ggml.ggml_qk_size(ggml_tensor_qtype["asym_int5"])
@@ -328,7 +329,7 @@ class MatMulLowBit(torch.autograd.Function):
class LowBitLinear(nn.Linear):
def __init__(self, input_features, output_features, qtype, bias=True,
- conver_to_half=True):
+ conver_to_half=True, mp_group=None):
super().__init__(input_features, output_features, bias)
self.weight = FP4Params(self.weight.data,
requires_grad=False,
@@ -339,6 +340,7 @@ class LowBitLinear(nn.Linear):
self.weight_length = self.out_len * self.in_len
self.qtype = qtype
self.conver_to_half = conver_to_half
+ self.mp_group = mp_group
def forward(self, x: torch.Tensor):
if self.bias is not None and self.bias.dtype != x.dtype:
@@ -378,13 +380,18 @@ class LowBitLinear(nn.Linear):
input_seq_size)
new_shape = x_shape[:-1] + (self.out_len,)
result = result.view(new_shape)
+ if self.mp_group is not None:
+ from deepspeed import comm as dist
+ dist.inference_all_reduce(result, group=self.mp_group)
if self.bias is not None:
result += self.bias
else:
# CPU logic
# todo may need to set a different number on different platforms
- invalidInputError(self.qtype != NF3 and self.qtype != NF4 and self.qtype != FP8,
- "NF3, NF4 and FP8 quantization are currently not supported on CPU")
+ invalidInputError(self.qtype != NF3 and self.qtype != NF4 and self.qtype != FP8
+ and self.qtype != FP4,
+ "NF3, NF4, FP4 and FP8 quantization are currently not"
+ " supported on CPU")
if IS_SERVER and (not IS_SPR) and \
self.qtype == SYM_INT4 and x_2d.shape[0] >= TORCH_LINEAR_THRESHOLD:
x0_fp32 = ggml_int4_convert_fp32(x0, self.weight_shape, self.weight_length)
@@ -400,7 +407,7 @@ class LowBitLinear(nn.Linear):
class FP16Linear(nn.Linear):
def __init__(self, input_features, output_features, qtype, bias=True,
- conver_to_half=True):
+ conver_to_half=True, mp_group=None):
super().__init__(input_features, output_features, bias)
self.in_len = input_features
self.out_len = output_features
@@ -408,6 +415,7 @@ class FP16Linear(nn.Linear):
self.weight_length = self.out_len * self.in_len
self.qtype = qtype
self.conver_to_half = conver_to_half
+ self.mp_group = mp_group
def forward(self, x: torch.Tensor):
if self.bias is not None and self.bias.dtype != x.dtype:
@@ -442,6 +450,9 @@ class FP16Linear(nn.Linear):
new_shape = x_shape[:-1] + (self.out_len,)
result = result.view(new_shape)
+ if self.mp_group is not None:
+ from deepspeed import comm as dist
+ dist.inference_all_reduce(result, group=self.mp_group)
if self.bias is not None:
result += self.bias
diff --git a/python/llm/src/bigdl/llm/transformers/model.py b/python/llm/src/bigdl/llm/transformers/model.py
index 3f34c45f..e498987f 100644
--- a/python/llm/src/bigdl/llm/transformers/model.py
+++ b/python/llm/src/bigdl/llm/transformers/model.py
@@ -60,9 +60,10 @@ class _BaseAutoModelClass:
:param load_in_4bit: boolean value, True means load linear's weight to symmetric int 4.
Default to be False.
:param load_in_low_bit: str value, options are sym_int4, asym_int4, sym_int5, asym_int5
- , sym_int8, nf3, nf4 or fp16. sym_int4 means symmetric int 4,
- asym_int4 means asymmetric int 4, nf4 means 4-bit NormalFloat, etc.
- Relevant low bit optimizations will be applied to the model.
+ , sym_int8, nf3, nf4, fp4, fp8 or fp16. sym_int4 means symmetric
+ int 4, asym_int4 means asymmetric int 4, nf4 means 4-bit
+ NormalFloat, etc. Relevant low bit optimizations will be applied
+ to the model.
:param optimize_model: boolean value, Whether to further optimize the low_bit llm model.
Default to be True.
:param modules_to_not_convert: list of str value, modules (nn.Module) that are skipped when
@@ -106,8 +107,8 @@ class _BaseAutoModelClass:
from .convert import ggml_convert_low_bit
invalidInputError(q_k in ggml_tensor_qtype,
f"Unknown load_in_low_bit value: {q_k}, expected:"
- f" sym_int4, asym_int4, sym_int5, asym_int5, sym_int8, nf3, nf4 "
- "or fp16.")
+ f" sym_int4, asym_int4, sym_int5, asym_int5, sym_int8, nf3, nf4, "
+ "fp4, fp8 or fp16.")
qtype = ggml_tensor_qtype[q_k]
# In case it needs a second try,
diff --git a/python/llm/src/bigdl/llm/transformers/models/chatglm2.py b/python/llm/src/bigdl/llm/transformers/models/chatglm2.py
index 7dc90f86..fa54ea3e 100644
--- a/python/llm/src/bigdl/llm/transformers/models/chatglm2.py
+++ b/python/llm/src/bigdl/llm/transformers/models/chatglm2.py
@@ -74,6 +74,19 @@ def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Ten
return torch.cat((x_out2, x_pass), dim=-1)
+def chatglm_rms_norm_forward(self, hidden_states):
+ if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad):
+ hidden_states, _ = torch.ops.torch_ipex.rms_norm(hidden_states,
+ [self.weight.size(0)], self.weight)
+ else:
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
+ return self.weight * hidden_states.to(input_dtype)
+ return hidden_states
+
+
def chatglm2_attention_forward_8eb45c(
self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True
):
diff --git a/python/llm/src/bigdl/llm/transformers/models/llama.py b/python/llm/src/bigdl/llm/transformers/models/llama.py
index 0dd39ae6..94515ea0 100644
--- a/python/llm/src/bigdl/llm/transformers/models/llama.py
+++ b/python/llm/src/bigdl/llm/transformers/models/llama.py
@@ -32,6 +32,7 @@
# limitations under the License.
import torch
+import importlib
import torch.nn as nn
from typing import Optional, Tuple
import math
@@ -58,10 +59,27 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
+def get_ipex_version():
+
+ if importlib.util.find_spec("intel_extension_for_pytorch") is not None:
+ import intel_extension_for_pytorch as ipex
+ return ipex.__version__
+ else:
+ return None
+
+
+ipex_version = get_ipex_version()
+
+
def llama_rms_norm_forward(self, hidden_states):
if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad):
- hidden_states, _ = torch.ops.torch_ipex.rms_norm(hidden_states,
- [self.weight.size(0)], self.weight)
+ if ipex_version == "2.0.110+xpu":
+ hidden_states, _ = torch.ops.torch_ipex.rms_norm(hidden_states,
+ [self.weight.size(0)], self.weight)
+ else:
+ hidden_states, _ = torch.ops.torch_ipex.rms_norm(hidden_states,
+ [self.weight.size(0)], self.weight,
+ self.variance_epsilon)
else:
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
diff --git a/python/llm/test/benchmark/arc-perf-test.yaml b/python/llm/test/benchmark/arc-perf-test.yaml
index 9b5246bd..580adc50 100644
--- a/python/llm/test/benchmark/arc-perf-test.yaml
+++ b/python/llm/test/benchmark/arc-perf-test.yaml
@@ -8,6 +8,7 @@ local_model_hub: '/mnt/disk1/models'
warm_up: 1
num_trials: 3
num_beams: 1 # default to greedy search
+low_bit: 'sym_int4' # default to use 'sym_int4' (i.e. symmetric int4)
in_out_pairs:
- '32-32'
- '1024-128'
diff --git a/python/llm/test/benchmark/csv_to_html.py b/python/llm/test/benchmark/csv_to_html.py
new file mode 100644
index 00000000..bc61fd5b
--- /dev/null
+++ b/python/llm/test/benchmark/csv_to_html.py
@@ -0,0 +1,39 @@
+#
+# Copyright 2016 The BigDL Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# Python program to convert CSV to HTML Table
+
+import os
+import sys
+import argparse
+import pandas as pd
+
+def main():
+ parser = argparse.ArgumentParser(description="convert .csv file to .html file")
+ parser.add_argument("-f", "--folder_path", type=str, dest="folder_path",
+ help="The directory which stores the .csv file", default="../../dev/benchmark/all-in-one")
+ args = parser.parse_args()
+
+ csv_files = []
+ for file_name in os.listdir(args.folder_path):
+ file_path = os.path.join(args.folder_path, file_name)
+ if os.path.isfile(file_path) and file_name.endswith(".csv"):
+ csv_files.append(file_path)
+
+ a = pd.read_csv(csv_files[0], index_col=0).to_html(csv_files[0].split("/")[-1].split(".")[0]+".html")
+
+if __name__ == "__main__":
+ sys.exit(main())
\ No newline at end of file
diff --git a/python/llm/test/inference/test_optimize_mistral.py b/python/llm/test/inference/test_optimize_mistral.py
new file mode 100644
index 00000000..ce2d4325
--- /dev/null
+++ b/python/llm/test/inference/test_optimize_mistral.py
@@ -0,0 +1,53 @@
+#
+# Copyright 2016 The BigDL Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import os
+import pytest
+
+from bigdl.llm.transformers import AutoModelForCausalLM
+from transformers import AutoTokenizer
+
+
+mistral_model_path = os.environ.get('MISTRAL_ORIGIN_PATH')
+
+prompt = "Once upon a time, there existed a little girl who liked to have adventures. She wanted to go to places and meet new people, and have fun"
+
+@pytest.mark.parametrize("Model, Tokenizer, model_path, prompt", [
+ (AutoModelForCausalLM, AutoTokenizer, mistral_model_path, prompt)
+])
+
+def test_optimize_model(Model, Tokenizer, model_path, prompt):
+ tokenizer = Tokenizer.from_pretrained(model_path, trust_remote_code=True)
+ input_ids = tokenizer.encode(prompt, return_tensors="pt")
+
+ model = Model.from_pretrained(model_path,
+ load_in_4bit=True,
+ optimize_model=False,
+ trust_remote_code=True)
+ logits_base_model = (model(input_ids)).logits
+
+ model = Model.from_pretrained(model_path,
+ load_in_4bit=True,
+ optimize_model=True,
+ trust_remote_code=True)
+ logits_optimized_model = (model(input_ids)).logits
+ diff = abs(logits_base_model - logits_optimized_model).flatten()
+
+ assert any(diff) is False
+
+
+if __name__ == '__main__':
+ pytest.main([__file__])
diff --git a/python/llm/test/inference/test_optimize_model.py b/python/llm/test/inference/test_optimize_model.py
index 6d593e72..1d101027 100644
--- a/python/llm/test/inference/test_optimize_model.py
+++ b/python/llm/test/inference/test_optimize_model.py
@@ -14,8 +14,8 @@
# limitations under the License.
#
-import pytest
import os
+import pytest
from bigdl.llm.transformers import AutoModelForCausalLM, AutoModel
from transformers import LlamaTokenizer, AutoTokenizer
@@ -32,8 +32,9 @@ prompt = "Once upon a time, there existed a little girl who liked to have advent
(AutoModelForCausalLM, LlamaTokenizer, llama_model_path, prompt),
(AutoModelForCausalLM, AutoTokenizer, bloom_model_path, prompt),
(AutoModel, AutoTokenizer, chatglm2_6b_model_path, prompt),
- (AutoModelForCausalLM, AutoTokenizer, replit_code_model_path, prompt),
+ (AutoModelForCausalLM, AutoTokenizer, replit_code_model_path, prompt)
])
+
def test_optimize_model(Model, Tokenizer, model_path, prompt):
tokenizer = Tokenizer.from_pretrained(model_path, trust_remote_code=True)
input_ids = tokenizer.encode(prompt, return_tensors="pt")
diff --git a/python/llm/test/inference_gpu/test_transformers_api.py b/python/llm/test/inference_gpu/test_transformers_api.py
new file mode 100644
index 00000000..69f2578d
--- /dev/null
+++ b/python/llm/test/inference_gpu/test_transformers_api.py
@@ -0,0 +1,52 @@
+#
+# Copyright 2016 The BigDL Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+
+import os
+import pytest
+
+from bigdl.llm.transformers import AutoModelForCausalLM, AutoModel
+from transformers import LlamaTokenizer, AutoTokenizer
+
+device = os.environ['DEVICE']
+print(f'Running on {device}')
+if device == 'xpu':
+ import intel_extension_for_pytorch as ipex
+
+@pytest.mark.parametrize('prompt, answer', [
+ ('What is the capital of France?\n\n','Paris')
+ ])
+@pytest.mark.parametrize('Model, Tokenizer, model_path',[
+ (AutoModelForCausalLM, LlamaTokenizer, os.environ.get('LLAMA2_7B_ORIGIN_PATH')),
+ (AutoModel, AutoTokenizer, os.environ.get('CHATGLM2_6B_ORIGIN_PATH')),
+ (AutoModelForCausalLM, AutoTokenizer, os.environ.get('FALCON_7B_ORIGIN_PATH')),
+ ])
+def test_completion(Model, Tokenizer, model_path, prompt, answer):
+ tokenizer = Tokenizer.from_pretrained(model_path, trust_remote_code=True)
+ model = Model.from_pretrained(model_path,
+ load_in_4bit=True,
+ optimize_model=True,
+ trust_remote_code=True)
+ model = model.to(device)
+
+ input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
+ output = model.generate(input_ids, max_new_tokens=32)
+ output_str = tokenizer.decode(output[0], skip_special_tokens=True)
+
+ assert answer in output_str
+
+if __name__ == '__main__':
+ pytest.main([__file__])
diff --git a/python/llm/test/run-llm-inference-tests-gpu.sh b/python/llm/test/run-llm-inference-tests-gpu.sh
new file mode 100644
index 00000000..3d22c1cf
--- /dev/null
+++ b/python/llm/test/run-llm-inference-tests-gpu.sh
@@ -0,0 +1,26 @@
+#!/bin/bash
+
+export ANALYTICS_ZOO_ROOT=${ANALYTICS_ZOO_ROOT}
+export LLM_HOME=${ANALYTICS_ZOO_ROOT}/python/llm/src
+export LLM_INFERENCE_TEST_DIR=${ANALYTICS_ZOO_ROOT}/python/llm/test/inference_gpu
+
+export USE_XETLA=OFF
+export SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=1
+export DEVICE='xpu'
+
+set -e
+
+echo "# Start testing inference"
+start=$(date "+%s")
+
+if [ -z "$THREAD_NUM" ]; then
+ THREAD_NUM=2
+fi
+export OMP_NUM_THREADS=$THREAD_NUM
+pytest ${LLM_INFERENCE_TEST_DIR} -v -s
+
+now=$(date "+%s")
+time=$((now-start))
+
+echo "Bigdl-llm gpu tests finished"
+echo "Time used:$time seconds"
diff --git a/python/llm/test/run-llm-inference-tests.sh b/python/llm/test/run-llm-inference-tests.sh
index 1c670e6b..976ea085 100644
--- a/python/llm/test/run-llm-inference-tests.sh
+++ b/python/llm/test/run-llm-inference-tests.sh
@@ -9,13 +9,19 @@ set -e
echo "# Start testing inference"
start=$(date "+%s")
-python -m pytest -s ${LLM_INFERENCE_TEST_DIR} -k "not test_transformers" -v
+python -m pytest -s ${LLM_INFERENCE_TEST_DIR} -k "not test_transformers" -v \
+ --ignore=${LLM_INFERENCE_TEST_DIR}/test_optimize_mistral.py
if [ -z "$THREAD_NUM" ]; then
THREAD_NUM=2
fi
export OMP_NUM_THREADS=$THREAD_NUM
-python -m pytest -s ${LLM_INFERENCE_TEST_DIR} -k test_transformers -v
+python -m pytest -s ${LLM_INFERENCE_TEST_DIR} -k test_transformers -v \
+ --ignore=${LLM_INFERENCE_TEST_DIR}/test_optimize_mistral.py
+
+python -m pip install transformers==4.34.0
+python -m pytest -s ${LLM_INFERENCE_TEST_DIR}/test_optimize_mistral.py -v
+python -m pip install transformers==4.31.0
now=$(date "+%s")
time=$((now-start))