[VLLM] Change padding patterns for vLLM & clean code (#9609)
* optimize * fix minor error * optimizations * fix style
This commit is contained in:
		
							parent
							
								
									89069d6173
								
							
						
					
					
						commit
						6978b2c316
					
				
					 2 changed files with 78 additions and 47 deletions
				
			
		| 
						 | 
				
			
			@ -26,6 +26,7 @@ from bigdl.llm.vllm.model_executor.models.bigdl_model import BigDLModelForCausal
 | 
			
		|||
from bigdl.llm.vllm.logger import init_logger
 | 
			
		||||
import math
 | 
			
		||||
import time
 | 
			
		||||
from bigdl.llm.vllm.model_executor.input_metadata import InputMetadata
 | 
			
		||||
from transformers.generation.logits_process import (
 | 
			
		||||
    LogitsProcessorList,
 | 
			
		||||
    RepetitionPenaltyLogitsProcessor,
 | 
			
		||||
| 
						 | 
				
			
			@ -39,7 +40,16 @@ logger = init_logger(__name__)
 | 
			
		|||
 | 
			
		||||
 | 
			
		||||
def _pad_to_max(x: List[int], max_len: int, padding_id: int = 0) -> List[int]:
 | 
			
		||||
    return x + [padding_id] * (max_len - len(x))
 | 
			
		||||
    return [padding_id] * (max_len - len(x)) + x
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _get_attention_mask_for_prompts(
 | 
			
		||||
    input_ids: List[List[int]], max_prompt_len: int
 | 
			
		||||
) -> List[List[int]]:
 | 
			
		||||
    attention_mask = [
 | 
			
		||||
        [0] * (max_prompt_len - len(prompt)) + [1] * len(prompt) for prompt in input_ids
 | 
			
		||||
    ]
 | 
			
		||||
    return attention_mask
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class BigDLLlamaForCausalLM(BigDLModelForCausalLM):
 | 
			
		||||
| 
						 | 
				
			
			@ -98,49 +108,53 @@ class BigDLLlamaForCausalLM(BigDLModelForCausalLM):
 | 
			
		|||
    def forward(
 | 
			
		||||
        self,
 | 
			
		||||
        seq_group_meta_data_lists: List[SequenceGroupMetadata],
 | 
			
		||||
        kv_cache: Optional = None,
 | 
			
		||||
        input_metadata: Optional = None,
 | 
			
		||||
        # kv_cache in the format [[dict() for _ in range(2)] for _ in range(32)]
 | 
			
		||||
        kv_cache: Optional[List[List[Dict]]] = None,
 | 
			
		||||
        input_metadata: Optional[InputMetadata] = None,
 | 
			
		||||
    ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
 | 
			
		||||
        kv_cache_size_0 = self.model.config.num_hidden_layers
 | 
			
		||||
        kv_cache_size_1 = 2
 | 
			
		||||
        seq_len = len(seq_group_meta_data_lists)
 | 
			
		||||
        num_layers = self.model.config.num_hidden_layers
 | 
			
		||||
        # One for key, one for value
 | 
			
		||||
        decoder_kv_size = 2
 | 
			
		||||
 | 
			
		||||
        bigdl_input_ids = []
 | 
			
		||||
        bigdl_position_ids = []
 | 
			
		||||
        bigdl_attention_mask = []
 | 
			
		||||
 | 
			
		||||
        cur_seq_ids = []
 | 
			
		||||
        bigdl_sampling_params = {}
 | 
			
		||||
        max_context_len = 0
 | 
			
		||||
        all_decoding = True
 | 
			
		||||
        max_prompt_len = 0
 | 
			
		||||
 | 
			
		||||
        # 0. Verify is_prompt or is_decoding
 | 
			
		||||
        is_decoding_stage = not seq_group_meta_data_lists[0].is_prompt
 | 
			
		||||
 | 
			
		||||
        # 1. Assemble bigdl_input_ids
 | 
			
		||||
        for seq_group_meta_data in seq_group_meta_data_lists:
 | 
			
		||||
            req_id = seq_group_meta_data.request_id
 | 
			
		||||
            all_decoding = all_decoding and (not seq_group_meta_data.is_prompt)
 | 
			
		||||
            # req_id = seq_group_meta_data.request_id
 | 
			
		||||
            # is_decoding_stage = is_decoding_stage and (not seq_group_meta_data.is_prompt)
 | 
			
		||||
            seq_ids = list(seq_group_meta_data.seq_data.keys())
 | 
			
		||||
            seq_id = seq_ids[0]
 | 
			
		||||
            cur_seq_ids.append(seq_id)
 | 
			
		||||
            seq_data = seq_group_meta_data.seq_data[seq_id]
 | 
			
		||||
 | 
			
		||||
            cur_seq_input_ids = seq_data.get_token_ids()
 | 
			
		||||
            context_len = seq_data.get_len()
 | 
			
		||||
            # context_len = seq_data.get_len()
 | 
			
		||||
            if seq_group_meta_data.is_prompt:
 | 
			
		||||
                bigdl_input_ids.append(cur_seq_input_ids)
 | 
			
		||||
                max_context_len = max(max_context_len, context_len)
 | 
			
		||||
                max_prompt_len = max(max_prompt_len, seq_data.get_len())
 | 
			
		||||
            else:
 | 
			
		||||
                bigdl_input_ids.append([cur_seq_input_ids[-1]])
 | 
			
		||||
        # 1. Assemble bigdl_input_ids end
 | 
			
		||||
 | 
			
		||||
            bigdl_sampling_params[seq_id] = seq_group_meta_data.sampling_params
 | 
			
		||||
 | 
			
		||||
        if all_decoding:
 | 
			
		||||
        if is_decoding_stage:
 | 
			
		||||
            bigdl_kv_cache = self.prepare_kv_cache(cur_seq_ids, seq_group_meta_data_lists,
 | 
			
		||||
                                                   kv_cache, kv_cache_size_0, kv_cache_size_1)
 | 
			
		||||
                                                   kv_cache, num_layers, decoder_kv_size)
 | 
			
		||||
        else:
 | 
			
		||||
            bigdl_attention_mask = _get_attention_mask_for_prompts(bigdl_input_ids, max_prompt_len)
 | 
			
		||||
            bigdl_input_ids = [
 | 
			
		||||
                _pad_to_max(input_ids, max_context_len, self.pad_token_id)
 | 
			
		||||
                _pad_to_max(input_ids, max_prompt_len, self.pad_token_id)
 | 
			
		||||
                for input_ids in bigdl_input_ids
 | 
			
		||||
            ]
 | 
			
		||||
 | 
			
		||||
        if all_decoding:
 | 
			
		||||
        if is_decoding_stage:
 | 
			
		||||
            cur_seq_len = bigdl_kv_cache[0][0].size(2)
 | 
			
		||||
            for seq_group_meta_data in seq_group_meta_data_lists:
 | 
			
		||||
                seq_ids = list(seq_group_meta_data.seq_data.keys())
 | 
			
		||||
| 
						 | 
				
			
			@ -152,7 +166,8 @@ class BigDLLlamaForCausalLM(BigDLModelForCausalLM):
 | 
			
		|||
                bigdl_attention_mask.append(cur_attention_mask)
 | 
			
		||||
 | 
			
		||||
        bigdl_input_ids = torch.tensor(bigdl_input_ids, device=self.device)
 | 
			
		||||
        if all_decoding:
 | 
			
		||||
 | 
			
		||||
        if is_decoding_stage:
 | 
			
		||||
            bigdl_position_ids = torch.tensor(bigdl_position_ids, device=self.device)
 | 
			
		||||
            bigdl_attention_mask = torch.tensor(bigdl_attention_mask, device=self.device)
 | 
			
		||||
            kwargs = {
 | 
			
		||||
| 
						 | 
				
			
			@ -166,6 +181,7 @@ class BigDLLlamaForCausalLM(BigDLModelForCausalLM):
 | 
			
		|||
        else:
 | 
			
		||||
            kwargs = {
 | 
			
		||||
                "input_ids": bigdl_input_ids,
 | 
			
		||||
                "attention_mask": torch.tensor(bigdl_attention_mask, device=self.device),
 | 
			
		||||
                # "position_ids": bigdl_position_ids,
 | 
			
		||||
                "past_key_values": None,
 | 
			
		||||
                "use_cache": True,
 | 
			
		||||
| 
						 | 
				
			
			@ -190,7 +206,7 @@ class BigDLLlamaForCausalLM(BigDLModelForCausalLM):
 | 
			
		|||
        # logger.info(f"before: {tmp['allocated_bytes.all.current']}")
 | 
			
		||||
 | 
			
		||||
        self.update_kv_cache(cur_seq_ids,
 | 
			
		||||
                             kv_cache, kv_cache_size_0, kv_cache_size_1)
 | 
			
		||||
                             kv_cache, num_layers, decoder_kv_size)
 | 
			
		||||
 | 
			
		||||
        # tmp = torch.xpu.memory_stats()
 | 
			
		||||
        # logger.info(f"after: {tmp['allocated_bytes.all.current']}")
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -90,45 +90,60 @@ class BigDLModelForCausalLM(nn.Module):
 | 
			
		|||
        cur_seq_ids: List[int],
 | 
			
		||||
        seq_group_meta_data_lists: List[SequenceGroupMetadata],
 | 
			
		||||
        kv_cache: Dict,
 | 
			
		||||
        kv_cache_size_0: int,
 | 
			
		||||
        num_layers: int,
 | 
			
		||||
        kv_cache_size_1: int,
 | 
			
		||||
    ):
 | 
			
		||||
        max_seq_limit = self.max_seq_limit
 | 
			
		||||
        if (self.last_kv_cache is not None) and cur_seq_ids == self.last_seq_ids:
 | 
			
		||||
            if self.last_kv_cache[0][0].size(2) < max_seq_limit * 1.5:
 | 
			
		||||
                bigdl_kv_cache = self.last_kv_cache
 | 
			
		||||
                # Immediately set it to None to decrease ref-count
 | 
			
		||||
                self.last_kv_cache = None
 | 
			
		||||
            else:
 | 
			
		||||
                bigdl_kv_cache = [[tmp.narrow(2, self.last_kv_cache[0][0].size(2)
 | 
			
		||||
                                   - max_seq_limit, max_seq_limit)
 | 
			
		||||
                                   for tmp in tmp_list] for tmp_list in self.last_kv_cache]
 | 
			
		||||
                del self.last_kv_cache
 | 
			
		||||
                self.last_kv_cache = None
 | 
			
		||||
        else:
 | 
			
		||||
            del self.last_kv_cache
 | 
			
		||||
            bigdl_kv_cache = []
 | 
			
		||||
            for i in range(kv_cache_size_0):
 | 
			
		||||
            max_kv_len = max(kv_cache[0][0][processed_seq_id].size(dim=1)
 | 
			
		||||
                             for processed_seq_id in cur_seq_ids)
 | 
			
		||||
            max_kv_len = min(max_kv_len, max_seq_limit)
 | 
			
		||||
            for layer in range(num_layers):
 | 
			
		||||
                cur_list = []
 | 
			
		||||
                for j in range(kv_cache_size_1):
 | 
			
		||||
                    views = []
 | 
			
		||||
                    max_len = 0
 | 
			
		||||
                    for seq_group_meta_data in seq_group_meta_data_lists:
 | 
			
		||||
                        seq_ids = list(seq_group_meta_data.seq_data.keys())
 | 
			
		||||
                        seq_id = seq_ids[0]
 | 
			
		||||
                        seq_data = seq_group_meta_data.seq_data[seq_id]
 | 
			
		||||
                        view_size = [1] + list(kv_cache[i][j][seq_id].shape)
 | 
			
		||||
                        views.append(kv_cache[i][j][seq_id].view(view_size))
 | 
			
		||||
                        max_len = max(max_len, view_size[2])
 | 
			
		||||
                for kv in range(kv_cache_size_1):
 | 
			
		||||
                    kv_list = []
 | 
			
		||||
                    # for seq_group_meta_data in seq_group_meta_data_lists:
 | 
			
		||||
                    #     seq_ids = list(seq_group_meta_data.seq_data.keys())
 | 
			
		||||
                    #     seq_id = seq_ids[0]
 | 
			
		||||
                    #     # seq_data = seq_group_meta_data.seq_data[seq_id]
 | 
			
		||||
                    #     view_size = [1] + list(kv_cache[layer][kv][seq_id].shape)
 | 
			
		||||
                    #     kv_list.append(kv_cache[layer][kv][seq_id].view(view_size))
 | 
			
		||||
                    for seq_id in cur_seq_ids:
 | 
			
		||||
                        processed_kv_cache = kv_cache[layer][kv][seq_id]
 | 
			
		||||
                        # Clean
 | 
			
		||||
                        kv_cache[layer][kv][processed_kv_cache] = None
 | 
			
		||||
                        if processed_kv_cache.size(dim=1) != max_kv_len:
 | 
			
		||||
                            processed_kv_cache = _pad_kv_cache_view(processed_kv_cache, max_kv_len,
 | 
			
		||||
                                                                    self.device, 1)
 | 
			
		||||
                            # Do padding
 | 
			
		||||
                        kv_list.append(processed_kv_cache)
 | 
			
		||||
                    current_layer_kv_cache = torch.stack(kv_list, dim=0)
 | 
			
		||||
                    kv_list.clear()
 | 
			
		||||
 | 
			
		||||
                    views = [_pad_kv_cache_view(v, max_len, self.device) for v in views]
 | 
			
		||||
                    cur_view = torch.cat(views, dim=0)
 | 
			
		||||
                    # kv_list = [_pad_kv_cache_view(v, max_kv_len, self.device) for v in kv_list]
 | 
			
		||||
                    # cur_view = torch.cat(kv_list, dim=0)
 | 
			
		||||
 | 
			
		||||
                    if cur_view.size(2) > max_seq_limit:
 | 
			
		||||
                        cur_view = _pad_kv_cache_view(cur_view, max_seq_limit, self.device)
 | 
			
		||||
                    cur_list.append(cur_view)
 | 
			
		||||
                    # if cur_view.size(2) > max_seq_limit:
 | 
			
		||||
                    #     cur_view = _pad_kv_cache_view(cur_view, max_seq_limit, self.device)
 | 
			
		||||
                    cur_list.append(current_layer_kv_cache)
 | 
			
		||||
 | 
			
		||||
                    for seq_group_meta_data in seq_group_meta_data_lists:
 | 
			
		||||
                        seq_ids = list(seq_group_meta_data.seq_data.keys())
 | 
			
		||||
                        seq_id = seq_ids[0]
 | 
			
		||||
                        del kv_cache[i][j][seq_id]
 | 
			
		||||
                    # for seq_group_meta_data in seq_group_meta_data_lists:
 | 
			
		||||
                    #     seq_ids = list(seq_group_meta_data.seq_data.keys())
 | 
			
		||||
                    #     seq_id = seq_ids[0]
 | 
			
		||||
                    #     del kv_cache[layer][kv][seq_id]
 | 
			
		||||
                bigdl_kv_cache.append(cur_list)
 | 
			
		||||
 | 
			
		||||
        return bigdl_kv_cache
 | 
			
		||||
| 
						 | 
				
			
			@ -139,15 +154,15 @@ class BigDLModelForCausalLM(nn.Module):
 | 
			
		|||
        self,
 | 
			
		||||
        cur_seq_ids: List[int],
 | 
			
		||||
        kv_cache,
 | 
			
		||||
        kv_cache_size_0: int,
 | 
			
		||||
        layer: int,
 | 
			
		||||
        kv_cache_size_1: int,
 | 
			
		||||
    ) -> None:
 | 
			
		||||
        for i in range(kv_cache_size_0):
 | 
			
		||||
        for i in range(layer):
 | 
			
		||||
            for j in range(kv_cache_size_1):
 | 
			
		||||
                index = 0
 | 
			
		||||
                batch_dim = 0
 | 
			
		||||
                for seq_id in cur_seq_ids:
 | 
			
		||||
                    kv_cache[i][j][seq_id] = self.last_kv_cache[i][j][index]
 | 
			
		||||
                    index = index + 1
 | 
			
		||||
                    kv_cache[i][j][seq_id] = self.last_kv_cache[i][j][batch_dim]
 | 
			
		||||
                    batch_dim = batch_dim + 1
 | 
			
		||||
 | 
			
		||||
    def forward(
 | 
			
		||||
        self,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue