[vLLM] Add option to adjust KV_CACHE_ALLOC_BLOCK_LENGTH (#9782)
* add option kv_cache_block * change var name
This commit is contained in:
parent
99bddd3ab4
commit
5857a38321
1 changed files with 6 additions and 4 deletions
|
|
@ -36,6 +36,7 @@ import importlib
|
|||
import torch.nn as nn
|
||||
from typing import Optional, Tuple, Union, List
|
||||
import math
|
||||
import os
|
||||
import torch.nn.functional as F
|
||||
from bigdl.llm.utils.common import invalidInputError
|
||||
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
|
||||
|
|
@ -319,7 +320,7 @@ def llama_attention_selective_batching_forward_4_31(
|
|||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
# Minimize this value to reduce memory allocation.
|
||||
KV_CACHE_ALLOC_BLOCK_LENGTH = 64
|
||||
VLLM_KV_CACHE_ALLOC_BLOCK_LENGTH = int(os.environ.get('VLLM_KV_CACHE_ALLOC_BLOCK', 64))
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
device = hidden_states.device
|
||||
# for flash attention
|
||||
|
|
@ -359,7 +360,7 @@ def llama_attention_selective_batching_forward_4_31(
|
|||
self.head_dim,
|
||||
kv_seq_len,
|
||||
kv_seq_len +
|
||||
KV_CACHE_ALLOC_BLOCK_LENGTH,
|
||||
VLLM_KV_CACHE_ALLOC_BLOCK_LENGTH,
|
||||
dtype=past_k.dtype,
|
||||
device=device)
|
||||
new_cache_k[:] = past_k
|
||||
|
|
@ -421,7 +422,7 @@ def llama_attention_selective_batching_forward_4_31(
|
|||
self.head_dim,
|
||||
past_k.size(2),
|
||||
current_kv_len +
|
||||
KV_CACHE_ALLOC_BLOCK_LENGTH,
|
||||
VLLM_KV_CACHE_ALLOC_BLOCK_LENGTH,
|
||||
dtype=past_k.dtype,
|
||||
device=device)
|
||||
new_cache_k[:] = past_k
|
||||
|
|
@ -635,7 +636,8 @@ def llama_model_selective_batching_forward_4_31(
|
|||
# TODO: validate correctness
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
if position_ids is None:
|
||||
invalidInputError("vLLM: position_ids should never be None")
|
||||
invalidInputError(False,
|
||||
"vLLM: position_ids should never be None")
|
||||
else:
|
||||
# print(f"Original position_ids is {position_ids}")
|
||||
position_ids = position_ids.view(-1, seq_length)
|
||||
|
|
|
|||
Loading…
Reference in a new issue