LLM: fix n_batch in bloom pybinding (#8454)

This commit is contained in:
binbin Deng 2023-07-04 15:10:32 +08:00 committed by GitHub
parent 372c775cb4
commit e54e52b438

View file

@ -359,7 +359,7 @@ class Bloom(GenerationMixin):
input_ids=input_ids, input_ids=input_ids,
seed=self.seed, seed=self.seed,
n_threads=self.n_threads, n_threads=self.n_threads,
n_batch=len(input_ids)) n_batch=self.n_batch)
def _generate( def _generate(
self, self,
@ -426,4 +426,4 @@ class Bloom(GenerationMixin):
input_ids=input_ids, input_ids=input_ids,
seed=self.seed, seed=self.seed,
n_threads=self.n_threads, n_threads=self.n_threads,
n_batch=len(input_ids)) n_batch=self.n_batch)