From 5857a38321e878c793efffc2c2962ee3e2cb9112 Mon Sep 17 00:00:00 2001 From: Guancheng Fu <110874468+gc-fu@users.noreply.github.com> Date: Thu, 28 Dec 2023 14:41:47 +0800 Subject: [PATCH] [vLLM] Add option to adjust KV_CACHE_ALLOC_BLOCK_LENGTH (#9782) * add option kv_cache_block * change var name --- python/llm/src/bigdl/llm/transformers/models/llama.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/python/llm/src/bigdl/llm/transformers/models/llama.py b/python/llm/src/bigdl/llm/transformers/models/llama.py index 2300ea18..12a5d21d 100644 --- a/python/llm/src/bigdl/llm/transformers/models/llama.py +++ b/python/llm/src/bigdl/llm/transformers/models/llama.py @@ -36,6 +36,7 @@ import importlib import torch.nn as nn from typing import Optional, Tuple, Union, List import math +import os import torch.nn.functional as F from bigdl.llm.utils.common import invalidInputError from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache @@ -319,7 +320,7 @@ def llama_attention_selective_batching_forward_4_31( **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: # Minimize this value to reduce memory allocation. - KV_CACHE_ALLOC_BLOCK_LENGTH = 64 + VLLM_KV_CACHE_ALLOC_BLOCK_LENGTH = int(os.environ.get('VLLM_KV_CACHE_ALLOC_BLOCK', 64)) bsz, q_len, _ = hidden_states.size() device = hidden_states.device # for flash attention @@ -359,7 +360,7 @@ def llama_attention_selective_batching_forward_4_31( self.head_dim, kv_seq_len, kv_seq_len + - KV_CACHE_ALLOC_BLOCK_LENGTH, + VLLM_KV_CACHE_ALLOC_BLOCK_LENGTH, dtype=past_k.dtype, device=device) new_cache_k[:] = past_k @@ -421,7 +422,7 @@ def llama_attention_selective_batching_forward_4_31( self.head_dim, past_k.size(2), current_kv_len + - KV_CACHE_ALLOC_BLOCK_LENGTH, + VLLM_KV_CACHE_ALLOC_BLOCK_LENGTH, dtype=past_k.dtype, device=device) new_cache_k[:] = past_k @@ -635,7 +636,8 @@ def llama_model_selective_batching_forward_4_31( # TODO: validate correctness device = input_ids.device if input_ids is not None else inputs_embeds.device if position_ids is None: - invalidInputError("vLLM: position_ids should never be None") + invalidInputError(False, + "vLLM: position_ids should never be None") else: # print(f"Original position_ids is {position_ids}") position_ids = position_ids.view(-1, seq_length)