LLM: fix n_batch in starcoder pybinding (#8461)

This commit is contained in:
binbin Deng 2023-07-05 17:06:50 +08:00 committed by GitHub
parent f2bb469847
commit 77808fa124

View file

@ -363,7 +363,7 @@ class Starcoder(GenerationMixin):
input_ids=input_ids, input_ids=input_ids,
seed=self.seed, seed=self.seed,
n_threads=self.n_threads, n_threads=self.n_threads,
n_batch=len(input_ids)) n_batch=self.n_batch)
def _generate( def _generate(
self, self,
@ -432,4 +432,4 @@ class Starcoder(GenerationMixin):
input_ids=input_ids, input_ids=input_ids,
seed=self.seed, seed=self.seed,
n_threads=self.n_threads, n_threads=self.n_threads,
n_batch=len(input_ids)) n_batch=self.n_batch)