Add npu benchmark all-in-one script (#11571)

* npu benchmark
This commit is contained in:
Zhao Changmin 2024-07-15 10:42:37 +08:00 committed by GitHub
parent 019da6c0ab
commit 06745e5742
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 80 additions and 2 deletions

View file

@ -57,6 +57,7 @@ test_api:
# - "bigdl_ipex_int8" # on Intel CPU, (qtype=int8)
# - "speculative_cpu" # on Intel CPU, inference with self-speculative decoding
# - "deepspeed_transformer_int4_cpu" # on Intel CPU, deepspeed autotp inference
# - "transformers_int4_npu_win" # on Intel NPU for Windows, transformer-like API, (qtype=int4)
cpu_embedding: False # whether put embedding to CPU
streaming: False # whether output in streaming way (only available now for gpu win related test_api)
use_fp16_torch_dtype: True # whether use fp16 for non-linear layer (only available now for "pipeline_parallel_gpu" test_api)

View file

@ -33,6 +33,7 @@ test_api:
# - "bigdl_ipex_int8" # on Intel CPU, (qtype=int8)
# - "speculative_cpu" # on Intel CPU, inference with self-speculative decoding
# - "deepspeed_transformer_int4_cpu" # on Intel CPU, deepspeed autotp inference
# - "transformers_int4_npu_win" # on Intel NPU for Windows, transformer-like API, (qtype=int4)
cpu_embedding: False # whether put embedding to CPU
streaming: False # whether output in streaming way (only available now for gpu win related test_api)
use_fp16_torch_dtype: True # whether use fp16 for non-linear layer (only available now for "pipeline_parallel_gpu" test_api)

View file

@ -161,6 +161,8 @@ def run_model(repo_id, test_api, in_out_pairs, local_model_hub=None, warm_up=1,
result = run_speculative_gpu(repo_id, local_model_hub, in_out_pairs, warm_up, num_trials, num_beams, batch_size)
elif test_api == 'pipeline_parallel_gpu':
result = run_pipeline_parallel_gpu(repo_id, local_model_hub, in_out_pairs, warm_up, num_trials, num_beams, low_bit, batch_size, cpu_embedding, fp16=use_fp16_torch_dtype)
elif test_api == 'transformers_int4_npu_win':
result = transformers_int4_npu_win(repo_id, local_model_hub, in_out_pairs, warm_up, num_trials, num_beams, low_bit, batch_size)
else:
invalidInputError(False, "Unknown test_api " + test_api + ", please check your config.yaml.")
@ -567,6 +569,78 @@ def run_transformer_int4_gpu(repo_id,
gc.collect()
return result
def transformers_int4_npu_win(repo_id,
local_model_hub,
in_out_pairs,
warm_up,
num_trials,
num_beams,
low_bit,
batch_size):
from ipex_llm.transformers.npu_model import AutoModel, AutoModelForCausalLM
from transformers import AutoTokenizer, LlamaTokenizer
model_path = get_model_path(repo_id, local_model_hub)
# Load model in 4 bit,
# which convert the relevant layers in the model into INT4 format
st = time.perf_counter()
if repo_id in CHATGLM_IDS:
model = AutoModel.from_pretrained(model_path, load_in_low_bit=low_bit, trust_remote_code=True, torch_dtype='auto').eval()
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
elif repo_id in LLAMA_IDS:
model = AutoModelForCausalLM.from_pretrained(model_path, load_in_low_bit=low_bit, trust_remote_code=True,
use_cache=True).eval()
tokenizer = LlamaTokenizer.from_pretrained(model_path, trust_remote_code=True)
else:
model = AutoModelForCausalLM.from_pretrained(model_path, load_in_low_bit=low_bit, trust_remote_code=True,
use_cache=True).eval()
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
end = time.perf_counter()
load_time = end - st
print(">> loading of model costs {}s".format(load_time))
model = BenchmarkWrapper(model)
result = {}
with torch.inference_mode():
for in_out in in_out_pairs:
in_out_len = in_out.split("-")
in_len = int(in_out_len[0])
out_len = int(in_out_len[1])
# As different tokenizer has different encodings,
# in_len.txt maybe shorter than we need,
# use much longer context to make sure input length
test_length = min(in_len*2, 8192)
while test_length not in [32, 256, 1024, 2048, 8192]:
test_length = test_length * 2
input_str = open(f"prompt/continuation/{test_length}.txt", 'r').read()
# As different tokenizer has different encodings,
# slice the input_ids to ensure the prompt length is required length.
input_ids = tokenizer.encode(input_str, return_tensors="pt")
input_ids = input_ids[:, :in_len]
true_str = tokenizer.batch_decode(input_ids)[0]
input_list = [true_str] * batch_size
input_ids = tokenizer(input_list, return_tensors="pt").input_ids
actual_in_len = input_ids.shape[1]
result[in_out] = []
for i in range(num_trials + warm_up):
st = time.perf_counter()
output_ids = model.generate(input_ids, do_sample=False, max_new_tokens=out_len,
min_new_tokens=out_len, num_beams=num_beams)
end = time.perf_counter()
print("model generate cost: " + str(end - st))
output = tokenizer.batch_decode(output_ids)
print(output[0])
actual_out_len = output_ids.shape[1] - actual_in_len
if i >= warm_up:
result[in_out].append([model.first_cost, model.rest_cost_mean, model.encoder_time,
actual_in_len, actual_out_len, load_time])
del model
gc.collect()
return result
def run_optimize_model_gpu(repo_id,
local_model_hub,
in_out_pairs,

View file

@ -1,5 +1,5 @@
# Run Large Language Model on Intel NPU
In this directory, you will find examples on how you could apply IPEX-LLM INT4 or INT8 optimizations on LLM models on [Intel NPUs](../../../README.md). In this directory, you will find examples on how you could apply IPEX-LLM INT4 or INT8 optimizations on LLM models on Intel NPUs. See the table blow for verified models.
In this directory, you will find examples on how you could apply IPEX-LLM INT4 or INT8 optimizations on LLM models on [Intel NPUs](../../../README.md). See the table blow for verified models.
## Verified Models
@ -8,12 +8,14 @@ In this directory, you will find examples on how you could apply IPEX-LLM INT4 o
| Llama2 | [meta-llama/Llama-2-7b-chat-hf](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf) |
| Llama3 | [meta-llama/Meta-Llama-3-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct) |
| Chatglm3 | [THUDM/chatglm3-6b](https://huggingface.co/THUDM/chatglm3-6b) |
| Chatglm2 | [THUDM/chatglm2-6b](https://huggingface.co/THUDM/chatglm2-6b) |
| Qwen2 | [Qwen/Qwen2-7B-Instruct](https://huggingface.co/Qwen/Qwen2-7B-Instruct) |
| MiniCPM | [openbmb/MiniCPM-2B-sft-bf16](https://huggingface.co/openbmb/MiniCPM-2B-sft-bf16) |
| Phi-3 | [microsoft/Phi-3-mini-4k-instruct](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct) |
| Stablelm | [stabilityai/stablelm-zephyr-3b](https://huggingface.co/stabilityai/stablelm-zephyr-3b) |
| Baichuan2 | [baichuan-inc/Baichuan2-7B-Chat](https://huggingface.co/baichuan-inc/Baichuan-7B-Chat) |
| Deepseek | [deepseek-ai/deepseek-coder-6.7b-instruct](https://huggingface.co/deepseek-ai/deepseek-coder-6.7b-instruct) |
| Mistral | [mistralai/Mistral-7B-Instruct-v0.1](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1) |
## 0. Requirements
To run these examples with IPEX-LLM on Intel NPUs, make sure to install the newest driver version of Intel NPU.

View file

@ -1,5 +1,5 @@
# Run Large Multimodal Model on Intel NPU
In this directory, you will find examples on how you could apply IPEX-LLM INT4 or INT8 optimizations on Large Multimodal Models on [Intel NPUs](../../../README.md). In this directory, you will find examples on how you could apply IPEX-LLM INT4 or INT8 optimizations on Large Multimodal Models on Intel NPUs. See the table blow for verified models.
In this directory, you will find examples on how you could apply IPEX-LLM INT4 or INT8 optimizations on Large Multimodal Models on [Intel NPUs](../../../README.md). See the table blow for verified models.
## Verified Models