From 916c338772deb69f4efed413952f61ee1e25f9b3 Mon Sep 17 00:00:00 2001 From: Xiangyu Tian <109123695+xiangyuT@users.noreply.github.com> Date: Tue, 28 Nov 2023 11:09:54 +0800 Subject: [PATCH] fix bugs in vllm length check (#9543) --- .../src/bigdl/llm/vllm/model_executor/models/bigdl_model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/llm/src/bigdl/llm/vllm/model_executor/models/bigdl_model.py b/python/llm/src/bigdl/llm/vllm/model_executor/models/bigdl_model.py index 4bc95d13..6ab9e109 100644 --- a/python/llm/src/bigdl/llm/vllm/model_executor/models/bigdl_model.py +++ b/python/llm/src/bigdl/llm/vllm/model_executor/models/bigdl_model.py @@ -91,7 +91,7 @@ class BigDLModelForCausalLM(nn.Module): ): max_seq_limit = self.max_seq_limit if (self.last_kv_cache is not None) and cur_seq_ids == self.last_seq_ids: - if self.last_kv_cache[0][0].size(2) < max_seq_limit: + if self.last_kv_cache[0][0].size(2) < max_seq_limit * 1.5: bigdl_kv_cache = self.last_kv_cache else: bigdl_kv_cache = [[tmp.narrow(2, self.last_kv_cache[0][0].size(2) @@ -117,7 +117,7 @@ class BigDLModelForCausalLM(nn.Module): views = [_pad_kv_cache_view(v, max_len, self.device) for v in views] cur_view = torch.cat(views, dim=0) - if cur_view.size(2) > max_seq_limit * 1.5: + if cur_view.size(2) > max_seq_limit: cur_view = _pad_kv_cache_view(cur_view, max_seq_limit, self.device) cur_list.append(cur_view)