benchmark for native int4 (#8918)

* native4

* update

* update

* update
This commit is contained in:
Xin Qiu 2023-09-07 15:56:15 +08:00 committed by GitHub
parent c0797ea232
commit e9de9d9950
2 changed files with 68 additions and 15 deletions

View file

@ -1,6 +1,13 @@
repo_id: repo_id:
- 'THUDM/chatglm-6b'
- 'THUDM/chatglm2-6b' - 'THUDM/chatglm2-6b'
- 'meta-llama/Llama-2-7b-chat-hf' - 'meta-llama/Llama-2-7b-chat-hf'
local_model_hub: 'path to your local model hub' local_model_hub: 'path to your local model hub'
warm_up: 1 warm_up: 1
num_trials: 3 num_trials: 3
in_out_pairs:
- '32-32'
- '1024-128'
test_api:
- "transformer_int4"
- "native_int4"

View file

@ -18,7 +18,6 @@
# this code is copied from llama2 example test, and added performance test # this code is copied from llama2 example test, and added performance test
import torch import torch
import time import time
import argparse
from bigdl.llm.transformers import AutoModel, AutoModelForCausalLM from bigdl.llm.transformers import AutoModel, AutoModelForCausalLM
from transformers import AutoTokenizer from transformers import AutoTokenizer
@ -31,14 +30,18 @@ benchmark_util_path = os.path.join(current_dir, '..')
import sys import sys
sys.path.append(benchmark_util_path) sys.path.append(benchmark_util_path)
from benchmark_util import BenchmarkWrapper from benchmark_util import BenchmarkWrapper
from bigdl.llm.utils.common.log4Error import invalidInputError
results = [] results = []
def run_model(repo_id, local_model_hub=None, warm_up=1, num_trials=3): def run_model(repo_id, test_api, in_out_pairs, local_model_hub=None, warm_up=1, num_trials=3):
# TODO: make a parameter # TODO: make a parameter
in_out_pairs = ['32-32', '1024-128'] if test_api == 'transformer_int4':
result = run_transformer_int4(repo_id, local_model_hub, in_out_pairs, warm_up, num_trials) result = run_transformer_int4(repo_id, local_model_hub, in_out_pairs, warm_up, num_trials)
elif test_api == 'native_int4':
run_native_int4(repo_id, local_model_hub, in_out_pairs, warm_up, num_trials)
for in_out_pair in in_out_pairs: for in_out_pair in in_out_pairs:
results.append([repo_id, results.append([repo_id,
np.mean(result[in_out_pair], axis=0)[0], np.mean(result[in_out_pair], axis=0)[0],
@ -46,18 +49,59 @@ def run_model(repo_id, local_model_hub=None, warm_up=1, num_trials=3):
np.mean(result[in_out_pair], axis=0)[2], np.mean(result[in_out_pair], axis=0)[2],
in_out_pair]) in_out_pair])
def get_model_path(repo_id, local_model_hub):
if local_model_hub:
repo_model_name = repo_id.split("/")[1]
return local_model_hub + os.path.sep + repo_model_name
else:
return repo_id
def run_native_int4(repo_id,
local_model_hub,
in_out_pairs,
warm_up,
num_trials):
model_path = get_model_path(repo_id, local_model_hub)
from bigdl.llm.transformers import BigdlNativeForCausalLM
from bigdl.llm import llm_convert
if "chatglm" in repo_id.lower():
family = "chatglm"
elif "llama" in repo_id.lower():
family = "llama"
else:
invalidInputError(False, "Model family unknown: " + repo_id)
bigdl_llm_path = llm_convert(model=model_path,
outfile="./", outtype='int4', model_family=family)
for in_out in in_out_pairs:
in_out_len = in_out.split("-")
in_len = int(in_out_len[0])
out_len = int(in_out_len[1])
input_str = open(f"prompt/{in_len}.txt", 'r').read()
# As different tokenizer has different encodings,
# slice the input_ids to ensure the prompt length is required length.
n_ctx = in_len + out_len if in_len + out_len > 512 else 512
for i in range(num_trials + warm_up):
model = BigdlNativeForCausalLM.from_pretrained(bigdl_llm_path, model_family=family, n_ctx=n_ctx)
input_ids = model.tokenize(input_str)
input_ids = input_ids[:in_len]
true_input = model.batch_decode(input_ids)
st = time.perf_counter()
output = model(true_input, max_tokens=out_len)
end = time.perf_counter()
print("model generate cost: " + str(end - st))
print(output)
os.remove(bigdl_llm_path)
def run_transformer_int4(repo_id, def run_transformer_int4(repo_id,
local_model_hub, local_model_hub,
in_out_pairs, in_out_pairs,
warm_up, warm_up,
num_trials, num_trials):
device='cpu'): model_path = get_model_path(repo_id, local_model_hub)
if local_model_hub:
repo_model_name = repo_id.split("/")[1]
model_path = local_model_hub + "/" + repo_model_name
else:
model_path = repo_id
# Load model in 4 bit, # Load model in 4 bit,
# which convert the relevant layers in the model into INT4 format # which convert the relevant layers in the model into INT4 format
st = time.perf_counter() st = time.perf_counter()
@ -102,7 +146,9 @@ if __name__ == '__main__':
today = date.today() today = date.today()
import pandas as pd import pandas as pd
for model in conf.repo_id: for api in conf.test_api:
run_model(model, conf['local_model_hub'], conf['warm_up'], conf['num_trials']) for model in conf.repo_id:
df = pd.DataFrame(results, columns=['model', '1st token avg latency (s)', '2+ avg latency (s/token)', 'encoder time (s)', 'input/output tokens']) run_model(model, api, conf['in_out_pairs'], conf['local_model_hub'], conf['warm_up'], conf['num_trials'])
df.to_csv(f'{current_dir}/results-{today}.csv') df = pd.DataFrame(results, columns=['model', '1st token avg latency (s)', '2+ avg latency (s/token)', 'encoder time (s)', 'input/output tokens'])
df.to_csv(f'{current_dir}/{api}-results-{today}.csv')
result = []