From 6ea1e71af044c780ed9a1e956e27e047b69a45ad Mon Sep 17 00:00:00 2001 From: binbin Deng <108676127+plusbang@users.noreply.github.com> Date: Mon, 17 Jun 2024 09:59:36 +0800 Subject: [PATCH] Update PP inference benchmark script (#11323) --- python/llm/dev/benchmark/all-in-one/README.md | 50 ++++++------ .../llm/dev/benchmark/all-in-one/config.yaml | 1 - .../all-in-one/run-pipeline-parallel-arc.sh | 13 ++++ python/llm/dev/benchmark/all-in-one/run.py | 77 +++++-------------- .../transformers/pipeline_parallel.py | 4 +- 5 files changed, 63 insertions(+), 82 deletions(-) create mode 100644 python/llm/dev/benchmark/all-in-one/run-pipeline-parallel-arc.sh diff --git a/python/llm/dev/benchmark/all-in-one/README.md b/python/llm/dev/benchmark/all-in-one/README.md index b1740eea..78682951 100644 --- a/python/llm/dev/benchmark/all-in-one/README.md +++ b/python/llm/dev/benchmark/all-in-one/README.md @@ -27,7 +27,7 @@ repo_id: - 'meta-llama/Llama-2-7b-chat-hf' # - 'liuhaotian/llava-v1.5-7b' # requires a LLAVA_REPO_DIR env variables pointing to the llava dir; added only for gpu win related test_api now local_model_hub: 'path to your local model hub' -warm_up: 1 +warm_up: 1 # must set >=2 when run "pipeline_parallel_gpu" test_api num_trials: 3 num_beams: 1 # default to greedy search low_bit: 'sym_int4' # default to use 'sym_int4' (i.e. symmetric int4) @@ -36,29 +36,33 @@ in_out_pairs: - '32-32' - '1024-128' test_api: - - "transformer_int4_gpu" # on Intel GPU - # - "transformer_int4_fp16_gpu" # on Intel GPU, use fp16 for non-linear layer - # - "ipex_fp16_gpu" # on Intel GPU - # - "bigdl_fp16_gpu" # on Intel GPU - # - "optimize_model_gpu" # on Intel GPU - # - "transformer_int4_gpu_win" # on Intel GPU for Windows - # - "transformer_int4_fp16_gpu_win" # on Intel GPU for Windows, use fp16 for non-linear layer - # - "transformer_int4_loadlowbit_gpu_win" # on Intel GPU for Windows using load_low_bit API. Please make sure you have used the save.py to save the converted low bit model - # - "deepspeed_optimize_model_gpu" # deepspeed autotp on Intel GPU - # - "speculative_gpu" - # - "transformer_int4" - # - "native_int4" - # - "optimize_model" - # - "pytorch_autocast_bf16" - # - "transformer_autocast_bf16" - # - "bigdl_ipex_bf16" - # - "bigdl_ipex_int4" - # - "bigdl_ipex_int8" - # - "speculative_cpu" - # - "deepspeed_transformer_int4_cpu" # on Intel SPR Server + - "transformer_int4_fp16_gpu" # on Intel GPU, transformer-like API, (qtype=int4), (dtype=fp16) + # - "transformer_int4_fp16_gpu_win" # on Intel GPU for Windows, transformer-like API, (qtype=int4), (dtype=fp16) + # - "transformer_int4_gpu" # on Intel GPU, transformer-like API, (qtype=int4), (dtype=fp32) + # - "transformer_int4_gpu_win" # on Intel GPU for Windows, transformer-like API, (qtype=int4), (dtype=fp32) + # - "transformer_int4_loadlowbit_gpu_win" # on Intel GPU for Windows, transformer-like API, (qtype=int4), use load_low_bit API. Please make sure you have used the save.py to save the converted low bit model + # - "bigdl_fp16_gpu" # on Intel GPU, use ipex-llm transformers API, (dtype=fp16), (qtype=fp16) + # - "optimize_model_gpu" # on Intel GPU, can optimize any pytorch models include transformer model + # - "deepspeed_optimize_model_gpu" # on Intel GPU, deepspeed autotp inference + # - "pipeline_parallel_gpu" # on Intel GPU, pipeline parallel inference + # - "speculative_gpu" # on Intel GPU, inference with self-speculative decoding + # - "transformer_int4" # on Intel CPU, transformer-like API, (qtype=int4) + # - "native_int4" # on Intel CPU + # - "optimize_model" # on Intel CPU, can optimize any pytorch models include transformer model + # - "pytorch_autocast_bf16" # on Intel CPU + # - "transformer_autocast_bf16" # on Intel CPU + # - "bigdl_ipex_bf16" # on Intel CPU, (qtype=bf16) + # - "bigdl_ipex_int4" # on Intel CPU, (qtype=int4) + # - "bigdl_ipex_int8" # on Intel CPU, (qtype=int8) + # - "speculative_cpu" # on Intel CPU, inference with self-speculative decoding + # - "deepspeed_transformer_int4_cpu" # on Intel CPU, deepspeed autotp inference + # - "transformer_int4_fp16_lookahead_gpu" # on Intel GPU, transformer-like API, with lookahead, (qtype=int4), (dtype=fp16) cpu_embedding: False # whether put embedding to CPU -streaming: False # whether output in streaming way (only avaiable now for gpu win related test_api) - +streaming: False # whether output in streaming way (only available now for gpu win related test_api) +use_fp16_torch_dtype: True # whether use fp16 for non-linear layer (only available now for "pipeline_parallel_gpu" test_api) +lookahead: 3 +max_matching_ngram_size: 2 +task: 'continuation' # when test_api is "transformer_int4_fp16_lookahead_gpu", task could be 'QA', 'continuation' or 'summarize' ``` diff --git a/python/llm/dev/benchmark/all-in-one/config.yaml b/python/llm/dev/benchmark/all-in-one/config.yaml index 2ff2b3f2..8d46cbf6 100644 --- a/python/llm/dev/benchmark/all-in-one/config.yaml +++ b/python/llm/dev/benchmark/all-in-one/config.yaml @@ -36,7 +36,6 @@ test_api: cpu_embedding: False # whether put embedding to CPU streaming: False # whether output in streaming way (only available now for gpu win related test_api) use_fp16_torch_dtype: True # whether use fp16 for non-linear layer (only available now for "pipeline_parallel_gpu" test_api) -n_gpu: 2 # number of GPUs to use (only available now for "pipeline_parallel_gpu" test_api) lookahead: 3 max_matching_ngram_size: 2 task: 'continuation' # when test_api is "transformer_int4_fp16_lookahead_gpu", task could be 'QA', 'continuation' or 'summarize' diff --git a/python/llm/dev/benchmark/all-in-one/run-pipeline-parallel-arc.sh b/python/llm/dev/benchmark/all-in-one/run-pipeline-parallel-arc.sh new file mode 100644 index 00000000..236b8d0b --- /dev/null +++ b/python/llm/dev/benchmark/all-in-one/run-pipeline-parallel-arc.sh @@ -0,0 +1,13 @@ +source /opt/intel/oneapi/setvars.sh +export MASTER_ADDR=127.0.0.1 +export MASTER_PORT=8080 +export FI_PROVIDER=tcp +export USE_XETLA=OFF +export OMP_NUM_THREADS=6 +if [[ $KERNEL_VERSION != *"6.5"* ]]; then + export SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=1 +fi +export TORCH_LLM_ALLREDUCE=0 + +NUM_GPUS=2 # number of used GPU +CCL_ZE_IPC_EXCHANGE=sockets torchrun --standalone --nnodes=1 --nproc-per-node $NUM_GPUS run.py diff --git a/python/llm/dev/benchmark/all-in-one/run.py b/python/llm/dev/benchmark/all-in-one/run.py index af254d1a..aaa7dffb 100644 --- a/python/llm/dev/benchmark/all-in-one/run.py +++ b/python/llm/dev/benchmark/all-in-one/run.py @@ -106,7 +106,7 @@ def preprocess_prompt(tokenizer, in_len, task): input_ids = tokenizer.encode(input_str, return_tensors="pt") return input_ids -def run_model(repo_id, test_api, in_out_pairs, local_model_hub=None, warm_up=1, num_trials=3, num_beams=1, low_bit='sym_int4', cpu_embedding=False, batch_size=1, streaming=False, use_fp16_torch_dtype=False, n_gpu=2): +def run_model(repo_id, test_api, in_out_pairs, local_model_hub=None, warm_up=1, num_trials=3, num_beams=1, low_bit='sym_int4', cpu_embedding=False, batch_size=1, streaming=False, use_fp16_torch_dtype=False): # TODO: make a parameter result= {} if test_api == 'transformer_int4': @@ -152,7 +152,7 @@ def run_model(repo_id, test_api, in_out_pairs, local_model_hub=None, warm_up=1, elif test_api == 'speculative_gpu': result = run_speculative_gpu(repo_id, local_model_hub, in_out_pairs, warm_up, num_trials, num_beams, batch_size) elif test_api == 'pipeline_parallel_gpu': - result = run_pipeline_parallel_gpu(repo_id, local_model_hub, in_out_pairs, warm_up, num_trials, num_beams, low_bit, batch_size, cpu_embedding, fp16=use_fp16_torch_dtype, n_gpu=n_gpu) + result = run_pipeline_parallel_gpu(repo_id, local_model_hub, in_out_pairs, warm_up, num_trials, num_beams, low_bit, batch_size, cpu_embedding, fp16=use_fp16_torch_dtype) elif test_api == "transformer_int4_fp16_lookahead_gpu": result = run_transformer_int4_gpu(repo_id, local_model_hub, in_out_pairs, warm_up, num_trials, num_beams, low_bit, batch_size, cpu_embedding, fp16=True, lookahead=True) else: @@ -1747,41 +1747,30 @@ def run_pipeline_parallel_gpu(repo_id, low_bit, batch_size, cpu_embedding, - fp16=False, - n_gpu=2): - from ipex_llm.transformers import AutoModel, AutoModelForCausalLM + fp16=False): + from ipex_llm.transformers import AutoModel, AutoModelForCausalLM, init_pipeline_parallel from transformers import AutoTokenizer, GPTJForCausalLM, LlamaTokenizer + init_pipeline_parallel() model_path = get_model_path(repo_id, local_model_hub) + pipeline_parallel_stages = torch.distributed.get_world_size() # Load model in 4 bit, # which convert the relevant layers in the model into INT4 format st = time.perf_counter() origin_repo_id = repo_id.replace("-4bit", "") if origin_repo_id in CHATGLM_IDS: - if "4bit" in repo_id: - model = AutoModel.load_low_bit(model_path, optimize_model=True, - trust_remote_code=True, use_cache=True, cpu_embedding=cpu_embedding).eval() - else: - model = AutoModel.from_pretrained(model_path, load_in_low_bit=low_bit, optimize_model=True, - trust_remote_code=True, use_cache=True).eval() - tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, cpu_embedding=cpu_embedding) + model = AutoModel.from_pretrained(model_path, load_in_low_bit=low_bit, optimize_model=True, + trust_remote_code=True, use_cache=True, cpu_embedding=cpu_embedding, + pipeline_parallel_stages=pipeline_parallel_stages).eval() + tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) elif origin_repo_id in LLAMA_IDS: - model = AutoModelForCausalLM.from_pretrained(model_path, load_in_low_bit=low_bit, trust_remote_code=True, - use_cache=True, cpu_embedding=cpu_embedding).eval() + model = AutoModelForCausalLM.from_pretrained(model_path, load_in_low_bit=low_bit, optimize_model=True, + trust_remote_code=True, use_cache=True, cpu_embedding=cpu_embedding, + pipeline_parallel_stages=pipeline_parallel_stages).eval() tokenizer = LlamaTokenizer.from_pretrained(model_path, trust_remote_code=True) else: - if "4bit" in repo_id: - model = AutoModelForCausalLM.load_low_bit(model_path, optimize_model=True, - trust_remote_code=True, use_cache=True, cpu_embedding=cpu_embedding).eval() - else: - if 'starcoder' in repo_id: - # Load starcoder-15.5b model in bf16 format to avoid CPU OOM. - model = AutoModelForCausalLM.from_pretrained(model_path, optimize_model=True, load_in_low_bit=low_bit, - trust_remote_code=True, use_cache=True, cpu_embedding=cpu_embedding, torch_dtype=torch.bfloat16).eval() - # Convert the low-bit model back to fp32 for performance considerations. - model = model.float() - else: - model = AutoModelForCausalLM.from_pretrained(model_path, optimize_model=True, load_in_low_bit=low_bit, - trust_remote_code=True, use_cache=True, cpu_embedding=cpu_embedding).eval() + model = AutoModelForCausalLM.from_pretrained(model_path, load_in_low_bit=low_bit, optimize_model=True, + trust_remote_code=True, use_cache=True, cpu_embedding=cpu_embedding, + pipeline_parallel_stages=pipeline_parallel_stages).eval() tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) if fp16: @@ -1792,29 +1781,8 @@ def run_pipeline_parallel_gpu(repo_id, load_time = end - st print(">> loading of model costs {}s and {}GB".format(load_time, torch.xpu.memory.memory_reserved()/(1024**3))) - model_layers = ['model.embed_tokens'] - for i in range(model.config.num_hidden_layers): - model_layers.append(f'model.layers.{i}') - model_layers = model_layers + ['model.norm', 'lm_head'] - - device_map = {} - split_len = len(model_layers) // n_gpu - for i in range(n_gpu): - device_map.update({key: f'xpu:{i}' for key in model_layers[split_len * i: split_len * (i + 1)]}) - if i == n_gpu - 1: - device_map.update({key: f'xpu:{i}' for key in model_layers[split_len * (i + 1): ]}) - print(f">> device map: {device_map}") - - from accelerate import dispatch_model - model = dispatch_model( - model, - device_map=device_map, - offload_dir=None, - skip_keys=["past_key_value", "past_key_values"], - ) - - model = BenchmarkWrapper(model) result = {} + local_rank = torch.distributed.get_rank() with torch.inference_mode(): for in_out in in_out_pairs: in_out_len = in_out.split("-") @@ -1833,7 +1801,7 @@ def run_pipeline_parallel_gpu(repo_id, input_ids = input_ids[:, :in_len] true_str = tokenizer.batch_decode(input_ids)[0] input_list = [true_str] * batch_size - input_ids = tokenizer(input_list, return_tensors="pt").input_ids.to('xpu:0') + input_ids = tokenizer(input_list, return_tensors="pt").input_ids.to(f'xpu:{local_rank}') actual_in_len = input_ids.shape[1] result[in_out] = [] for i in range(num_trials + warm_up): @@ -1849,8 +1817,8 @@ def run_pipeline_parallel_gpu(repo_id, print(output[0]) torch.xpu.empty_cache() if i >= warm_up: - result[in_out].append([model.first_cost, model.rest_cost_mean, model.encoder_time, - actual_in_len, actual_out_len, load_time, model.peak_memory, fp16]) + result[in_out].append([model.first_token_time, model.rest_cost_mean, 0, + actual_in_len, actual_out_len, load_time,]) del model torch.xpu.empty_cache() return result @@ -1865,13 +1833,10 @@ if __name__ == '__main__': excludes = conf['exclude'] streaming = False use_fp16_torch_dtype = False - n_gpu = 2 if 'streaming' in conf: streaming = conf['streaming'] if 'use_fp16_torch_dtype' in conf: use_fp16_torch_dtype = conf['use_fp16_torch_dtype'] - if 'n_gpu' in conf: - n_gpu = conf['n_gpu'] import pandas as pd for api in conf.test_api: @@ -1891,7 +1856,7 @@ if __name__ == '__main__': if model_id_input in excludes or model_id_input_batch_size in excludes: in_out_pairs.remove(in_out) run_model(model, api, in_out_pairs, conf['local_model_hub'], conf['warm_up'], conf['num_trials'], conf['num_beams'], - conf['low_bit'], conf['cpu_embedding'], batch_size, streaming, use_fp16_torch_dtype, n_gpu) + conf['low_bit'], conf['cpu_embedding'], batch_size, streaming, use_fp16_torch_dtype) df = pd.DataFrame(results, columns=['model', '1st token avg latency (ms)', '2+ avg latency (ms/token)', 'encoder time (ms)', 'input/output tokens', 'batch_size', 'actual input/output tokens', 'num_beams', 'low_bit', 'cpu_embedding', 'model loading time (s)', 'peak mem (GB)', 'streaming', 'use_fp16_torch_dtype']) diff --git a/python/llm/src/ipex_llm/transformers/pipeline_parallel.py b/python/llm/src/ipex_llm/transformers/pipeline_parallel.py index 9b7d0fa3..42e0c631 100644 --- a/python/llm/src/ipex_llm/transformers/pipeline_parallel.py +++ b/python/llm/src/ipex_llm/transformers/pipeline_parallel.py @@ -172,7 +172,7 @@ def pipeline_parallel_generate(self, past_key_values=_past_key_values, use_cache=True) else: inputs_embeds = torch.empty(_input_ids.shape + (self.config.hidden_size,), - device=f'xpu:{local_rank}', dtype=torch.float32) + device=f'xpu:{local_rank}', dtype=self.dtype) dist.recv(inputs_embeds, src=pre_rank) outputs = self(input_ids=None, inputs_embeds=inputs_embeds, past_key_values=_past_key_values, use_cache=True) @@ -182,7 +182,7 @@ def pipeline_parallel_generate(self, next_ids = torch.argmax(logits[:, -1:, :], dim=-1) dist.broadcast(next_ids, src=local_rank) else: - dist.send(outputs[0], dst=next_rank) + dist.send(outputs[0].to(self.dtype), dst=next_rank) next_ids = torch.empty((bs, 1), device=f'xpu:{local_rank}', dtype=torch.int64) dist.broadcast(next_ids, src=self.pipeline_parallel_stages - 1)