LLM: fix n_batch in bloom pybinding (#8454)
This commit is contained in:
parent
372c775cb4
commit
e54e52b438
1 changed files with 2 additions and 2 deletions
|
|
@ -359,7 +359,7 @@ class Bloom(GenerationMixin):
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
seed=self.seed,
|
seed=self.seed,
|
||||||
n_threads=self.n_threads,
|
n_threads=self.n_threads,
|
||||||
n_batch=len(input_ids))
|
n_batch=self.n_batch)
|
||||||
|
|
||||||
def _generate(
|
def _generate(
|
||||||
self,
|
self,
|
||||||
|
|
@ -426,4 +426,4 @@ class Bloom(GenerationMixin):
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
seed=self.seed,
|
seed=self.seed,
|
||||||
n_threads=self.n_threads,
|
n_threads=self.n_threads,
|
||||||
n_batch=len(input_ids))
|
n_batch=self.n_batch)
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue