Add logits processor & qwen eos stop in speculative decoding (#9963)

* add logits processor & qwen eos

* fix style

* fix

* fix

* fix style

* fix style

* support transformers 4.31

* fix style

* fix style

---------

Co-authored-by: rnwang04 <ruonan1.wang@intel.com>
This commit is contained in:
Yina Chen 2024-01-23 15:57:28 +08:00 committed by GitHub
parent 60b35db1f1
commit 36c665667d

View file

@ -14,7 +14,9 @@
# limitations under the License. # limitations under the License.
# #
# Some parts of this file is adapted from # Some parts of this file is adapted from
# https://github.com/dilab-zju/self-speculative-decoding/blob/main/decoding.py # https://github.com/dilab-zju/self-speculative-decoding/blob/main/decoding.py and
# https://github.com/huggingface/transformers/blob/main/src/transformers/generation
# /utils.py
# #
import torch import torch
@ -33,6 +35,8 @@ from bigdl.llm.utils.common import invalidInputError
from transformers import GenerationMixin from transformers import GenerationMixin
original_generate = GenerationMixin.generate original_generate = GenerationMixin.generate
logger = logging.getLogger("bigdl.llm.speculative")
@torch.no_grad() @torch.no_grad()
def generate( def generate(
@ -57,7 +61,7 @@ def generate(
value = kwargs.pop(var, None) value = kwargs.pop(var, None)
if value is not None: if value is not None:
new_speculative_kwargs[var] = value new_speculative_kwargs[var] = value
return self.speculative_generate(input_ids=inputs, return self.speculative_generate(inputs=inputs,
draft_model=self.draft_model, draft_model=self.draft_model,
**new_speculative_kwargs) **new_speculative_kwargs)
else: else:
@ -113,20 +117,123 @@ def clear_benchmarks(self):
@torch.no_grad() @torch.no_grad()
def speculative_generate(self, def speculative_generate(self,
input_ids: Optional[torch.Tensor] = None, inputs: Optional[torch.Tensor] = None,
draft_model=None, draft_model=None,
max_new_tokens=10, max_new_tokens=10,
max_step_draft=8, max_step_draft=8,
th_stop_draft=0.8, th_stop_draft=0.8,
auto_th_stop_draft=True, auto_th_stop_draft=True,
auto_parameters=[1, 0.5, 0.9, 1e-2, 0.9], auto_parameters=[1, 0.5, 0.9, 1e-2, 0.9],
do_sample=False, hf_adjust=False,
top_k=0, generation_config: Optional[GenerationConfig] = None,
top_p=0.85, **sampling_kwargs):
temperature=0.2,
hf_adjust=False):
invalidInputError(draft_model is not None, invalidInputError(draft_model is not None,
"Draft model should be provided.") "Draft model should be provided.")
if generation_config is None:
# legacy: users may modify the model configuration to control generation.
# To trigger this legacy behavior, two conditions must be met
# 1) the generation config must have been created from the
# model config (`_from_model_config` field);
# 2) the generation config must have seen no modification
# since its creation (the hash is the same).
if self.generation_config._from_model_config \
and self.generation_config._original_object_hash == hash(
self.generation_config):
new_generation_config = GenerationConfig.from_model_config(self.config)
if new_generation_config != self.generation_config:
warnings.warn(
"You have modified the pretrained model configuration to control "
"generation. This is a deprecated strategy to control generation "
"and will be removed soon, in a future version. Please use and "
"modify the model generation configuration (see"
" https://huggingface.co/docs/transformers/generation_strategies"
"#default-text-generation-configuration )"
)
self.generation_config = new_generation_config
generation_config = self.generation_config
generation_config = copy.deepcopy(generation_config)
# All unused kwargs must be model kwargs
model_kwargs = generation_config.update(**sampling_kwargs)
generation_config.validate()
self._validate_model_kwargs(model_kwargs.copy())
if generation_config.pad_token_id is None and generation_config.eos_token_id is not None:
if model_kwargs.get("attention_mask", None) is None:
logger.warning(
"The attention mask and the pad token id were not set. As a consequence, "
"you may observe unexpected behavior. Please pass your input's "
"`attention_mask` to obtain reliable results."
)
eos_token_id = generation_config.eos_token_id
if isinstance(eos_token_id, list):
eos_token_id = eos_token_id[0]
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:"
f"{eos_token_id} for open-end generation.")
generation_config.pad_token_id = eos_token_id
# 2. Set generation parameters if not already defined
logits_processor = LogitsProcessorList()
stopping_criteria = StoppingCriteriaList()
# 3. Define model inputs
# inputs_tensor has to be defined
# model_input_name is defined if model-specific keyword input is passed
# otherwise model_input_name is None
# all model-specific keyword inputs are removed from `model_kwargs`
inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(
inputs, generation_config.bos_token_id, model_kwargs
)
batch_size = inputs_tensor.shape[0]
# 4. Define other model kwargs
# Removed not used
# decoder-only models should use left-padding for generation
if not self.config.is_encoder_decoder:
# If `input_ids` was given, check if the last id in any sequence is `pad_token_id`
# Note: If using, `inputs_embeds` this check does not work,
# because we want to be more hands-off.
if (
generation_config.pad_token_id is not None
and len(inputs_tensor.shape) == 2
and torch.sum(inputs_tensor[:, -1] == generation_config.pad_token_id) > 0
):
logger.warning(
"A decoder-only architecture is being used, but right-padding "
"was detected! For correct generation results, please set "
"`padding_side='left'` when initializing the tokenizer."
)
else:
invalidInputError(False, "encoder-decoder models are not supported now.")
# 5. Prepare `input_ids` which will be used for auto-regressive generation
input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids")
# if streamer is not None:
# streamer.put(input_ids.cpu())
input_ids_length = input_ids.shape[-1]
# Here we use sample generation mode
# 8. prepare distribution pre_processing samplers
logits_processor = self._get_logits_processor(
generation_config=generation_config,
input_ids_seq_length=input_ids_length,
encoder_input_ids=inputs_tensor,
prefix_allowed_tokens_fn=None,
logits_processor=logits_processor,
)
# 12. expand input_ids with `num_return_sequences` additional sequences per batch
input_ids, model_kwargs = self._expand_inputs_for_generation(
input_ids=input_ids,
expand_size=generation_config.num_return_sequences,
is_encoder_decoder=self.config.is_encoder_decoder,
**model_kwargs,
)
step = 0 step = 0
step_draft = 0 step_draft = 0
step_verify = 0 step_verify = 0
@ -144,10 +251,6 @@ def speculative_generate(self,
self.clear_benchmarks() self.clear_benchmarks()
if self.config.model_type == "qwen":
from transformers.generation.logits_process import RepetitionPenaltyLogitsProcessor
logit_processor = RepetitionPenaltyLogitsProcessor(
penalty=self.generation_config.repetition_penalty)
# Example: # Example:
# Target model forward for the first token # Target model forward for the first token
# Step 1. target_model(prompt) -> a # Step 1. target_model(prompt) -> a
@ -172,11 +275,10 @@ def speculative_generate(self,
use_cache=True) use_cache=True)
logits = output['logits'] logits = output['logits']
logits = logits[:, -1:] logits = logits[:, -1:]
if self.config.model_type == "qwen": logits[:, -1, :] = logits_processor(current_input_ids, logits[:, -1, :])
temp_input_ids = torch.cat((input_ids, generate_ids[:, :step]), dim=-1) output_ids = sample(logits, do_sample=generation_config.do_sample,
logits[:, -1, :] = logit_processor(temp_input_ids, logits[:, -1, :]) top_k=generation_config.top_k, top_p=generation_config.top_p,
output_ids = sample(logits, do_sample=do_sample, top_k=top_k, temperature=generation_config.temperature)
top_p=top_p, temperature=temperature)
generate_ids[:, step] = output_ids generate_ids[:, step] = output_ids
current_input_ids = output_ids current_input_ids = output_ids
past_key_values = output['past_key_values'] past_key_values = output['past_key_values']
@ -208,15 +310,18 @@ def speculative_generate(self,
past_key_values=draft_past_key_values, past_key_values=draft_past_key_values,
return_dict=True, return_dict=True,
use_cache=True) use_cache=True)
if self.config.model_type == "qwen":
temp_input_ids = torch.cat((input_ids, generate_ids[:, :step], temp_input_ids = torch.cat((input_ids, generate_ids[:, :step],
draft_generate_ids[:, 1:step_draft+1]), dim=-1) draft_generate_ids[:, 1:step_draft+1]), dim=-1)
draft_output['logits'][:, -1, :] = logit_processor( logits = draft_output['logits']
temp_input_ids, logits[:, -1, :] = logits_processor(temp_input_ids,
draft_output['logits'][:, -1, :]) draft_output['logits'][:, -1, :])
draft_output_ids, draft_output_probs = sample( draft_output_ids, draft_output_probs = sample(
draft_output['logits'], return_probs=True, do_sample=do_sample, logits,
top_k=top_k, top_p=top_p, temperature=temperature) return_probs=True,
do_sample=generation_config.do_sample,
top_k=generation_config.top_k,
top_p=generation_config.top_p,
temperature=generation_config.temperature)
draft_generate_ids[:, step_draft+1] = draft_output_ids draft_generate_ids[:, step_draft+1] = draft_output_ids
draft_current_input_ids = draft_output_ids draft_current_input_ids = draft_output_ids
draft_past_key_values = draft_output['past_key_values'] draft_past_key_values = draft_output['past_key_values']
@ -254,14 +359,14 @@ def speculative_generate(self,
return_dict=True, return_dict=True,
use_cache=True) use_cache=True)
logits = output['logits'] logits = output['logits']
if self.config.model_type == "qwen":
temp_input_ids = torch.cat((input_ids, generate_ids[:, :step], temp_input_ids = torch.cat((input_ids, generate_ids[:, :step],
draft_generate_ids[:, 1:step_draft + 2]), dim=-1) draft_generate_ids[:, 1:step_draft + 2]), dim=-1)
for i in range(logits.size(1)): for i in range(logits.size(1)):
logits[:, i, :] = logit_processor(temp_input_ids[:, :input_ids.size(1)+step+i], logits[:, i, :] = logits_processor(temp_input_ids[:, :input_ids.size(1)+step+i],
output['logits'][:, i, :]) output['logits'][:, i, :])
output_ids = sample(logits, do_sample=do_sample, top_k=top_k, output_ids = sample(logits, do_sample=generation_config.do_sample,
top_p=top_p, temperature=temperature) top_k=generation_config.top_k, top_p=generation_config.top_p,
temperature=generation_config.temperature)
if self.device.type == 'xpu': if self.device.type == 'xpu':
torch.xpu.synchronize() torch.xpu.synchronize()
toc = time.time() toc = time.time()
@ -338,8 +443,8 @@ def speculative_generate(self,
# Stop on eos and remove content after eos # Stop on eos and remove content after eos
output_ids_list = output_ids[0].tolist() output_ids_list = output_ids[0].tolist()
if self.config.eos_token_id in output_ids_list: if generation_config.eos_token_id in output_ids_list:
idx = output_ids_list.index(self.config.eos_token_id) idx = output_ids_list.index(generation_config.eos_token_id)
step -= (len(output_ids_list) - idx - 1) step -= (len(output_ids_list) - idx - 1)
break break