fix bugs in vllm length check (#9543)

This commit is contained in:
Xiangyu Tian 2023-11-28 11:09:54 +08:00 committed by GitHub
parent 5098bc3544
commit 916c338772

View file

@ -91,7 +91,7 @@ class BigDLModelForCausalLM(nn.Module):
): ):
max_seq_limit = self.max_seq_limit max_seq_limit = self.max_seq_limit
if (self.last_kv_cache is not None) and cur_seq_ids == self.last_seq_ids: if (self.last_kv_cache is not None) and cur_seq_ids == self.last_seq_ids:
if self.last_kv_cache[0][0].size(2) < max_seq_limit: if self.last_kv_cache[0][0].size(2) < max_seq_limit * 1.5:
bigdl_kv_cache = self.last_kv_cache bigdl_kv_cache = self.last_kv_cache
else: else:
bigdl_kv_cache = [[tmp.narrow(2, self.last_kv_cache[0][0].size(2) bigdl_kv_cache = [[tmp.narrow(2, self.last_kv_cache[0][0].size(2)
@ -117,7 +117,7 @@ class BigDLModelForCausalLM(nn.Module):
views = [_pad_kv_cache_view(v, max_len, self.device) for v in views] views = [_pad_kv_cache_view(v, max_len, self.device) for v in views]
cur_view = torch.cat(views, dim=0) cur_view = torch.cat(views, dim=0)
if cur_view.size(2) > max_seq_limit * 1.5: if cur_view.size(2) > max_seq_limit:
cur_view = _pad_kv_cache_view(cur_view, max_seq_limit, self.device) cur_view = _pad_kv_cache_view(cur_view, max_seq_limit, self.device)
cur_list.append(cur_view) cur_list.append(cur_view)