Add benchmark script for pipeline parallel inference (#10873)
This commit is contained in:
		
							parent
							
								
									46ba962168
								
							
						
					
					
						commit
						f51bf018eb
					
				
					 2 changed files with 131 additions and 6 deletions
				
			
		| 
						 | 
				
			
			@ -3,7 +3,7 @@ repo_id:
 | 
			
		|||
  - 'meta-llama/Llama-2-7b-chat-hf'
 | 
			
		||||
  # - 'liuhaotian/llava-v1.5-7b' # requires a LLAVA_REPO_DIR env variables pointing to the llava dir; added only for gpu win related test_api now
 | 
			
		||||
local_model_hub: 'path to your local model hub'
 | 
			
		||||
warm_up: 1
 | 
			
		||||
warm_up: 1 # must set >=2 when run "pipeline_parallel_gpu" test_api
 | 
			
		||||
num_trials: 3
 | 
			
		||||
num_beams: 1 # default to greedy search
 | 
			
		||||
low_bit: 'sym_int4' # default to use 'sym_int4' (i.e. symmetric int4)
 | 
			
		||||
| 
						 | 
				
			
			@ -21,6 +21,7 @@ test_api:
 | 
			
		|||
  # - "transformer_int4_fp16_gpu_win" # on Intel GPU for Windows, use fp16 for non-linear layer
 | 
			
		||||
  # - "transformer_int4_loadlowbit_gpu_win" # on Intel GPU for Windows using load_low_bit API. Please make sure you have used the save.py to save the converted low bit model
 | 
			
		||||
  # - "deepspeed_optimize_model_gpu" # deepspeed autotp on Intel GPU
 | 
			
		||||
  # - "pipeline_parallel_gpu" # pipeline parallel inference on Intel GPU
 | 
			
		||||
  # - "speculative_gpu"
 | 
			
		||||
  # - "transformer_int4"
 | 
			
		||||
  # - "native_int4"
 | 
			
		||||
| 
						 | 
				
			
			@ -34,3 +35,5 @@ test_api:
 | 
			
		|||
  # - "deepspeed_transformer_int4_cpu" # on Intel SPR Server
 | 
			
		||||
cpu_embedding: False # whether put embedding to CPU
 | 
			
		||||
streaming: False # whether output in streaming way (only avaiable now for gpu win related test_api)
 | 
			
		||||
use_fp16_torch_dtype: True # whether use fp16 for non-linear layer (only avaiable now for "pipeline_parallel_gpu" test_api)
 | 
			
		||||
n_gpu: 2 # number of GPUs to use (only avaiable now for "pipeline_parallel_gpu" test_api)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -50,7 +50,7 @@ def run_model_in_thread(model, in_out, tokenizer, result, warm_up, num_beams, in
 | 
			
		|||
    for i in range(num_trials + warm_up):
 | 
			
		||||
        st = time.perf_counter()
 | 
			
		||||
        output_ids = model.generate(input_ids, do_sample=False, max_new_tokens=out_len,
 | 
			
		||||
                                    min_new_tokens=out_len, num_beams=num_beams)
 | 
			
		||||
                                    num_beams=num_beams)
 | 
			
		||||
        torch.xpu.synchronize()
 | 
			
		||||
        end = time.perf_counter()
 | 
			
		||||
        output_ids = output_ids.cpu()
 | 
			
		||||
