LLM: update benchmark_utils.py to handle do_sample=True (#8903)

This commit is contained in:
Ruonan Wang 2023-09-07 14:20:47 +08:00 committed by GitHub
parent 3d1c7e7082
commit 057e77e229

View file

@ -2662,8 +2662,12 @@ class BenchmarkWrapper:
unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
this_peer_finished = False # used by synced_gpus only
first_token_time = None
last_token_time = []
# auto-regressive generation
while True:
st = time.perf_counter()
if synced_gpus:
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
# The following logic allows an early break if all peers finished generating their sequence
@ -2740,6 +2744,14 @@ class BenchmarkWrapper:
if unfinished_sequences.max() == 0:
this_peer_finished = True
if self.device.type == "xpu":
torch.xpu.synchronize()
end = time.perf_counter()
if first_token_time is None:
first_token_time = end - st
else:
last_token_time.append(end - st)
# stop if we exceed the maximum length
if stopping_criteria(input_ids, scores):
this_peer_finished = True
@ -2747,6 +2759,15 @@ class BenchmarkWrapper:
if this_peer_finished and not synced_gpus:
break
if self.do_print:
print(f"=========First token cost {first_token_time:.4f} s=========")
if len(last_token_time) > 1:
self.first_cost = first_token_time
self.rest_cost_mean = np.mean(last_token_time)
if self.do_print:
print(f"=========Rest tokens cost average {self.rest_cost_mean:.4f} s ({len(last_token_time)}"
f" tokens in all)=========")
if streamer is not None:
streamer.end()
@ -3301,7 +3322,11 @@ class BenchmarkWrapper:
beam_scores = beam_scores.view((batch_size * num_beams,))
this_peer_finished = False # used by synced_gpus only
first_token_time = None
last_token_time = []
while True:
st = time.perf_counter()
if synced_gpus:
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
# The following logic allows an early break if all peers finished generating their sequence
@ -3408,6 +3433,14 @@ class BenchmarkWrapper:
else:
this_peer_finished = True
if self.device.type == "xpu":
torch.xpu.synchronize()
end = time.perf_counter()
if first_token_time is None:
first_token_time = end - st
else:
last_token_time.append(end - st)
sequence_outputs = beam_scorer.finalize(
input_ids,
beam_scores,
@ -3419,6 +3452,15 @@ class BenchmarkWrapper:
beam_indices=beam_indices,
)
if self.do_print:
print(f"=========First token cost {first_token_time:.4f} s=========")
if len(last_token_time) > 1:
self.first_cost = first_token_time
self.rest_cost_mean = np.mean(last_token_time)
if self.do_print:
print(f"=========Rest tokens cost average {self.rest_cost_mean:.4f} s ({len(last_token_time)}"
f" tokens in all)=========")
if return_dict_in_generate:
if not output_scores:
sequence_outputs["sequence_scores"] = None