From 36c665667d9c27e0a5f714b83a23fa0dd9d70af6 Mon Sep 17 00:00:00 2001 From: Yina Chen <33650826+cyita@users.noreply.github.com> Date: Tue, 23 Jan 2024 15:57:28 +0800 Subject: [PATCH] Add logits processor & qwen eos stop in speculative decoding (#9963) * add logits processor & qwen eos * fix style * fix * fix * fix style * fix style * support transformers 4.31 * fix style * fix style --------- Co-authored-by: rnwang04 --- .../src/bigdl/llm/transformers/speculative.py | 175 ++++++++++++++---- 1 file changed, 140 insertions(+), 35 deletions(-) diff --git a/python/llm/src/bigdl/llm/transformers/speculative.py b/python/llm/src/bigdl/llm/transformers/speculative.py index 61117682..53f2bdf7 100644 --- a/python/llm/src/bigdl/llm/transformers/speculative.py +++ b/python/llm/src/bigdl/llm/transformers/speculative.py @@ -14,7 +14,9 @@ # limitations under the License. # # Some parts of this file is adapted from -# https://github.com/dilab-zju/self-speculative-decoding/blob/main/decoding.py +# https://github.com/dilab-zju/self-speculative-decoding/blob/main/decoding.py and +# https://github.com/huggingface/transformers/blob/main/src/transformers/generation +# /utils.py # import torch @@ -33,6 +35,8 @@ from bigdl.llm.utils.common import invalidInputError from transformers import GenerationMixin original_generate = GenerationMixin.generate +logger = logging.getLogger("bigdl.llm.speculative") + @torch.no_grad() def generate( @@ -57,7 +61,7 @@ def generate( value = kwargs.pop(var, None) if value is not None: new_speculative_kwargs[var] = value - return self.speculative_generate(input_ids=inputs, + return self.speculative_generate(inputs=inputs, draft_model=self.draft_model, **new_speculative_kwargs) else: @@ -113,20 +117,123 @@ def clear_benchmarks(self): @torch.no_grad() def speculative_generate(self, - input_ids: Optional[torch.Tensor] = None, + inputs: Optional[torch.Tensor] = None, draft_model=None, max_new_tokens=10, max_step_draft=8, th_stop_draft=0.8, auto_th_stop_draft=True, auto_parameters=[1, 0.5, 0.9, 1e-2, 0.9], - do_sample=False, - top_k=0, - top_p=0.85, - temperature=0.2, - hf_adjust=False): + hf_adjust=False, + generation_config: Optional[GenerationConfig] = None, + **sampling_kwargs): invalidInputError(draft_model is not None, "Draft model should be provided.") + + if generation_config is None: + # legacy: users may modify the model configuration to control generation. + # To trigger this legacy behavior, two conditions must be met + # 1) the generation config must have been created from the + # model config (`_from_model_config` field); + # 2) the generation config must have seen no modification + # since its creation (the hash is the same). + if self.generation_config._from_model_config \ + and self.generation_config._original_object_hash == hash( + self.generation_config): + new_generation_config = GenerationConfig.from_model_config(self.config) + if new_generation_config != self.generation_config: + warnings.warn( + "You have modified the pretrained model configuration to control " + "generation. This is a deprecated strategy to control generation " + "and will be removed soon, in a future version. Please use and " + "modify the model generation configuration (see" + " https://huggingface.co/docs/transformers/generation_strategies" + "#default-text-generation-configuration )" + ) + self.generation_config = new_generation_config + generation_config = self.generation_config + + generation_config = copy.deepcopy(generation_config) + # All unused kwargs must be model kwargs + model_kwargs = generation_config.update(**sampling_kwargs) + generation_config.validate() + self._validate_model_kwargs(model_kwargs.copy()) + + if generation_config.pad_token_id is None and generation_config.eos_token_id is not None: + if model_kwargs.get("attention_mask", None) is None: + logger.warning( + "The attention mask and the pad token id were not set. As a consequence, " + "you may observe unexpected behavior. Please pass your input's " + "`attention_mask` to obtain reliable results." + ) + eos_token_id = generation_config.eos_token_id + if isinstance(eos_token_id, list): + eos_token_id = eos_token_id[0] + logger.warning(f"Setting `pad_token_id` to `eos_token_id`:" + f"{eos_token_id} for open-end generation.") + generation_config.pad_token_id = eos_token_id + + # 2. Set generation parameters if not already defined + logits_processor = LogitsProcessorList() + stopping_criteria = StoppingCriteriaList() + + # 3. Define model inputs + # inputs_tensor has to be defined + # model_input_name is defined if model-specific keyword input is passed + # otherwise model_input_name is None + # all model-specific keyword inputs are removed from `model_kwargs` + inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs( + inputs, generation_config.bos_token_id, model_kwargs + ) + batch_size = inputs_tensor.shape[0] + + # 4. Define other model kwargs + # Removed not used + + # decoder-only models should use left-padding for generation + if not self.config.is_encoder_decoder: + # If `input_ids` was given, check if the last id in any sequence is `pad_token_id` + # Note: If using, `inputs_embeds` this check does not work, + # because we want to be more hands-off. + if ( + generation_config.pad_token_id is not None + and len(inputs_tensor.shape) == 2 + and torch.sum(inputs_tensor[:, -1] == generation_config.pad_token_id) > 0 + ): + logger.warning( + "A decoder-only architecture is being used, but right-padding " + "was detected! For correct generation results, please set " + "`padding_side='left'` when initializing the tokenizer." + ) + else: + invalidInputError(False, "encoder-decoder models are not supported now.") + + # 5. Prepare `input_ids` which will be used for auto-regressive generation + input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids") + + # if streamer is not None: + # streamer.put(input_ids.cpu()) + + input_ids_length = input_ids.shape[-1] + + # Here we use sample generation mode + # 8. prepare distribution pre_processing samplers + logits_processor = self._get_logits_processor( + generation_config=generation_config, + input_ids_seq_length=input_ids_length, + encoder_input_ids=inputs_tensor, + prefix_allowed_tokens_fn=None, + logits_processor=logits_processor, + ) + + # 12. expand input_ids with `num_return_sequences` additional sequences per batch + input_ids, model_kwargs = self._expand_inputs_for_generation( + input_ids=input_ids, + expand_size=generation_config.num_return_sequences, + is_encoder_decoder=self.config.is_encoder_decoder, + **model_kwargs, + ) + step = 0 step_draft = 0 step_verify = 0 @@ -144,10 +251,6 @@ def speculative_generate(self, self.clear_benchmarks() - if self.config.model_type == "qwen": - from transformers.generation.logits_process import RepetitionPenaltyLogitsProcessor - logit_processor = RepetitionPenaltyLogitsProcessor( - penalty=self.generation_config.repetition_penalty) # Example: # Target model forward for the first token # Step 1. target_model(prompt) -> a @@ -172,11 +275,10 @@ def speculative_generate(self, use_cache=True) logits = output['logits'] logits = logits[:, -1:] - if self.config.model_type == "qwen": - temp_input_ids = torch.cat((input_ids, generate_ids[:, :step]), dim=-1) - logits[:, -1, :] = logit_processor(temp_input_ids, logits[:, -1, :]) - output_ids = sample(logits, do_sample=do_sample, top_k=top_k, - top_p=top_p, temperature=temperature) + logits[:, -1, :] = logits_processor(current_input_ids, logits[:, -1, :]) + output_ids = sample(logits, do_sample=generation_config.do_sample, + top_k=generation_config.top_k, top_p=generation_config.top_p, + temperature=generation_config.temperature) generate_ids[:, step] = output_ids current_input_ids = output_ids past_key_values = output['past_key_values'] @@ -208,15 +310,18 @@ def speculative_generate(self, past_key_values=draft_past_key_values, return_dict=True, use_cache=True) - if self.config.model_type == "qwen": - temp_input_ids = torch.cat((input_ids, generate_ids[:, :step], - draft_generate_ids[:, 1:step_draft+1]), dim=-1) - draft_output['logits'][:, -1, :] = logit_processor( - temp_input_ids, - draft_output['logits'][:, -1, :]) + temp_input_ids = torch.cat((input_ids, generate_ids[:, :step], + draft_generate_ids[:, 1:step_draft+1]), dim=-1) + logits = draft_output['logits'] + logits[:, -1, :] = logits_processor(temp_input_ids, + draft_output['logits'][:, -1, :]) draft_output_ids, draft_output_probs = sample( - draft_output['logits'], return_probs=True, do_sample=do_sample, - top_k=top_k, top_p=top_p, temperature=temperature) + logits, + return_probs=True, + do_sample=generation_config.do_sample, + top_k=generation_config.top_k, + top_p=generation_config.top_p, + temperature=generation_config.temperature) draft_generate_ids[:, step_draft+1] = draft_output_ids draft_current_input_ids = draft_output_ids draft_past_key_values = draft_output['past_key_values'] @@ -254,14 +359,14 @@ def speculative_generate(self, return_dict=True, use_cache=True) logits = output['logits'] - if self.config.model_type == "qwen": - temp_input_ids = torch.cat((input_ids, generate_ids[:, :step], - draft_generate_ids[:, 1:step_draft + 2]), dim=-1) - for i in range(logits.size(1)): - logits[:, i, :] = logit_processor(temp_input_ids[:, :input_ids.size(1)+step+i], - output['logits'][:, i, :]) - output_ids = sample(logits, do_sample=do_sample, top_k=top_k, - top_p=top_p, temperature=temperature) + temp_input_ids = torch.cat((input_ids, generate_ids[:, :step], + draft_generate_ids[:, 1:step_draft + 2]), dim=-1) + for i in range(logits.size(1)): + logits[:, i, :] = logits_processor(temp_input_ids[:, :input_ids.size(1)+step+i], + output['logits'][:, i, :]) + output_ids = sample(logits, do_sample=generation_config.do_sample, + top_k=generation_config.top_k, top_p=generation_config.top_p, + temperature=generation_config.temperature) if self.device.type == 'xpu': torch.xpu.synchronize() toc = time.time() @@ -338,8 +443,8 @@ def speculative_generate(self, # Stop on eos and remove content after eos output_ids_list = output_ids[0].tolist() - if self.config.eos_token_id in output_ids_list: - idx = output_ids_list.index(self.config.eos_token_id) + if generation_config.eos_token_id in output_ids_list: + idx = output_ids_list.index(generation_config.eos_token_id) step -= (len(output_ids_list) - idx - 1) break