LLM: update benchmark_util.py for beam search (#9126)
* update reorder_cache * fix
This commit is contained in:
parent
e8c5645067
commit
1363e666fc
1 changed files with 1 additions and 4 deletions
|
|
@ -821,10 +821,7 @@ class BenchmarkWrapper:
|
||||||
return model_kwargs
|
return model_kwargs
|
||||||
|
|
||||||
def _reorder_cache(self, past_key_values, beam_idx):
|
def _reorder_cache(self, past_key_values, beam_idx):
|
||||||
raise NotImplementedError(
|
return self.model._reorder_cache(past_key_values, beam_idx)
|
||||||
f"Make sure that a `_reorder_cache` function is correctly implemented in {self.__class__.__module__} to"
|
|
||||||
f" enable beam search for {self.__class__}"
|
|
||||||
)
|
|
||||||
|
|
||||||
def _get_logits_warper(
|
def _get_logits_warper(
|
||||||
self,
|
self,
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue