diff --git a/python/llm/dev/benchmark/all-in-one/config.yaml b/python/llm/dev/benchmark/all-in-one/config.yaml index 99685f60..30227b0e 100644 --- a/python/llm/dev/benchmark/all-in-one/config.yaml +++ b/python/llm/dev/benchmark/all-in-one/config.yaml @@ -3,7 +3,7 @@ repo_id: - 'meta-llama/Llama-2-7b-chat-hf' # - 'liuhaotian/llava-v1.5-7b' # requires a LLAVA_REPO_DIR env variables pointing to the llava dir; added only for gpu win related test_api now local_model_hub: 'path to your local model hub' -warm_up: 1 +warm_up: 1 # must set >=2 when run "pipeline_parallel_gpu" test_api num_trials: 3 num_beams: 1 # default to greedy search low_bit: 'sym_int4' # default to use 'sym_int4' (i.e. symmetric int4) @@ -21,6 +21,7 @@ test_api: # - "transformer_int4_fp16_gpu_win" # on Intel GPU for Windows, use fp16 for non-linear layer # - "transformer_int4_loadlowbit_gpu_win" # on Intel GPU for Windows using load_low_bit API. Please make sure you have used the save.py to save the converted low bit model # - "deepspeed_optimize_model_gpu" # deepspeed autotp on Intel GPU + # - "pipeline_parallel_gpu" # pipeline parallel inference on Intel GPU # - "speculative_gpu" # - "transformer_int4" # - "native_int4" @@ -34,3 +35,5 @@ test_api: # - "deepspeed_transformer_int4_cpu" # on Intel SPR Server cpu_embedding: False # whether put embedding to CPU streaming: False # whether output in streaming way (only avaiable now for gpu win related test_api) +use_fp16_torch_dtype: True # whether use fp16 for non-linear layer (only avaiable now for "pipeline_parallel_gpu" test_api) +n_gpu: 2 # number of GPUs to use (only avaiable now for "pipeline_parallel_gpu" test_api) diff --git a/python/llm/dev/benchmark/all-in-one/run.py b/python/llm/dev/benchmark/all-in-one/run.py index 6bc6013e..ce1b18fd 100644 --- a/python/llm/dev/benchmark/all-in-one/run.py +++ b/python/llm/dev/benchmark/all-in-one/run.py @@ -50,7 +50,7 @@ def run_model_in_thread(model, in_out, tokenizer, result, warm_up, num_beams, in for i in range(num_trials + warm_up): st = time.perf_counter() output_ids = model.generate(input_ids, do_sample=False, max_new_tokens=out_len, - min_new_tokens=out_len, num_beams=num_beams) + num_beams=num_beams) torch.xpu.synchronize() end = time.perf_counter() output_ids = output_ids.cpu() @@ -63,7 +63,7 @@ def run_model_in_thread(model, in_out, tokenizer, result, warm_up, num_beams, in result[in_out].append([model.first_cost, model.rest_cost_mean, model.encoder_time, actual_in_len, actual_out_len, load_time, model.peak_memory]) -def run_model(repo_id, test_api, in_out_pairs, local_model_hub=None, warm_up=1, num_trials=3, num_beams=1, low_bit='sym_int4', cpu_embedding=False, batch_size=1, streaming=False): +def run_model(repo_id, test_api, in_out_pairs, local_model_hub=None, warm_up=1, num_trials=3, num_beams=1, low_bit='sym_int4', cpu_embedding=False, batch_size=1, streaming=False, use_fp16_torch_dtype=False, n_gpu=2): # TODO: make a parameter result= {} if test_api == 'transformer_int4': @@ -108,6 +108,8 @@ def run_model(repo_id, test_api, in_out_pairs, local_model_hub=None, warm_up=1, result = run_speculative_cpu(repo_id, local_model_hub, in_out_pairs, warm_up, num_trials, num_beams, batch_size) elif test_api == 'speculative_gpu': result = run_speculative_gpu(repo_id, local_model_hub, in_out_pairs, warm_up, num_trials, num_beams, batch_size) + elif test_api == 'pipeline_parallel_gpu': + result = run_pipeline_parallel_gpu(repo_id, local_model_hub, in_out_pairs, warm_up, num_trials, num_beams, low_bit, batch_size, cpu_embedding, fp16=use_fp16_torch_dtype, n_gpu=n_gpu) for in_out_pair in in_out_pairs: if result and result[in_out_pair]: @@ -124,7 +126,8 @@ def run_model(repo_id, test_api, in_out_pairs, local_model_hub=None, warm_up=1, cpu_embedding, round(result[in_out_pair][-1][5], 2), result[in_out_pair][-1][6] if any(keyword in test_api for keyword in ['int4_gpu', 'int4_fp16_gpu_win', 'int4_loadlowbit_gpu', 'fp16_gpu', 'deepspeed_optimize_model_gpu']) else 'N/A', - streaming if 'win' in test_api else 'N/A'], + streaming if 'win' in test_api else 'N/A', + use_fp16_torch_dtype], ) @@ -1674,6 +1677,125 @@ def run_speculative_gpu(repo_id, return result +def run_pipeline_parallel_gpu(repo_id, + local_model_hub, + in_out_pairs, + warm_up, + num_trials, + num_beams, + low_bit, + batch_size, + cpu_embedding, + fp16=False, + n_gpu=2): + from ipex_llm.transformers import AutoModel, AutoModelForCausalLM + from transformers import AutoTokenizer, GPTJForCausalLM, LlamaTokenizer + import intel_extension_for_pytorch as ipex + model_path = get_model_path(repo_id, local_model_hub) + # Load model in 4 bit, + # which convert the relevant layers in the model into INT4 format + st = time.perf_counter() + origin_repo_id = repo_id.replace("-4bit", "") + if origin_repo_id in CHATGLM_IDS: + if "4bit" in repo_id: + model = AutoModel.load_low_bit(model_path, optimize_model=True, + trust_remote_code=True, use_cache=True, cpu_embedding=cpu_embedding).eval() + else: + model = AutoModel.from_pretrained(model_path, load_in_low_bit=low_bit, optimize_model=True, + trust_remote_code=True, use_cache=True).eval() + tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, cpu_embedding=cpu_embedding) + elif origin_repo_id in LLAMA_IDS: + model = AutoModelForCausalLM.from_pretrained(model_path, load_in_low_bit=low_bit, trust_remote_code=True, + use_cache=True, cpu_embedding=cpu_embedding).eval() + tokenizer = LlamaTokenizer.from_pretrained(model_path, trust_remote_code=True) + else: + if "4bit" in repo_id: + model = AutoModelForCausalLM.load_low_bit(model_path, optimize_model=True, + trust_remote_code=True, use_cache=True, cpu_embedding=cpu_embedding).eval() + else: + if 'starcoder' in repo_id: + # Load starcoder-15.5b model in bf16 format to avoid CPU OOM. + model = AutoModelForCausalLM.from_pretrained(model_path, optimize_model=True, load_in_low_bit=low_bit, + trust_remote_code=True, use_cache=True, cpu_embedding=cpu_embedding, torch_dtype=torch.bfloat16).eval() + # Convert the low-bit model back to fp32 for performance considerations. + model = model.float() + else: + model = AutoModelForCausalLM.from_pretrained(model_path, optimize_model=True, load_in_low_bit=low_bit, + trust_remote_code=True, use_cache=True, cpu_embedding=cpu_embedding).eval() + tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + + if fp16: + model = model.half() + print("Convert model to half precision") + + end = time.perf_counter() + load_time = end - st + print(">> loading of model costs {}s and {}GB".format(load_time, torch.xpu.memory.memory_reserved()/(1024**3))) + + model_layers = ['model.embed_tokens'] + for i in range(model.config.num_hidden_layers): + model_layers.append(f'model.layers.{i}') + model_layers = model_layers + ['model.norm', 'lm_head'] + + device_map = {} + split_len = len(model_layers) // n_gpu + for i in range(n_gpu): + device_map.update({key: f'xpu:{i}' for key in model_layers[split_len * i: split_len * (i + 1)]}) + if i == n_gpu - 1: + device_map.update({key: f'xpu:{i}' for key in model_layers[split_len * (i + 1): ]}) + print(f">> device map: {device_map}") + + from accelerate import dispatch_model + model = dispatch_model( + model, + device_map=device_map, + offload_dir=None, + skip_keys=["past_key_value", "past_key_values"], + ) + + model = BenchmarkWrapper(model) + result = {} + with torch.inference_mode(): + for in_out in in_out_pairs: + in_out_len = in_out.split("-") + in_len = int(in_out_len[0]) + out_len = int(in_out_len[1]) + # As different tokenizer has different encodings, + # in_len.txt maybe shorter than we need, + # use much longer context to make sure input length + test_length = min(in_len*2, 8192) + while test_length not in [32, 256, 1024, 2048, 8192]: + test_length = test_length * 2 + input_str = open(f"prompt/{test_length}.txt", 'r').read() + # As different tokenizer has different encodings, + # slice the input_ids to ensure the prompt length is required length. + input_ids = tokenizer.encode(input_str, return_tensors="pt") + input_ids = input_ids[:, :in_len] + true_str = tokenizer.batch_decode(input_ids)[0] + input_list = [true_str] * batch_size + input_ids = tokenizer(input_list, return_tensors="pt").input_ids.to('xpu:0') + actual_in_len = input_ids.shape[1] + result[in_out] = [] + for i in range(num_trials + warm_up): + st = time.perf_counter() + output_ids = model.generate(input_ids, do_sample=False, max_new_tokens=out_len, + num_beams=num_beams) + torch.xpu.synchronize() + end = time.perf_counter() + output_ids = output_ids.cpu() + print("model generate cost: " + str(end - st)) + output = tokenizer.batch_decode(output_ids) + actual_out_len = output_ids.shape[1] - actual_in_len + print(output[0]) + torch.xpu.empty_cache() + if i >= warm_up: + result[in_out].append([model.first_cost, model.rest_cost_mean, model.encoder_time, + actual_in_len, actual_out_len, load_time, model.peak_memory, fp16]) + del model + torch.xpu.empty_cache() + return result + + if __name__ == '__main__': from omegaconf import OmegaConf conf = OmegaConf.load(f'{current_dir}/config.yaml') @@ -1698,9 +1820,9 @@ if __name__ == '__main__': if model_id_input in excludes or model_id_input_batch_size in excludes: in_out_pairs.remove(in_out) run_model(model, api, in_out_pairs, conf['local_model_hub'], conf['warm_up'], conf['num_trials'], conf['num_beams'], - conf['low_bit'], conf['cpu_embedding'], conf['batch_size'], streaming) + conf['low_bit'], conf['cpu_embedding'], conf['batch_size'], streaming, conf['use_fp16_torch_dtype'], conf['n_gpu']) df = pd.DataFrame(results, columns=['model', '1st token avg latency (ms)', '2+ avg latency (ms/token)', 'encoder time (ms)', 'input/output tokens', 'batch_size', 'actual input/output tokens', 'num_beams', 'low_bit', 'cpu_embedding', - 'model loading time (s)', 'peak mem (GB)', 'streaming']) + 'model loading time (s)', 'peak mem (GB)', 'streaming', 'use_fp16_torch_dtype']) df.to_csv(csv_name) results = []