move reserved memory to benchmark_utils.py (#9907)
* move reserved memory to benchmark_utils.py * meet code review
This commit is contained in:
parent
7032a2ad73
commit
610b5226be
2 changed files with 32 additions and 22 deletions
|
|
@ -45,23 +45,22 @@ LLAVA_IDS = ['liuhaotian/llava-v1.5-7b']
|
|||
results = []
|
||||
excludes = []
|
||||
|
||||
def run_model_in_thread(model, in_out, tokenizer, result, warm_up, num_beams, input_ids, out_len, actual_in_len, num_trials, reserved_mem_list=[]):
|
||||
def run_model_in_thread(model, in_out, tokenizer, result, warm_up, num_beams, input_ids, out_len, actual_in_len, num_trials):
|
||||
for i in range(num_trials + warm_up):
|
||||
st = time.perf_counter()
|
||||
output_ids = model.generate(input_ids, do_sample=False, max_new_tokens=out_len,
|
||||
num_beams=num_beams)
|
||||
torch.xpu.synchronize()
|
||||
end = time.perf_counter()
|
||||
reserved_mem_list.append(torch.xpu.memory.memory_reserved()/(1024**3))
|
||||
gpu_peak_mem = max(reserved_mem_list) # always keep the peak gpu mem at current stage
|
||||
output_ids = output_ids.cpu()
|
||||
print("model generate cost: " + str(end - st))
|
||||
output = tokenizer.batch_decode(output_ids)
|
||||
print(output[0])
|
||||
torch.xpu.empty_cache()
|
||||
actual_out_len = output_ids.shape[1] - actual_in_len
|
||||
if i >= warm_up:
|
||||
result[in_out].append([model.first_cost, model.rest_cost_mean, model.encoder_time,
|
||||
actual_in_len, actual_out_len, gpu_peak_mem])
|
||||
actual_in_len, actual_out_len, model.peak_memory])
|
||||
|
||||
def run_model(repo_id, test_api, in_out_pairs, local_model_hub=None, warm_up=1, num_trials=3, num_beams=1, low_bit='sym_int4', cpu_embedding=False):
|
||||
# TODO: make a parameter
|
||||
|
|
@ -360,7 +359,6 @@ def run_transformer_int4_gpu(repo_id,
|
|||
from bigdl.llm.transformers import AutoModel, AutoModelForCausalLM
|
||||
from transformers import AutoTokenizer, GPTJForCausalLM, LlamaTokenizer
|
||||
import intel_extension_for_pytorch as ipex
|
||||
reserved_mem_list = []
|
||||
model_path = get_model_path(repo_id, local_model_hub)
|
||||
# Load model in 4 bit,
|
||||
# which convert the relevant layers in the model into INT4 format
|
||||
|
|
@ -396,8 +394,7 @@ def run_transformer_int4_gpu(repo_id,
|
|||
# For gpt-j model family, this optimization can provide a better performance.
|
||||
model = ipex.optimize(model.eval(), inplace=True)
|
||||
end = time.perf_counter()
|
||||
print(">> loading of model costs {}s".format(end - st))
|
||||
reserved_mem_list.append(torch.xpu.memory.memory_reserved()/(1024**3))
|
||||
print(">> loading of model costs {}s and {}GB".format(end - st, torch.xpu.memory.memory_reserved()/(1024**3)))
|
||||
|
||||
model = BenchmarkWrapper(model)
|
||||
|
||||
|
|
@ -424,7 +421,7 @@ def run_transformer_int4_gpu(repo_id,
|
|||
input_ids = tokenizer.encode(true_str, return_tensors="pt").to('xpu')
|
||||
actual_in_len = input_ids.shape[1]
|
||||
result[in_out] = []
|
||||
thread = threading.Thread(target=run_model_in_thread, args=(model, in_out, tokenizer, result, warm_up, num_beams, input_ids, out_len, actual_in_len, num_trials, reserved_mem_list))
|
||||
thread = threading.Thread(target=run_model_in_thread, args=(model, in_out, tokenizer, result, warm_up, num_beams, input_ids, out_len, actual_in_len, num_trials))
|
||||
thread.start()
|
||||
thread.join()
|
||||
model.to('cpu')
|
||||
|
|
@ -760,7 +757,6 @@ def run_transformer_int4_gpu_win(repo_id,
|
|||
from bigdl.llm.transformers import AutoModel, AutoModelForCausalLM
|
||||
from transformers import AutoTokenizer, GPTJForCausalLM, LlamaTokenizer
|
||||
import intel_extension_for_pytorch as ipex
|
||||
reserved_mem_list = []
|
||||
model_path = get_model_path(repo_id, local_model_hub)
|
||||
# Load model in 4 bit,
|
||||
# which convert the relevant layers in the model into INT4 format
|
||||
|
|
@ -792,8 +788,7 @@ def run_transformer_int4_gpu_win(repo_id,
|
|||
# For gpt-j model family, this optimization can provide a better performance.
|
||||
model = ipex.optimize(model.eval(), inplace=True)
|
||||
end = time.perf_counter()
|
||||
print(">> loading of model costs {}s".format(end - st))
|
||||
reserved_mem_list.append(torch.xpu.memory.memory_reserved()/(1024**3))
|
||||
print(">> loading of model costs {}s and {}GB".format(end - st, torch.xpu.memory.memory_reserved()/(1024**3)))
|
||||
|
||||
model = BenchmarkWrapper(model)
|
||||
|
||||
|
|
@ -825,8 +820,6 @@ def run_transformer_int4_gpu_win(repo_id,
|
|||
num_beams=num_beams)
|
||||
torch.xpu.synchronize()
|
||||
end = time.perf_counter()
|
||||
reserved_mem_list.append(torch.xpu.memory.memory_reserved()/(1024**3))
|
||||
gpu_peak_mem = max(reserved_mem_list) # always keep the peak gpu mem at current stage
|
||||
output_ids = output_ids.cpu()
|
||||
print("model generate cost: " + str(end - st))
|
||||
output = tokenizer.batch_decode(output_ids)
|
||||
|
|
@ -834,7 +827,7 @@ def run_transformer_int4_gpu_win(repo_id,
|
|||
actual_out_len = output_ids.shape[1] - actual_in_len
|
||||
if i >= warm_up:
|
||||
result[in_out].append([model.first_cost, model.rest_cost_mean, model.encoder_time,
|
||||
actual_in_len, actual_out_len, gpu_peak_mem])
|
||||
actual_in_len, actual_out_len, model.peak_memory])
|
||||
# torch.xpu.empty_cache() # this may make first token slower
|
||||
except RuntimeError:
|
||||
traceback.print_exc()
|
||||
|
|
|
|||
|
|
@ -516,6 +516,7 @@ class BenchmarkWrapper:
|
|||
self.encoder_time = 0.0
|
||||
self.first_cost = 0.0
|
||||
self.rest_cost_mean = 0.0
|
||||
self.peak_memory = 0.0
|
||||
print(self.model.__class__)
|
||||
|
||||
def __getattr__(self, attr):
|
||||
|
|
@ -2363,6 +2364,7 @@ class BenchmarkWrapper:
|
|||
|
||||
first_token_time = None
|
||||
last_token_time = []
|
||||
memory_every_token = []
|
||||
while True:
|
||||
st = time.perf_counter()
|
||||
if synced_gpus:
|
||||
|
|
@ -2440,6 +2442,7 @@ class BenchmarkWrapper:
|
|||
|
||||
if self.device.type == "xpu":
|
||||
torch.xpu.synchronize()
|
||||
memory_every_token.append(torch.xpu.memory.memory_reserved() / (1024**3))
|
||||
end = time.perf_counter()
|
||||
if first_token_time is None:
|
||||
first_token_time = end - st
|
||||
|
|
@ -2454,13 +2457,15 @@ class BenchmarkWrapper:
|
|||
break
|
||||
|
||||
if self.do_print:
|
||||
print(f"=========First token cost {first_token_time:.4f} s=========")
|
||||
print(f"=========First token cost {first_token_time:.4f} s and {memory_every_token[0]} GB=========")
|
||||
if len(last_token_time) > 1:
|
||||
self.first_cost = first_token_time
|
||||
self.rest_cost_mean = np.mean(last_token_time)
|
||||
self.peak_memory = np.max(memory_every_token[1:])
|
||||
if self.do_print:
|
||||
print(f"=========Rest tokens cost average {self.rest_cost_mean:.4f} s ({len(last_token_time)}"
|
||||
f" tokens in all)=========")
|
||||
f" tokens in all) and {self.peak_memory} GB=========")
|
||||
print(f"Peak memory for every token: {memory_every_token}")
|
||||
|
||||
if streamer is not None:
|
||||
streamer.end()
|
||||
|
|
@ -2662,6 +2667,7 @@ class BenchmarkWrapper:
|
|||
|
||||
first_token_time = None
|
||||
last_token_time = []
|
||||
memory_every_token = []
|
||||
# auto-regressive generation
|
||||
while True:
|
||||
st = time.perf_counter()
|
||||
|
|
@ -2743,6 +2749,7 @@ class BenchmarkWrapper:
|
|||
|
||||
if self.device.type == "xpu":
|
||||
torch.xpu.synchronize()
|
||||
memory_every_token.append(torch.xpu.memory.memory_reserved() / (1024 ** 3))
|
||||
end = time.perf_counter()
|
||||
if first_token_time is None:
|
||||
first_token_time = end - st
|
||||
|
|
@ -2757,13 +2764,15 @@ class BenchmarkWrapper:
|
|||
break
|
||||
|
||||
if self.do_print:
|
||||
print(f"=========First token cost {first_token_time:.4f} s=========")
|
||||
print(f"=========First token cost {first_token_time:.4f} s and {memory_every_token[0]} GB=========")
|
||||
if len(last_token_time) > 1:
|
||||
self.first_cost = first_token_time
|
||||
self.rest_cost_mean = np.mean(last_token_time)
|
||||
self.peak_memory = np.max(memory_every_token[1:])
|
||||
if self.do_print:
|
||||
print(f"=========Rest tokens cost average {self.rest_cost_mean:.4f} s ({len(last_token_time)}"
|
||||
f" tokens in all)=========")
|
||||
f" tokens in all) and {self.peak_memory} GB=========")
|
||||
print(f"Peak memory for every token: {memory_every_token}")
|
||||
|
||||
if streamer is not None:
|
||||
streamer.end()
|
||||
|
|
@ -2975,6 +2984,7 @@ class BenchmarkWrapper:
|
|||
first_token_time = None
|
||||
last_token_time = []
|
||||
this_peer_finished = False # used by synced_gpus only
|
||||
memory_every_token = []
|
||||
while True:
|
||||
st = time.perf_counter()
|
||||
if synced_gpus:
|
||||
|
|
@ -3072,6 +3082,7 @@ class BenchmarkWrapper:
|
|||
|
||||
if self.device.type == "xpu":
|
||||
torch.xpu.synchronize()
|
||||
memory_every_token.append(torch.xpu.memory.memory_reserved() / (1024 ** 3))
|
||||
end = time.perf_counter()
|
||||
if first_token_time is None:
|
||||
first_token_time = end - st
|
||||
|
|
@ -3096,13 +3107,15 @@ class BenchmarkWrapper:
|
|||
)
|
||||
|
||||
if self.do_print:
|
||||
print(f"=========First token cost {first_token_time:.4f} s=========")
|
||||
print(f"=========First token cost {first_token_time:.4f} s and {memory_every_token[0]} GB=========")
|
||||
if len(last_token_time) > 1:
|
||||
self.first_cost = first_token_time
|
||||
self.rest_cost_mean = np.mean(last_token_time)
|
||||
self.peak_memory = np.max(memory_every_token[1:])
|
||||
if self.do_print:
|
||||
print(f"=========Rest tokens cost average {self.rest_cost_mean:.4f} s ({len(last_token_time)}"
|
||||
f" tokens in all)=========")
|
||||
f" tokens in all) and {self.peak_memory} GB=========")
|
||||
print(f"Peak memory for every token: {memory_every_token}")
|
||||
|
||||
if return_dict_in_generate:
|
||||
if not output_scores:
|
||||
|
|
@ -3322,6 +3335,7 @@ class BenchmarkWrapper:
|
|||
|
||||
first_token_time = None
|
||||
last_token_time = []
|
||||
memory_every_token = []
|
||||
while True:
|
||||
st = time.perf_counter()
|
||||
if synced_gpus:
|
||||
|
|
@ -3432,6 +3446,7 @@ class BenchmarkWrapper:
|
|||
|
||||
if self.device.type == "xpu":
|
||||
torch.xpu.synchronize()
|
||||
memory_every_token.append(torch.xpu.memory.memory_reserved() / (1024 ** 3))
|
||||
end = time.perf_counter()
|
||||
if first_token_time is None:
|
||||
first_token_time = end - st
|
||||
|
|
@ -3450,13 +3465,15 @@ class BenchmarkWrapper:
|
|||
)
|
||||
|
||||
if self.do_print:
|
||||
print(f"=========First token cost {first_token_time:.4f} s=========")
|
||||
print(f"=========First token cost {first_token_time:.4f} s and {memory_every_token[0]} GB=========")
|
||||
if len(last_token_time) > 1:
|
||||
self.first_cost = first_token_time
|
||||
self.rest_cost_mean = np.mean(last_token_time)
|
||||
self.peak_memory = np.max(memory_every_token[1:])
|
||||
if self.do_print:
|
||||
print(f"=========Rest tokens cost average {self.rest_cost_mean:.4f} s ({len(last_token_time)}"
|
||||
f" tokens in all)=========")
|
||||
f" tokens in all) and {self.peak_memory} GB=========")
|
||||
print(f"Peak memory for every token: {memory_every_token}")
|
||||
|
||||
if return_dict_in_generate:
|
||||
if not output_scores:
|
||||
|
|
|
|||
Loading…
Reference in a new issue