diff --git a/python/llm/src/bigdl/llm/vllm/model_executor/models/bigdl_llama.py b/python/llm/src/bigdl/llm/vllm/model_executor/models/bigdl_llama.py index 233c0216..331e740e 100644 --- a/python/llm/src/bigdl/llm/vllm/model_executor/models/bigdl_llama.py +++ b/python/llm/src/bigdl/llm/vllm/model_executor/models/bigdl_llama.py @@ -35,7 +35,6 @@ from transformers.generation.logits_process import ( TopPLogitsWarper, ) - logger = init_logger(__name__) @@ -161,18 +160,18 @@ class BigDLLlamaForCausalLM(BigDLModelForCausalLM): seq_id = seq_ids[0] seq_data = seq_group_meta_data.seq_data[seq_id] cur_pos = seq_data.get_len() - bigdl_position_ids.append([cur_pos - 1]) + # bigdl_position_ids.append([cur_pos - 1]) cur_attention_mask = [0] * (cur_seq_len - cur_pos + 1) + [1] * (cur_pos) bigdl_attention_mask.append(cur_attention_mask) bigdl_input_ids = torch.tensor(bigdl_input_ids, device=self.device) if is_decoding_stage: - bigdl_position_ids = torch.tensor(bigdl_position_ids, device=self.device) + # bigdl_position_ids = torch.tensor(bigdl_position_ids, device=self.device) bigdl_attention_mask = torch.tensor(bigdl_attention_mask, device=self.device) kwargs = { "input_ids": bigdl_input_ids, - "position_ids": bigdl_position_ids, + # "position_ids": bigdl_position_ids, "attention_mask": bigdl_attention_mask, "past_key_values": bigdl_kv_cache, "use_cache": True, @@ -199,6 +198,7 @@ class BigDLLlamaForCausalLM(BigDLModelForCausalLM): # self.last_kv_cache = outputs.past_key_values self._set_last_seq_ids(cur_seq_ids[:]) self._set_last_kv_cache(outputs.past_key_values) + # pdb.set_trace() logits = outputs.logits[:, -1, :] bigdl_output = self.sampler(logits, input_metadata, st_timestamp) diff --git a/python/llm/src/bigdl/llm/vllm/model_executor/models/bigdl_model.py b/python/llm/src/bigdl/llm/vllm/model_executor/models/bigdl_model.py index 46e76432..6694a3f1 100644 --- a/python/llm/src/bigdl/llm/vllm/model_executor/models/bigdl_model.py +++ b/python/llm/src/bigdl/llm/vllm/model_executor/models/bigdl_model.py @@ -108,42 +108,31 @@ class BigDLModelForCausalLM(nn.Module): else: del self.last_kv_cache bigdl_kv_cache = [] - max_kv_len = max(kv_cache[0][0][processed_seq_id].size(dim=1) - for processed_seq_id in cur_seq_ids) + max_kv_len = max( + seq_group_meta_data.seq_data[next(iter(seq_group_meta_data.seq_data))].get_len() + for seq_group_meta_data in seq_group_meta_data_lists + ) max_kv_len = min(max_kv_len, max_seq_limit) - for layer in range(num_layers): + + for i in range(num_layers): cur_list = [] - for kv in range(kv_cache_size_1): - kv_list = [] - # for seq_group_meta_data in seq_group_meta_data_lists: - # seq_ids = list(seq_group_meta_data.seq_data.keys()) - # seq_id = seq_ids[0] - # # seq_data = seq_group_meta_data.seq_data[seq_id] - # view_size = [1] + list(kv_cache[layer][kv][seq_id].shape) - # kv_list.append(kv_cache[layer][kv][seq_id].view(view_size)) - for seq_id in cur_seq_ids: - processed_kv_cache = kv_cache[layer][kv][seq_id] - # Clean - kv_cache[layer][kv][processed_kv_cache] = None - if processed_kv_cache.size(dim=1) != max_kv_len: - processed_kv_cache = _pad_kv_cache_view(processed_kv_cache, max_kv_len, - self.device, 1) - # Do padding - kv_list.append(processed_kv_cache) - current_layer_kv_cache = torch.stack(kv_list, dim=0) - kv_list.clear() + for j in range(kv_cache_size_1): + views = [] + for seq_group_meta_data in seq_group_meta_data_lists: + seq_ids = list(seq_group_meta_data.seq_data.keys()) + seq_id = seq_ids[0] + view_size = [1] + list(kv_cache[i][j][seq_id].shape) + views.append(kv_cache[i][j][seq_id].view(view_size)) - # kv_list = [_pad_kv_cache_view(v, max_kv_len, self.device) for v in kv_list] - # cur_view = torch.cat(kv_list, dim=0) + views = [_pad_kv_cache_view(v, max_kv_len, self.device) for v in views] + cur_view = torch.cat(views, dim=0) + cur_list.append(cur_view) - # if cur_view.size(2) > max_seq_limit: - # cur_view = _pad_kv_cache_view(cur_view, max_seq_limit, self.device) - cur_list.append(current_layer_kv_cache) + for seq_group_meta_data in seq_group_meta_data_lists: + seq_ids = list(seq_group_meta_data.seq_data.keys()) + seq_id = seq_ids[0] + del kv_cache[i][j][seq_id] - # for seq_group_meta_data in seq_group_meta_data_lists: - # seq_ids = list(seq_group_meta_data.seq_data.keys()) - # seq_id = seq_ids[0] - # del kv_cache[layer][kv][seq_id] bigdl_kv_cache.append(cur_list) return bigdl_kv_cache