[vLLM] Add option to adjust KV_CACHE_ALLOC_BLOCK_LENGTH (#9782)
* add option kv_cache_block * change var name
This commit is contained in:
		
							parent
							
								
									99bddd3ab4
								
							
						
					
					
						commit
						5857a38321
					
				
					 1 changed files with 6 additions and 4 deletions
				
			
		| 
						 | 
					@ -36,6 +36,7 @@ import importlib
 | 
				
			||||||
import torch.nn as nn
 | 
					import torch.nn as nn
 | 
				
			||||||
from typing import Optional, Tuple, Union, List
 | 
					from typing import Optional, Tuple, Union, List
 | 
				
			||||||
import math
 | 
					import math
 | 
				
			||||||
 | 
					import os
 | 
				
			||||||
import torch.nn.functional as F
 | 
					import torch.nn.functional as F
 | 
				
			||||||
from bigdl.llm.utils.common import invalidInputError
 | 
					from bigdl.llm.utils.common import invalidInputError
 | 
				
			||||||
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
 | 
					from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
 | 
				
			||||||
| 
						 | 
					@ -319,7 +320,7 @@ def llama_attention_selective_batching_forward_4_31(
 | 
				
			||||||
    **kwargs,
 | 
					    **kwargs,
 | 
				
			||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
 | 
					) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
 | 
				
			||||||
    # Minimize this value to reduce memory allocation.
 | 
					    # Minimize this value to reduce memory allocation.
 | 
				
			||||||
    KV_CACHE_ALLOC_BLOCK_LENGTH = 64
 | 
					    VLLM_KV_CACHE_ALLOC_BLOCK_LENGTH = int(os.environ.get('VLLM_KV_CACHE_ALLOC_BLOCK', 64))
 | 
				
			||||||
    bsz, q_len, _ = hidden_states.size()
 | 
					    bsz, q_len, _ = hidden_states.size()
 | 
				
			||||||
    device = hidden_states.device
 | 
					    device = hidden_states.device
 | 
				
			||||||
    # for flash attention
 | 
					    # for flash attention
 | 
				
			||||||
| 
						 | 
					@ -359,7 +360,7 @@ def llama_attention_selective_batching_forward_4_31(
 | 
				
			||||||
                                                       self.head_dim,
 | 
					                                                       self.head_dim,
 | 
				
			||||||
                                                       kv_seq_len,
 | 
					                                                       kv_seq_len,
 | 
				
			||||||
                                                       kv_seq_len +
 | 
					                                                       kv_seq_len +
 | 
				
			||||||
                                                       KV_CACHE_ALLOC_BLOCK_LENGTH,
 | 
					                                                       VLLM_KV_CACHE_ALLOC_BLOCK_LENGTH,
 | 
				
			||||||
                                                       dtype=past_k.dtype,
 | 
					                                                       dtype=past_k.dtype,
 | 
				
			||||||
                                                       device=device)
 | 
					                                                       device=device)
 | 
				
			||||||
            new_cache_k[:] = past_k
 | 
					            new_cache_k[:] = past_k
 | 
				
			||||||
| 
						 | 
					@ -421,7 +422,7 @@ def llama_attention_selective_batching_forward_4_31(
 | 
				
			||||||
                                                               self.head_dim,
 | 
					                                                               self.head_dim,
 | 
				
			||||||
                                                               past_k.size(2),
 | 
					                                                               past_k.size(2),
 | 
				
			||||||
                                                               current_kv_len +
 | 
					                                                               current_kv_len +
 | 
				
			||||||
                                                               KV_CACHE_ALLOC_BLOCK_LENGTH,
 | 
					                                                               VLLM_KV_CACHE_ALLOC_BLOCK_LENGTH,
 | 
				
			||||||
                                                               dtype=past_k.dtype,
 | 
					                                                               dtype=past_k.dtype,
 | 
				
			||||||
                                                               device=device)
 | 
					                                                               device=device)
 | 
				
			||||||
                    new_cache_k[:] = past_k
 | 
					                    new_cache_k[:] = past_k
 | 
				
			||||||
| 
						 | 
					@ -635,7 +636,8 @@ def llama_model_selective_batching_forward_4_31(
 | 
				
			||||||
    # TODO: validate correctness
 | 
					    # TODO: validate correctness
 | 
				
			||||||
    device = input_ids.device if input_ids is not None else inputs_embeds.device
 | 
					    device = input_ids.device if input_ids is not None else inputs_embeds.device
 | 
				
			||||||
    if position_ids is None:
 | 
					    if position_ids is None:
 | 
				
			||||||
        invalidInputError("vLLM: position_ids should never be None")
 | 
					        invalidInputError(False,
 | 
				
			||||||
 | 
					                          "vLLM: position_ids should never be None")
 | 
				
			||||||
    else:
 | 
					    else:
 | 
				
			||||||
        # print(f"Original position_ids is {position_ids}")
 | 
					        # print(f"Original position_ids is {position_ids}")
 | 
				
			||||||
        position_ids = position_ids.view(-1, seq_length)
 | 
					        position_ids = position_ids.view(-1, seq_length)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue