diff --git a/python/llm/src/bigdl/llm/transformers/speculative.py b/python/llm/src/bigdl/llm/transformers/speculative.py index 45e8db50..ec02646a 100644 --- a/python/llm/src/bigdl/llm/transformers/speculative.py +++ b/python/llm/src/bigdl/llm/transformers/speculative.py @@ -57,7 +57,8 @@ def generate( new_speculative_kwargs = {} for var in ['max_new_tokens', 'max_step_draft', 'th_stop_draft', 'do_sample', 'top_k', 'top_p', 'temperature', 'hf_adjust', - 'auto_th_stop_draft', 'auto_parameters', 'repetition_penalty']: + 'auto_th_stop_draft', 'auto_parameters', 'repetition_penalty', + 'attention_mask']: value = kwargs.pop(var, None) if value is not None: new_speculative_kwargs[var] = value @@ -115,6 +116,133 @@ def clear_benchmarks(self): self.n_matched = 0 +def _prepare_past_key_values_storage_cpu(self, past_key_values, + max_new_tokens, _enable_ipex=False): + past_key_values_storage = [] + if _enable_ipex: + ipex_past_key_values = [] + cur_len = past_key_values[0][0].size(1) + ipex_past_key_values = [ + [pkv[1].permute(1, 2, 0, 3)[:, :, :cur_len, :], + pkv[2].permute(1, 2, 0, 3)[:, :, :cur_len, :]] + for pkv in past_key_values + ] + + for i in range(len(past_key_values)): + if not _enable_ipex: + len0 = past_key_values[i][0].size(0) + len1 = past_key_values[i][0].size(1) + len2 = past_key_values[i][0].size(2) + len3 = past_key_values[i][0].size(3) + else: + len0 = past_key_values[i][1].size(1) + len1 = past_key_values[i][1].size(2) + len2 = past_key_values[i][0].size(2) # seq length + len3 = past_key_values[i][1].size(3) + if self.config.model_type == "qwen": + k0 = torch.ones(len0, len2, len1 + max_new_tokens, len3, + dtype=torch.float32) + v0 = torch.ones(len0, len2, len1 + max_new_tokens, len3, + dtype=torch.float32) + k0 = k0.transpose(1, 2) + v0 = v0.transpose(1, 2) + past_key_values_storage.append((k0, v0)) + past_key_values_storage[i][0][:, :len1, :, :] = past_key_values[i][0].to( + torch.float32) + past_key_values_storage[i][1][:, :len1, :, :] = past_key_values[i][1].to( + torch.float32) + elif self.config.model_type == "chatglm": + k0 = torch.ones(len1, len2, len0 + max_new_tokens, len3, + dtype=torch.float32) + v0 = torch.ones(len1, len2, len0 + max_new_tokens, len3, + dtype=torch.float32) + k0 = k0.permute(2, 0, 1, 3) + v0 = v0.permute(2, 0, 1, 3) + past_key_values_storage.append((k0, v0)) + past_key_values_storage[i][0][:len0, :, :, :] = past_key_values[i][0].to( + torch.float32) + past_key_values_storage[i][1][:len0, :, :, :] = past_key_values[i][1].to( + torch.float32) + else: + k0 = torch.ones(len0, len1, len2 + max_new_tokens, len3, + dtype=torch.float32) + v0 = torch.ones(len0, len1, len2 + max_new_tokens, len3, + dtype=torch.float32) + past_key_values_storage.append((k0, v0)) + if not _enable_ipex: + past_key_values_storage[i][0][:, :, :len2, :] = past_key_values[i][0].to( + torch.float32) + past_key_values_storage[i][1][:, :, :len2, :] = past_key_values[i][1].to( + torch.float32) + else: + past_key_values_storage[i][0][:, :, :len2, :] = ipex_past_key_values[i][0].to( + torch.float32) + past_key_values_storage[i][1][:, :, :len2, :] = ipex_past_key_values[i][1].to( + torch.float32) + + return past_key_values_storage + + +def _prepare_draft_past_key_values_cpu(self, past_key_values, past_key_values_storage): + tmp_past_key_values = [] + for i in range(len(past_key_values)): + if self.config.model_type == "qwen": + len1 = past_key_values[0][0].size(1) + k0 = past_key_values_storage[i][0][:, :len1, :, :] + v0 = past_key_values_storage[i][1][:, :len1, :, :] + tmp_past_key_values.append((k0, v0)) + elif self.config.model_type == "chatglm": + len0 = past_key_values[0][0].size(0) + k0 = past_key_values_storage[i][0][:len0, :, :, :] + v0 = past_key_values_storage[i][1][:len0, :, :, :] + tmp_past_key_values.append((k0, v0)) + else: + len2 = past_key_values[0][0].size(2) + k0 = past_key_values_storage[i][0][:, :, :len2, :] + v0 = past_key_values_storage[i][1][:, :, :len2, :] + tmp_past_key_values.append((k0, v0)) + return tmp_past_key_values + + +def _update_past_key_values_storage_cpu(self, past_key_values, past_key_values_storage, + original_draft_past_key_values, _enable_ipex=False): + for i in range(len(past_key_values)): + if not _enable_ipex: + if self.config.model_type == "qwen": + size = original_draft_past_key_values[i][0].size(1) + size1 = past_key_values[i][0].size(1) + past_key_values_storage[i][0][:, size:size1, :, :] = \ + past_key_values[i][0][:, size:size1, :, :].to(torch.float32) + past_key_values_storage[i][1][:, size:size1, :, :] = \ + past_key_values[i][1][:, size:size1, :, :].to(torch.float32) + elif self.config.model_type == "chatglm": + size = original_draft_past_key_values[i][0].size(0) + size1 = past_key_values[i][0].size(0) + past_key_values_storage[i][0][size:size1, :, :, :] = \ + past_key_values[i][0][size:size1, :, :, :].to(torch.float32) + past_key_values_storage[i][1][size:size1, :, :, :] = \ + past_key_values[i][1][size:size1, :, :, :].to(torch.float32) + else: + size = original_draft_past_key_values[i][0].size(2) + size1 = past_key_values[i][0].size(2) + past_key_values_storage[i][0][:, :, size:size1, :] = \ + past_key_values[i][0][:, :, size:size1, :].to(torch.float32) + past_key_values_storage[i][1][:, :, size:size1, :] = \ + past_key_values[i][1][:, :, size:size1, :].to(torch.float32) + else: + size = original_draft_past_key_values[i][0].size(2) + size1 = past_key_values[i][0].size(1) + delta_past_key = \ + past_key_values[i][1][size:size1, :, :, :].permute(1, 2, 0, 3) + delta_past_value = \ + past_key_values[i][2][size:size1, :, :, :].permute(1, 2, 0, 3) + + past_key_values_storage[i][0][:, :, size:size1, :] = \ + delta_past_key.to(torch.float32) + past_key_values_storage[i][1][:, :, size:size1, :] = \ + delta_past_value.to(torch.float32) + + @torch.no_grad() def speculative_generate(self, inputs: Optional[torch.Tensor] = None, @@ -126,6 +254,7 @@ def speculative_generate(self, auto_parameters=[1, 0.5, 0.9, 1e-2, 0.9], hf_adjust=False, generation_config: Optional[GenerationConfig] = None, + attention_mask=None, **sampling_kwargs): invalidInputError(draft_model is not None, "Draft model should be provided.") @@ -225,7 +354,15 @@ def speculative_generate(self, draft_generate_ids = torch.empty([input_ids.size(0), draft_gen_length], dtype=torch.long, device=self.device) past_key_values = None - past_key_values1 = [] + past_key_values_storage = [] + + _enable_ipex = os.getenv("BIGDL_OPT_IPEX") + _enable_ipex = (_enable_ipex is not None) and (_enable_ipex.lower() == "true") + if _enable_ipex: + if not ((self.config.model_type == 'baichuan' and self.config.hidden_size == 5120) or + ('llama' in self.config.model_type)): + invalidInputError(False, "BigDL Speculative Decoding with IPEX BF16 only supports \ + Llama and Baichuan2-13b models currently.") tmp_matchness = 0 e2e_tic = 0.0 @@ -252,6 +389,7 @@ def speculative_generate(self, tic = time.time() output = self(input_ids=current_input_ids, past_key_values=past_key_values, + attention_mask=attention_mask, return_dict=True, use_cache=True) logits = output['logits'] @@ -273,68 +411,17 @@ def speculative_generate(self, draft_current_input_ids = current_input_ids # Target model KV cache to draft model - # init draft_self_past_key_values:past_key_values1 and assign initial fp32 value - if self.device.type == 'cpu' and step == 1: - for i in range(len(past_key_values)): - len0 = past_key_values[i][0].size(0) - len1 = past_key_values[i][0].size(1) - len2 = past_key_values[i][0].size(2) - len3 = past_key_values[i][0].size(3) - if self.config.model_type == "qwen": - k0 = torch.ones(len0, len2, len1 + max_new_tokens, len3, - dtype=torch.float32) - v0 = torch.ones(len0, len2, len1 + max_new_tokens, len3, - dtype=torch.float32) - k0 = k0.transpose(1, 2) - v0 = v0.transpose(1, 2) - past_key_values1.append((k0, v0)) - past_key_values1[i][0][:, :len1, :, :] = past_key_values[i][0].to( - torch.float32) - past_key_values1[i][1][:, :len1, :, :] = past_key_values[i][1].to( - torch.float32) - elif self.config.model_type == "chatglm": - k0 = torch.ones(len1, len2, len0 + max_new_tokens, len3, - dtype=torch.float32) - v0 = torch.ones(len1, len2, len0 + max_new_tokens, len3, - dtype=torch.float32) - k0 = k0.permute(2, 0, 1, 3) - v0 = v0.permute(2, 0, 1, 3) - past_key_values1.append((k0, v0)) - past_key_values1[i][0][:len0, :, :, :] = past_key_values[i][0].to( - torch.float32) - past_key_values1[i][1][:len0, :, :, :] = past_key_values[i][1].to( - torch.float32) - else: - k0 = torch.ones(len0, len1, len2 + max_new_tokens, len3, - dtype=torch.float32) - v0 = torch.ones(len0, len1, len2 + max_new_tokens, len3, - dtype=torch.float32) - past_key_values1.append((k0, v0)) - past_key_values1[i][0][:, :, :len2, :] = past_key_values[i][0].to( - torch.float32) - past_key_values1[i][1][:, :, :len2, :] = past_key_values[i][1].to( - torch.float32) - - # each iter cut off cur_len kv_cache from past_key_values1 if self.device.type == 'cpu': - tmp_past_key_values = [] - for i in range(len(past_key_values)): - if self.config.model_type == "qwen": - len1 = past_key_values[0][0].size(1) - k0 = past_key_values1[i][0][:, :len1, :, :] - v0 = past_key_values1[i][1][:, :len1, :, :] - tmp_past_key_values.append((k0, v0)) - elif self.config.model_type == "chatglm": - len0 = past_key_values[0][0].size(0) - k0 = past_key_values1[i][0][:len0, :, :, :] - v0 = past_key_values1[i][1][:len0, :, :, :] - tmp_past_key_values.append((k0, v0)) - else: - len2 = past_key_values[0][0].size(2) - k0 = past_key_values1[i][0][:, :, :len2, :] - v0 = past_key_values1[i][1][:, :, :len2, :] - tmp_past_key_values.append((k0, v0)) - draft_past_key_values = tmp_past_key_values + # init past_key_values_storage and assign initial fp32 value + if step == 1: + past_key_values_storage = \ + _prepare_past_key_values_storage_cpu(self, past_key_values, + max_new_tokens, _enable_ipex) + # each iter cut off cur_len kv_cache from past_key_values1 + draft_past_key_values = \ + _prepare_draft_past_key_values_cpu(self, past_key_values, + past_key_values_storage) + original_draft_past_key_values = draft_past_key_values else: draft_past_key_values = past_key_values draft_generate_ids[:, 0] = current_input_ids @@ -342,17 +429,25 @@ def speculative_generate(self, # Draft model auto-regressively generate k tokens # Early stop when prob less then th_stop_draft for step_draft in range(max_step_draft): + if attention_mask is None: + draft_attention_mask = None + else: + appended_len = step_draft + step + ones_to_append = torch.ones(attention_mask.size(0), appended_len) + draft_attention_mask = torch.cat((attention_mask, ones_to_append), dim=1) if self.config.model_type == "chatglm": past_key_value_len = past_key_values[0][0].shape[0] position_ids = torch.Tensor([[past_key_value_len + step_draft]]).long() draft_output = draft_model(input_ids=draft_current_input_ids, past_key_values=draft_past_key_values, + attention_mask=draft_attention_mask, return_dict=True, use_cache=True, position_ids=position_ids) else: draft_output = draft_model(input_ids=draft_current_input_ids, past_key_values=draft_past_key_values, + attention_mask=draft_attention_mask, return_dict=True, use_cache=True) temp_input_ids = torch.cat((input_ids, generate_ids[:, :step], @@ -388,27 +483,56 @@ def speculative_generate(self, # input.size is k + 1, 1 previous token + k drafts # verified output.size is k + 1, k token + 1 final # Final token is always accepted - if self.config.model_type == "chatglm": - past_key_value_len = past_key_values[0][0].shape[0] - position_ids = torch.arange(drafted_input_ids.shape[1], dtype=torch.long, - device=drafted_input_ids.device) - position_ids = position_ids.unsqueeze(0).repeat(1, 1) + past_key_value_len - output = self(input_ids=drafted_input_ids, - past_key_values=past_key_values, - return_dict=True, - use_cache=True, - position_ids=position_ids) + if attention_mask is None: + cur_attention_mask = None else: - output = self(input_ids=drafted_input_ids, - past_key_values=past_key_values, - return_dict=True, - use_cache=True) - logits = output['logits'] + appended_len = drafted_input_ids.size(1) + step - 1 + ones_to_append = torch.ones(attention_mask.size(0), appended_len) + cur_attention_mask = torch.cat((attention_mask, ones_to_append), dim=1) + if _enable_ipex and hasattr(self, "trace_graph"): + if self.config.model_type == "baichuan": + output = self.trace_graph(input_ids=drafted_input_ids, + attention_mask=cur_attention_mask, + past_key_values=past_key_values, + ) + elif "llama" in self.config.model_type: + past_key_value_len = past_key_values[0][0].shape[2] + position_ids = torch.arange(drafted_input_ids.shape[1], dtype=torch.long, + device=drafted_input_ids.device).unsqueeze(0) + position_ids = position_ids.repeat(1, 1) + past_key_value_len + output = self.trace_graph(input_ids=drafted_input_ids, + attention_mask=cur_attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + ) + logits = output[0] + past_key_values = output[1] + else: + if self.config.model_type == "chatglm": + past_key_value_len = past_key_values[0][0].shape[0] + position_ids = torch.arange(drafted_input_ids.shape[1], dtype=torch.long, + device=drafted_input_ids.device) + position_ids = position_ids.unsqueeze(0).repeat(1, 1) + past_key_value_len + output = self(input_ids=drafted_input_ids, + past_key_values=past_key_values, + attention_mask=cur_attention_mask, + return_dict=True, + use_cache=True, + position_ids=position_ids) + else: + output = self(input_ids=drafted_input_ids, + past_key_values=past_key_values, + attention_mask=cur_attention_mask, + return_dict=True, + use_cache=True) + if isinstance(output, dict): + logits = output['logits'] + past_key_values = output['past_key_values'] temp_input_ids = torch.cat((input_ids, generate_ids[:, :step], draft_generate_ids[:, 1:step_draft + 2]), dim=-1) for i in range(logits.size(1)): logits[:, i, :] = logits_processor(temp_input_ids[:, :input_ids.size(1)+step+i], - output['logits'][:, i, :]) + logits[:, i, :]) output_ids = sample(logits, do_sample=generation_config.do_sample, top_k=generation_config.top_k, top_p=generation_config.top_p, temperature=generation_config.temperature) @@ -418,7 +542,6 @@ def speculative_generate(self, self.verify_time.append(toc - tic) self.generate_time.append(self.draft_time[-1] + self.verify_time[-1]) - past_key_values = output['past_key_values'] # Compare drafts with target verified outputs # Drafts start from [1, k] # Verified output start from [0, k - 1] @@ -432,55 +555,46 @@ def speculative_generate(self, if max_of_max_matched != max_matched: output_ids = output_ids[:, :max_matched] # For Qwen - if self.config.model_type == "qwen": - past_key_values = [ - (k[:, :-(max_of_max_matched - max_matched), :], - v[:, :-(max_of_max_matched - max_matched), :]) - for k, v in past_key_values - ] - elif self.config.model_type == "chatglm": - # for chatglm, cache shape is [sl, bs, nh, hn] - past_key_values = [ - (k[:-(max_of_max_matched - max_matched), :, :, :], - v[:-(max_of_max_matched - max_matched), :, :, :]) - for k, v in past_key_values - ] - elif self.config.model_type == "baichuan": - past_key_values = [ - (k[:, :, :-(max_of_max_matched - max_matched), :], - v[:, :, :-(max_of_max_matched - max_matched), :]) - for k, v in past_key_values - ] + if _enable_ipex: + cur_len = past_key_values[0][0].size(1) + delta = max_of_max_matched - max_matched + tmp = torch.empty(1, (cur_len - delta), (cur_len - delta), 1, + dtype=torch.long, + ).contiguous() + past_key_values = [[tmp, key_cache, value_cache, beam_idx] + for _, key_cache, value_cache, beam_idx in past_key_values] else: - past_key_values = [ - (k[:, :, :-(max_of_max_matched - max_matched)], - v[:, :, :-(max_of_max_matched - max_matched)]) for k, v in past_key_values - ] + if self.config.model_type == "qwen": + past_key_values = [ + (k[:, :-(max_of_max_matched - max_matched), :], + v[:, :-(max_of_max_matched - max_matched), :]) + for k, v in past_key_values + ] + elif self.config.model_type == "chatglm": + # for chatglm, cache shape is [sl, bs, nh, hn] + past_key_values = [ + (k[:-(max_of_max_matched - max_matched), :, :, :], + v[:-(max_of_max_matched - max_matched), :, :, :]) + for k, v in past_key_values + ] + elif self.config.model_type == "baichuan": + past_key_values = [ + (k[:, :, :-(max_of_max_matched - max_matched), :], + v[:, :, :-(max_of_max_matched - max_matched), :]) + for k, v in past_key_values + ] + else: + past_key_values = [ + (k[:, :, :-(max_of_max_matched - max_matched)], + v[:, :, :-(max_of_max_matched - max_matched)]) + for k, v in past_key_values + ] # Each iter assign new_matched kv_cache to past_key_values1 if self.device.type == 'cpu': - for i in range(len(past_key_values)): - if self.config.model_type == "qwen": - size = tmp_past_key_values[i][0].size(1) - size1 = past_key_values[i][0].size(1) - past_key_values1[i][0][:, size:size1, :, :] = \ - past_key_values[i][0][:, size:size1, :, :].to(torch.float32) - past_key_values1[i][1][:, size:size1, :, :] = \ - past_key_values[i][1][:, size:size1, :, :].to(torch.float32) - elif self.config.model_type == "chatglm": - size = tmp_past_key_values[i][0].size(0) - size1 = past_key_values[i][0].size(0) - past_key_values1[i][0][size:size1, :, :, :] = \ - past_key_values[i][0][size:size1, :, :, :].to(torch.float32) - past_key_values1[i][1][size:size1, :, :, :] = \ - past_key_values[i][1][size:size1, :, :, :].to(torch.float32) - else: - size = tmp_past_key_values[i][0].size(2) - size1 = past_key_values[i][0].size(2) - past_key_values1[i][0][:, :, size:size1, :] = \ - past_key_values[i][0][:, :, size:size1, :].to(torch.float32) - past_key_values1[i][1][:, :, size:size1, :] = \ - past_key_values[i][1][:, :, size:size1, :].to(torch.float32) + _update_past_key_values_storage_cpu(self, past_key_values, past_key_values_storage, + original_draft_past_key_values, + _enable_ipex) generate_ids[:, step:step+output_ids.size(1)] = output_ids current_input_ids = output_ids[:, -1:]