[LLM] all-on-one update: memory optimize and streaming output (#10302)
* Memory saving for continous in-out pair run and add support for streaming output on MTL iGPU * Small fix * Small fix * Add things back
This commit is contained in:
		
							parent
							
								
									367b1db4f7
								
							
						
					
					
						commit
						27d9a14989
					
				
					 3 changed files with 45 additions and 18 deletions
				
			
		| 
						 | 
					@ -50,6 +50,7 @@ test_api:
 | 
				
			||||||
  # - "transformer_int4_gpu_win" # on Intel GPU for Windows
 | 
					  # - "transformer_int4_gpu_win" # on Intel GPU for Windows
 | 
				
			||||||
  # - "transformer_int4_loadlowbit_gpu_win" # on Intel GPU for Windows using load_low_bit API. Please make sure you have used the save.py to save the converted low bit model
 | 
					  # - "transformer_int4_loadlowbit_gpu_win" # on Intel GPU for Windows using load_low_bit API. Please make sure you have used the save.py to save the converted low bit model
 | 
				
			||||||
cpu_embedding: False # whether put embedding to CPU (only avaiable now for gpu win related test_api)
 | 
					cpu_embedding: False # whether put embedding to CPU (only avaiable now for gpu win related test_api)
 | 
				
			||||||
 | 
					streaming: False # whether output in streaming way (only avaiable now for gpu win related test_api)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
```
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -23,5 +23,7 @@ test_api:
 | 
				
			||||||
  # - "transformer_int4_gpu"  # on Intel GPU
 | 
					  # - "transformer_int4_gpu"  # on Intel GPU
 | 
				
			||||||
  # - "optimize_model_gpu"  # on Intel GPU
 | 
					  # - "optimize_model_gpu"  # on Intel GPU
 | 
				
			||||||
  # - "deepspeed_transformer_int4_cpu" # on Intel SPR Server
 | 
					  # - "deepspeed_transformer_int4_cpu" # on Intel SPR Server
 | 
				
			||||||
  # - "transformer_int4_gpu_win" # on Intel GPU for Windows (catch GPU peak memory)
 | 
					  # - "transformer_int4_gpu_win" # on Intel GPU for Windows
 | 
				
			||||||
 | 
					  # - "transformer_int4_loadlowbit_gpu_win" # on Intel GPU for Windows using load_low_bit API. Please make sure you have used the save.py to save the converted low bit model
 | 
				
			||||||
cpu_embedding: False # whether put embedding to CPU (only avaiable now for gpu win related test_api)
 | 
					cpu_embedding: False # whether put embedding to CPU (only avaiable now for gpu win related test_api)
 | 
				
			||||||
 | 
					streaming: False # whether output in streaming way (only avaiable now for gpu win related test_api)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -63,7 +63,7 @@ def run_model_in_thread(model, in_out, tokenizer, result, warm_up, num_beams, in
 | 
				
			||||||
            result[in_out].append([model.first_cost, model.rest_cost_mean, model.encoder_time,
 | 
					            result[in_out].append([model.first_cost, model.rest_cost_mean, model.encoder_time,
 | 
				
			||||||
                                   actual_in_len, actual_out_len, load_time, model.peak_memory])
 | 
					                                   actual_in_len, actual_out_len, load_time, model.peak_memory])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def run_model(repo_id, test_api, in_out_pairs, local_model_hub=None, warm_up=1, num_trials=3, num_beams=1, low_bit='sym_int4', cpu_embedding=False, batch_size=1):
 | 
					def run_model(repo_id, test_api, in_out_pairs, local_model_hub=None, warm_up=1, num_trials=3, num_beams=1, low_bit='sym_int4', cpu_embedding=False, batch_size=1, streaming=False):
 | 
				
			||||||
    # TODO: make a parameter
 | 
					    # TODO: make a parameter
 | 
				
			||||||
    result= {}
 | 
					    result= {}
 | 
				
			||||||
    if test_api == 'transformer_int4':
 | 
					    if test_api == 'transformer_int4':
 | 
				
			||||||
| 
						 | 
					@ -85,11 +85,11 @@ def run_model(repo_id, test_api, in_out_pairs, local_model_hub=None, warm_up=1,
 | 
				
			||||||
    elif test_api == 'deepspeed_transformer_int4_cpu':
 | 
					    elif test_api == 'deepspeed_transformer_int4_cpu':
 | 
				
			||||||
        result = run_deepspeed_transformer_int4_cpu(repo_id, local_model_hub, in_out_pairs, warm_up, num_trials, num_beams, low_bit, batch_size)
 | 
					        result = run_deepspeed_transformer_int4_cpu(repo_id, local_model_hub, in_out_pairs, warm_up, num_trials, num_beams, low_bit, batch_size)
 | 
				
			||||||
    elif test_api == 'transformer_int4_gpu_win':
 | 
					    elif test_api == 'transformer_int4_gpu_win':
 | 
				
			||||||
        result = run_transformer_int4_gpu_win(repo_id, local_model_hub, in_out_pairs, warm_up, num_trials, num_beams, low_bit, cpu_embedding, batch_size)
 | 
					        result = run_transformer_int4_gpu_win(repo_id, local_model_hub, in_out_pairs, warm_up, num_trials, num_beams, low_bit, cpu_embedding, batch_size, streaming)
 | 
				
			||||||
    elif test_api == 'transformer_int4_loadlowbit_gpu_win':
 | 
					    elif test_api == 'transformer_int4_loadlowbit_gpu_win':
 | 
				
			||||||
        # drop the results of the first time for better performance
 | 
					        # drop the results of the first time for better performance
 | 
				
			||||||
        run_transformer_int4_loadlowbit_gpu_win(repo_id, local_model_hub, in_out_pairs, warm_up, num_trials, num_beams, low_bit, cpu_embedding, batch_size)
 | 
					        run_transformer_int4_loadlowbit_gpu_win(repo_id, local_model_hub, in_out_pairs, warm_up, num_trials, num_beams, low_bit, cpu_embedding, batch_size, streaming)
 | 
				
			||||||
        result = run_transformer_int4_loadlowbit_gpu_win(repo_id, local_model_hub, in_out_pairs, warm_up, num_trials, num_beams, low_bit, cpu_embedding, batch_size)
 | 
					        result = run_transformer_int4_loadlowbit_gpu_win(repo_id, local_model_hub, in_out_pairs, warm_up, num_trials, num_beams, low_bit, cpu_embedding, batch_size, streaming)
 | 
				
			||||||
    elif test_api == 'transformer_autocast_bf16':
 | 
					    elif test_api == 'transformer_autocast_bf16':
 | 
				
			||||||
        result = run_transformer_autocast_bf16(repo_id, local_model_hub, in_out_pairs, warm_up, num_trials, num_beams, batch_size)
 | 
					        result = run_transformer_autocast_bf16(repo_id, local_model_hub, in_out_pairs, warm_up, num_trials, num_beams, batch_size)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -107,7 +107,9 @@ def run_model(repo_id, test_api, in_out_pairs, local_model_hub=None, warm_up=1,
 | 
				
			||||||
                            low_bit,
 | 
					                            low_bit,
 | 
				
			||||||
                            cpu_embedding if 'win' in test_api else 'N/A',
 | 
					                            cpu_embedding if 'win' in test_api else 'N/A',
 | 
				
			||||||
                            round(result[in_out_pair][-1][5], 2),
 | 
					                            round(result[in_out_pair][-1][5], 2),
 | 
				
			||||||
                            result[in_out_pair][-1][6] if 'int4_gpu' in test_api or 'int4_loadlowbit_gpu' in test_api else 'N/A']) # currently only peak mem for transformer_int4_gpu is caught here
 | 
					                            result[in_out_pair][-1][6] if 'int4_gpu' in test_api or 'int4_loadlowbit_gpu' in test_api else 'N/A',
 | 
				
			||||||
 | 
					                            streaming if 'win' in test_api else 'N/A'],
 | 
				
			||||||
 | 
					                            ) 
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def get_model_path(repo_id, local_model_hub):
 | 
					def get_model_path(repo_id, local_model_hub):
 | 
				
			||||||
| 
						 | 
					@ -800,9 +802,10 @@ def run_transformer_int4_gpu_win(repo_id,
 | 
				
			||||||
                                 num_beams,
 | 
					                                 num_beams,
 | 
				
			||||||
                                 low_bit,
 | 
					                                 low_bit,
 | 
				
			||||||
                                 cpu_embedding,
 | 
					                                 cpu_embedding,
 | 
				
			||||||
                                 batch_size):
 | 
					                                 batch_size,
 | 
				
			||||||
 | 
					                                 streaming):
 | 
				
			||||||
    from bigdl.llm.transformers import AutoModel, AutoModelForCausalLM
 | 
					    from bigdl.llm.transformers import AutoModel, AutoModelForCausalLM
 | 
				
			||||||
    from transformers import AutoTokenizer, GPTJForCausalLM, LlamaTokenizer
 | 
					    from transformers import AutoTokenizer, GPTJForCausalLM, LlamaTokenizer, TextStreamer
 | 
				
			||||||
    import intel_extension_for_pytorch as ipex
 | 
					    import intel_extension_for_pytorch as ipex
 | 
				
			||||||
    model_path = get_model_path(repo_id, local_model_hub)
 | 
					    model_path = get_model_path(repo_id, local_model_hub)
 | 
				
			||||||
    # Load model in 4 bit,
 | 
					    # Load model in 4 bit,
 | 
				
			||||||
| 
						 | 
					@ -839,6 +842,7 @@ def run_transformer_int4_gpu_win(repo_id,
 | 
				
			||||||
    print(">> loading of model costs {}s and {}GB".format(load_time, torch.xpu.memory.memory_reserved()/(1024**3)))
 | 
					    print(">> loading of model costs {}s and {}GB".format(load_time, torch.xpu.memory.memory_reserved()/(1024**3)))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    model = BenchmarkWrapper(model)
 | 
					    model = BenchmarkWrapper(model)
 | 
				
			||||||
 | 
					    streamer = TextStreamer(tokenizer, skip_prompt=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    result = {}
 | 
					    result = {}
 | 
				
			||||||
    with torch.inference_mode():
 | 
					    with torch.inference_mode():
 | 
				
			||||||
| 
						 | 
					@ -865,14 +869,19 @@ def run_transformer_int4_gpu_win(repo_id,
 | 
				
			||||||
                result[in_out] = []
 | 
					                result[in_out] = []
 | 
				
			||||||
                for i in range(num_trials + warm_up):
 | 
					                for i in range(num_trials + warm_up):
 | 
				
			||||||
                    st = time.perf_counter()
 | 
					                    st = time.perf_counter()
 | 
				
			||||||
                    output_ids = model.generate(input_ids, do_sample=False, max_new_tokens=out_len,
 | 
					                    if streaming:
 | 
				
			||||||
                                                num_beams=num_beams)
 | 
					                        output_ids = model.generate(input_ids, do_sample=False, max_new_tokens=out_len,
 | 
				
			||||||
 | 
					                                                    num_beams=num_beams, streamer=streamer)
 | 
				
			||||||
 | 
					                    else:
 | 
				
			||||||
 | 
					                        output_ids = model.generate(input_ids, do_sample=False, max_new_tokens=out_len,
 | 
				
			||||||
 | 
					                                                    num_beams=num_beams)
 | 
				
			||||||
                    torch.xpu.synchronize()
 | 
					                    torch.xpu.synchronize()
 | 
				
			||||||
                    end = time.perf_counter()
 | 
					                    end = time.perf_counter()
 | 
				
			||||||
                    output_ids = output_ids.cpu()
 | 
					                    output_ids = output_ids.cpu()
 | 
				
			||||||
                    print("model generate cost: " + str(end - st))
 | 
					                    print("model generate cost: " + str(end - st))
 | 
				
			||||||
                    output = tokenizer.batch_decode(output_ids)
 | 
					                    output = tokenizer.batch_decode(output_ids)
 | 
				
			||||||
                    print(output[0])
 | 
					                    if not streaming:
 | 
				
			||||||
 | 
					                        print(output[0])
 | 
				
			||||||
                    actual_out_len = output_ids.shape[1] - actual_in_len
 | 
					                    actual_out_len = output_ids.shape[1] - actual_in_len
 | 
				
			||||||
                    if i >= warm_up:
 | 
					                    if i >= warm_up:
 | 
				
			||||||
                        result[in_out].append([model.first_cost, model.rest_cost_mean, model.encoder_time,
 | 
					                        result[in_out].append([model.first_cost, model.rest_cost_mean, model.encoder_time,
 | 
				
			||||||
| 
						 | 
					@ -881,6 +890,8 @@ def run_transformer_int4_gpu_win(repo_id,
 | 
				
			||||||
            except RuntimeError:
 | 
					            except RuntimeError:
 | 
				
			||||||
                traceback.print_exc()
 | 
					                traceback.print_exc()
 | 
				
			||||||
                pass
 | 
					                pass
 | 
				
			||||||
 | 
					            torch.xpu.synchronize()
 | 
				
			||||||
 | 
					            torch.xpu.empty_cache()
 | 
				
			||||||
    model.to('cpu')
 | 
					    model.to('cpu')
 | 
				
			||||||
    torch.xpu.synchronize()
 | 
					    torch.xpu.synchronize()
 | 
				
			||||||
    torch.xpu.empty_cache()
 | 
					    torch.xpu.empty_cache()
 | 
				
			||||||
| 
						 | 
					@ -897,9 +908,10 @@ def run_transformer_int4_loadlowbit_gpu_win(repo_id,
 | 
				
			||||||
                                            num_beams,
 | 
					                                            num_beams,
 | 
				
			||||||
                                            low_bit,
 | 
					                                            low_bit,
 | 
				
			||||||
                                            cpu_embedding,
 | 
					                                            cpu_embedding,
 | 
				
			||||||
                                            batch_size):
 | 
					                                            batch_size,
 | 
				
			||||||
 | 
					                                            streaming):
 | 
				
			||||||
    from bigdl.llm.transformers import AutoModel, AutoModelForCausalLM
 | 
					    from bigdl.llm.transformers import AutoModel, AutoModelForCausalLM
 | 
				
			||||||
    from transformers import AutoTokenizer, GPTJForCausalLM, LlamaTokenizer
 | 
					    from transformers import AutoTokenizer, GPTJForCausalLM, LlamaTokenizer, TextStreamer
 | 
				
			||||||
    import intel_extension_for_pytorch as ipex
 | 
					    import intel_extension_for_pytorch as ipex
 | 
				
			||||||
    model_path = get_model_path(repo_id, local_model_hub)
 | 
					    model_path = get_model_path(repo_id, local_model_hub)
 | 
				
			||||||
    # Load BigDL-LLM optimized low bit model
 | 
					    # Load BigDL-LLM optimized low bit model
 | 
				
			||||||
| 
						 | 
					@ -935,6 +947,7 @@ def run_transformer_int4_loadlowbit_gpu_win(repo_id,
 | 
				
			||||||
    print(">> loading of model costs {}s and {}GB".format(load_time, torch.xpu.memory.memory_reserved()/(1024**3)))
 | 
					    print(">> loading of model costs {}s and {}GB".format(load_time, torch.xpu.memory.memory_reserved()/(1024**3)))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    model = BenchmarkWrapper(model)
 | 
					    model = BenchmarkWrapper(model)
 | 
				
			||||||
 | 
					    streamer = TextStreamer(tokenizer, skip_prompt=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    result = {}
 | 
					    result = {}
 | 
				
			||||||
    with torch.inference_mode():
 | 
					    with torch.inference_mode():
 | 
				
			||||||
| 
						 | 
					@ -961,14 +974,19 @@ def run_transformer_int4_loadlowbit_gpu_win(repo_id,
 | 
				
			||||||
                result[in_out] = []
 | 
					                result[in_out] = []
 | 
				
			||||||
                for i in range(num_trials + warm_up):
 | 
					                for i in range(num_trials + warm_up):
 | 
				
			||||||
                    st = time.perf_counter()
 | 
					                    st = time.perf_counter()
 | 
				
			||||||
                    output_ids = model.generate(input_ids, do_sample=False, max_new_tokens=out_len,
 | 
					                    if streaming:
 | 
				
			||||||
                                                num_beams=num_beams)
 | 
					                        output_ids = model.generate(input_ids, do_sample=False, max_new_tokens=out_len,
 | 
				
			||||||
 | 
					                                                    num_beams=num_beams, streamer=streamer)
 | 
				
			||||||
 | 
					                    else:
 | 
				
			||||||
 | 
					                        output_ids = model.generate(input_ids, do_sample=False, max_new_tokens=out_len,
 | 
				
			||||||
 | 
					                                                    num_beams=num_beams)
 | 
				
			||||||
                    torch.xpu.synchronize()
 | 
					                    torch.xpu.synchronize()
 | 
				
			||||||
                    end = time.perf_counter()
 | 
					                    end = time.perf_counter()
 | 
				
			||||||
                    output_ids = output_ids.cpu()
 | 
					                    output_ids = output_ids.cpu()
 | 
				
			||||||
                    print("model generate cost: " + str(end - st))
 | 
					                    print("model generate cost: " + str(end - st))
 | 
				
			||||||
                    output = tokenizer.batch_decode(output_ids)
 | 
					                    output = tokenizer.batch_decode(output_ids)
 | 
				
			||||||
                    print(output[0])
 | 
					                    if not streaming:
 | 
				
			||||||
 | 
					                        print(output[0])
 | 
				
			||||||
                    actual_out_len = output_ids.shape[1] - actual_in_len
 | 
					                    actual_out_len = output_ids.shape[1] - actual_in_len
 | 
				
			||||||
                    if i >= warm_up:
 | 
					                    if i >= warm_up:
 | 
				
			||||||
                        result[in_out].append([model.first_cost, model.rest_cost_mean, model.encoder_time,
 | 
					                        result[in_out].append([model.first_cost, model.rest_cost_mean, model.encoder_time,
 | 
				
			||||||
| 
						 | 
					@ -977,6 +995,8 @@ def run_transformer_int4_loadlowbit_gpu_win(repo_id,
 | 
				
			||||||
            except RuntimeError:
 | 
					            except RuntimeError:
 | 
				
			||||||
                traceback.print_exc()
 | 
					                traceback.print_exc()
 | 
				
			||||||
                pass
 | 
					                pass
 | 
				
			||||||
 | 
					            torch.xpu.synchronize()
 | 
				
			||||||
 | 
					            torch.xpu.empty_cache()
 | 
				
			||||||
    model.to('cpu')
 | 
					    model.to('cpu')
 | 
				
			||||||
    torch.xpu.synchronize()
 | 
					    torch.xpu.synchronize()
 | 
				
			||||||
    torch.xpu.empty_cache()
 | 
					    torch.xpu.empty_cache()
 | 
				
			||||||
| 
						 | 
					@ -1059,6 +1079,10 @@ if __name__ == '__main__':
 | 
				
			||||||
    today = date.today()
 | 
					    today = date.today()
 | 
				
			||||||
    if 'exclude' in conf:
 | 
					    if 'exclude' in conf:
 | 
				
			||||||
        excludes = conf['exclude']
 | 
					        excludes = conf['exclude']
 | 
				
			||||||
 | 
					    streaming = False
 | 
				
			||||||
 | 
					    if 'streaming' in conf:
 | 
				
			||||||
 | 
					        streaming = conf['streaming']
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    
 | 
					    
 | 
				
			||||||
    import pandas as pd
 | 
					    import pandas as pd
 | 
				
			||||||
    for api in conf.test_api:
 | 
					    for api in conf.test_api:
 | 
				
			||||||
| 
						 | 
					@ -1073,9 +1097,9 @@ if __name__ == '__main__':
 | 
				
			||||||
                    if model_id_input in excludes or model_id_input_batch_size in excludes:
 | 
					                    if model_id_input in excludes or model_id_input_batch_size in excludes:
 | 
				
			||||||
                        in_out_pairs.remove(in_out)
 | 
					                        in_out_pairs.remove(in_out)
 | 
				
			||||||
            run_model(model, api, in_out_pairs, conf['local_model_hub'], conf['warm_up'], conf['num_trials'], conf['num_beams'],
 | 
					            run_model(model, api, in_out_pairs, conf['local_model_hub'], conf['warm_up'], conf['num_trials'], conf['num_beams'],
 | 
				
			||||||
                      conf['low_bit'], conf['cpu_embedding'], conf['batch_size'])
 | 
					                      conf['low_bit'], conf['cpu_embedding'], conf['batch_size'], streaming)
 | 
				
			||||||
        df = pd.DataFrame(results, columns=['model', '1st token avg latency (ms)', '2+ avg latency (ms/token)', 'encoder time (ms)',
 | 
					        df = pd.DataFrame(results, columns=['model', '1st token avg latency (ms)', '2+ avg latency (ms/token)', 'encoder time (ms)',
 | 
				
			||||||
                                            'input/output tokens', 'batch_size', 'actual input/output tokens', 'num_beams', 'low_bit', 'cpu_embedding',
 | 
					                                            'input/output tokens', 'batch_size', 'actual input/output tokens', 'num_beams', 'low_bit', 'cpu_embedding',
 | 
				
			||||||
                                            'model loading time (s)', 'peak mem (GB)'])
 | 
					                                            'model loading time (s)', 'peak mem (GB)', 'streaming'])
 | 
				
			||||||
        df.to_csv(csv_name)
 | 
					        df.to_csv(csv_name)
 | 
				
			||||||
        results = []
 | 
					        results = []
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue