LLM: fix n_batch in bloom pybinding (#8454)
This commit is contained in:
parent
372c775cb4
commit
e54e52b438
1 changed files with 2 additions and 2 deletions
|
|
@ -359,7 +359,7 @@ class Bloom(GenerationMixin):
|
|||
input_ids=input_ids,
|
||||
seed=self.seed,
|
||||
n_threads=self.n_threads,
|
||||
n_batch=len(input_ids))
|
||||
n_batch=self.n_batch)
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
|
|
@ -426,4 +426,4 @@ class Bloom(GenerationMixin):
|
|||
input_ids=input_ids,
|
||||
seed=self.seed,
|
||||
n_threads=self.n_threads,
|
||||
n_batch=len(input_ids))
|
||||
n_batch=self.n_batch)
|
||||
|
|
|
|||
Loading…
Reference in a new issue