Support prompt lookup in ipex-llm (#10768)

* lookup init

* add lookup

* fix style

* remove redundant code

* change param name

* fix style
This commit is contained in:
Yina Chen 2024-04-16 16:52:38 +08:00 committed by GitHub
parent d30b22a81b
commit 899d392e2f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 449 additions and 82 deletions

View file

@ -0,0 +1,319 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Some parts of this file is adapted from
# https://github.com/huggingface/transformers/blob/main/src/transformers/generation
# /candidate_generator.py and
# https://github.com/huggingface/transformers/blob/main/src/transformers/generation
# /utils.py
#
from typing import Callable, List, Optional, Tuple
import torch
import time
import copy
import logging
from transformers import GenerationConfig, LogitsProcessorList, StoppingCriteriaList
from ipex_llm.transformers.speculative import greedy, deepmind_sample, logits_to_probs,\
_crop_past_key_values, _prepare_generate_args, _non_cpu_ipex_verify, clear_benchmarks
from ipex_llm.utils.common import invalidInputError
logger = logging.getLogger("ipex_llm.lookup")
# patch GenerationMixin.generate
from transformers import GenerationMixin
original_generate = GenerationMixin.generate
query_group_size = 16
@torch.no_grad()
def generate(
self,
inputs: Optional[torch.Tensor] = None,
generation_config: Optional[GenerationConfig] = None,
logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]]=None,
synced_gpus: Optional[bool] = None,
assistant_model: Optional["PreTrainedModel"] = None,
streamer: Optional["BaseStreamer"] = None,
**kwargs,
):
lookahead = kwargs.pop("lookahead", None)
if lookahead:
from ipex_llm.transformers.convert import get_enable_ipex
_enable_ipex = get_enable_ipex()
if self.device.type == "cpu" and _enable_ipex:
logger.warning("Prompt lookup is currently not supported on CPU with IPEX, "
"fallback to original generate.")
kwargs.pop("max_matching_ngram_size")
else:
# Do prompt lookup generation
return self.lookup_generate(inputs=inputs,
num_output_tokens=lookahead,
generation_config=generation_config,
logits_processor=logits_processor,
stopping_criteria=stopping_criteria,
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
**kwargs)
return original_generate(self,
inputs=inputs,
generation_config=generation_config,
logits_processor=logits_processor,
stopping_criteria=stopping_criteria,
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
synced_gpus=synced_gpus,
assistant_model=assistant_model,
streamer=streamer,
**kwargs)
GenerationMixin.generate = generate
# This class is copied from https://github.com/huggingface/transformers/blob/main/src
# /transformers/generation/candidate_generator.py
class PromptLookupCandidateGenerator():
"""
`CandidateGenerator` class to be used for prompt lookup generation.
This class generates candidates
by looking up
likely continuations in the provided prompt (input_ids) itself.
Read the following blog post for more information:
https://github.com/apoorvumang/prompt-lookup-decoding
Args:
max_matching_ngram_size (`int`):
The maximum ngram size to be considered for matching in the prompt
num_output_tokens (`int`):
The number of tokens to be output as candidate tokens.
"""
def __init__(
self,
num_output_tokens: int = 10,
max_matching_ngram_size: int = None,
):
self.num_output_tokens = num_output_tokens
self.max_matching_ngram_size = max_matching_ngram_size if max_matching_ngram_size else 2
invalidInputError(self.max_matching_ngram_size > 0 and self.num_output_tokens > 0,
"Invalid max_matching_ngram_size or num_output_tokens")
def get_candidates(self,
input_ids: torch.LongTensor)-> Tuple[torch.LongTensor,
Optional[torch.FloatTensor]]:
"""
Fetches the candidates to be tried for the current input.
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary.
[What are input IDs?](../glossary#input-ids)
Return:
`torch.LongTensor` of shape `(num_candidates, candidate_length)`:
The candidate sequences to be tried.
"""
input_length = input_ids.size(1)
chosen_ids = None
match_found = False
for ngram_size in range(min(self.max_matching_ngram_size, input_length - 1), 0, -1):
# Create sliding windows of size ngram_size
windows = input_ids.unfold(dimension=1, size=ngram_size, step=1)
# Convert ngram to a tensor for comparison
ngram_tensor = input_ids[0, -ngram_size:]
# Find where the windows match the ngram
matches = (windows == ngram_tensor).all(dim=2)
# Get the indices of matches
match_indices = matches.nonzero(as_tuple=True)[1]
# Iterate through match indices to find a valid continuation
for idx in match_indices:
start_idx = idx + ngram_size
end_idx = start_idx + self.num_output_tokens
end_idx = min(end_idx, input_length)
if start_idx < end_idx:
chosen_ids = input_ids[0, start_idx:end_idx]
match_found = True
break
if match_found:
break
if chosen_ids is None or len(chosen_ids) == 0:
# In case we didn't find a match return the input sequence unchanged,
# reverts back to autoregressive decoding
return input_ids, None
# Now need extend input_ids with chosen_ids
chosen_ids = chosen_ids.unsqueeze(0)
candidate_input_ids = torch.cat((input_ids, chosen_ids), dim=1)
# assisted_generation expects logits as well, but we don't have those here,
# so returning None
return candidate_input_ids, None
def update_candidate_strategy(self, input_ids: torch.LongTensor,
scores: torch.FloatTensor, num_matches: int):
"""
Updates the candidate generation strategy based on the outcomes.
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary.
[What are input IDs?](../glossary#input-ids)
scores (`torch.FloatTensor` of shape `(batch_size, candidate_length,
config.vocab_size)`):
Prediction scores of a language modeling head. These can be logits for each
vocabulary when not using beam search or log softmax for each vocabulary
token when using beam search
num_matches (`int`):
The number of matches between the candidate sequences and the model predictions.
"""
# Currently does nothing
return
@torch.no_grad()
def lookup_generate(self,
inputs: Optional[torch.Tensor] = None,
max_new_tokens: int = 10,
num_output_tokens: int = 10,
max_matching_ngram_size: int = None,
generation_config: Optional[GenerationConfig] = None,
attention_mask=None,
**sampling_kwargs):
input_ids, generation_config, logits_processor, stopping_criteria, \
model_kwargs = _prepare_generate_args(self, inputs, generation_config,
**sampling_kwargs)
candidates_generator = PromptLookupCandidateGenerator(
num_output_tokens=num_output_tokens,
max_matching_ngram_size=max_matching_ngram_size)
step = 0
step_verify = 0
clear_benchmarks(self)
past_key_values = None
input_len = input_ids.shape[1]
while True:
if step >= max_new_tokens:
break
if step == 0:
# first token use full model
tic = time.time()
output = self(input_ids=input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
return_dict=True,
use_cache=True)
logits = output['logits']
logits = logits[:, -1:]
logits[:, -1, :] = logits_processor(input_ids, logits[:, -1, :])
if generation_config.do_sample:
output_ids, prob_list = deepmind_sample(logits,
top_k=generation_config.top_k,
top_p=generation_config.top_p,
temperature=generation_config.temperature)
else:
output_ids = greedy(logits)
input_ids = torch.cat((input_ids, output_ids), dim=-1)
past_key_values = output['past_key_values']
step += 1
if self.device.type == 'xpu':
torch.xpu.synchronize()
toc = time.time()
self.first_token_time = toc - tic
e2e_tic = time.time()
else:
cur_len = input_ids.shape[-1]
toc = time.time()
candidate_input_ids, _ = candidates_generator.get_candidates(input_ids=input_ids)
candidate_length = candidate_input_ids.shape[1] - input_ids.shape[1]
verify_input_ids = candidate_input_ids[:, -candidate_length - 1:]
self.draft_num.append(candidate_length)
tic = time.time()
self.draft_time.append(tic - toc)
output = _non_cpu_ipex_verify(self, verify_input_ids, past_key_values,
attention_mask, return_dict=True, use_cache=True)
if isinstance(output, dict):
logits = output['logits']
past_key_values = output['past_key_values']
if len(logits_processor) > 0:
for i in range(candidate_length + 1):
logits[:, i, :] = logits_processor(candidate_input_ids[:, : cur_len + i],
logits[:, i, :])
if generation_config.do_sample:
output_ids, prob_list = deepmind_sample(logits,
top_k=generation_config.top_k,
top_p=generation_config.top_p,
temperature=generation_config.temperature)
else:
output_ids = greedy(logits)
if self.device.type == 'xpu':
torch.xpu.synchronize()
toc = time.time()
self.verify_time.append(toc - tic)
# Compare drafts with target verified outputs
# Drafts start from [1, k]
# Verified output start from [0, k - 1]
# including the one generated by the base model
max_matched = ((output_ids[:, :-1] != verify_input_ids[:, 1:]).cumsum(-1) == 0)
max_matched = max_matched.sum(-1).item() + 1
max_of_max_matched = output_ids.size(1)
# Accept number is max_matched, min is 1
self.accept_num.append(max_matched)
self.n_matched += max_matched - 1
self.n_drafted += candidate_length
# Clean up target model KV cache
if max_of_max_matched != max_matched:
output_ids = output_ids[:, :max_matched]
new_cache_size = max_of_max_matched - max_matched
past_key_values = _crop_past_key_values(self, past_key_values, new_cache_size)
input_ids = torch.cat((input_ids, output_ids), dim=-1)
step += output_ids.size(1)
step_verify += 1
# Stop on eos and remove content after eos
output_ids_list = output_ids[0].tolist()
if generation_config.eos_token_id in output_ids_list:
idx = output_ids_list.index(generation_config.eos_token_id)
step -= (len(output_ids_list) - idx - 1)
break
step = min(step, max_new_tokens)
e2e_toc = time.time()
self.n_token_generated = step
self.e2e_time_without_first = e2e_toc - e2e_tic
return input_ids[:, : input_len + step]

