[vLLM] Add option to adjust KV_CACHE_ALLOC_BLOCK_LENGTH (#9782)
* add option kv_cache_block * change var name
This commit is contained in:
parent
99bddd3ab4
commit
5857a38321
1 changed files with 6 additions and 4 deletions
|
|
@ -36,6 +36,7 @@ import importlib
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from typing import Optional, Tuple, Union, List
|
from typing import Optional, Tuple, Union, List
|
||||||
import math
|
import math
|
||||||
|
import os
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from bigdl.llm.utils.common import invalidInputError
|
from bigdl.llm.utils.common import invalidInputError
|
||||||
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
|
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
|
||||||
|
|
@ -319,7 +320,7 @@ def llama_attention_selective_batching_forward_4_31(
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
# Minimize this value to reduce memory allocation.
|
# Minimize this value to reduce memory allocation.
|
||||||
KV_CACHE_ALLOC_BLOCK_LENGTH = 64
|
VLLM_KV_CACHE_ALLOC_BLOCK_LENGTH = int(os.environ.get('VLLM_KV_CACHE_ALLOC_BLOCK', 64))
|
||||||
bsz, q_len, _ = hidden_states.size()
|
bsz, q_len, _ = hidden_states.size()
|
||||||
device = hidden_states.device
|
device = hidden_states.device
|
||||||
# for flash attention
|
# for flash attention
|
||||||
|
|
@ -359,7 +360,7 @@ def llama_attention_selective_batching_forward_4_31(
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
kv_seq_len,
|
kv_seq_len,
|
||||||
kv_seq_len +
|
kv_seq_len +
|
||||||
KV_CACHE_ALLOC_BLOCK_LENGTH,
|
VLLM_KV_CACHE_ALLOC_BLOCK_LENGTH,
|
||||||
dtype=past_k.dtype,
|
dtype=past_k.dtype,
|
||||||
device=device)
|
device=device)
|
||||||
new_cache_k[:] = past_k
|
new_cache_k[:] = past_k
|
||||||
|
|
@ -421,7 +422,7 @@ def llama_attention_selective_batching_forward_4_31(
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
past_k.size(2),
|
past_k.size(2),
|
||||||
current_kv_len +
|
current_kv_len +
|
||||||
KV_CACHE_ALLOC_BLOCK_LENGTH,
|
VLLM_KV_CACHE_ALLOC_BLOCK_LENGTH,
|
||||||
dtype=past_k.dtype,
|
dtype=past_k.dtype,
|
||||||
device=device)
|
device=device)
|
||||||
new_cache_k[:] = past_k
|
new_cache_k[:] = past_k
|
||||||
|
|
@ -635,7 +636,8 @@ def llama_model_selective_batching_forward_4_31(
|
||||||
# TODO: validate correctness
|
# TODO: validate correctness
|
||||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||||
if position_ids is None:
|
if position_ids is None:
|
||||||
invalidInputError("vLLM: position_ids should never be None")
|
invalidInputError(False,
|
||||||
|
"vLLM: position_ids should never be None")
|
||||||
else:
|
else:
|
||||||
# print(f"Original position_ids is {position_ids}")
|
# print(f"Original position_ids is {position_ids}")
|
||||||
position_ids = position_ids.view(-1, seq_length)
|
position_ids = position_ids.view(-1, seq_length)
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue