diff --git a/python/llm/dev/benchmark/all-in-one/config.yaml b/python/llm/dev/benchmark/all-in-one/config.yaml index d5fd85fc..e4863154 100644 --- a/python/llm/dev/benchmark/all-in-one/config.yaml +++ b/python/llm/dev/benchmark/all-in-one/config.yaml @@ -1,6 +1,13 @@ repo_id: + - 'THUDM/chatglm-6b' - 'THUDM/chatglm2-6b' - 'meta-llama/Llama-2-7b-chat-hf' local_model_hub: 'path to your local model hub' warm_up: 1 num_trials: 3 +in_out_pairs: + - '32-32' + - '1024-128' +test_api: + - "transformer_int4" + - "native_int4" diff --git a/python/llm/dev/benchmark/all-in-one/run.py b/python/llm/dev/benchmark/all-in-one/run.py index fee17822..56d45796 100644 --- a/python/llm/dev/benchmark/all-in-one/run.py +++ b/python/llm/dev/benchmark/all-in-one/run.py @@ -18,7 +18,6 @@ # this code is copied from llama2 example test, and added performance test import torch import time -import argparse from bigdl.llm.transformers import AutoModel, AutoModelForCausalLM from transformers import AutoTokenizer @@ -31,14 +30,18 @@ benchmark_util_path = os.path.join(current_dir, '..') import sys sys.path.append(benchmark_util_path) from benchmark_util import BenchmarkWrapper +from bigdl.llm.utils.common.log4Error import invalidInputError results = [] -def run_model(repo_id, local_model_hub=None, warm_up=1, num_trials=3): +def run_model(repo_id, test_api, in_out_pairs, local_model_hub=None, warm_up=1, num_trials=3): # TODO: make a parameter - in_out_pairs = ['32-32', '1024-128'] - result = run_transformer_int4(repo_id, local_model_hub, in_out_pairs, warm_up, num_trials) + if test_api == 'transformer_int4': + result = run_transformer_int4(repo_id, local_model_hub, in_out_pairs, warm_up, num_trials) + elif test_api == 'native_int4': + run_native_int4(repo_id, local_model_hub, in_out_pairs, warm_up, num_trials) + for in_out_pair in in_out_pairs: results.append([repo_id, np.mean(result[in_out_pair], axis=0)[0], @@ -46,18 +49,59 @@ def run_model(repo_id, local_model_hub=None, warm_up=1, num_trials=3): np.mean(result[in_out_pair], axis=0)[2], in_out_pair]) +def get_model_path(repo_id, local_model_hub): + if local_model_hub: + repo_model_name = repo_id.split("/")[1] + return local_model_hub + os.path.sep + repo_model_name + else: + return repo_id + + +def run_native_int4(repo_id, + local_model_hub, + in_out_pairs, + warm_up, + num_trials): + model_path = get_model_path(repo_id, local_model_hub) + from bigdl.llm.transformers import BigdlNativeForCausalLM + from bigdl.llm import llm_convert + if "chatglm" in repo_id.lower(): + family = "chatglm" + elif "llama" in repo_id.lower(): + family = "llama" + else: + invalidInputError(False, "Model family unknown: " + repo_id) + + bigdl_llm_path = llm_convert(model=model_path, + outfile="./", outtype='int4', model_family=family) + for in_out in in_out_pairs: + in_out_len = in_out.split("-") + in_len = int(in_out_len[0]) + out_len = int(in_out_len[1]) + input_str = open(f"prompt/{in_len}.txt", 'r').read() + # As different tokenizer has different encodings, + # slice the input_ids to ensure the prompt length is required length. + n_ctx = in_len + out_len if in_len + out_len > 512 else 512 + for i in range(num_trials + warm_up): + model = BigdlNativeForCausalLM.from_pretrained(bigdl_llm_path, model_family=family, n_ctx=n_ctx) + input_ids = model.tokenize(input_str) + input_ids = input_ids[:in_len] + true_input = model.batch_decode(input_ids) + st = time.perf_counter() + output = model(true_input, max_tokens=out_len) + end = time.perf_counter() + print("model generate cost: " + str(end - st)) + print(output) + + os.remove(bigdl_llm_path) + def run_transformer_int4(repo_id, local_model_hub, in_out_pairs, warm_up, - num_trials, - device='cpu'): - if local_model_hub: - repo_model_name = repo_id.split("/")[1] - model_path = local_model_hub + "/" + repo_model_name - else: - model_path = repo_id + num_trials): + model_path = get_model_path(repo_id, local_model_hub) # Load model in 4 bit, # which convert the relevant layers in the model into INT4 format st = time.perf_counter() @@ -102,7 +146,9 @@ if __name__ == '__main__': today = date.today() import pandas as pd - for model in conf.repo_id: - run_model(model, conf['local_model_hub'], conf['warm_up'], conf['num_trials']) - df = pd.DataFrame(results, columns=['model', '1st token avg latency (s)', '2+ avg latency (s/token)', 'encoder time (s)', 'input/output tokens']) - df.to_csv(f'{current_dir}/results-{today}.csv') \ No newline at end of file + for api in conf.test_api: + for model in conf.repo_id: + run_model(model, api, conf['in_out_pairs'], conf['local_model_hub'], conf['warm_up'], conf['num_trials']) + df = pd.DataFrame(results, columns=['model', '1st token avg latency (s)', '2+ avg latency (s/token)', 'encoder time (s)', 'input/output tokens']) + df.to_csv(f'{current_dir}/{api}-results-{today}.csv') + result = [] \ No newline at end of file