[vLLM] Add option to adjust KV_CACHE_ALLOC_BLOCK_LENGTH (#9782)

* add option kv_cache_block

* change var name
This commit is contained in:
Guancheng Fu 2023-12-28 14:41:47 +08:00 committed by GitHub
parent 99bddd3ab4
commit 5857a38321

View file

@ -36,6 +36,7 @@ import importlib
import torch.nn as nn import torch.nn as nn
from typing import Optional, Tuple, Union, List from typing import Optional, Tuple, Union, List
import math import math
import os
import torch.nn.functional as F import torch.nn.functional as F
from bigdl.llm.utils.common import invalidInputError from bigdl.llm.utils.common import invalidInputError
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
@ -319,7 +320,7 @@ def llama_attention_selective_batching_forward_4_31(
**kwargs, **kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
# Minimize this value to reduce memory allocation. # Minimize this value to reduce memory allocation.
KV_CACHE_ALLOC_BLOCK_LENGTH = 64 VLLM_KV_CACHE_ALLOC_BLOCK_LENGTH = int(os.environ.get('VLLM_KV_CACHE_ALLOC_BLOCK', 64))
bsz, q_len, _ = hidden_states.size() bsz, q_len, _ = hidden_states.size()
device = hidden_states.device device = hidden_states.device
# for flash attention # for flash attention
@ -359,7 +360,7 @@ def llama_attention_selective_batching_forward_4_31(
self.head_dim, self.head_dim,
kv_seq_len, kv_seq_len,
kv_seq_len + kv_seq_len +
KV_CACHE_ALLOC_BLOCK_LENGTH, VLLM_KV_CACHE_ALLOC_BLOCK_LENGTH,
dtype=past_k.dtype, dtype=past_k.dtype,
device=device) device=device)
new_cache_k[:] = past_k new_cache_k[:] = past_k
@ -421,7 +422,7 @@ def llama_attention_selective_batching_forward_4_31(
self.head_dim, self.head_dim,
past_k.size(2), past_k.size(2),
current_kv_len + current_kv_len +
KV_CACHE_ALLOC_BLOCK_LENGTH, VLLM_KV_CACHE_ALLOC_BLOCK_LENGTH,
dtype=past_k.dtype, dtype=past_k.dtype,
device=device) device=device)
new_cache_k[:] = past_k new_cache_k[:] = past_k
@ -635,7 +636,8 @@ def llama_model_selective_batching_forward_4_31(
# TODO: validate correctness # TODO: validate correctness
device = input_ids.device if input_ids is not None else inputs_embeds.device device = input_ids.device if input_ids is not None else inputs_embeds.device
if position_ids is None: if position_ids is None:
invalidInputError("vLLM: position_ids should never be None") invalidInputError(False,
"vLLM: position_ids should never be None")
else: else:
# print(f"Original position_ids is {position_ids}") # print(f"Original position_ids is {position_ids}")
position_ids = position_ids.view(-1, seq_length) position_ids = position_ids.view(-1, seq_length)