| 
						 | 
				
			
			@ -63,7 +63,7 @@ def run_model_in_thread(model, in_out, tokenizer, result, warm_up, num_beams, in
 | 
			
		|||
            result[in_out].append([model.first_cost, model.rest_cost_mean, model.encoder_time,
 | 
			
		||||
                                   actual_in_len, actual_out_len, load_time, model.peak_memory])
 | 
			
		||||
 | 
			
		||||
def run_model(repo_id, test_api, in_out_pairs, local_model_hub=None, warm_up=1, num_trials=3, num_beams=1, low_bit='sym_int4', cpu_embedding=False, batch_size=1, streaming=False):
 | 
			
		||||
def run_model(repo_id, test_api, in_out_pairs, local_model_hub=None, warm_up=1, num_trials=3, num_beams=1, low_bit='sym_int4', cpu_embedding=False, batch_size=1, streaming=False, use_fp16_torch_dtype=False, n_gpu=2):
 | 
			
		||||
    # TODO: make a parameter
 | 
			
		||||
    result= {}
 | 
			
		||||
    if test_api == 'transformer_int4':
 | 
			
		||||
| 
						 | 
				
			
			@ -108,6 +108,8 @@ def run_model(repo_id, test_api, in_out_pairs, local_model_hub=None, warm_up=1,
 | 
			
		|||
        result = run_speculative_cpu(repo_id, local_model_hub, in_out_pairs, warm_up, num_trials, num_beams, batch_size)
 | 
			
		||||
    elif test_api == 'speculative_gpu':
 | 
			
		||||
        result = run_speculative_gpu(repo_id, local_model_hub, in_out_pairs, warm_up, num_trials, num_beams, batch_size)
 | 
			
		||||
    elif test_api == 'pipeline_parallel_gpu':
 | 
			
		||||
        result = run_pipeline_parallel_gpu(repo_id, local_model_hub, in_out_pairs, warm_up, num_trials, num_beams, low_bit, batch_size, cpu_embedding, fp16=use_fp16_torch_dtype, n_gpu=n_gpu)
 | 
			
		||||
 | 
			
		||||
    for in_out_pair in in_out_pairs:
 | 
			
		||||
        if result and result[in_out_pair]:
 | 
			
		||||
| 
						 | 
				
			
			@ -124,7 +126,8 @@ def run_model(repo_id, test_api, in_out_pairs, local_model_hub=None, warm_up=1,
 | 
			
		|||
                            cpu_embedding,
 | 
			
		||||
                            round(result[in_out_pair][-1][5], 2),
 | 
			
		||||
                            result[in_out_pair][-1][6] if any(keyword in test_api for keyword in ['int4_gpu', 'int4_fp16_gpu_win', 'int4_loadlowbit_gpu', 'fp16_gpu', 'deepspeed_optimize_model_gpu']) else 'N/A',
 | 
			
		||||
                            streaming if 'win' in test_api else 'N/A'],
 | 
			
		||||
                            streaming if 'win' in test_api else 'N/A',
 | 
			
		||||
                            use_fp16_torch_dtype],
 | 
			
		||||
                            ) 
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -1674,6 +1677,125 @@ def run_speculative_gpu(repo_id,
 | 
			
		|||
    return result
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def run_pipeline_parallel_gpu(repo_id,
 | 
			
		||||
                              local_model_hub,
 | 
			
		||||
                              in_out_pairs,
 | 
			
		||||
                              warm_up,
 | 
			
		||||
                              num_trials,
 | 
			
		||||
                              num_beams,
 | 
			
		||||
                              low_bit,
 | 
			
		||||
                              batch_size,
 | 
			
		||||
                              cpu_embedding,
 | 
			
		||||
                              fp16=False,
 | 
			
		||||
                              n_gpu=2):
 | 
			
		||||
    from ipex_llm.transformers import AutoModel, AutoModelForCausalLM
 | 
			
		||||
    from transformers import AutoTokenizer, GPTJForCausalLM, LlamaTokenizer
 | 
			
		||||
    import intel_extension_for_pytorch as ipex
 | 
			
		||||
    model_path = get_model_path(repo_id, local_model_hub)
 | 
			
		||||
    # Load model in 4 bit,
 | 
			
		||||
    # which convert the relevant layers in the model into INT4 format
 | 
			
		||||
    st = time.perf_counter()
 | 
			
		||||
    origin_repo_id = repo_id.replace("-4bit", "")
 | 
			
		||||
    if origin_repo_id in CHATGLM_IDS:
 | 
			
		||||
        if "4bit" in repo_id:
 | 
			
		||||
            model = AutoModel.load_low_bit(model_path, optimize_model=True,
 | 
			
		||||
                                            trust_remote_code=True, use_cache=True, cpu_embedding=cpu_embedding).eval()  
 | 
			
		||||
        else:
 | 
			
		||||
            model = AutoModel.from_pretrained(model_path, load_in_low_bit=low_bit, optimize_model=True,
 | 
			
		||||
                                            trust_remote_code=True, use_cache=True).eval()
 | 
			
		||||
        tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, cpu_embedding=cpu_embedding)
 | 
			
		||||
    elif origin_repo_id in LLAMA_IDS:
 | 
			
		||||
        model = AutoModelForCausalLM.from_pretrained(model_path, load_in_low_bit=low_bit, trust_remote_code=True,
 | 
			
		||||
                                                     use_cache=True, cpu_embedding=cpu_embedding).eval()
 | 
			
		||||
        tokenizer = LlamaTokenizer.from_pretrained(model_path, trust_remote_code=True)
 | 
			
		||||
    else:
 | 
			
		||||
        if "4bit" in repo_id:
 | 
			
		||||
            model = AutoModelForCausalLM.load_low_bit(model_path, optimize_model=True,
 | 
			
		||||
                                            trust_remote_code=True, use_cache=True, cpu_embedding=cpu_embedding).eval()
 | 
			
		||||
        else:
 | 
			
		||||
            if 'starcoder' in repo_id:
 | 
			
		||||
                # Load starcoder-15.5b model in bf16 format to avoid CPU OOM.
 | 
			
		||||
                model = AutoModelForCausalLM.from_pretrained(model_path, optimize_model=True, load_in_low_bit=low_bit,
 | 
			
		||||
                                                            trust_remote_code=True, use_cache=True, cpu_embedding=cpu_embedding, torch_dtype=torch.bfloat16).eval()
 | 
			
		||||
                # Convert the low-bit model back to fp32 for performance considerations.
 | 
			
		||||
                model = model.float()
 | 
			
		||||
            else:
 | 
			
		||||
                model = AutoModelForCausalLM.from_pretrained(model_path, optimize_model=True, load_in_low_bit=low_bit,
 | 
			
		||||
                                                            trust_remote_code=True, use_cache=True, cpu_embedding=cpu_embedding).eval()
 | 
			
		||||
        tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
 | 
			
		||||
 | 
			
		||||
    if fp16:
 | 
			
		||||
        model = model.half()
 | 
			
		||||
        print("Convert model to half precision")
 | 
			
		||||
 | 
			
		||||
    end = time.perf_counter()
 | 
			
		||||
    load_time = end - st
 | 
			
		||||
    print(">> loading of model costs {}s and {}GB".format(load_time, torch.xpu.memory.memory_reserved()/(1024**3)))
 | 
			
		||||
 | 
			
		||||
    model_layers = ['model.embed_tokens']
 | 
			
		||||
    for i in range(model.config.num_hidden_layers):
 | 
			
		||||
        model_layers.append(f'model.layers.{i}')
 | 
			
		||||
    model_layers = model_layers + ['model.norm', 'lm_head']
 | 
			
		||||
 | 
			
		||||
    device_map = {}
 | 
			
		||||
    split_len = len(model_layers) // n_gpu
 | 
			
		||||
    for i in range(n_gpu):
 | 
			
		||||
        device_map.update({key: f'xpu:{i}' for key in model_layers[split_len * i: split_len * (i + 1)]})
 | 
			
		||||
        if i == n_gpu - 1:
 | 
			
		||||
            device_map.update({key: f'xpu:{i}' for key in model_layers[split_len * (i + 1): ]})
 | 
			
		||||
    print(f">> device map: {device_map}")
 | 
			
		||||
 | 
			
		||||
    from accelerate import dispatch_model
 | 
			
		||||
    model = dispatch_model(
 | 
			
		||||
        model,
 | 
			
		||||
        device_map=device_map,
 | 
			
		||||
        offload_dir=None,
 | 
			
		||||
        skip_keys=["past_key_value", "past_key_values"],
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    model = BenchmarkWrapper(model)
 | 
			
		||||
    result = {}
 | 
			
		||||
    with torch.inference_mode():
 | 
			
		||||
        for in_out in in_out_pairs:
 | 
			
		||||
            in_out_len = in_out.split("-")
 | 
			
		||||
            in_len = int(in_out_len[0])
 | 
			
		||||
            out_len = int(in_out_len[1])
 | 
			
		||||
            # As different tokenizer has different encodings,
 | 
			
		||||
            # in_len.txt maybe shorter than we need,
 | 
			
		||||
            # use much longer context to make sure input length
 | 
			
		||||
            test_length = min(in_len*2, 8192)
 | 
			
		||||
            while test_length not in [32, 256, 1024, 2048, 8192]:
 | 
			
		||||
                test_length = test_length * 2
 | 
			
		||||
            input_str = open(f"prompt/{test_length}.txt", 'r').read()
 | 
			
		||||
            # As different tokenizer has different encodings,
 | 
			
		||||
            # slice the input_ids to ensure the prompt length is required length.
 | 
			
		||||
            input_ids = tokenizer.encode(input_str, return_tensors="pt")
 | 
			
		||||
            input_ids = input_ids[:, :in_len]
 | 
			
		||||
            true_str = tokenizer.batch_decode(input_ids)[0]
 | 
			
		||||
            input_list = [true_str] * batch_size
 | 
			
		||||
            input_ids = tokenizer(input_list, return_tensors="pt").input_ids.to('xpu:0')
 | 
			
		||||
            actual_in_len = input_ids.shape[1]
 | 
			
		||||
            result[in_out] = []
 | 
			
		||||
            for i in range(num_trials + warm_up):
 | 
			
		||||
                st = time.perf_counter()
 | 
			
		||||
                output_ids = model.generate(input_ids, do_sample=False, max_new_tokens=out_len,
 | 
			
		||||
                                            num_beams=num_beams)
 | 
			
		||||
                torch.xpu.synchronize()
 | 
			
		||||
                end = time.perf_counter()
 | 
			
		||||
                output_ids = output_ids.cpu()
 | 
			
		||||
                print("model generate cost: " + str(end - st))
 | 
			
		||||
                output = tokenizer.batch_decode(output_ids)
 | 
			
		||||
                actual_out_len = output_ids.shape[1] - actual_in_len
 | 
			
		||||
                print(output[0])
 | 
			
		||||
                torch.xpu.empty_cache()
 | 
			
		||||
                if i >= warm_up:
 | 
			
		||||
                    result[in_out].append([model.first_cost, model.rest_cost_mean, model.encoder_time,
 | 
			
		||||
                                           actual_in_len, actual_out_len, load_time, model.peak_memory, fp16])
 | 
			
		||||
    del model
 | 
			
		||||
    torch.xpu.empty_cache()
 | 
			
		||||
    return result
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == '__main__':
 | 
			
		||||
    from omegaconf import OmegaConf
 | 
			
		||||
    conf = OmegaConf.load(f'{current_dir}/config.yaml')
 | 
			
		||||
| 
						 | 
				
			
			@ -1698,9 +1820,9 @@ if __name__ == '__main__':
 | 
			
		|||
                    if model_id_input in excludes or model_id_input_batch_size in excludes:
 | 
			
		||||
                        in_out_pairs.remove(in_out)
 | 
			
		||||
            run_model(model, api, in_out_pairs, conf['local_model_hub'], conf['warm_up'], conf['num_trials'], conf['num_beams'],
 | 
			
		||||
                      conf['low_bit'], conf['cpu_embedding'], conf['batch_size'], streaming)
 | 
			
		||||
                      conf['low_bit'], conf['cpu_embedding'], conf['batch_size'], streaming, conf['use_fp16_torch_dtype'], conf['n_gpu'])
 | 
			
		||||
        df = pd.DataFrame(results, columns=['model', '1st token avg latency (ms)', '2+ avg latency (ms/token)', 'encoder time (ms)',
 | 
			
		||||
                                            'input/output tokens', 'batch_size', 'actual input/output tokens', 'num_beams', 'low_bit', 'cpu_embedding',
 | 
			
		||||
                                            'model loading time (s)', 'peak mem (GB)', 'streaming'])
 | 
			
		||||
                                            'model loading time (s)', 'peak mem (GB)', 'streaming', 'use_fp16_torch_dtype'])
 | 
			
		||||
        df.to_csv(csv_name)
 | 
			
		||||
        results = []
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue