diff --git a/python/llm/dev/benchmark/all-in-one/config.yaml b/python/llm/dev/benchmark/all-in-one/config.yaml index 9df48a01..447beab4 100644 --- a/python/llm/dev/benchmark/all-in-one/config.yaml +++ b/python/llm/dev/benchmark/all-in-one/config.yaml @@ -33,7 +33,10 @@ test_api: # - "bigdl_ipex_int8" # on Intel CPU, (qtype=int8) # - "speculative_cpu" # on Intel CPU, inference with self-speculative decoding # - "deepspeed_transformer_int4_cpu" # on Intel CPU, deepspeed autotp inference + # - "transformer_int4_fp16_lookahead_gpu" # on Intel GPU, transformer-like API, with lookahead, (qtype=int4), (dtype=fp16) cpu_embedding: False # whether put embedding to CPU streaming: False # whether output in streaming way (only avaiable now for gpu win related test_api) use_fp16_torch_dtype: True # whether use fp16 for non-linear layer (only avaiable now for "pipeline_parallel_gpu" test_api) n_gpu: 2 # number of GPUs to use (only avaiable now for "pipeline_parallel_gpu" test_api) +lookahead: 3 +max_matching_ngram_size: 2 diff --git a/python/llm/dev/benchmark/all-in-one/run.py b/python/llm/dev/benchmark/all-in-one/run.py index 56dfd54d..7257a136 100644 --- a/python/llm/dev/benchmark/all-in-one/run.py +++ b/python/llm/dev/benchmark/all-in-one/run.py @@ -45,11 +45,15 @@ LLAVA_IDS = ['liuhaotian/llava-v1.5-7b'] results = [] excludes = [] -def run_model_in_thread(model, in_out, tokenizer, result, warm_up, num_beams, input_ids, out_len, actual_in_len, num_trials, load_time): +def run_model_in_thread(model, in_out, tokenizer, result, warm_up, num_beams, input_ids, out_len, actual_in_len, num_trials, load_time, lookahead): for i in range(num_trials + warm_up): st = time.perf_counter() - output_ids = model.generate(input_ids, do_sample=False, max_new_tokens=out_len, + if lookahead: + output_ids = model.generate(input_ids, lookahead=conf.lookahead, do_sample=False, max_matching_ngram_size=conf.max_matching_ngram_size, max_new_tokens=out_len, min_new_tokens=out_len, num_beams=num_beams) + else: + output_ids = model.generate(input_ids, do_sample=False, max_new_tokens=out_len, + min_new_tokens=out_len, num_beams=num_beams) torch.xpu.synchronize() end = time.perf_counter() output_ids = output_ids.cpu() @@ -59,8 +63,12 @@ def run_model_in_thread(model, in_out, tokenizer, result, warm_up, num_beams, in torch.xpu.empty_cache() actual_out_len = output_ids.shape[1] - actual_in_len if i >= warm_up: - result[in_out].append([model.first_cost, model.rest_cost_mean, model.encoder_time, - actual_in_len, actual_out_len, load_time, model.peak_memory]) + if lookahead: + result[in_out].append([model.first_token_time, (end - st - model.first_token_time)/model.n_token_generated, 0, + actual_in_len, actual_out_len, load_time, 0]) + else: + result[in_out].append([model.first_cost, model.rest_cost_mean, model.encoder_time, + actual_in_len, actual_out_len, load_time, model.peak_memory]) def run_model(repo_id, test_api, in_out_pairs, local_model_hub=None, warm_up=1, num_trials=3, num_beams=1, low_bit='sym_int4', cpu_embedding=False, batch_size=1, streaming=False, use_fp16_torch_dtype=False, n_gpu=2): # TODO: make a parameter @@ -109,6 +117,8 @@ def run_model(repo_id, test_api, in_out_pairs, local_model_hub=None, warm_up=1, result = run_speculative_gpu(repo_id, local_model_hub, in_out_pairs, warm_up, num_trials, num_beams, batch_size) elif test_api == 'pipeline_parallel_gpu': result = run_pipeline_parallel_gpu(repo_id, local_model_hub, in_out_pairs, warm_up, num_trials, num_beams, low_bit, batch_size, cpu_embedding, fp16=use_fp16_torch_dtype, n_gpu=n_gpu) + elif test_api == "transformer_int4_fp16_lookahead_gpu": + result = run_transformer_int4_gpu(repo_id, local_model_hub, in_out_pairs, warm_up, num_trials, num_beams, low_bit, batch_size, cpu_embedding, fp16=True, lookahead=True) else: invalidInputError(False, "Unknown test_api " + test_api + ", please check your config.yaml.") @@ -117,7 +127,7 @@ def run_model(repo_id, test_api, in_out_pairs, local_model_hub=None, warm_up=1, results.append([repo_id, round(np.mean(result[in_out_pair], axis=0)[0]*1000.0, 2), round(np.mean(result[in_out_pair], axis=0)[1]*1000.0, 2), - round(np.mean(result[in_out_pair], axis=0)[2]*1000.0, 2), + round(np.mean(result[in_out_pair], axis=0)[2]*1000.0, 2) if 'lookahead' not in test_api else 'N/A', in_out_pair, batch_size, f'{int(np.mean(result[in_out_pair], axis=0)[3])}' + @@ -396,7 +406,8 @@ def run_transformer_int4_gpu(repo_id, low_bit, batch_size, cpu_embedding, - fp16=False): + fp16=False, + lookahead=False): from ipex_llm.transformers import AutoModel, AutoModelForCausalLM from transformers import AutoTokenizer, GPTJForCausalLM, LlamaTokenizer import intel_extension_for_pytorch as ipex @@ -443,7 +454,8 @@ def run_transformer_int4_gpu(repo_id, load_time = end - st print(">> loading of model costs {}s and {}GB".format(load_time, torch.xpu.memory.memory_reserved()/(1024**3))) - model = BenchmarkWrapper(model) + if not lookahead: + model = BenchmarkWrapper(model) result = {} with torch.inference_mode(): @@ -460,16 +472,25 @@ def run_transformer_int4_gpu(repo_id, # For the sequence length not in [32, 256, 1024, 2048, 8192], it will be truncated from 8192.txt. test_length = min(test_length, 8192) input_str = open(f"prompt/{test_length}.txt", 'r').read() - # As different tokenizer has different encodings, - # slice the input_ids to ensure the prompt length is required length. - input_ids = tokenizer.encode(input_str, return_tensors="pt") - input_ids = input_ids[:, :in_len] + if lookahead: + question = "Can you please summarize this article?" + question_tokens = tokenizer.encode(question, return_tensors="pt") + max_article_len = in_len - question_tokens.size(1) + article_ids = tokenizer.encode(input_str, return_tensors="pt") + if article_ids.size(1) > max_article_len: + article_ids = article_ids[:, :max_article_len] + input_ids = torch.cat((article_ids, question_tokens), dim=1) + else: + # As different tokenizer has different encodings, + # slice the input_ids to ensure the prompt length is required length. + input_ids = tokenizer.encode(input_str, return_tensors="pt") + input_ids = input_ids[:, :in_len] true_str = tokenizer.batch_decode(input_ids)[0] input_list = [true_str] * batch_size input_ids = tokenizer(input_list, return_tensors="pt").input_ids.to('xpu') actual_in_len = input_ids.shape[1] result[in_out] = [] - thread = threading.Thread(target=run_model_in_thread, args=(model, in_out, tokenizer, result, warm_up, num_beams, input_ids, out_len, actual_in_len, num_trials, load_time)) + thread = threading.Thread(target=run_model_in_thread, args=(model, in_out, tokenizer, result, warm_up, num_beams, input_ids, out_len, actual_in_len, num_trials, load_time, lookahead)) thread.start() thread.join()