diff --git a/python/llm/src/bigdl/llm/ggml/model/starcoder/starcoder.py b/python/llm/src/bigdl/llm/ggml/model/starcoder/starcoder.py index ff480b9a..ed00177f 100644 --- a/python/llm/src/bigdl/llm/ggml/model/starcoder/starcoder.py +++ b/python/llm/src/bigdl/llm/ggml/model/starcoder/starcoder.py @@ -363,7 +363,7 @@ class Starcoder(GenerationMixin): input_ids=input_ids, seed=self.seed, n_threads=self.n_threads, - n_batch=len(input_ids)) + n_batch=self.n_batch) def _generate( self, @@ -432,4 +432,4 @@ class Starcoder(GenerationMixin): input_ids=input_ids, seed=self.seed, n_threads=self.n_threads, - n_batch=len(input_ids)) + n_batch=self.n_batch)