View file

@ -330,7 +330,8 @@ class _BaseAutoModelClass:
model = cls.load_convert(q_k, optimize_model, *args, **kwargs) model = cls.load_convert(q_k, optimize_model, *args, **kwargs)
if speculative: if speculative:
from .speculative import speculative_generate, clear_benchmarks from .speculative import speculative_generate, clear_benchmarks,\
_crop_past_key_values
# load a sym_int4 model as draft model # load a sym_int4 model as draft model
draft_model = cls.load_convert('sym_int4', optimize_model, *args, **kwargs) draft_model = cls.load_convert('sym_int4', optimize_model, *args, **kwargs)
model.draft_model = draft_model model.draft_model = draft_model
@ -338,6 +339,12 @@ class _BaseAutoModelClass:
# add speculative_generate to pretrained model dynamically # add speculative_generate to pretrained model dynamically
model.clear_benchmarks = types.MethodType(clear_benchmarks, model) model.clear_benchmarks = types.MethodType(clear_benchmarks, model)
model.speculative_generate = types.MethodType(speculative_generate, model) model.speculative_generate = types.MethodType(speculative_generate, model)
model._crop_past_key_values = types.MethodType(_crop_past_key_values, model)
# add lookup_generate to pretrained model
from .lookup import lookup_generate
import types
model.lookup_generate = types.MethodType(lookup_generate, model)
else: else:
# load default # load default
model = cls.HF_Model.from_pretrained(*args, **kwargs) model = cls.HF_Model.from_pretrained(*args, **kwargs)

