Add logits processor & qwen eos stop in speculative decoding (#9963)
* add logits processor & qwen eos * fix style * fix * fix * fix style * fix style * support transformers 4.31 * fix style * fix style --------- Co-authored-by: rnwang04 <ruonan1.wang@intel.com>
This commit is contained in:
parent
60b35db1f1
commit
36c665667d
1 changed files with 140 additions and 35 deletions
|
|
@ -14,7 +14,9 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
# Some parts of this file is adapted from
|
# Some parts of this file is adapted from
|
||||||
# https://github.com/dilab-zju/self-speculative-decoding/blob/main/decoding.py
|
# https://github.com/dilab-zju/self-speculative-decoding/blob/main/decoding.py and
|
||||||
|
# https://github.com/huggingface/transformers/blob/main/src/transformers/generation
|
||||||
|
# /utils.py
|
||||||
#
|
#
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
@ -33,6 +35,8 @@ from bigdl.llm.utils.common import invalidInputError
|
||||||
from transformers import GenerationMixin
|
from transformers import GenerationMixin
|
||||||
original_generate = GenerationMixin.generate
|
original_generate = GenerationMixin.generate
|
||||||
|
|
||||||
|
logger = logging.getLogger("bigdl.llm.speculative")
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def generate(
|
def generate(
|
||||||
|
|
@ -57,7 +61,7 @@ def generate(
|
||||||
value = kwargs.pop(var, None)
|
value = kwargs.pop(var, None)
|
||||||
if value is not None:
|
if value is not None:
|
||||||
new_speculative_kwargs[var] = value
|
new_speculative_kwargs[var] = value
|
||||||
return self.speculative_generate(input_ids=inputs,
|
return self.speculative_generate(inputs=inputs,
|
||||||
draft_model=self.draft_model,
|
draft_model=self.draft_model,
|
||||||
**new_speculative_kwargs)
|
**new_speculative_kwargs)
|
||||||
else:
|
else:
|
||||||
|
|
@ -113,20 +117,123 @@ def clear_benchmarks(self):
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def speculative_generate(self,
|
def speculative_generate(self,
|
||||||
input_ids: Optional[torch.Tensor] = None,
|
inputs: Optional[torch.Tensor] = None,
|
||||||
draft_model=None,
|
draft_model=None,
|
||||||
max_new_tokens=10,
|
max_new_tokens=10,
|
||||||
max_step_draft=8,
|
max_step_draft=8,
|
||||||
th_stop_draft=0.8,
|
th_stop_draft=0.8,
|
||||||
auto_th_stop_draft=True,
|
auto_th_stop_draft=True,
|
||||||
auto_parameters=[1, 0.5, 0.9, 1e-2, 0.9],
|
auto_parameters=[1, 0.5, 0.9, 1e-2, 0.9],
|
||||||
do_sample=False,
|
hf_adjust=False,
|
||||||
top_k=0,
|
generation_config: Optional[GenerationConfig] = None,
|
||||||
top_p=0.85,
|
**sampling_kwargs):
|
||||||
temperature=0.2,
|
|
||||||
hf_adjust=False):
|
|
||||||
invalidInputError(draft_model is not None,
|
invalidInputError(draft_model is not None,
|
||||||
"Draft model should be provided.")
|
"Draft model should be provided.")
|
||||||
|
|
||||||
|
if generation_config is None:
|
||||||
|
# legacy: users may modify the model configuration to control generation.
|
||||||
|
# To trigger this legacy behavior, two conditions must be met
|
||||||
|
# 1) the generation config must have been created from the
|
||||||
|
# model config (`_from_model_config` field);
|
||||||
|
# 2) the generation config must have seen no modification
|
||||||
|
# since its creation (the hash is the same).
|
||||||
|
if self.generation_config._from_model_config \
|
||||||
|
and self.generation_config._original_object_hash == hash(
|
||||||
|
self.generation_config):
|
||||||
|
new_generation_config = GenerationConfig.from_model_config(self.config)
|
||||||
|
if new_generation_config != self.generation_config:
|
||||||
|
warnings.warn(
|
||||||
|
"You have modified the pretrained model configuration to control "
|
||||||
|
"generation. This is a deprecated strategy to control generation "
|
||||||
|
"and will be removed soon, in a future version. Please use and "
|
||||||
|
"modify the model generation configuration (see"
|
||||||
|
" https://huggingface.co/docs/transformers/generation_strategies"
|
||||||
|
"#default-text-generation-configuration )"
|
||||||
|
)
|
||||||
|
self.generation_config = new_generation_config
|
||||||
|
generation_config = self.generation_config
|
||||||
|
|
||||||
|
generation_config = copy.deepcopy(generation_config)
|
||||||
|
# All unused kwargs must be model kwargs
|
||||||
|
model_kwargs = generation_config.update(**sampling_kwargs)
|
||||||
|
generation_config.validate()
|
||||||
|
self._validate_model_kwargs(model_kwargs.copy())
|
||||||
|
|
||||||
|
if generation_config.pad_token_id is None and generation_config.eos_token_id is not None:
|
||||||
|
if model_kwargs.get("attention_mask", None) is None:
|
||||||
|
logger.warning(
|
||||||
|
"The attention mask and the pad token id were not set. As a consequence, "
|
||||||
|
"you may observe unexpected behavior. Please pass your input's "
|
||||||
|
"`attention_mask` to obtain reliable results."
|
||||||
|
)
|
||||||
|
eos_token_id = generation_config.eos_token_id
|
||||||
|
if isinstance(eos_token_id, list):
|
||||||
|
eos_token_id = eos_token_id[0]
|
||||||
|
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:"
|
||||||
|
f"{eos_token_id} for open-end generation.")
|
||||||
|
generation_config.pad_token_id = eos_token_id
|
||||||
|
|
||||||
|
# 2. Set generation parameters if not already defined
|
||||||
|
logits_processor = LogitsProcessorList()
|
||||||
|
stopping_criteria = StoppingCriteriaList()
|
||||||
|
|
||||||
|
# 3. Define model inputs
|
||||||
|
# inputs_tensor has to be defined
|
||||||
|
# model_input_name is defined if model-specific keyword input is passed
|
||||||
|
# otherwise model_input_name is None
|
||||||
|
# all model-specific keyword inputs are removed from `model_kwargs`
|
||||||
|
inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(
|
||||||
|
inputs, generation_config.bos_token_id, model_kwargs
|
||||||
|
)
|
||||||
|
batch_size = inputs_tensor.shape[0]
|
||||||
|
|
||||||
|
# 4. Define other model kwargs
|
||||||
|
# Removed not used
|
||||||
|
|
||||||
|
# decoder-only models should use left-padding for generation
|
||||||
|
if not self.config.is_encoder_decoder:
|
||||||
|
# If `input_ids` was given, check if the last id in any sequence is `pad_token_id`
|
||||||
|
# Note: If using, `inputs_embeds` this check does not work,
|
||||||
|
# because we want to be more hands-off.
|
||||||
|
if (
|
||||||
|
generation_config.pad_token_id is not None
|
||||||
|
and len(inputs_tensor.shape) == 2
|
||||||
|
and torch.sum(inputs_tensor[:, -1] == generation_config.pad_token_id) > 0
|
||||||
|
):
|
||||||
|
logger.warning(
|
||||||
|
"A decoder-only architecture is being used, but right-padding "
|
||||||
|
"was detected! For correct generation results, please set "
|
||||||
|
"`padding_side='left'` when initializing the tokenizer."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
invalidInputError(False, "encoder-decoder models are not supported now.")
|
||||||
|
|
||||||
|
# 5. Prepare `input_ids` which will be used for auto-regressive generation
|
||||||
|
input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids")
|
||||||
|
|
||||||
|
# if streamer is not None:
|
||||||
|
# streamer.put(input_ids.cpu())
|
||||||
|
|
||||||
|
input_ids_length = input_ids.shape[-1]
|
||||||
|
|
||||||
|
# Here we use sample generation mode
|
||||||
|
# 8. prepare distribution pre_processing samplers
|
||||||
|
logits_processor = self._get_logits_processor(
|
||||||
|
generation_config=generation_config,
|
||||||
|
input_ids_seq_length=input_ids_length,
|
||||||
|
encoder_input_ids=inputs_tensor,
|
||||||
|
prefix_allowed_tokens_fn=None,
|
||||||
|
logits_processor=logits_processor,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 12. expand input_ids with `num_return_sequences` additional sequences per batch
|
||||||
|
input_ids, model_kwargs = self._expand_inputs_for_generation(
|
||||||
|
input_ids=input_ids,
|
||||||
|
expand_size=generation_config.num_return_sequences,
|
||||||
|
is_encoder_decoder=self.config.is_encoder_decoder,
|
||||||
|
**model_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
step = 0
|
step = 0
|
||||||
step_draft = 0
|
step_draft = 0
|
||||||
step_verify = 0
|
step_verify = 0
|
||||||
|
|
@ -144,10 +251,6 @@ def speculative_generate(self,
|
||||||
|
|
||||||
self.clear_benchmarks()
|
self.clear_benchmarks()
|
||||||
|
|
||||||
if self.config.model_type == "qwen":
|
|
||||||
from transformers.generation.logits_process import RepetitionPenaltyLogitsProcessor
|
|
||||||
logit_processor = RepetitionPenaltyLogitsProcessor(
|
|
||||||
penalty=self.generation_config.repetition_penalty)
|
|
||||||
# Example:
|
# Example:
|
||||||
# Target model forward for the first token
|
# Target model forward for the first token
|
||||||
# Step 1. target_model(prompt) -> a
|
# Step 1. target_model(prompt) -> a
|
||||||
|
|
@ -172,11 +275,10 @@ def speculative_generate(self,
|
||||||
use_cache=True)
|
use_cache=True)
|
||||||
logits = output['logits']
|
logits = output['logits']
|
||||||
logits = logits[:, -1:]
|
logits = logits[:, -1:]
|
||||||
if self.config.model_type == "qwen":
|
logits[:, -1, :] = logits_processor(current_input_ids, logits[:, -1, :])
|
||||||
temp_input_ids = torch.cat((input_ids, generate_ids[:, :step]), dim=-1)
|
output_ids = sample(logits, do_sample=generation_config.do_sample,
|
||||||
logits[:, -1, :] = logit_processor(temp_input_ids, logits[:, -1, :])
|
top_k=generation_config.top_k, top_p=generation_config.top_p,
|
||||||
output_ids = sample(logits, do_sample=do_sample, top_k=top_k,
|
temperature=generation_config.temperature)
|
||||||
top_p=top_p, temperature=temperature)
|
|
||||||
generate_ids[:, step] = output_ids
|
generate_ids[:, step] = output_ids
|
||||||
current_input_ids = output_ids
|
current_input_ids = output_ids
|
||||||
past_key_values = output['past_key_values']
|
past_key_values = output['past_key_values']
|
||||||
|
|
@ -208,15 +310,18 @@ def speculative_generate(self,
|
||||||
past_key_values=draft_past_key_values,
|
past_key_values=draft_past_key_values,
|
||||||
return_dict=True,
|
return_dict=True,
|
||||||
use_cache=True)
|
use_cache=True)
|
||||||
if self.config.model_type == "qwen":
|
|
||||||
temp_input_ids = torch.cat((input_ids, generate_ids[:, :step],
|
temp_input_ids = torch.cat((input_ids, generate_ids[:, :step],
|
||||||
draft_generate_ids[:, 1:step_draft+1]), dim=-1)
|
draft_generate_ids[:, 1:step_draft+1]), dim=-1)
|
||||||
draft_output['logits'][:, -1, :] = logit_processor(
|
logits = draft_output['logits']
|
||||||
temp_input_ids,
|
logits[:, -1, :] = logits_processor(temp_input_ids,
|
||||||
draft_output['logits'][:, -1, :])
|
draft_output['logits'][:, -1, :])
|
||||||
draft_output_ids, draft_output_probs = sample(
|
draft_output_ids, draft_output_probs = sample(
|
||||||
draft_output['logits'], return_probs=True, do_sample=do_sample,
|
logits,
|
||||||
top_k=top_k, top_p=top_p, temperature=temperature)
|
return_probs=True,
|
||||||
|
do_sample=generation_config.do_sample,
|
||||||
|
top_k=generation_config.top_k,
|
||||||
|
top_p=generation_config.top_p,
|
||||||
|
temperature=generation_config.temperature)
|
||||||
draft_generate_ids[:, step_draft+1] = draft_output_ids
|
draft_generate_ids[:, step_draft+1] = draft_output_ids
|
||||||
draft_current_input_ids = draft_output_ids
|
draft_current_input_ids = draft_output_ids
|
||||||
draft_past_key_values = draft_output['past_key_values']
|
draft_past_key_values = draft_output['past_key_values']
|
||||||
|
|
@ -254,14 +359,14 @@ def speculative_generate(self,
|
||||||
return_dict=True,
|
return_dict=True,
|
||||||
use_cache=True)
|
use_cache=True)
|
||||||
logits = output['logits']
|
logits = output['logits']
|
||||||
if self.config.model_type == "qwen":
|
|
||||||
temp_input_ids = torch.cat((input_ids, generate_ids[:, :step],
|
temp_input_ids = torch.cat((input_ids, generate_ids[:, :step],
|
||||||
draft_generate_ids[:, 1:step_draft + 2]), dim=-1)
|
draft_generate_ids[:, 1:step_draft + 2]), dim=-1)
|
||||||
for i in range(logits.size(1)):
|
for i in range(logits.size(1)):
|
||||||
logits[:, i, :] = logit_processor(temp_input_ids[:, :input_ids.size(1)+step+i],
|
logits[:, i, :] = logits_processor(temp_input_ids[:, :input_ids.size(1)+step+i],
|
||||||
output['logits'][:, i, :])
|
output['logits'][:, i, :])
|
||||||
output_ids = sample(logits, do_sample=do_sample, top_k=top_k,
|
output_ids = sample(logits, do_sample=generation_config.do_sample,
|
||||||
top_p=top_p, temperature=temperature)
|
top_k=generation_config.top_k, top_p=generation_config.top_p,
|
||||||
|
temperature=generation_config.temperature)
|
||||||
if self.device.type == 'xpu':
|
if self.device.type == 'xpu':
|
||||||
torch.xpu.synchronize()
|
torch.xpu.synchronize()
|
||||||
toc = time.time()
|
toc = time.time()
|
||||||
|
|
@ -338,8 +443,8 @@ def speculative_generate(self,
|
||||||
|
|
||||||
# Stop on eos and remove content after eos
|
# Stop on eos and remove content after eos
|
||||||
output_ids_list = output_ids[0].tolist()
|
output_ids_list = output_ids[0].tolist()
|
||||||
if self.config.eos_token_id in output_ids_list:
|
if generation_config.eos_token_id in output_ids_list:
|
||||||
idx = output_ids_list.index(self.config.eos_token_id)
|
idx = output_ids_list.index(generation_config.eos_token_id)
|
||||||
step -= (len(output_ids_list) - idx - 1)
|
step -= (len(output_ids_list) - idx - 1)
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue