Add logits processor & qwen eos stop in speculative decoding (#9963)
* add logits processor & qwen eos * fix style * fix * fix * fix style * fix style * support transformers 4.31 * fix style * fix style --------- Co-authored-by: rnwang04 <ruonan1.wang@intel.com>
This commit is contained in:
		
							parent
							
								
									60b35db1f1
								
							
						
					
					
						commit
						36c665667d
					
				
					 1 changed files with 140 additions and 35 deletions
				
			
		| 
						 | 
					@ -14,7 +14,9 @@
 | 
				
			||||||
# limitations under the License.
 | 
					# limitations under the License.
 | 
				
			||||||
#
 | 
					#
 | 
				
			||||||
# Some parts of this file is adapted from
 | 
					# Some parts of this file is adapted from
 | 
				
			||||||
# https://github.com/dilab-zju/self-speculative-decoding/blob/main/decoding.py
 | 
					# https://github.com/dilab-zju/self-speculative-decoding/blob/main/decoding.py and
 | 
				
			||||||
 | 
					# https://github.com/huggingface/transformers/blob/main/src/transformers/generation
 | 
				
			||||||
 | 
					# /utils.py
 | 
				
			||||||
#
 | 
					#
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
| 
						 | 
					@ -33,6 +35,8 @@ from bigdl.llm.utils.common import invalidInputError
 | 
				
			||||||
from transformers import GenerationMixin
 | 
					from transformers import GenerationMixin
 | 
				
			||||||
original_generate = GenerationMixin.generate
 | 
					original_generate = GenerationMixin.generate
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					logger = logging.getLogger("bigdl.llm.speculative")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@torch.no_grad()
 | 
					@torch.no_grad()
 | 
				
			||||||
def generate(
 | 
					def generate(
 | 
				
			||||||
| 
						 | 
					@ -57,7 +61,7 @@ def generate(
 | 
				
			||||||
            value = kwargs.pop(var, None)
 | 
					            value = kwargs.pop(var, None)
 | 
				
			||||||
            if value is not None:
 | 
					            if value is not None:
 | 
				
			||||||
                new_speculative_kwargs[var] = value
 | 
					                new_speculative_kwargs[var] = value
 | 
				
			||||||
        return self.speculative_generate(input_ids=inputs,
 | 
					        return self.speculative_generate(inputs=inputs,
 | 
				
			||||||
                                         draft_model=self.draft_model,
 | 
					                                         draft_model=self.draft_model,
 | 
				
			||||||
                                         **new_speculative_kwargs)
 | 
					                                         **new_speculative_kwargs)
 | 
				
			||||||
    else:
 | 
					    else:
 | 
				
			||||||
| 
						 | 
					@ -113,20 +117,123 @@ def clear_benchmarks(self):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@torch.no_grad()
 | 
					@torch.no_grad()
 | 
				
			||||||
def speculative_generate(self,
 | 
					def speculative_generate(self,
 | 
				
			||||||
                         input_ids: Optional[torch.Tensor] = None,
 | 
					                         inputs: Optional[torch.Tensor] = None,
 | 
				
			||||||
                         draft_model=None,
 | 
					                         draft_model=None,
 | 
				
			||||||
                         max_new_tokens=10,
 | 
					                         max_new_tokens=10,
 | 
				
			||||||
                         max_step_draft=8,
 | 
					                         max_step_draft=8,
 | 
				
			||||||
                         th_stop_draft=0.8,
 | 
					                         th_stop_draft=0.8,
 | 
				
			||||||
                         auto_th_stop_draft=True,
 | 
					                         auto_th_stop_draft=True,
 | 
				
			||||||
                         auto_parameters=[1, 0.5, 0.9, 1e-2, 0.9],
 | 
					                         auto_parameters=[1, 0.5, 0.9, 1e-2, 0.9],
 | 
				
			||||||
                         do_sample=False,
 | 
					                         hf_adjust=False,
 | 
				
			||||||
                         top_k=0,
 | 
					                         generation_config: Optional[GenerationConfig] = None,
 | 
				
			||||||
                         top_p=0.85,
 | 
					                         **sampling_kwargs):
 | 
				
			||||||
                         temperature=0.2,
 | 
					 | 
				
			||||||
                         hf_adjust=False):
 | 
					 | 
				
			||||||
    invalidInputError(draft_model is not None,
 | 
					    invalidInputError(draft_model is not None,
 | 
				
			||||||
                      "Draft model should be provided.")
 | 
					                      "Draft model should be provided.")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if generation_config is None:
 | 
				
			||||||
 | 
					        # legacy: users may modify the model configuration to control generation.
 | 
				
			||||||
 | 
					        # To trigger this legacy behavior, two conditions must be met
 | 
				
			||||||
 | 
					        # 1) the generation config must have been created from the
 | 
				
			||||||
 | 
					        # model config (`_from_model_config` field);
 | 
				
			||||||
 | 
					        # 2) the generation config must have seen no modification
 | 
				
			||||||
 | 
					        # since its creation (the hash is the same).
 | 
				
			||||||
 | 
					        if self.generation_config._from_model_config \
 | 
				
			||||||
 | 
					            and self.generation_config._original_object_hash == hash(
 | 
				
			||||||
 | 
					                self.generation_config):
 | 
				
			||||||
 | 
					            new_generation_config = GenerationConfig.from_model_config(self.config)
 | 
				
			||||||
 | 
					            if new_generation_config != self.generation_config:
 | 
				
			||||||
 | 
					                warnings.warn(
 | 
				
			||||||
 | 
					                    "You have modified the pretrained model configuration to control "
 | 
				
			||||||
 | 
					                    "generation. This is a deprecated strategy to control generation "
 | 
				
			||||||
 | 
					                    "and will be removed soon, in a future version. Please use and "
 | 
				
			||||||
 | 
					                    "modify the model generation configuration (see"
 | 
				
			||||||
 | 
					                    " https://huggingface.co/docs/transformers/generation_strategies"
 | 
				
			||||||
 | 
					                    "#default-text-generation-configuration )"
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
 | 
					                self.generation_config = new_generation_config
 | 
				
			||||||
 | 
					        generation_config = self.generation_config
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    generation_config = copy.deepcopy(generation_config)
 | 
				
			||||||
 | 
					    # All unused kwargs must be model kwargs
 | 
				
			||||||
 | 
					    model_kwargs = generation_config.update(**sampling_kwargs)
 | 
				
			||||||
 | 
					    generation_config.validate()
 | 
				
			||||||
 | 
					    self._validate_model_kwargs(model_kwargs.copy())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if generation_config.pad_token_id is None and generation_config.eos_token_id is not None:
 | 
				
			||||||
 | 
					        if model_kwargs.get("attention_mask", None) is None:
 | 
				
			||||||
 | 
					            logger.warning(
 | 
				
			||||||
 | 
					                "The attention mask and the pad token id were not set. As a consequence, "
 | 
				
			||||||
 | 
					                "you may observe unexpected behavior. Please pass your input's "
 | 
				
			||||||
 | 
					                "`attention_mask` to obtain reliable results."
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					        eos_token_id = generation_config.eos_token_id
 | 
				
			||||||
 | 
					        if isinstance(eos_token_id, list):
 | 
				
			||||||
 | 
					            eos_token_id = eos_token_id[0]
 | 
				
			||||||
 | 
					        logger.warning(f"Setting `pad_token_id` to `eos_token_id`:"
 | 
				
			||||||
 | 
					                       f"{eos_token_id} for open-end generation.")
 | 
				
			||||||
 | 
					        generation_config.pad_token_id = eos_token_id
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # 2. Set generation parameters if not already defined
 | 
				
			||||||
 | 
					    logits_processor = LogitsProcessorList()
 | 
				
			||||||
 | 
					    stopping_criteria = StoppingCriteriaList()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # 3. Define model inputs
 | 
				
			||||||
 | 
					    # inputs_tensor has to be defined
 | 
				
			||||||
 | 
					    # model_input_name is defined if model-specific keyword input is passed
 | 
				
			||||||
 | 
					    # otherwise model_input_name is None
 | 
				
			||||||
 | 
					    # all model-specific keyword inputs are removed from `model_kwargs`
 | 
				
			||||||
 | 
					    inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(
 | 
				
			||||||
 | 
					        inputs, generation_config.bos_token_id, model_kwargs
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    batch_size = inputs_tensor.shape[0]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # 4. Define other model kwargs
 | 
				
			||||||
 | 
					    # Removed not used
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # decoder-only models should use left-padding for generation
 | 
				
			||||||
 | 
					    if not self.config.is_encoder_decoder:
 | 
				
			||||||
 | 
					        # If `input_ids` was given, check if the last id in any sequence is `pad_token_id`
 | 
				
			||||||
 | 
					        # Note: If using, `inputs_embeds` this check does not work,
 | 
				
			||||||
 | 
					        # because we want to be more hands-off.
 | 
				
			||||||
 | 
					        if (
 | 
				
			||||||
 | 
					            generation_config.pad_token_id is not None
 | 
				
			||||||
 | 
					            and len(inputs_tensor.shape) == 2
 | 
				
			||||||
 | 
					            and torch.sum(inputs_tensor[:, -1] == generation_config.pad_token_id) > 0
 | 
				
			||||||
 | 
					        ):
 | 
				
			||||||
 | 
					            logger.warning(
 | 
				
			||||||
 | 
					                "A decoder-only architecture is being used, but right-padding "
 | 
				
			||||||
 | 
					                "was detected! For correct generation results, please set "
 | 
				
			||||||
 | 
					                "`padding_side='left'` when initializing the tokenizer."
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        invalidInputError(False, "encoder-decoder models are not supported now.")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # 5. Prepare `input_ids` which will be used for auto-regressive generation
 | 
				
			||||||
 | 
					    input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # if streamer is not None:
 | 
				
			||||||
 | 
					    #     streamer.put(input_ids.cpu())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    input_ids_length = input_ids.shape[-1]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Here we use sample generation mode
 | 
				
			||||||
 | 
					    # 8. prepare distribution pre_processing samplers
 | 
				
			||||||
 | 
					    logits_processor = self._get_logits_processor(
 | 
				
			||||||
 | 
					        generation_config=generation_config,
 | 
				
			||||||
 | 
					        input_ids_seq_length=input_ids_length,
 | 
				
			||||||
 | 
					        encoder_input_ids=inputs_tensor,
 | 
				
			||||||
 | 
					        prefix_allowed_tokens_fn=None,
 | 
				
			||||||
 | 
					        logits_processor=logits_processor,
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # 12. expand input_ids with `num_return_sequences` additional sequences per batch
 | 
				
			||||||
 | 
					    input_ids, model_kwargs = self._expand_inputs_for_generation(
 | 
				
			||||||
 | 
					        input_ids=input_ids,
 | 
				
			||||||
 | 
					        expand_size=generation_config.num_return_sequences,
 | 
				
			||||||
 | 
					        is_encoder_decoder=self.config.is_encoder_decoder,
 | 
				
			||||||
 | 
					        **model_kwargs,
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    step = 0
 | 
					    step = 0
 | 
				
			||||||
    step_draft = 0
 | 
					    step_draft = 0
 | 
				
			||||||
    step_verify = 0
 | 
					    step_verify = 0
 | 
				
			||||||
| 
						 | 
					@ -144,10 +251,6 @@ def speculative_generate(self,
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    self.clear_benchmarks()
 | 
					    self.clear_benchmarks()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if self.config.model_type == "qwen":
 | 
					 | 
				
			||||||
        from transformers.generation.logits_process import RepetitionPenaltyLogitsProcessor
 | 
					 | 
				
			||||||
        logit_processor = RepetitionPenaltyLogitsProcessor(
 | 
					 | 
				
			||||||
            penalty=self.generation_config.repetition_penalty)
 | 
					 | 
				
			||||||
    # Example:
 | 
					    # Example:
 | 
				
			||||||
    # Target model forward for the first token
 | 
					    # Target model forward for the first token
 | 
				
			||||||
    # Step 1. target_model(prompt) -> a
 | 
					    # Step 1. target_model(prompt) -> a
 | 
				
			||||||
| 
						 | 
					@ -172,11 +275,10 @@ def speculative_generate(self,
 | 
				
			||||||
                          use_cache=True)
 | 
					                          use_cache=True)
 | 
				
			||||||
            logits = output['logits']
 | 
					            logits = output['logits']
 | 
				
			||||||
            logits = logits[:, -1:]
 | 
					            logits = logits[:, -1:]
 | 
				
			||||||
            if self.config.model_type == "qwen":
 | 
					            logits[:, -1, :] = logits_processor(current_input_ids, logits[:, -1, :])
 | 
				
			||||||
                temp_input_ids = torch.cat((input_ids, generate_ids[:, :step]), dim=-1)
 | 
					            output_ids = sample(logits, do_sample=generation_config.do_sample,
 | 
				
			||||||
                logits[:, -1, :] = logit_processor(temp_input_ids, logits[:, -1, :])
 | 
					                                top_k=generation_config.top_k, top_p=generation_config.top_p,
 | 
				
			||||||
            output_ids = sample(logits, do_sample=do_sample, top_k=top_k,
 | 
					                                temperature=generation_config.temperature)
 | 
				
			||||||
                                top_p=top_p, temperature=temperature)
 | 
					 | 
				
			||||||
            generate_ids[:, step] = output_ids
 | 
					            generate_ids[:, step] = output_ids
 | 
				
			||||||
            current_input_ids = output_ids
 | 
					            current_input_ids = output_ids
 | 
				
			||||||
            past_key_values = output['past_key_values']
 | 
					            past_key_values = output['past_key_values']
 | 
				
			||||||
| 
						 | 
					@ -208,15 +310,18 @@ def speculative_generate(self,
 | 
				
			||||||
                                               past_key_values=draft_past_key_values,
 | 
					                                               past_key_values=draft_past_key_values,
 | 
				
			||||||
                                               return_dict=True,
 | 
					                                               return_dict=True,
 | 
				
			||||||
                                               use_cache=True)
 | 
					                                               use_cache=True)
 | 
				
			||||||
                if self.config.model_type == "qwen":
 | 
					                temp_input_ids = torch.cat((input_ids, generate_ids[:, :step],
 | 
				
			||||||
                    temp_input_ids = torch.cat((input_ids, generate_ids[:, :step],
 | 
					                                            draft_generate_ids[:, 1:step_draft+1]), dim=-1)
 | 
				
			||||||
                                                draft_generate_ids[:, 1:step_draft+1]), dim=-1)
 | 
					                logits = draft_output['logits']
 | 
				
			||||||
                    draft_output['logits'][:, -1, :] = logit_processor(
 | 
					                logits[:, -1, :] = logits_processor(temp_input_ids,
 | 
				
			||||||
                        temp_input_ids,
 | 
					                                                    draft_output['logits'][:, -1, :])
 | 
				
			||||||
                        draft_output['logits'][:, -1, :])
 | 
					 | 
				
			||||||
                draft_output_ids, draft_output_probs = sample(
 | 
					                draft_output_ids, draft_output_probs = sample(
 | 
				
			||||||
                    draft_output['logits'], return_probs=True, do_sample=do_sample,
 | 
					                    logits,
 | 
				
			||||||
                    top_k=top_k, top_p=top_p, temperature=temperature)
 | 
					                    return_probs=True,
 | 
				
			||||||
 | 
					                    do_sample=generation_config.do_sample,
 | 
				
			||||||
 | 
					                    top_k=generation_config.top_k,
 | 
				
			||||||
 | 
					                    top_p=generation_config.top_p,
 | 
				
			||||||
 | 
					                    temperature=generation_config.temperature)
 | 
				
			||||||
                draft_generate_ids[:, step_draft+1] = draft_output_ids
 | 
					                draft_generate_ids[:, step_draft+1] = draft_output_ids
 | 
				
			||||||
                draft_current_input_ids = draft_output_ids
 | 
					                draft_current_input_ids = draft_output_ids
 | 
				
			||||||
                draft_past_key_values = draft_output['past_key_values']
 | 
					                draft_past_key_values = draft_output['past_key_values']
 | 
				
			||||||
| 
						 | 
					@ -254,14 +359,14 @@ def speculative_generate(self,
 | 
				
			||||||
                              return_dict=True,
 | 
					                              return_dict=True,
 | 
				
			||||||
                              use_cache=True)
 | 
					                              use_cache=True)
 | 
				
			||||||
            logits = output['logits']
 | 
					            logits = output['logits']
 | 
				
			||||||
            if self.config.model_type == "qwen":
 | 
					            temp_input_ids = torch.cat((input_ids, generate_ids[:, :step],
 | 
				
			||||||
                temp_input_ids = torch.cat((input_ids, generate_ids[:, :step],
 | 
					                                        draft_generate_ids[:, 1:step_draft + 2]), dim=-1)
 | 
				
			||||||
                                            draft_generate_ids[:, 1:step_draft + 2]), dim=-1)
 | 
					            for i in range(logits.size(1)):
 | 
				
			||||||
                for i in range(logits.size(1)):
 | 
					                logits[:, i, :] = logits_processor(temp_input_ids[:, :input_ids.size(1)+step+i],
 | 
				
			||||||
                    logits[:, i, :] = logit_processor(temp_input_ids[:, :input_ids.size(1)+step+i],
 | 
					                                                   output['logits'][:, i, :])
 | 
				
			||||||
                                                      output['logits'][:, i, :])
 | 
					            output_ids = sample(logits, do_sample=generation_config.do_sample,
 | 
				
			||||||
            output_ids = sample(logits, do_sample=do_sample, top_k=top_k,
 | 
					                                top_k=generation_config.top_k, top_p=generation_config.top_p,
 | 
				
			||||||
                                top_p=top_p, temperature=temperature)
 | 
					                                temperature=generation_config.temperature)
 | 
				
			||||||
            if self.device.type == 'xpu':
 | 
					            if self.device.type == 'xpu':
 | 
				
			||||||
                torch.xpu.synchronize()
 | 
					                torch.xpu.synchronize()
 | 
				
			||||||
            toc = time.time()
 | 
					            toc = time.time()
 | 
				
			||||||
| 
						 | 
					@ -338,8 +443,8 @@ def speculative_generate(self,
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Stop on eos and remove content after eos
 | 
					        # Stop on eos and remove content after eos
 | 
				
			||||||
        output_ids_list = output_ids[0].tolist()
 | 
					        output_ids_list = output_ids[0].tolist()
 | 
				
			||||||
        if self.config.eos_token_id in output_ids_list:
 | 
					        if generation_config.eos_token_id in output_ids_list:
 | 
				
			||||||
            idx = output_ids_list.index(self.config.eos_token_id)
 | 
					            idx = output_ids_list.index(generation_config.eos_token_id)
 | 
				
			||||||
            step -= (len(output_ids_list) - idx - 1)
 | 
					            step -= (len(output_ids_list) - idx - 1)
 | 
				
			||||||
            break
 | 
					            break
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue