diff --git a/python/llm/src/bigdl/llm/vllm/model_executor/models/bigdl_llama.py b/python/llm/src/bigdl/llm/vllm/model_executor/models/bigdl_llama.py index 6268d96d..233c0216 100644 --- a/python/llm/src/bigdl/llm/vllm/model_executor/models/bigdl_llama.py +++ b/python/llm/src/bigdl/llm/vllm/model_executor/models/bigdl_llama.py @@ -26,6 +26,7 @@ from bigdl.llm.vllm.model_executor.models.bigdl_model import BigDLModelForCausal from bigdl.llm.vllm.logger import init_logger import math import time +from bigdl.llm.vllm.model_executor.input_metadata import InputMetadata from transformers.generation.logits_process import ( LogitsProcessorList, RepetitionPenaltyLogitsProcessor, @@ -39,7 +40,16 @@ logger = init_logger(__name__) def _pad_to_max(x: List[int], max_len: int, padding_id: int = 0) -> List[int]: - return x + [padding_id] * (max_len - len(x)) + return [padding_id] * (max_len - len(x)) + x + + +def _get_attention_mask_for_prompts( + input_ids: List[List[int]], max_prompt_len: int +) -> List[List[int]]: + attention_mask = [ + [0] * (max_prompt_len - len(prompt)) + [1] * len(prompt) for prompt in input_ids + ] + return attention_mask class BigDLLlamaForCausalLM(BigDLModelForCausalLM): @@ -98,49 +108,53 @@ class BigDLLlamaForCausalLM(BigDLModelForCausalLM): def forward( self, seq_group_meta_data_lists: List[SequenceGroupMetadata], - kv_cache: Optional = None, - input_metadata: Optional = None, + # kv_cache in the format [[dict() for _ in range(2)] for _ in range(32)] + kv_cache: Optional[List[List[Dict]]] = None, + input_metadata: Optional[InputMetadata] = None, ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]: - kv_cache_size_0 = self.model.config.num_hidden_layers - kv_cache_size_1 = 2 - seq_len = len(seq_group_meta_data_lists) + num_layers = self.model.config.num_hidden_layers + # One for key, one for value + decoder_kv_size = 2 bigdl_input_ids = [] bigdl_position_ids = [] bigdl_attention_mask = [] cur_seq_ids = [] - bigdl_sampling_params = {} - max_context_len = 0 - all_decoding = True + max_prompt_len = 0 + + # 0. Verify is_prompt or is_decoding + is_decoding_stage = not seq_group_meta_data_lists[0].is_prompt + + # 1. Assemble bigdl_input_ids for seq_group_meta_data in seq_group_meta_data_lists: - req_id = seq_group_meta_data.request_id - all_decoding = all_decoding and (not seq_group_meta_data.is_prompt) + # req_id = seq_group_meta_data.request_id + # is_decoding_stage = is_decoding_stage and (not seq_group_meta_data.is_prompt) seq_ids = list(seq_group_meta_data.seq_data.keys()) seq_id = seq_ids[0] cur_seq_ids.append(seq_id) seq_data = seq_group_meta_data.seq_data[seq_id] cur_seq_input_ids = seq_data.get_token_ids() - context_len = seq_data.get_len() + # context_len = seq_data.get_len() if seq_group_meta_data.is_prompt: bigdl_input_ids.append(cur_seq_input_ids) - max_context_len = max(max_context_len, context_len) + max_prompt_len = max(max_prompt_len, seq_data.get_len()) else: bigdl_input_ids.append([cur_seq_input_ids[-1]]) + # 1. Assemble bigdl_input_ids end - bigdl_sampling_params[seq_id] = seq_group_meta_data.sampling_params - - if all_decoding: + if is_decoding_stage: bigdl_kv_cache = self.prepare_kv_cache(cur_seq_ids, seq_group_meta_data_lists, - kv_cache, kv_cache_size_0, kv_cache_size_1) + kv_cache, num_layers, decoder_kv_size) else: + bigdl_attention_mask = _get_attention_mask_for_prompts(bigdl_input_ids, max_prompt_len) bigdl_input_ids = [ - _pad_to_max(input_ids, max_context_len, self.pad_token_id) + _pad_to_max(input_ids, max_prompt_len, self.pad_token_id) for input_ids in bigdl_input_ids ] - if all_decoding: + if is_decoding_stage: cur_seq_len = bigdl_kv_cache[0][0].size(2) for seq_group_meta_data in seq_group_meta_data_lists: seq_ids = list(seq_group_meta_data.seq_data.keys()) @@ -152,7 +166,8 @@ class BigDLLlamaForCausalLM(BigDLModelForCausalLM): bigdl_attention_mask.append(cur_attention_mask) bigdl_input_ids = torch.tensor(bigdl_input_ids, device=self.device) - if all_decoding: + + if is_decoding_stage: bigdl_position_ids = torch.tensor(bigdl_position_ids, device=self.device) bigdl_attention_mask = torch.tensor(bigdl_attention_mask, device=self.device) kwargs = { @@ -166,6 +181,7 @@ class BigDLLlamaForCausalLM(BigDLModelForCausalLM): else: kwargs = { "input_ids": bigdl_input_ids, + "attention_mask": torch.tensor(bigdl_attention_mask, device=self.device), # "position_ids": bigdl_position_ids, "past_key_values": None, "use_cache": True, @@ -190,7 +206,7 @@ class BigDLLlamaForCausalLM(BigDLModelForCausalLM): # logger.info(f"before: {tmp['allocated_bytes.all.current']}") self.update_kv_cache(cur_seq_ids, - kv_cache, kv_cache_size_0, kv_cache_size_1) + kv_cache, num_layers, decoder_kv_size) # tmp = torch.xpu.memory_stats() # logger.info(f"after: {tmp['allocated_bytes.all.current']}") diff --git a/python/llm/src/bigdl/llm/vllm/model_executor/models/bigdl_model.py b/python/llm/src/bigdl/llm/vllm/model_executor/models/bigdl_model.py index 88e94728..46e76432 100644 --- a/python/llm/src/bigdl/llm/vllm/model_executor/models/bigdl_model.py +++ b/python/llm/src/bigdl/llm/vllm/model_executor/models/bigdl_model.py @@ -90,45 +90,60 @@ class BigDLModelForCausalLM(nn.Module): cur_seq_ids: List[int], seq_group_meta_data_lists: List[SequenceGroupMetadata], kv_cache: Dict, - kv_cache_size_0: int, + num_layers: int, kv_cache_size_1: int, ): max_seq_limit = self.max_seq_limit if (self.last_kv_cache is not None) and cur_seq_ids == self.last_seq_ids: if self.last_kv_cache[0][0].size(2) < max_seq_limit * 1.5: bigdl_kv_cache = self.last_kv_cache + # Immediately set it to None to decrease ref-count + self.last_kv_cache = None else: bigdl_kv_cache = [[tmp.narrow(2, self.last_kv_cache[0][0].size(2) - max_seq_limit, max_seq_limit) for tmp in tmp_list] for tmp_list in self.last_kv_cache] del self.last_kv_cache + self.last_kv_cache = None else: del self.last_kv_cache bigdl_kv_cache = [] - for i in range(kv_cache_size_0): + max_kv_len = max(kv_cache[0][0][processed_seq_id].size(dim=1) + for processed_seq_id in cur_seq_ids) + max_kv_len = min(max_kv_len, max_seq_limit) + for layer in range(num_layers): cur_list = [] - for j in range(kv_cache_size_1): - views = [] - max_len = 0 - for seq_group_meta_data in seq_group_meta_data_lists: - seq_ids = list(seq_group_meta_data.seq_data.keys()) - seq_id = seq_ids[0] - seq_data = seq_group_meta_data.seq_data[seq_id] - view_size = [1] + list(kv_cache[i][j][seq_id].shape) - views.append(kv_cache[i][j][seq_id].view(view_size)) - max_len = max(max_len, view_size[2]) + for kv in range(kv_cache_size_1): + kv_list = [] + # for seq_group_meta_data in seq_group_meta_data_lists: + # seq_ids = list(seq_group_meta_data.seq_data.keys()) + # seq_id = seq_ids[0] + # # seq_data = seq_group_meta_data.seq_data[seq_id] + # view_size = [1] + list(kv_cache[layer][kv][seq_id].shape) + # kv_list.append(kv_cache[layer][kv][seq_id].view(view_size)) + for seq_id in cur_seq_ids: + processed_kv_cache = kv_cache[layer][kv][seq_id] + # Clean + kv_cache[layer][kv][processed_kv_cache] = None + if processed_kv_cache.size(dim=1) != max_kv_len: + processed_kv_cache = _pad_kv_cache_view(processed_kv_cache, max_kv_len, + self.device, 1) + # Do padding + kv_list.append(processed_kv_cache) + current_layer_kv_cache = torch.stack(kv_list, dim=0) + kv_list.clear() - views = [_pad_kv_cache_view(v, max_len, self.device) for v in views] - cur_view = torch.cat(views, dim=0) + # kv_list = [_pad_kv_cache_view(v, max_kv_len, self.device) for v in kv_list] + # cur_view = torch.cat(kv_list, dim=0) - if cur_view.size(2) > max_seq_limit: - cur_view = _pad_kv_cache_view(cur_view, max_seq_limit, self.device) - cur_list.append(cur_view) + # if cur_view.size(2) > max_seq_limit: + # cur_view = _pad_kv_cache_view(cur_view, max_seq_limit, self.device) + cur_list.append(current_layer_kv_cache) - for seq_group_meta_data in seq_group_meta_data_lists: - seq_ids = list(seq_group_meta_data.seq_data.keys()) - seq_id = seq_ids[0] - del kv_cache[i][j][seq_id] + # for seq_group_meta_data in seq_group_meta_data_lists: + # seq_ids = list(seq_group_meta_data.seq_data.keys()) + # seq_id = seq_ids[0] + # del kv_cache[layer][kv][seq_id] bigdl_kv_cache.append(cur_list) return bigdl_kv_cache @@ -139,15 +154,15 @@ class BigDLModelForCausalLM(nn.Module): self, cur_seq_ids: List[int], kv_cache, - kv_cache_size_0: int, + layer: int, kv_cache_size_1: int, ) -> None: - for i in range(kv_cache_size_0): + for i in range(layer): for j in range(kv_cache_size_1): - index = 0 + batch_dim = 0 for seq_id in cur_seq_ids: - kv_cache[i][j][seq_id] = self.last_kv_cache[i][j][index] - index = index + 1 + kv_cache[i][j][seq_id] = self.last_kv_cache[i][j][batch_dim] + batch_dim = batch_dim + 1 def forward( self,