diff --git a/python/llm/dev/benchmark/benchmark_util.py b/python/llm/dev/benchmark/benchmark_util.py index e2a202da..984bc762 100644 --- a/python/llm/dev/benchmark/benchmark_util.py +++ b/python/llm/dev/benchmark/benchmark_util.py @@ -821,10 +821,7 @@ class BenchmarkWrapper: return model_kwargs def _reorder_cache(self, past_key_values, beam_idx): - raise NotImplementedError( - f"Make sure that a `_reorder_cache` function is correctly implemented in {self.__class__.__module__} to" - f" enable beam search for {self.__class__}" - ) + return self.model._reorder_cache(past_key_values, beam_idx) def _get_logits_warper( self,