View file

@ -439,26 +439,49 @@ def _check_and_extend_kv_cache(past_key_values, max_step_draft, kv_alloc_block_l
return past_key_values, not enough_kv_room return past_key_values, not enough_kv_room
@torch.no_grad() def _crop_past_key_values(self, past_key_values, new_cache_size, _enable_ipex=False):
def speculative_generate(self, if _enable_ipex:
inputs: Optional[torch.Tensor] = None, cur_len = past_key_values[0][0].size(1)
draft_model=None, delta = new_cache_size
max_new_tokens=10, tmp = torch.empty(1, (cur_len - delta), (cur_len - delta), 1,
max_step_draft=8, dtype=torch.long).contiguous()
th_stop_draft=0.8, past_key_values = [[tmp, key_cache, value_cache, beam_idx]
auto_th_stop_draft=True, for _, key_cache, value_cache, beam_idx in past_key_values]
auto_parameters=[1, 0.5, 0.9, 1e-2, 0.9], else:
hf_adjust=False, if self.config.model_type in ["qwen"]:
min_step_draft=3, past_key_values = [
generation_config: Optional[GenerationConfig] = None, (k[:, :-(new_cache_size), :],
attention_mask=None, v[:, :-(new_cache_size), :])
**sampling_kwargs): for k, v in past_key_values
invalidInputError(draft_model is not None, ]
"Draft model should be provided.") elif self.config.model_type == "chatglm":
# min_step_draft >= 1. Since the max_step_draft may adjust, # for chatglm, cache shape is [sl, bs, nh, hn]
# min_step_draft can > max_step_draft past_key_values = [
min_step_draft = min_step_draft if min_step_draft >= 1 else 1 (k[:-(new_cache_size), :, :, :],
v[:-(new_cache_size), :, :, :])
for k, v in past_key_values
]
elif self.config.model_type in ["baichuan", "gptj"]:
past_key_values = [
(k[:, :, :-(new_cache_size), :],
v[:, :, :-(new_cache_size), :])
for k, v in past_key_values
]
elif self.config.model_type == "gpt_bigcode":
past_key_values = [
kv[:, :-(new_cache_size)]
for kv in past_key_values
]
else:
past_key_values = [
(k[:, :, :-(new_cache_size)],
v[:, :, :-(new_cache_size)])
for k, v in past_key_values
]
return past_key_values
def _prepare_generate_args(self, inputs, generation_config, **sampling_kwargs):
if generation_config is None: if generation_config is None:
generation_config = self.generation_config generation_config = self.generation_config
@ -494,10 +517,27 @@ def speculative_generate(self,
inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs( inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(
inputs, generation_config.bos_token_id, model_kwargs inputs, generation_config.bos_token_id, model_kwargs
) )
batch_size = inputs_tensor.shape[0]
# 4. Define other model kwargs # 4. Define other model kwargs
# Removed not used # model_kwargs["output_attentions"] = generation_config.output_attentions
# model_kwargs["output_hidden_states"] = generation_config.output_hidden_states
# # decoder-only models with inputs_embeds forwarding must use caching
# # (otherwise we can't detect whether we are generating the first new token or not,
# # and we only want to use the embeddings for the first new token)
# if not self.config.is_encoder_decoder and model_input_name == "inputs_embeds":
# model_kwargs["use_cache"] = True
# else:
# model_kwargs["use_cache"] = generation_config.use_cache
# accepts_attention_mask = "attention_mask" in set(
# inspect.signature(self.forward).parameters.keys())
# requires_attention_mask = "encoder_outputs" not in model_kwargs
# if model_kwargs.get("attention_mask", None) is None and \
# requires_attention_mask and accepts_attention_mask:
# model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
# inputs_tensor, generation_config.pad_token_id, generation_config.eos_token_id
# )
# decoder-only models should use left-padding for generation # decoder-only models should use left-padding for generation
if not self.config.is_encoder_decoder: if not self.config.is_encoder_decoder:
@ -543,6 +583,61 @@ def speculative_generate(self,
**model_kwargs, **model_kwargs,
) )
return input_ids, generation_config, logits_processor, stopping_criteria, model_kwargs
def _non_cpu_ipex_verify(self, verify_input_ids, past_key_values, cur_attention_mask=None,
return_dict=True, use_cache=True):
forward_args = {
"input_ids": verify_input_ids,
"past_key_values": past_key_values,
"return_dict": return_dict,
"use_cache": use_cache,
}
if cur_attention_mask:
forward_args["attention_mask"] = cur_attention_mask
if self.config.model_type == "chatglm":
past_key_value_len = past_key_values[0][0].shape[0]
position_ids = torch.arange(verify_input_ids.shape[1], dtype=torch.long,
device=verify_input_ids.device)
position_ids = position_ids.unsqueeze(0).repeat(1, 1) + past_key_value_len
forward_args["position_ids"] = position_ids
elif self.config.model_type == "gptj":
past_length = past_key_values[0][0].size(2)
input_len = verify_input_ids.shape[1]
position_ids = torch.arange(past_length, input_len + past_length,
dtype=torch.long, device=verify_input_ids.device)
position_ids = position_ids.unsqueeze(0).view(-1, input_len)
forward_args["position_ids"] = position_ids
return self(**forward_args)
@torch.no_grad()
def speculative_generate(self,
inputs: Optional[torch.Tensor] = None,
draft_model=None,
max_new_tokens=10,
max_step_draft=8,
th_stop_draft=0.8,
auto_th_stop_draft=True,
auto_parameters=[1, 0.5, 0.9, 1e-2, 0.9],
hf_adjust=False,
min_step_draft=3,
generation_config: Optional[GenerationConfig] = None,
attention_mask=None,
**sampling_kwargs):
invalidInputError(draft_model is not None,
"Draft model should be provided.")
# min_step_draft >= 1. Since the max_step_draft may adjust,
# min_step_draft can > max_step_draft
min_step_draft = min_step_draft if min_step_draft >= 1 else 1
input_ids, generation_config, logits_processor, stopping_criteria, \
model_kwargs = _prepare_generate_args(self, inputs, generation_config,
**sampling_kwargs)
step = 0 step = 0
step_draft = 0 step_draft = 0
step_verify = 0 step_verify = 0
@ -851,27 +946,8 @@ def speculative_generate(self,
logits = output[0] logits = output[0]
past_key_values = output[1] past_key_values = output[1]
else: else:
forward_args = { output = _non_cpu_ipex_verify(self, drafted_input_ids, past_key_values,
"input_ids": drafted_input_ids, cur_attention_mask, return_dict=True, use_cache=True)
"past_key_values": past_key_values,
"attention_mask": cur_attention_mask,
"return_dict": True,
"use_cache": True,
}
if self.config.model_type == "chatglm":
past_key_value_len = past_key_values[0][0].shape[0]
position_ids = torch.arange(drafted_input_ids.shape[1], dtype=torch.long,
device=drafted_input_ids.device)
position_ids = position_ids.unsqueeze(0).repeat(1, 1) + past_key_value_len
forward_args["position_ids"] = position_ids
elif self.config.model_type == "gptj":
past_length = past_key_values[0][0].size(2)
input_len = drafted_input_ids.shape[1]
position_ids = torch.arange(past_length, input_len + past_length,
dtype=torch.long, device=drafted_input_ids.device)
position_ids = position_ids.unsqueeze(0).view(-1, input_len)
forward_args["position_ids"] = position_ids
output = self(**forward_args)
if isinstance(output, dict): if isinstance(output, dict):
logits = output['logits'] logits = output['logits']
past_key_values = output['past_key_values'] past_key_values = output['past_key_values']
@ -939,45 +1015,10 @@ def speculative_generate(self,
# Clean up target model KV cache # Clean up target model KV cache
if max_of_max_matched != max_matched: if max_of_max_matched != max_matched:
output_ids = output_ids[:, :max_matched] output_ids = output_ids[:, :max_matched]
if _enable_ipex: new_cache_size = max_of_max_matched - max_matched
cur_len = past_key_values[0][0].size(1) past_key_values = self._crop_past_key_values(past_key_values,
delta = max_of_max_matched - max_matched new_cache_size,
tmp = torch.empty(1, (cur_len - delta), (cur_len - delta), 1, _enable_ipex)
dtype=torch.long,
).contiguous()
past_key_values = [[tmp, key_cache, value_cache, beam_idx]
for _, key_cache, value_cache, beam_idx in past_key_values]
else:
if self.config.model_type in ["qwen"]:
past_key_values = [
(k[:, :-(max_of_max_matched - max_matched), :],
v[:, :-(max_of_max_matched - max_matched), :])
for k, v in past_key_values
]
elif self.config.model_type == "chatglm":
# for chatglm, cache shape is [sl, bs, nh, hn]
past_key_values = [
(k[:-(max_of_max_matched - max_matched), :, :, :],
v[:-(max_of_max_matched - max_matched), :, :, :])
for k, v in past_key_values
]
elif self.config.model_type in ["baichuan", "gptj"]:
past_key_values = [
(k[:, :, :-(max_of_max_matched - max_matched), :],
v[:, :, :-(max_of_max_matched - max_matched), :])
for k, v in past_key_values
]
elif self.config.model_type == "gpt_bigcode":
past_key_values = [
kv[:, :-(max_of_max_matched - max_matched)]
for kv in past_key_values
]
else:
past_key_values = [
(k[:, :, :-(max_of_max_matched - max_matched)],
v[:, :, :-(max_of_max_matched - max_matched)])
for k, v in past_key_values
]
# Each iter assign new_matched kv_cache to past_key_values1 # Each iter assign new_matched kv_cache to past_key_values1
if self.device.type == 'cpu' and (not _enable_ipex): if self.device.type == 'cpu' and (not _enable_ipex):