Support prompt lookup in ipex-llm (#10768)
* lookup init * add lookup * fix style * remove redundant code * change param name * fix style
This commit is contained in:
parent
d30b22a81b
commit
899d392e2f
3 changed files with 449 additions and 82 deletions
319
python/llm/src/ipex_llm/transformers/lookup.py
Normal file
319
python/llm/src/ipex_llm/transformers/lookup.py
Normal file
|
|
@ -0,0 +1,319 @@
|
||||||
|
#
|
||||||
|
# Copyright 2016 The BigDL Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
# Some parts of this file is adapted from
|
||||||
|
# https://github.com/huggingface/transformers/blob/main/src/transformers/generation
|
||||||
|
# /candidate_generator.py and
|
||||||
|
# https://github.com/huggingface/transformers/blob/main/src/transformers/generation
|
||||||
|
# /utils.py
|
||||||
|
#
|
||||||
|
|
||||||
|
from typing import Callable, List, Optional, Tuple
|
||||||
|
import torch
|
||||||
|
import time
|
||||||
|
import copy
|
||||||
|
import logging
|
||||||
|
from transformers import GenerationConfig, LogitsProcessorList, StoppingCriteriaList
|
||||||
|
from ipex_llm.transformers.speculative import greedy, deepmind_sample, logits_to_probs,\
|
||||||
|
_crop_past_key_values, _prepare_generate_args, _non_cpu_ipex_verify, clear_benchmarks
|
||||||
|
from ipex_llm.utils.common import invalidInputError
|
||||||
|
|
||||||
|
logger = logging.getLogger("ipex_llm.lookup")
|
||||||
|
|
||||||
|
# patch GenerationMixin.generate
|
||||||
|
from transformers import GenerationMixin
|
||||||
|
original_generate = GenerationMixin.generate
|
||||||
|
query_group_size = 16
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def generate(
|
||||||
|
self,
|
||||||
|
inputs: Optional[torch.Tensor] = None,
|
||||||
|
generation_config: Optional[GenerationConfig] = None,
|
||||||
|
logits_processor: Optional[LogitsProcessorList] = None,
|
||||||
|
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
||||||
|
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]]=None,
|
||||||
|
synced_gpus: Optional[bool] = None,
|
||||||
|
assistant_model: Optional["PreTrainedModel"] = None,
|
||||||
|
streamer: Optional["BaseStreamer"] = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
lookahead = kwargs.pop("lookahead", None)
|
||||||
|
if lookahead:
|
||||||
|
from ipex_llm.transformers.convert import get_enable_ipex
|
||||||
|
_enable_ipex = get_enable_ipex()
|
||||||
|
|
||||||
|
if self.device.type == "cpu" and _enable_ipex:
|
||||||
|
|
||||||
|
logger.warning("Prompt lookup is currently not supported on CPU with IPEX, "
|
||||||
|
"fallback to original generate.")
|
||||||
|
kwargs.pop("max_matching_ngram_size")
|
||||||
|
else:
|
||||||
|
# Do prompt lookup generation
|
||||||
|
return self.lookup_generate(inputs=inputs,
|
||||||
|
num_output_tokens=lookahead,
|
||||||
|
generation_config=generation_config,
|
||||||
|
logits_processor=logits_processor,
|
||||||
|
stopping_criteria=stopping_criteria,
|
||||||
|
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
|
return original_generate(self,
|
||||||
|
inputs=inputs,
|
||||||
|
generation_config=generation_config,
|
||||||
|
logits_processor=logits_processor,
|
||||||
|
stopping_criteria=stopping_criteria,
|
||||||
|
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
|
||||||
|
synced_gpus=synced_gpus,
|
||||||
|
assistant_model=assistant_model,
|
||||||
|
streamer=streamer,
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
|
GenerationMixin.generate = generate
|
||||||
|
|
||||||
|
|
||||||
|
# This class is copied from https://github.com/huggingface/transformers/blob/main/src
|
||||||
|
# /transformers/generation/candidate_generator.py
|
||||||
|
class PromptLookupCandidateGenerator():
|
||||||
|
"""
|
||||||
|
`CandidateGenerator` class to be used for prompt lookup generation.
|
||||||
|
This class generates candidates
|
||||||
|
by looking up
|
||||||
|
likely continuations in the provided prompt (input_ids) itself.
|
||||||
|
Read the following blog post for more information:
|
||||||
|
https://github.com/apoorvumang/prompt-lookup-decoding
|
||||||
|
|
||||||
|
Args:
|
||||||
|
max_matching_ngram_size (`int`):
|
||||||
|
The maximum ngram size to be considered for matching in the prompt
|
||||||
|
num_output_tokens (`int`):
|
||||||
|
The number of tokens to be output as candidate tokens.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_output_tokens: int = 10,
|
||||||
|
max_matching_ngram_size: int = None,
|
||||||
|
):
|
||||||
|
self.num_output_tokens = num_output_tokens
|
||||||
|
self.max_matching_ngram_size = max_matching_ngram_size if max_matching_ngram_size else 2
|
||||||
|
|
||||||
|
invalidInputError(self.max_matching_ngram_size > 0 and self.num_output_tokens > 0,
|
||||||
|
"Invalid max_matching_ngram_size or num_output_tokens")
|
||||||
|
|
||||||
|
def get_candidates(self,
|
||||||
|
input_ids: torch.LongTensor)-> Tuple[torch.LongTensor,
|
||||||
|
Optional[torch.FloatTensor]]:
|
||||||
|
"""
|
||||||
|
Fetches the candidates to be tried for the current input.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||||
|
Indices of input sequence tokens in the vocabulary.
|
||||||
|
[What are input IDs?](../glossary#input-ids)
|
||||||
|
|
||||||
|
Return:
|
||||||
|
`torch.LongTensor` of shape `(num_candidates, candidate_length)`:
|
||||||
|
The candidate sequences to be tried.
|
||||||
|
"""
|
||||||
|
input_length = input_ids.size(1)
|
||||||
|
|
||||||
|
chosen_ids = None
|
||||||
|
match_found = False
|
||||||
|
for ngram_size in range(min(self.max_matching_ngram_size, input_length - 1), 0, -1):
|
||||||
|
# Create sliding windows of size ngram_size
|
||||||
|
windows = input_ids.unfold(dimension=1, size=ngram_size, step=1)
|
||||||
|
|
||||||
|
# Convert ngram to a tensor for comparison
|
||||||
|
ngram_tensor = input_ids[0, -ngram_size:]
|
||||||
|
|
||||||
|
# Find where the windows match the ngram
|
||||||
|
matches = (windows == ngram_tensor).all(dim=2)
|
||||||
|
|
||||||
|
# Get the indices of matches
|
||||||
|
match_indices = matches.nonzero(as_tuple=True)[1]
|
||||||
|
|
||||||
|
# Iterate through match indices to find a valid continuation
|
||||||
|
for idx in match_indices:
|
||||||
|
start_idx = idx + ngram_size
|
||||||
|
end_idx = start_idx + self.num_output_tokens
|
||||||
|
end_idx = min(end_idx, input_length)
|
||||||
|
|
||||||
|
if start_idx < end_idx:
|
||||||
|
chosen_ids = input_ids[0, start_idx:end_idx]
|
||||||
|
match_found = True
|
||||||
|
break
|
||||||
|
if match_found:
|
||||||
|
break
|
||||||
|
|
||||||
|
if chosen_ids is None or len(chosen_ids) == 0:
|
||||||
|
# In case we didn't find a match return the input sequence unchanged,
|
||||||
|
# reverts back to autoregressive decoding
|
||||||
|
return input_ids, None
|
||||||
|
|
||||||
|
# Now need extend input_ids with chosen_ids
|
||||||
|
chosen_ids = chosen_ids.unsqueeze(0)
|
||||||
|
candidate_input_ids = torch.cat((input_ids, chosen_ids), dim=1)
|
||||||
|
# assisted_generation expects logits as well, but we don't have those here,
|
||||||
|
# so returning None
|
||||||
|
return candidate_input_ids, None
|
||||||
|
|
||||||
|
def update_candidate_strategy(self, input_ids: torch.LongTensor,
|
||||||
|
scores: torch.FloatTensor, num_matches: int):
|
||||||
|
"""
|
||||||
|
Updates the candidate generation strategy based on the outcomes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||||
|
Indices of input sequence tokens in the vocabulary.
|
||||||
|
[What are input IDs?](../glossary#input-ids)
|
||||||
|
scores (`torch.FloatTensor` of shape `(batch_size, candidate_length,
|
||||||
|
config.vocab_size)`):
|
||||||
|
Prediction scores of a language modeling head. These can be logits for each
|
||||||
|
vocabulary when not using beam search or log softmax for each vocabulary
|
||||||
|
token when using beam search
|
||||||
|
num_matches (`int`):
|
||||||
|
The number of matches between the candidate sequences and the model predictions.
|
||||||
|
"""
|
||||||
|
# Currently does nothing
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def lookup_generate(self,
|
||||||
|
inputs: Optional[torch.Tensor] = None,
|
||||||
|
max_new_tokens: int = 10,
|
||||||
|
num_output_tokens: int = 10,
|
||||||
|
max_matching_ngram_size: int = None,
|
||||||
|
generation_config: Optional[GenerationConfig] = None,
|
||||||
|
attention_mask=None,
|
||||||
|
**sampling_kwargs):
|
||||||
|
input_ids, generation_config, logits_processor, stopping_criteria, \
|
||||||
|
model_kwargs = _prepare_generate_args(self, inputs, generation_config,
|
||||||
|
**sampling_kwargs)
|
||||||
|
|
||||||
|
candidates_generator = PromptLookupCandidateGenerator(
|
||||||
|
num_output_tokens=num_output_tokens,
|
||||||
|
max_matching_ngram_size=max_matching_ngram_size)
|
||||||
|
|
||||||
|
step = 0
|
||||||
|
step_verify = 0
|
||||||
|
|
||||||
|
clear_benchmarks(self)
|
||||||
|
|
||||||
|
past_key_values = None
|
||||||
|
input_len = input_ids.shape[1]
|
||||||
|
|
||||||
|
while True:
|
||||||
|
if step >= max_new_tokens:
|
||||||
|
break
|
||||||
|
|
||||||
|
if step == 0:
|
||||||
|
# first token use full model
|
||||||
|
tic = time.time()
|
||||||
|
output = self(input_ids=input_ids,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
return_dict=True,
|
||||||
|
use_cache=True)
|
||||||
|
logits = output['logits']
|
||||||
|
logits = logits[:, -1:]
|
||||||
|
logits[:, -1, :] = logits_processor(input_ids, logits[:, -1, :])
|
||||||
|
if generation_config.do_sample:
|
||||||
|
output_ids, prob_list = deepmind_sample(logits,
|
||||||
|
top_k=generation_config.top_k,
|
||||||
|
top_p=generation_config.top_p,
|
||||||
|
temperature=generation_config.temperature)
|
||||||
|
else:
|
||||||
|
output_ids = greedy(logits)
|
||||||
|
input_ids = torch.cat((input_ids, output_ids), dim=-1)
|
||||||
|
past_key_values = output['past_key_values']
|
||||||
|
step += 1
|
||||||
|
if self.device.type == 'xpu':
|
||||||
|
torch.xpu.synchronize()
|
||||||
|
toc = time.time()
|
||||||
|
self.first_token_time = toc - tic
|
||||||
|
e2e_tic = time.time()
|
||||||
|
else:
|
||||||
|
cur_len = input_ids.shape[-1]
|
||||||
|
toc = time.time()
|
||||||
|
candidate_input_ids, _ = candidates_generator.get_candidates(input_ids=input_ids)
|
||||||
|
candidate_length = candidate_input_ids.shape[1] - input_ids.shape[1]
|
||||||
|
verify_input_ids = candidate_input_ids[:, -candidate_length - 1:]
|
||||||
|
self.draft_num.append(candidate_length)
|
||||||
|
tic = time.time()
|
||||||
|
self.draft_time.append(tic - toc)
|
||||||
|
output = _non_cpu_ipex_verify(self, verify_input_ids, past_key_values,
|
||||||
|
attention_mask, return_dict=True, use_cache=True)
|
||||||
|
if isinstance(output, dict):
|
||||||
|
logits = output['logits']
|
||||||
|
past_key_values = output['past_key_values']
|
||||||
|
|
||||||
|
if len(logits_processor) > 0:
|
||||||
|
for i in range(candidate_length + 1):
|
||||||
|
logits[:, i, :] = logits_processor(candidate_input_ids[:, : cur_len + i],
|
||||||
|
logits[:, i, :])
|
||||||
|
|
||||||
|
if generation_config.do_sample:
|
||||||
|
output_ids, prob_list = deepmind_sample(logits,
|
||||||
|
top_k=generation_config.top_k,
|
||||||
|
top_p=generation_config.top_p,
|
||||||
|
temperature=generation_config.temperature)
|
||||||
|
else:
|
||||||
|
output_ids = greedy(logits)
|
||||||
|
|
||||||
|
if self.device.type == 'xpu':
|
||||||
|
torch.xpu.synchronize()
|
||||||
|
toc = time.time()
|
||||||
|
self.verify_time.append(toc - tic)
|
||||||
|
|
||||||
|
# Compare drafts with target verified outputs
|
||||||
|
# Drafts start from [1, k]
|
||||||
|
# Verified output start from [0, k - 1]
|
||||||
|
# including the one generated by the base model
|
||||||
|
max_matched = ((output_ids[:, :-1] != verify_input_ids[:, 1:]).cumsum(-1) == 0)
|
||||||
|
max_matched = max_matched.sum(-1).item() + 1
|
||||||
|
|
||||||
|
max_of_max_matched = output_ids.size(1)
|
||||||
|
# Accept number is max_matched, min is 1
|
||||||
|
self.accept_num.append(max_matched)
|
||||||
|
self.n_matched += max_matched - 1
|
||||||
|
self.n_drafted += candidate_length
|
||||||
|
|
||||||
|
# Clean up target model KV cache
|
||||||
|
if max_of_max_matched != max_matched:
|
||||||
|
output_ids = output_ids[:, :max_matched]
|
||||||
|
new_cache_size = max_of_max_matched - max_matched
|
||||||
|
past_key_values = _crop_past_key_values(self, past_key_values, new_cache_size)
|
||||||
|
|
||||||
|
input_ids = torch.cat((input_ids, output_ids), dim=-1)
|
||||||
|
|
||||||
|
step += output_ids.size(1)
|
||||||
|
step_verify += 1
|
||||||
|
|
||||||
|
# Stop on eos and remove content after eos
|
||||||
|
output_ids_list = output_ids[0].tolist()
|
||||||
|
if generation_config.eos_token_id in output_ids_list:
|
||||||
|
idx = output_ids_list.index(generation_config.eos_token_id)
|
||||||
|
step -= (len(output_ids_list) - idx - 1)
|
||||||
|
break
|
||||||
|
|
||||||
|
step = min(step, max_new_tokens)
|
||||||
|
e2e_toc = time.time()
|
||||||
|
self.n_token_generated = step
|
||||||
|
self.e2e_time_without_first = e2e_toc - e2e_tic
|
||||||
|
|
||||||
|
return input_ids[:, : input_len + step]
|
||||||
|
|
@ -330,7 +330,8 @@ class _BaseAutoModelClass:
|
||||||
model = cls.load_convert(q_k, optimize_model, *args, **kwargs)
|
model = cls.load_convert(q_k, optimize_model, *args, **kwargs)
|
||||||
|
|
||||||
if speculative:
|
if speculative:
|
||||||
from .speculative import speculative_generate, clear_benchmarks
|
from .speculative import speculative_generate, clear_benchmarks,\
|
||||||
|
_crop_past_key_values
|
||||||
# load a sym_int4 model as draft model
|
# load a sym_int4 model as draft model
|
||||||
draft_model = cls.load_convert('sym_int4', optimize_model, *args, **kwargs)
|
draft_model = cls.load_convert('sym_int4', optimize_model, *args, **kwargs)
|
||||||
model.draft_model = draft_model
|
model.draft_model = draft_model
|
||||||
|
|
@ -338,6 +339,12 @@ class _BaseAutoModelClass:
|
||||||
# add speculative_generate to pretrained model dynamically
|
# add speculative_generate to pretrained model dynamically
|
||||||
model.clear_benchmarks = types.MethodType(clear_benchmarks, model)
|
model.clear_benchmarks = types.MethodType(clear_benchmarks, model)
|
||||||
model.speculative_generate = types.MethodType(speculative_generate, model)
|
model.speculative_generate = types.MethodType(speculative_generate, model)
|
||||||
|
model._crop_past_key_values = types.MethodType(_crop_past_key_values, model)
|
||||||
|
|
||||||
|
# add lookup_generate to pretrained model
|
||||||
|
from .lookup import lookup_generate
|
||||||
|
import types
|
||||||
|
model.lookup_generate = types.MethodType(lookup_generate, model)
|
||||||
else:
|
else:
|
||||||
# load default
|
# load default
|
||||||
model = cls.HF_Model.from_pretrained(*args, **kwargs)
|
model = cls.HF_Model.from_pretrained(*args, **kwargs)
|
||||||
|
|
|
||||||
|
|
@ -439,26 +439,49 @@ def _check_and_extend_kv_cache(past_key_values, max_step_draft, kv_alloc_block_l
|
||||||
return past_key_values, not enough_kv_room
|
return past_key_values, not enough_kv_room
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
def _crop_past_key_values(self, past_key_values, new_cache_size, _enable_ipex=False):
|
||||||
def speculative_generate(self,
|
if _enable_ipex:
|
||||||
inputs: Optional[torch.Tensor] = None,
|
cur_len = past_key_values[0][0].size(1)
|
||||||
draft_model=None,
|
delta = new_cache_size
|
||||||
max_new_tokens=10,
|
tmp = torch.empty(1, (cur_len - delta), (cur_len - delta), 1,
|
||||||
max_step_draft=8,
|
dtype=torch.long).contiguous()
|
||||||
th_stop_draft=0.8,
|
past_key_values = [[tmp, key_cache, value_cache, beam_idx]
|
||||||
auto_th_stop_draft=True,
|
for _, key_cache, value_cache, beam_idx in past_key_values]
|
||||||
auto_parameters=[1, 0.5, 0.9, 1e-2, 0.9],
|
else:
|
||||||
hf_adjust=False,
|
if self.config.model_type in ["qwen"]:
|
||||||
min_step_draft=3,
|
past_key_values = [
|
||||||
generation_config: Optional[GenerationConfig] = None,
|
(k[:, :-(new_cache_size), :],
|
||||||
attention_mask=None,
|
v[:, :-(new_cache_size), :])
|
||||||
**sampling_kwargs):
|
for k, v in past_key_values
|
||||||
invalidInputError(draft_model is not None,
|
]
|
||||||
"Draft model should be provided.")
|
elif self.config.model_type == "chatglm":
|
||||||
# min_step_draft >= 1. Since the max_step_draft may adjust,
|
# for chatglm, cache shape is [sl, bs, nh, hn]
|
||||||
# min_step_draft can > max_step_draft
|
past_key_values = [
|
||||||
min_step_draft = min_step_draft if min_step_draft >= 1 else 1
|
(k[:-(new_cache_size), :, :, :],
|
||||||
|
v[:-(new_cache_size), :, :, :])
|
||||||
|
for k, v in past_key_values
|
||||||
|
]
|
||||||
|
elif self.config.model_type in ["baichuan", "gptj"]:
|
||||||
|
past_key_values = [
|
||||||
|
(k[:, :, :-(new_cache_size), :],
|
||||||
|
v[:, :, :-(new_cache_size), :])
|
||||||
|
for k, v in past_key_values
|
||||||
|
]
|
||||||
|
elif self.config.model_type == "gpt_bigcode":
|
||||||
|
past_key_values = [
|
||||||
|
kv[:, :-(new_cache_size)]
|
||||||
|
for kv in past_key_values
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
past_key_values = [
|
||||||
|
(k[:, :, :-(new_cache_size)],
|
||||||
|
v[:, :, :-(new_cache_size)])
|
||||||
|
for k, v in past_key_values
|
||||||
|
]
|
||||||
|
return past_key_values
|
||||||
|
|
||||||
|
|
||||||
|
def _prepare_generate_args(self, inputs, generation_config, **sampling_kwargs):
|
||||||
if generation_config is None:
|
if generation_config is None:
|
||||||
generation_config = self.generation_config
|
generation_config = self.generation_config
|
||||||
|
|
||||||
|
|
@ -494,10 +517,27 @@ def speculative_generate(self,
|
||||||
inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(
|
inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(
|
||||||
inputs, generation_config.bos_token_id, model_kwargs
|
inputs, generation_config.bos_token_id, model_kwargs
|
||||||
)
|
)
|
||||||
batch_size = inputs_tensor.shape[0]
|
|
||||||
|
|
||||||
# 4. Define other model kwargs
|
# 4. Define other model kwargs
|
||||||
# Removed not used
|
# model_kwargs["output_attentions"] = generation_config.output_attentions
|
||||||
|
# model_kwargs["output_hidden_states"] = generation_config.output_hidden_states
|
||||||
|
# # decoder-only models with inputs_embeds forwarding must use caching
|
||||||
|
# # (otherwise we can't detect whether we are generating the first new token or not,
|
||||||
|
# # and we only want to use the embeddings for the first new token)
|
||||||
|
# if not self.config.is_encoder_decoder and model_input_name == "inputs_embeds":
|
||||||
|
# model_kwargs["use_cache"] = True
|
||||||
|
# else:
|
||||||
|
# model_kwargs["use_cache"] = generation_config.use_cache
|
||||||
|
|
||||||
|
# accepts_attention_mask = "attention_mask" in set(
|
||||||
|
# inspect.signature(self.forward).parameters.keys())
|
||||||
|
# requires_attention_mask = "encoder_outputs" not in model_kwargs
|
||||||
|
|
||||||
|
# if model_kwargs.get("attention_mask", None) is None and \
|
||||||
|
# requires_attention_mask and accepts_attention_mask:
|
||||||
|
# model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
|
||||||
|
# inputs_tensor, generation_config.pad_token_id, generation_config.eos_token_id
|
||||||
|
# )
|
||||||
|
|
||||||
# decoder-only models should use left-padding for generation
|
# decoder-only models should use left-padding for generation
|
||||||
if not self.config.is_encoder_decoder:
|
if not self.config.is_encoder_decoder:
|
||||||
|
|
@ -543,6 +583,61 @@ def speculative_generate(self,
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
return input_ids, generation_config, logits_processor, stopping_criteria, model_kwargs
|
||||||
|
|
||||||
|
|
||||||
|
def _non_cpu_ipex_verify(self, verify_input_ids, past_key_values, cur_attention_mask=None,
|
||||||
|
return_dict=True, use_cache=True):
|
||||||
|
forward_args = {
|
||||||
|
"input_ids": verify_input_ids,
|
||||||
|
"past_key_values": past_key_values,
|
||||||
|
"return_dict": return_dict,
|
||||||
|
"use_cache": use_cache,
|
||||||
|
}
|
||||||
|
if cur_attention_mask:
|
||||||
|
forward_args["attention_mask"] = cur_attention_mask
|
||||||
|
|
||||||
|
if self.config.model_type == "chatglm":
|
||||||
|
past_key_value_len = past_key_values[0][0].shape[0]
|
||||||
|
position_ids = torch.arange(verify_input_ids.shape[1], dtype=torch.long,
|
||||||
|
device=verify_input_ids.device)
|
||||||
|
position_ids = position_ids.unsqueeze(0).repeat(1, 1) + past_key_value_len
|
||||||
|
forward_args["position_ids"] = position_ids
|
||||||
|
elif self.config.model_type == "gptj":
|
||||||
|
past_length = past_key_values[0][0].size(2)
|
||||||
|
input_len = verify_input_ids.shape[1]
|
||||||
|
position_ids = torch.arange(past_length, input_len + past_length,
|
||||||
|
dtype=torch.long, device=verify_input_ids.device)
|
||||||
|
position_ids = position_ids.unsqueeze(0).view(-1, input_len)
|
||||||
|
forward_args["position_ids"] = position_ids
|
||||||
|
|
||||||
|
return self(**forward_args)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def speculative_generate(self,
|
||||||
|
inputs: Optional[torch.Tensor] = None,
|
||||||
|
draft_model=None,
|
||||||
|
max_new_tokens=10,
|
||||||
|
max_step_draft=8,
|
||||||
|
th_stop_draft=0.8,
|
||||||
|
auto_th_stop_draft=True,
|
||||||
|
auto_parameters=[1, 0.5, 0.9, 1e-2, 0.9],
|
||||||
|
hf_adjust=False,
|
||||||
|
min_step_draft=3,
|
||||||
|
generation_config: Optional[GenerationConfig] = None,
|
||||||
|
attention_mask=None,
|
||||||
|
**sampling_kwargs):
|
||||||
|
invalidInputError(draft_model is not None,
|
||||||
|
"Draft model should be provided.")
|
||||||
|
# min_step_draft >= 1. Since the max_step_draft may adjust,
|
||||||
|
# min_step_draft can > max_step_draft
|
||||||
|
min_step_draft = min_step_draft if min_step_draft >= 1 else 1
|
||||||
|
|
||||||
|
input_ids, generation_config, logits_processor, stopping_criteria, \
|
||||||
|
model_kwargs = _prepare_generate_args(self, inputs, generation_config,
|
||||||
|
**sampling_kwargs)
|
||||||
|
|
||||||
step = 0
|
step = 0
|
||||||
step_draft = 0
|
step_draft = 0
|
||||||
step_verify = 0
|
step_verify = 0
|
||||||
|
|
@ -851,27 +946,8 @@ def speculative_generate(self,
|
||||||
logits = output[0]
|
logits = output[0]
|
||||||
past_key_values = output[1]
|
past_key_values = output[1]
|
||||||
else:
|
else:
|
||||||
forward_args = {
|
output = _non_cpu_ipex_verify(self, drafted_input_ids, past_key_values,
|
||||||
"input_ids": drafted_input_ids,
|
cur_attention_mask, return_dict=True, use_cache=True)
|
||||||
"past_key_values": past_key_values,
|
|
||||||
"attention_mask": cur_attention_mask,
|
|
||||||
"return_dict": True,
|
|
||||||
"use_cache": True,
|
|
||||||
}
|
|
||||||
if self.config.model_type == "chatglm":
|
|
||||||
past_key_value_len = past_key_values[0][0].shape[0]
|
|
||||||
position_ids = torch.arange(drafted_input_ids.shape[1], dtype=torch.long,
|
|
||||||
device=drafted_input_ids.device)
|
|
||||||
position_ids = position_ids.unsqueeze(0).repeat(1, 1) + past_key_value_len
|
|
||||||
forward_args["position_ids"] = position_ids
|
|
||||||
elif self.config.model_type == "gptj":
|
|
||||||
past_length = past_key_values[0][0].size(2)
|
|
||||||
input_len = drafted_input_ids.shape[1]
|
|
||||||
position_ids = torch.arange(past_length, input_len + past_length,
|
|
||||||
dtype=torch.long, device=drafted_input_ids.device)
|
|
||||||
position_ids = position_ids.unsqueeze(0).view(-1, input_len)
|
|
||||||
forward_args["position_ids"] = position_ids
|
|
||||||
output = self(**forward_args)
|
|
||||||
if isinstance(output, dict):
|
if isinstance(output, dict):
|
||||||
logits = output['logits']
|
logits = output['logits']
|
||||||
past_key_values = output['past_key_values']
|
past_key_values = output['past_key_values']
|
||||||
|
|
@ -939,45 +1015,10 @@ def speculative_generate(self,
|
||||||
# Clean up target model KV cache
|
# Clean up target model KV cache
|
||||||
if max_of_max_matched != max_matched:
|
if max_of_max_matched != max_matched:
|
||||||
output_ids = output_ids[:, :max_matched]
|
output_ids = output_ids[:, :max_matched]
|
||||||
if _enable_ipex:
|
new_cache_size = max_of_max_matched - max_matched
|
||||||
cur_len = past_key_values[0][0].size(1)
|
past_key_values = self._crop_past_key_values(past_key_values,
|
||||||
delta = max_of_max_matched - max_matched
|
new_cache_size,
|
||||||
tmp = torch.empty(1, (cur_len - delta), (cur_len - delta), 1,
|
_enable_ipex)
|
||||||
dtype=torch.long,
|
|
||||||
).contiguous()
|
|
||||||
past_key_values = [[tmp, key_cache, value_cache, beam_idx]
|
|
||||||
for _, key_cache, value_cache, beam_idx in past_key_values]
|
|
||||||
else:
|
|
||||||
if self.config.model_type in ["qwen"]:
|
|
||||||
past_key_values = [
|
|
||||||
(k[:, :-(max_of_max_matched - max_matched), :],
|
|
||||||
v[:, :-(max_of_max_matched - max_matched), :])
|
|
||||||
for k, v in past_key_values
|
|
||||||
]
|
|
||||||
elif self.config.model_type == "chatglm":
|
|
||||||
# for chatglm, cache shape is [sl, bs, nh, hn]
|
|
||||||
past_key_values = [
|
|
||||||
(k[:-(max_of_max_matched - max_matched), :, :, :],
|
|
||||||
v[:-(max_of_max_matched - max_matched), :, :, :])
|
|
||||||
for k, v in past_key_values
|
|
||||||
]
|
|
||||||
elif self.config.model_type in ["baichuan", "gptj"]:
|
|
||||||
past_key_values = [
|
|
||||||
(k[:, :, :-(max_of_max_matched - max_matched), :],
|
|
||||||
v[:, :, :-(max_of_max_matched - max_matched), :])
|
|
||||||
for k, v in past_key_values
|
|
||||||
]
|
|
||||||
elif self.config.model_type == "gpt_bigcode":
|
|
||||||
past_key_values = [
|
|
||||||
kv[:, :-(max_of_max_matched - max_matched)]
|
|
||||||
for kv in past_key_values
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
past_key_values = [
|
|
||||||
(k[:, :, :-(max_of_max_matched - max_matched)],
|
|
||||||
v[:, :, :-(max_of_max_matched - max_matched)])
|
|
||||||
for k, v in past_key_values
|
|
||||||
]
|
|
||||||
|
|
||||||
# Each iter assign new_matched kv_cache to past_key_values1
|
# Each iter assign new_matched kv_cache to past_key_values1
|
||||||
if self.device.type == 'cpu' and (not _enable_ipex):
|
if self.device.type == 'cpu' and (not _enable_ipex):
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue