LLM: fix n_batch in starcoder pybinding (#8461)
This commit is contained in:
parent
f2bb469847
commit
77808fa124
1 changed files with 2 additions and 2 deletions
|
|
@ -363,7 +363,7 @@ class Starcoder(GenerationMixin):
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
seed=self.seed,
|
seed=self.seed,
|
||||||
n_threads=self.n_threads,
|
n_threads=self.n_threads,
|
||||||
n_batch=len(input_ids))
|
n_batch=self.n_batch)
|
||||||
|
|
||||||
def _generate(
|
def _generate(
|
||||||
self,
|
self,
|
||||||
|
|
@ -432,4 +432,4 @@ class Starcoder(GenerationMixin):
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
seed=self.seed,
|
seed=self.seed,
|
||||||
n_threads=self.n_threads,
|
n_threads=self.n_threads,
|
||||||
n_batch=len(input_ids))
|
n_batch=self.n_batch)
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue