[LLM] Add performance tests for windows iGPU (#9584)

* Add support for win gpu benchmark with peak gpu memory monitoring

* Add win igpu tests

* Small fix

* Forward outputs

* Small fix

* Test and small fixes

* Small fix

* Small fix and test

* Small fixes

* Add tests for 512-64 and change back to nightly tests

* Small fix
This commit is contained in:
Yuwen Hu 2023-12-04 20:50:02 +08:00 committed by GitHub
parent 29d5bb8df4
commit 3f4ad97929
5 changed files with 293 additions and 5 deletions

View file

@ -206,3 +206,129 @@ jobs:
if [ ${{ github.event.schedule}} ]; then if [ ${{ github.event.schedule}} ]; then
curl -T ./*.csv ${LLM_FTP_URL}/llm/nightly_perf/core_${{ matrix.platform }}/ curl -T ./*.csv ${LLM_FTP_URL}/llm/nightly_perf/core_${{ matrix.platform }}/
fi fi
llm-performance-test-on-igpu:
if: ${{ github.event.schedule || github.event.inputs.artifact == 'llm-performance-test-on-igpu' || github.event.inputs.artifact == 'all' }}
needs: llm-cpp-build
strategy:
fail-fast: false
matrix:
include:
- os: windows
python-version: "3.9"
runs-on: [self-hosted, "${{ matrix.os }}", llm, perf-igpu]
env:
ANALYTICS_ZOO_ROOT: ${{ github.workspace }}
steps:
- uses: actions/checkout@v3
# TODO: Put the bigdl-llm related install process for win gpu into a action function
- name: Download llm binary
uses: ./.github/actions/llm/download-llm-binary
- name: Prepare for install bigdl-llm from source
shell: bash
run: |
sed -i 's/"bigdl-core-xe==" + VERSION + "/"bigdl-core-xe/g' python/llm/setup.py
- name: Install bigdl-llm and other related packages
shell: cmd
run: |
call conda create -n igpu-perf python=${{ matrix.python-version }} libuv -y
call conda activate igpu-perf
pip install --upgrade pip
pip install --upgrade wheel
pip install --upgrade omegaconf pandas
pip install --upgrade tiktoken einops transformers_stream_generator
cd python\llm
python setup.py clean --all bdist_wheel --win
if not exist dist\bigdl_llm*.whl (exit /b 1)
for %%i in (dist\bigdl_llm*.whl) do set whl_name=%%i
pip install %whl_name%[xpu] -i %INTERNAL_PYPI_URL% --trusted-host %INTERNAL_PYPI_TRUSTED_HOST% -q
if %ERRORLEVEL% neq 0 (exit /b 1)
call conda deactivate
- name: Set directory envs
shell: bash
run: |
if [ ${{ github.event_name }} == 'schedule' ]; then
echo "CSV_SAVE_PATH=${CSV_NIGHTLY_PATH}" >> "$GITHUB_ENV"
else
echo "CSV_SAVE_PATH=${CSV_PR_PATH}" >> "$GITHUB_ENV"
fi
cur_date=$(date +%Y-%m-%d)
echo "LOG_FILE=${cur_date}_output.txt" >> "$GITHUB_ENV"
- name: Prepare igpu perf test
shell: bash
run: |
# hide time info
sed -i 's/str(end - st)/"xxxxxx"/g' python/llm/dev/benchmark/all-in-one/run.py
sed -i 's/{today}/{today}_test1/g' python/llm/dev/benchmark/all-in-one/run.py
sed -i "s/path to your local model hub/$MODEL_HUB_DIR/g" python/llm/test/benchmark/igpu-perf-test.yaml
- name: Test on igpu
shell: cmd
run: |
call conda activate igpu-perf
call "C:\Program Files (x86)\Intel\oneAPI\setvars.bat"
set SYCL_ENABLE_DEFAULT_CONTEXTS=1
set SYCL_CACHE_PERSISTENT=1
REM for llava
set TRANSFORMERS_OFFLINE=1
move python\llm\test\benchmark\igpu-perf-test.yaml python\llm\dev\benchmark\all-in-one\config.yaml
cd python\llm\dev\benchmark\all-in-one
python run.py >> %LOG_FILE% 2>&1
if %ERRORLEVEL% neq 0 (exit /b 1)
call conda deactivate
- name: Prepare igpu perf test for Mistral
shell: bash
run: |
sed -i 's/test1/test2/g' python/llm/dev/benchmark/all-in-one/run.py
sed -i "s/path to your local model hub/$MODEL_HUB_DIR/g" python/llm/test/benchmark/igpu-perf-test-434.yaml
- name: Test on igpu for Mistral
shell: cmd
run: |
call conda activate igpu-perf
pip install transformers==4.34.0
call "C:\Program Files (x86)\Intel\oneAPI\setvars.bat"
set SYCL_ENABLE_DEFAULT_CONTEXTS=1
set SYCL_CACHE_PERSISTENT=1
move python\llm\test\benchmark\igpu-perf-test-434.yaml python\llm\dev\benchmark\all-in-one\config.yaml
cd python\llm\dev\benchmark\all-in-one
python run.py >> %LOG_FILE% 2>&1
if %ERRORLEVEL% neq 0 (exit /b 1)
call conda deactivate
- name: Concat csv and generate html
shell: cmd
run: |
call conda activate igpu-perf
cd python\llm\dev\benchmark\all-in-one
move %LOG_FILE% %CSV_SAVE_PATH%\log\
python ..\..\..\test\benchmark\concat_csv.py
copy *.csv %CSV_SAVE_PATH%
del /q *.csv
cd ..\..\..\test\benchmark
python csv_to_html.py -f %CSV_SAVE_PATH%
if %ERRORLEVEL% neq 0 (exit /b 1)
call conda deactivate
- name: Remove conda env
if: ${{ always() }}
shell: cmd
run: |
call conda env remove -n igpu-perf -y

View file

@ -2,6 +2,7 @@ repo_id:
- 'THUDM/chatglm-6b' - 'THUDM/chatglm-6b'
- 'THUDM/chatglm2-6b' - 'THUDM/chatglm2-6b'
- 'meta-llama/Llama-2-7b-chat-hf' - 'meta-llama/Llama-2-7b-chat-hf'
# - 'liuhaotian/llava-v1.5-7b' # requires a LLAVA_REPO_DIR env variables pointing to the llava dir; added only for gpu win related test_api now
local_model_hub: 'path to your local model hub' local_model_hub: 'path to your local model hub'
warm_up: 1 warm_up: 1
num_trials: 3 num_trials: 3
@ -19,4 +20,5 @@ test_api:
# - "transformer_int4_gpu" # on Intel GPU # - "transformer_int4_gpu" # on Intel GPU
# - "optimize_model_gpu" # on Intel GPU # - "optimize_model_gpu" # on Intel GPU
# - "deepspeed_transformer_int4_cpu" # on Intel SPR Server # - "deepspeed_transformer_int4_cpu" # on Intel SPR Server
# - "transformer_int4_gpu_win" # on Intel GPU for Windows (catch GPU peak memory)
cpu_embedding: False # whether put embedding to CPU (only avaiable now for gpu win related test_api)

View file

@ -18,6 +18,7 @@
# this code is copied from llama2 example test, and added performance test # this code is copied from llama2 example test, and added performance test
import torch import torch
import time import time
import gc
import numpy as np import numpy as np
from datetime import date from datetime import date
@ -37,10 +38,12 @@ LLAMA_IDS = ['meta-llama/Llama-2-7b-chat-hf','meta-llama/Llama-2-13b-chat-hf',
CHATGLM_IDS = ['THUDM/chatglm-6b', 'THUDM/chatglm2-6b', 'THUDM/chatglm3-6b'] CHATGLM_IDS = ['THUDM/chatglm-6b', 'THUDM/chatglm2-6b', 'THUDM/chatglm3-6b']
LLAVA_IDS = ['liuhaotian/llava-v1.5-7b']
results = [] results = []
def run_model(repo_id, test_api, in_out_pairs, local_model_hub=None, warm_up=1, num_trials=3, num_beams=1, low_bit='sym_int4'): def run_model(repo_id, test_api, in_out_pairs, local_model_hub=None, warm_up=1, num_trials=3, num_beams=1, low_bit='sym_int4', cpu_embedding=False):
# TODO: make a parameter # TODO: make a parameter
result= {} result= {}
if test_api == 'transformer_int4': if test_api == 'transformer_int4':
@ -59,6 +62,8 @@ def run_model(repo_id, test_api, in_out_pairs, local_model_hub=None, warm_up=1,
result = run_ipex_fp16_gpu(repo_id, local_model_hub, in_out_pairs, warm_up, num_trials, num_beams) result = run_ipex_fp16_gpu(repo_id, local_model_hub, in_out_pairs, warm_up, num_trials, num_beams)
elif test_api == 'deepspeed_transformer_int4_cpu': elif test_api == 'deepspeed_transformer_int4_cpu':
result = run_deepspeed_transformer_int4_cpu(repo_id, local_model_hub, in_out_pairs, warm_up, num_trials, num_beams, low_bit) result = run_deepspeed_transformer_int4_cpu(repo_id, local_model_hub, in_out_pairs, warm_up, num_trials, num_beams, low_bit)
elif test_api == 'transformer_int4_gpu_win':
result = run_transformer_int4_gpu_win(repo_id, local_model_hub, in_out_pairs, warm_up, num_trials, num_beams, low_bit, cpu_embedding)
for in_out_pair in in_out_pairs: for in_out_pair in in_out_pairs:
if result and result[in_out_pair]: if result and result[in_out_pair]:
@ -70,7 +75,9 @@ def run_model(repo_id, test_api, in_out_pairs, local_model_hub=None, warm_up=1,
f'{int(np.mean(result[in_out_pair], axis=0)[3])}' + f'{int(np.mean(result[in_out_pair], axis=0)[3])}' +
f'-{int(np.mean(result[in_out_pair], axis=0)[4])}', f'-{int(np.mean(result[in_out_pair], axis=0)[4])}',
num_beams, num_beams,
low_bit]) low_bit,
cpu_embedding if 'win' in test_api else 'N/A',
result[in_out_pair][-1][5] if 'win' in test_api else 'N/A']) # currently only peak mem for win gpu is caught here
def get_model_path(repo_id, local_model_hub): def get_model_path(repo_id, local_model_hub):
@ -637,6 +644,102 @@ def run_deepspeed_transformer_int4_cpu(repo_id,
actual_in_len, actual_out_len]) actual_in_len, actual_out_len])
return result return result
def run_transformer_int4_gpu_win(repo_id,
local_model_hub,
in_out_pairs,
warm_up,
num_trials,
num_beams,
low_bit,
cpu_embedding):
from bigdl.llm.transformers import AutoModel, AutoModelForCausalLM
from transformers import AutoTokenizer, GPTJForCausalLM, LlamaTokenizer
import intel_extension_for_pytorch as ipex
reserved_mem_list = []
model_path = get_model_path(repo_id, local_model_hub)
# Load model in 4 bit,
# which convert the relevant layers in the model into INT4 format
st = time.perf_counter()
if repo_id in CHATGLM_IDS:
model = AutoModel.from_pretrained(model_path, load_in_low_bit=low_bit, optimize_model=True,
trust_remote_code=True, use_cache=True, cpu_embedding=cpu_embedding)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model = model.to('xpu')
elif repo_id in LLAMA_IDS:
model = AutoModelForCausalLM.from_pretrained(model_path, load_in_low_bit=low_bit, trust_remote_code=True,
use_cache=True, cpu_embedding=cpu_embedding)
tokenizer = LlamaTokenizer.from_pretrained(model_path, trust_remote_code=True)
model = model.to('xpu')
elif repo_id in LLAVA_IDS:
llava_repo_dir = os.environ.get('LLAVA_REPO_DIR')
sys.path.append(rf"{llava_repo_dir}")
from llava.model.language_model.llava_llama import LlavaLlamaForCausalLM
model = AutoModelForCausalLM.from_pretrained(model_path, load_in_low_bit=low_bit, optimize_model=True,
trust_remote_code=True, use_cache=True, cpu_embedding=cpu_embedding)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model = model.to('xpu')
else:
model = AutoModelForCausalLM.from_pretrained(model_path, optimize_model=True, load_in_low_bit=low_bit,
trust_remote_code=True, use_cache=True, cpu_embedding=cpu_embedding)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model = model.to('xpu')
if isinstance(model, GPTJForCausalLM):
# For gpt-j model family, this optimization can provide a better performance.
model = ipex.optimize(model.eval(), inplace=True)
end = time.perf_counter()
print(">> loading of model costs {}s".format(end - st))
reserved_mem_list.append(torch.xpu.memory.memory_reserved()/(1024**3))
model = BenchmarkWrapper(model)
result = {}
with torch.inference_mode():
for in_out in in_out_pairs:
try:
in_out_len = in_out.split("-")
in_len = int(in_out_len[0])
out_len = int(in_out_len[1])
# As different tokenizer has different encodings,
# in_len.txt maybe shorter than we need,
# use much longer context to make sure input length
test_length = min(in_len*2, 8192)
while test_length not in [32, 256, 1024, 2048, 8192]:
test_length = test_length * 2
input_str = open(f"prompt/{test_length}.txt", 'r').read()
# As different tokenizer has different encodings,
# slice the input_ids to ensure the prompt length is required length.
input_ids = tokenizer.encode(input_str, return_tensors="pt")
input_ids = input_ids[:, :in_len]
true_str = tokenizer.batch_decode(input_ids)[0]
input_ids = tokenizer.encode(true_str, return_tensors="pt").to('xpu')
actual_in_len = input_ids.shape[1]
result[in_out] = []
for i in range(num_trials + warm_up):
st = time.perf_counter()
output_ids = model.generate(input_ids, do_sample=False, max_new_tokens=out_len,
num_beams=num_beams)
torch.xpu.synchronize()
end = time.perf_counter()
reserved_mem_list.append(torch.xpu.memory.memory_reserved()/(1024**3))
gpu_peak_mem = max(reserved_mem_list) # always keep the peak gpu mem at current stage
output_ids = output_ids.cpu()
print("model generate cost: " + str(end - st))
output = tokenizer.batch_decode(output_ids)
print(output[0])
actual_out_len = output_ids.shape[1] - actual_in_len
if i >= warm_up:
result[in_out].append([model.first_cost, model.rest_cost_mean, model.encoder_time,
actual_in_len, actual_out_len, gpu_peak_mem])
except RuntimeError:
pass
model.to('cpu')
torch.xpu.synchronize()
torch.xpu.empty_cache()
del model
gc.collect()
return result
if __name__ == '__main__': if __name__ == '__main__':
from omegaconf import OmegaConf from omegaconf import OmegaConf
conf = OmegaConf.load(f'{current_dir}/config.yaml') conf = OmegaConf.load(f'{current_dir}/config.yaml')
@ -645,9 +748,11 @@ if __name__ == '__main__':
import pandas as pd import pandas as pd
for api in conf.test_api: for api in conf.test_api:
for model in conf.repo_id: for model in conf.repo_id:
run_model(model, api, conf['in_out_pairs'], conf['local_model_hub'], conf['warm_up'], conf['num_trials'], conf['num_beams'], conf['low_bit']) run_model(model, api, conf['in_out_pairs'], conf['local_model_hub'], conf['warm_up'], conf['num_trials'], conf['num_beams'],
conf['low_bit'], conf['cpu_embedding'])
df = pd.DataFrame(results, columns=['model', '1st token avg latency (ms)', '2+ avg latency (ms/token)', 'encoder time (ms)', df = pd.DataFrame(results, columns=['model', '1st token avg latency (ms)', '2+ avg latency (ms/token)', 'encoder time (ms)',
'input/output tokens', 'actual input/output tokens', 'num_beams', 'low_bit']) 'input/output tokens', 'actual input/output tokens', 'num_beams', 'low_bit', 'cpu_embedding',
'peak mem (GB)'])
df.to_csv(f'{current_dir}/{api}-results-{today}.csv') df.to_csv(f'{current_dir}/{api}-results-{today}.csv')
results = [] results = []

View file

@ -0,0 +1,22 @@
repo_id:
- 'mistralai/Mistral-7B-Instruct-v0.1'
local_model_hub: 'path to your local model hub'
warm_up: 3
num_trials: 5
num_beams: 1 # default to greedy search
low_bit: 'sym_int4' # default to use 'sym_int4' (i.e. symmetric int4)
in_out_pairs:
- '32-32'
- '512-64'
# - '1024-128'
test_api:
# - "transformer_int4"
# - "native_int4"
# - "optimize_model"
# - "pytorch_autocast_bf16"
# - "ipex_fp16_gpu" # on Intel GPU
# - "transformer_int4_gpu" # on Intel GPU
# - "optimize_model_gpu" # on Intel GPU
# - "deepspeed_transformer_int4_cpu" # on Intel SPR Server
- "transformer_int4_gpu_win" # on Intel GPU for Windows (catch GPU peak memory)
cpu_embedding: True # whether put embedding to CPU (only avaiable now for gpu win related test_api)

View file

@ -0,0 +1,33 @@
repo_id:
- 'THUDM/chatglm2-6b'
- 'THUDM/chatglm3-6b'
- 'baichuan-inc/Baichuan2-7B-Chat'
- 'internlm/internlm-chat-7b-8k'
- 'Qwen/Qwen-7B-Chat-10-12'
- 'BAAI/AquilaChat2-7B'
- '01-ai/Yi-6B'
- 'meta-llama/Llama-2-7b-chat-hf'
- 'WisdomShell/CodeShell-7B-Chat'
- 'tiiuae/falcon-7b-instruct-with-patch'
- 'mosaicml/mpt-7b-chat'
- 'liuhaotian/llava-v1.5-7b'
local_model_hub: 'path to your local model hub'
warm_up: 3
num_trials: 5
num_beams: 1 # default to greedy search
low_bit: 'sym_int4' # default to use 'sym_int4' (i.e. symmetric int4)
in_out_pairs:
- '32-32'
- '512-64'
# - '1024-128'
test_api:
# - "transformer_int4"
# - "native_int4"
# - "optimize_model"
# - "pytorch_autocast_bf16"
# - "ipex_fp16_gpu" # on Intel GPU
# - "transformer_int4_gpu" # on Intel GPU
# - "optimize_model_gpu" # on Intel GPU
# - "deepspeed_transformer_int4_cpu" # on Intel SPR Server
- "transformer_int4_gpu_win" # on Intel GPU for Windows (catch GPU peak memory)
cpu_embedding: True # whether put embedding to CPU (only avaiable now for gpu win related test_api)