From e54e52b4389cbd3fcf48f1b25f8e419b0cf7d1f7 Mon Sep 17 00:00:00 2001 From: binbin Deng <108676127+plusbang@users.noreply.github.com> Date: Tue, 4 Jul 2023 15:10:32 +0800 Subject: [PATCH] LLM: fix n_batch in bloom pybinding (#8454) --- python/llm/src/bigdl/llm/ggml/model/bloom/bloom.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/llm/src/bigdl/llm/ggml/model/bloom/bloom.py b/python/llm/src/bigdl/llm/ggml/model/bloom/bloom.py index 9e7af7a7..e312c73a 100644 --- a/python/llm/src/bigdl/llm/ggml/model/bloom/bloom.py +++ b/python/llm/src/bigdl/llm/ggml/model/bloom/bloom.py @@ -359,7 +359,7 @@ class Bloom(GenerationMixin): input_ids=input_ids, seed=self.seed, n_threads=self.n_threads, - n_batch=len(input_ids)) + n_batch=self.n_batch) def _generate( self, @@ -426,4 +426,4 @@ class Bloom(GenerationMixin): input_ids=input_ids, seed=self.seed, n_threads=self.n_threads, - n_batch=len(input_ids)) + n_batch=self.n_batch)