diff --git a/python/llm/src/bigdl/llm/transformers/speculative.py b/python/llm/src/bigdl/llm/transformers/speculative.py index e3c8cfa2..45e8db50 100644 --- a/python/llm/src/bigdl/llm/transformers/speculative.py +++ b/python/llm/src/bigdl/llm/transformers/speculative.py @@ -225,6 +225,7 @@ def speculative_generate(self, draft_generate_ids = torch.empty([input_ids.size(0), draft_gen_length], dtype=torch.long, device=self.device) past_key_values = None + past_key_values1 = [] tmp_matchness = 0 e2e_tic = 0.0 @@ -271,7 +272,71 @@ def speculative_generate(self, else: draft_current_input_ids = current_input_ids # Target model KV cache to draft model - draft_past_key_values = past_key_values + + # init draft_self_past_key_values:past_key_values1 and assign initial fp32 value + if self.device.type == 'cpu' and step == 1: + for i in range(len(past_key_values)): + len0 = past_key_values[i][0].size(0) + len1 = past_key_values[i][0].size(1) + len2 = past_key_values[i][0].size(2) + len3 = past_key_values[i][0].size(3) + if self.config.model_type == "qwen": + k0 = torch.ones(len0, len2, len1 + max_new_tokens, len3, + dtype=torch.float32) + v0 = torch.ones(len0, len2, len1 + max_new_tokens, len3, + dtype=torch.float32) + k0 = k0.transpose(1, 2) + v0 = v0.transpose(1, 2) + past_key_values1.append((k0, v0)) + past_key_values1[i][0][:, :len1, :, :] = past_key_values[i][0].to( + torch.float32) + past_key_values1[i][1][:, :len1, :, :] = past_key_values[i][1].to( + torch.float32) + elif self.config.model_type == "chatglm": + k0 = torch.ones(len1, len2, len0 + max_new_tokens, len3, + dtype=torch.float32) + v0 = torch.ones(len1, len2, len0 + max_new_tokens, len3, + dtype=torch.float32) + k0 = k0.permute(2, 0, 1, 3) + v0 = v0.permute(2, 0, 1, 3) + past_key_values1.append((k0, v0)) + past_key_values1[i][0][:len0, :, :, :] = past_key_values[i][0].to( + torch.float32) + past_key_values1[i][1][:len0, :, :, :] = past_key_values[i][1].to( + torch.float32) + else: + k0 = torch.ones(len0, len1, len2 + max_new_tokens, len3, + dtype=torch.float32) + v0 = torch.ones(len0, len1, len2 + max_new_tokens, len3, + dtype=torch.float32) + past_key_values1.append((k0, v0)) + past_key_values1[i][0][:, :, :len2, :] = past_key_values[i][0].to( + torch.float32) + past_key_values1[i][1][:, :, :len2, :] = past_key_values[i][1].to( + torch.float32) + + # each iter cut off cur_len kv_cache from past_key_values1 + if self.device.type == 'cpu': + tmp_past_key_values = [] + for i in range(len(past_key_values)): + if self.config.model_type == "qwen": + len1 = past_key_values[0][0].size(1) + k0 = past_key_values1[i][0][:, :len1, :, :] + v0 = past_key_values1[i][1][:, :len1, :, :] + tmp_past_key_values.append((k0, v0)) + elif self.config.model_type == "chatglm": + len0 = past_key_values[0][0].size(0) + k0 = past_key_values1[i][0][:len0, :, :, :] + v0 = past_key_values1[i][1][:len0, :, :, :] + tmp_past_key_values.append((k0, v0)) + else: + len2 = past_key_values[0][0].size(2) + k0 = past_key_values1[i][0][:, :, :len2, :] + v0 = past_key_values1[i][1][:, :, :len2, :] + tmp_past_key_values.append((k0, v0)) + draft_past_key_values = tmp_past_key_values + else: + draft_past_key_values = past_key_values draft_generate_ids[:, 0] = current_input_ids tic = time.time() # Draft model auto-regressively generate k tokens @@ -392,6 +457,31 @@ def speculative_generate(self, v[:, :, :-(max_of_max_matched - max_matched)]) for k, v in past_key_values ] + # Each iter assign new_matched kv_cache to past_key_values1 + if self.device.type == 'cpu': + for i in range(len(past_key_values)): + if self.config.model_type == "qwen": + size = tmp_past_key_values[i][0].size(1) + size1 = past_key_values[i][0].size(1) + past_key_values1[i][0][:, size:size1, :, :] = \ + past_key_values[i][0][:, size:size1, :, :].to(torch.float32) + past_key_values1[i][1][:, size:size1, :, :] = \ + past_key_values[i][1][:, size:size1, :, :].to(torch.float32) + elif self.config.model_type == "chatglm": + size = tmp_past_key_values[i][0].size(0) + size1 = past_key_values[i][0].size(0) + past_key_values1[i][0][size:size1, :, :, :] = \ + past_key_values[i][0][size:size1, :, :, :].to(torch.float32) + past_key_values1[i][1][size:size1, :, :, :] = \ + past_key_values[i][1][size:size1, :, :, :].to(torch.float32) + else: + size = tmp_past_key_values[i][0].size(2) + size1 = past_key_values[i][0].size(2) + past_key_values1[i][0][:, :, size:size1, :] = \ + past_key_values[i][0][:, :, size:size1, :].to(torch.float32) + past_key_values1[i][1][:, :, size:size1, :] = \ + past_key_values[i][1][:, :, size:size1, :].to(torch.float32) + generate_ids[:, step:step+output_ids.size(1)] = output_ids current_input_ids = output_ids[:, -1:]