fix bugs in vllm length check (#9543)
This commit is contained in:
parent
5098bc3544
commit
916c338772
1 changed files with 2 additions and 2 deletions
|
|
@ -91,7 +91,7 @@ class BigDLModelForCausalLM(nn.Module):
|
||||||
):
|
):
|
||||||
max_seq_limit = self.max_seq_limit
|
max_seq_limit = self.max_seq_limit
|
||||||
if (self.last_kv_cache is not None) and cur_seq_ids == self.last_seq_ids:
|
if (self.last_kv_cache is not None) and cur_seq_ids == self.last_seq_ids:
|
||||||
if self.last_kv_cache[0][0].size(2) < max_seq_limit:
|
if self.last_kv_cache[0][0].size(2) < max_seq_limit * 1.5:
|
||||||
bigdl_kv_cache = self.last_kv_cache
|
bigdl_kv_cache = self.last_kv_cache
|
||||||
else:
|
else:
|
||||||
bigdl_kv_cache = [[tmp.narrow(2, self.last_kv_cache[0][0].size(2)
|
bigdl_kv_cache = [[tmp.narrow(2, self.last_kv_cache[0][0].size(2)
|
||||||
|
|
@ -117,7 +117,7 @@ class BigDLModelForCausalLM(nn.Module):
|
||||||
views = [_pad_kv_cache_view(v, max_len, self.device) for v in views]
|
views = [_pad_kv_cache_view(v, max_len, self.device) for v in views]
|
||||||
cur_view = torch.cat(views, dim=0)
|
cur_view = torch.cat(views, dim=0)
|
||||||
|
|
||||||
if cur_view.size(2) > max_seq_limit * 1.5:
|
if cur_view.size(2) > max_seq_limit:
|
||||||
cur_view = _pad_kv_cache_view(cur_view, max_seq_limit, self.device)
|
cur_view = _pad_kv_cache_view(cur_view, max_seq_limit, self.device)
|
||||||
cur_list.append(cur_view)
|
cur_list.append(cur_view)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue