diff --git a/python/llm/src/ipex_llm/transformers/lookup.py b/python/llm/src/ipex_llm/transformers/lookup.py new file mode 100644 index 00000000..f26b26e9 --- /dev/null +++ b/python/llm/src/ipex_llm/transformers/lookup.py @@ -0,0 +1,319 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Some parts of this file is adapted from +# https://github.com/huggingface/transformers/blob/main/src/transformers/generation +# /candidate_generator.py and +# https://github.com/huggingface/transformers/blob/main/src/transformers/generation +# /utils.py +# + +from typing import Callable, List, Optional, Tuple +import torch +import time +import copy +import logging +from transformers import GenerationConfig, LogitsProcessorList, StoppingCriteriaList +from ipex_llm.transformers.speculative import greedy, deepmind_sample, logits_to_probs,\ + _crop_past_key_values, _prepare_generate_args, _non_cpu_ipex_verify, clear_benchmarks +from ipex_llm.utils.common import invalidInputError + +logger = logging.getLogger("ipex_llm.lookup") + +# patch GenerationMixin.generate +from transformers import GenerationMixin +original_generate = GenerationMixin.generate +query_group_size = 16 + + +@torch.no_grad() +def generate( + self, + inputs: Optional[torch.Tensor] = None, + generation_config: Optional[GenerationConfig] = None, + logits_processor: Optional[LogitsProcessorList] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]]=None, + synced_gpus: Optional[bool] = None, + assistant_model: Optional["PreTrainedModel"] = None, + streamer: Optional["BaseStreamer"] = None, + **kwargs, +): + lookahead = kwargs.pop("lookahead", None) + if lookahead: + from ipex_llm.transformers.convert import get_enable_ipex + _enable_ipex = get_enable_ipex() + + if self.device.type == "cpu" and _enable_ipex: + + logger.warning("Prompt lookup is currently not supported on CPU with IPEX, " + "fallback to original generate.") + kwargs.pop("max_matching_ngram_size") + else: + # Do prompt lookup generation + return self.lookup_generate(inputs=inputs, + num_output_tokens=lookahead, + generation_config=generation_config, + logits_processor=logits_processor, + stopping_criteria=stopping_criteria, + prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, + **kwargs) + + return original_generate(self, + inputs=inputs, + generation_config=generation_config, + logits_processor=logits_processor, + stopping_criteria=stopping_criteria, + prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, + synced_gpus=synced_gpus, + assistant_model=assistant_model, + streamer=streamer, + **kwargs) + +GenerationMixin.generate = generate + + +# This class is copied from https://github.com/huggingface/transformers/blob/main/src +# /transformers/generation/candidate_generator.py +class PromptLookupCandidateGenerator(): + """ + `CandidateGenerator` class to be used for prompt lookup generation. + This class generates candidates + by looking up + likely continuations in the provided prompt (input_ids) itself. + Read the following blog post for more information: + https://github.com/apoorvumang/prompt-lookup-decoding + + Args: + max_matching_ngram_size (`int`): + The maximum ngram size to be considered for matching in the prompt + num_output_tokens (`int`): + The number of tokens to be output as candidate tokens. + """ + + def __init__( + self, + num_output_tokens: int = 10, + max_matching_ngram_size: int = None, + ): + self.num_output_tokens = num_output_tokens + self.max_matching_ngram_size = max_matching_ngram_size if max_matching_ngram_size else 2 + + invalidInputError(self.max_matching_ngram_size > 0 and self.num_output_tokens > 0, + "Invalid max_matching_ngram_size or num_output_tokens") + + def get_candidates(self, + input_ids: torch.LongTensor)-> Tuple[torch.LongTensor, + Optional[torch.FloatTensor]]: + """ + Fetches the candidates to be tried for the current input. + + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. + [What are input IDs?](../glossary#input-ids) + + Return: + `torch.LongTensor` of shape `(num_candidates, candidate_length)`: + The candidate sequences to be tried. + """ + input_length = input_ids.size(1) + + chosen_ids = None + match_found = False + for ngram_size in range(min(self.max_matching_ngram_size, input_length - 1), 0, -1): + # Create sliding windows of size ngram_size + windows = input_ids.unfold(dimension=1, size=ngram_size, step=1) + + # Convert ngram to a tensor for comparison + ngram_tensor = input_ids[0, -ngram_size:] + + # Find where the windows match the ngram + matches = (windows == ngram_tensor).all(dim=2) + + # Get the indices of matches + match_indices = matches.nonzero(as_tuple=True)[1] + + # Iterate through match indices to find a valid continuation + for idx in match_indices: + start_idx = idx + ngram_size + end_idx = start_idx + self.num_output_tokens + end_idx = min(end_idx, input_length) + + if start_idx < end_idx: + chosen_ids = input_ids[0, start_idx:end_idx] + match_found = True + break + if match_found: + break + + if chosen_ids is None or len(chosen_ids) == 0: + # In case we didn't find a match return the input sequence unchanged, + # reverts back to autoregressive decoding + return input_ids, None + + # Now need extend input_ids with chosen_ids + chosen_ids = chosen_ids.unsqueeze(0) + candidate_input_ids = torch.cat((input_ids, chosen_ids), dim=1) + # assisted_generation expects logits as well, but we don't have those here, + # so returning None + return candidate_input_ids, None + + def update_candidate_strategy(self, input_ids: torch.LongTensor, + scores: torch.FloatTensor, num_matches: int): + """ + Updates the candidate generation strategy based on the outcomes. + + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. + [What are input IDs?](../glossary#input-ids) + scores (`torch.FloatTensor` of shape `(batch_size, candidate_length, + config.vocab_size)`): + Prediction scores of a language modeling head. These can be logits for each + vocabulary when not using beam search or log softmax for each vocabulary + token when using beam search + num_matches (`int`): + The number of matches between the candidate sequences and the model predictions. + """ + # Currently does nothing + return + + +@torch.no_grad() +def lookup_generate(self, + inputs: Optional[torch.Tensor] = None, + max_new_tokens: int = 10, + num_output_tokens: int = 10, + max_matching_ngram_size: int = None, + generation_config: Optional[GenerationConfig] = None, + attention_mask=None, + **sampling_kwargs): + input_ids, generation_config, logits_processor, stopping_criteria, \ + model_kwargs = _prepare_generate_args(self, inputs, generation_config, + **sampling_kwargs) + + candidates_generator = PromptLookupCandidateGenerator( + num_output_tokens=num_output_tokens, + max_matching_ngram_size=max_matching_ngram_size) + + step = 0 + step_verify = 0 + + clear_benchmarks(self) + + past_key_values = None + input_len = input_ids.shape[1] + + while True: + if step >= max_new_tokens: + break + + if step == 0: + # first token use full model + tic = time.time() + output = self(input_ids=input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + return_dict=True, + use_cache=True) + logits = output['logits'] + logits = logits[:, -1:] + logits[:, -1, :] = logits_processor(input_ids, logits[:, -1, :]) + if generation_config.do_sample: + output_ids, prob_list = deepmind_sample(logits, + top_k=generation_config.top_k, + top_p=generation_config.top_p, + temperature=generation_config.temperature) + else: + output_ids = greedy(logits) + input_ids = torch.cat((input_ids, output_ids), dim=-1) + past_key_values = output['past_key_values'] + step += 1 + if self.device.type == 'xpu': + torch.xpu.synchronize() + toc = time.time() + self.first_token_time = toc - tic + e2e_tic = time.time() + else: + cur_len = input_ids.shape[-1] + toc = time.time() + candidate_input_ids, _ = candidates_generator.get_candidates(input_ids=input_ids) + candidate_length = candidate_input_ids.shape[1] - input_ids.shape[1] + verify_input_ids = candidate_input_ids[:, -candidate_length - 1:] + self.draft_num.append(candidate_length) + tic = time.time() + self.draft_time.append(tic - toc) + output = _non_cpu_ipex_verify(self, verify_input_ids, past_key_values, + attention_mask, return_dict=True, use_cache=True) + if isinstance(output, dict): + logits = output['logits'] + past_key_values = output['past_key_values'] + + if len(logits_processor) > 0: + for i in range(candidate_length + 1): + logits[:, i, :] = logits_processor(candidate_input_ids[:, : cur_len + i], + logits[:, i, :]) + + if generation_config.do_sample: + output_ids, prob_list = deepmind_sample(logits, + top_k=generation_config.top_k, + top_p=generation_config.top_p, + temperature=generation_config.temperature) + else: + output_ids = greedy(logits) + + if self.device.type == 'xpu': + torch.xpu.synchronize() + toc = time.time() + self.verify_time.append(toc - tic) + + # Compare drafts with target verified outputs + # Drafts start from [1, k] + # Verified output start from [0, k - 1] + # including the one generated by the base model + max_matched = ((output_ids[:, :-1] != verify_input_ids[:, 1:]).cumsum(-1) == 0) + max_matched = max_matched.sum(-1).item() + 1 + + max_of_max_matched = output_ids.size(1) + # Accept number is max_matched, min is 1 + self.accept_num.append(max_matched) + self.n_matched += max_matched - 1 + self.n_drafted += candidate_length + + # Clean up target model KV cache + if max_of_max_matched != max_matched: + output_ids = output_ids[:, :max_matched] + new_cache_size = max_of_max_matched - max_matched + past_key_values = _crop_past_key_values(self, past_key_values, new_cache_size) + + input_ids = torch.cat((input_ids, output_ids), dim=-1) + + step += output_ids.size(1) + step_verify += 1 + + # Stop on eos and remove content after eos + output_ids_list = output_ids[0].tolist() + if generation_config.eos_token_id in output_ids_list: + idx = output_ids_list.index(generation_config.eos_token_id) + step -= (len(output_ids_list) - idx - 1) + break + + step = min(step, max_new_tokens) + e2e_toc = time.time() + self.n_token_generated = step + self.e2e_time_without_first = e2e_toc - e2e_tic + + return input_ids[:, : input_len + step] diff --git a/python/llm/src/ipex_llm/transformers/model.py b/python/llm/src/ipex_llm/transformers/model.py index 44b2e0ad..13d7d701 100644 --- a/python/llm/src/ipex_llm/transformers/model.py +++ b/python/llm/src/ipex_llm/transformers/model.py @@ -330,7 +330,8 @@ class _BaseAutoModelClass: model = cls.load_convert(q_k, optimize_model, *args, **kwargs) if speculative: - from .speculative import speculative_generate, clear_benchmarks + from .speculative import speculative_generate, clear_benchmarks,\ + _crop_past_key_values # load a sym_int4 model as draft model draft_model = cls.load_convert('sym_int4', optimize_model, *args, **kwargs) model.draft_model = draft_model @@ -338,6 +339,12 @@ class _BaseAutoModelClass: # add speculative_generate to pretrained model dynamically model.clear_benchmarks = types.MethodType(clear_benchmarks, model) model.speculative_generate = types.MethodType(speculative_generate, model) + model._crop_past_key_values = types.MethodType(_crop_past_key_values, model) + + # add lookup_generate to pretrained model + from .lookup import lookup_generate + import types + model.lookup_generate = types.MethodType(lookup_generate, model) else: # load default model = cls.HF_Model.from_pretrained(*args, **kwargs) diff --git a/python/llm/src/ipex_llm/transformers/speculative.py b/python/llm/src/ipex_llm/transformers/speculative.py index cc9f627e..3d75e5aa 100644 --- a/python/llm/src/ipex_llm/transformers/speculative.py +++ b/python/llm/src/ipex_llm/transformers/speculative.py @@ -439,26 +439,49 @@ def _check_and_extend_kv_cache(past_key_values, max_step_draft, kv_alloc_block_l return past_key_values, not enough_kv_room -@torch.no_grad() -def speculative_generate(self, - inputs: Optional[torch.Tensor] = None, - draft_model=None, - max_new_tokens=10, - max_step_draft=8, - th_stop_draft=0.8, - auto_th_stop_draft=True, - auto_parameters=[1, 0.5, 0.9, 1e-2, 0.9], - hf_adjust=False, - min_step_draft=3, - generation_config: Optional[GenerationConfig] = None, - attention_mask=None, - **sampling_kwargs): - invalidInputError(draft_model is not None, - "Draft model should be provided.") - # min_step_draft >= 1. Since the max_step_draft may adjust, - # min_step_draft can > max_step_draft - min_step_draft = min_step_draft if min_step_draft >= 1 else 1 +def _crop_past_key_values(self, past_key_values, new_cache_size, _enable_ipex=False): + if _enable_ipex: + cur_len = past_key_values[0][0].size(1) + delta = new_cache_size + tmp = torch.empty(1, (cur_len - delta), (cur_len - delta), 1, + dtype=torch.long).contiguous() + past_key_values = [[tmp, key_cache, value_cache, beam_idx] + for _, key_cache, value_cache, beam_idx in past_key_values] + else: + if self.config.model_type in ["qwen"]: + past_key_values = [ + (k[:, :-(new_cache_size), :], + v[:, :-(new_cache_size), :]) + for k, v in past_key_values + ] + elif self.config.model_type == "chatglm": + # for chatglm, cache shape is [sl, bs, nh, hn] + past_key_values = [ + (k[:-(new_cache_size), :, :, :], + v[:-(new_cache_size), :, :, :]) + for k, v in past_key_values + ] + elif self.config.model_type in ["baichuan", "gptj"]: + past_key_values = [ + (k[:, :, :-(new_cache_size), :], + v[:, :, :-(new_cache_size), :]) + for k, v in past_key_values + ] + elif self.config.model_type == "gpt_bigcode": + past_key_values = [ + kv[:, :-(new_cache_size)] + for kv in past_key_values + ] + else: + past_key_values = [ + (k[:, :, :-(new_cache_size)], + v[:, :, :-(new_cache_size)]) + for k, v in past_key_values + ] + return past_key_values + +def _prepare_generate_args(self, inputs, generation_config, **sampling_kwargs): if generation_config is None: generation_config = self.generation_config @@ -494,10 +517,27 @@ def speculative_generate(self, inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs( inputs, generation_config.bos_token_id, model_kwargs ) - batch_size = inputs_tensor.shape[0] # 4. Define other model kwargs - # Removed not used + # model_kwargs["output_attentions"] = generation_config.output_attentions + # model_kwargs["output_hidden_states"] = generation_config.output_hidden_states + # # decoder-only models with inputs_embeds forwarding must use caching + # # (otherwise we can't detect whether we are generating the first new token or not, + # # and we only want to use the embeddings for the first new token) + # if not self.config.is_encoder_decoder and model_input_name == "inputs_embeds": + # model_kwargs["use_cache"] = True + # else: + # model_kwargs["use_cache"] = generation_config.use_cache + + # accepts_attention_mask = "attention_mask" in set( + # inspect.signature(self.forward).parameters.keys()) + # requires_attention_mask = "encoder_outputs" not in model_kwargs + + # if model_kwargs.get("attention_mask", None) is None and \ + # requires_attention_mask and accepts_attention_mask: + # model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation( + # inputs_tensor, generation_config.pad_token_id, generation_config.eos_token_id + # ) # decoder-only models should use left-padding for generation if not self.config.is_encoder_decoder: @@ -543,6 +583,61 @@ def speculative_generate(self, **model_kwargs, ) + return input_ids, generation_config, logits_processor, stopping_criteria, model_kwargs + + +def _non_cpu_ipex_verify(self, verify_input_ids, past_key_values, cur_attention_mask=None, + return_dict=True, use_cache=True): + forward_args = { + "input_ids": verify_input_ids, + "past_key_values": past_key_values, + "return_dict": return_dict, + "use_cache": use_cache, + } + if cur_attention_mask: + forward_args["attention_mask"] = cur_attention_mask + + if self.config.model_type == "chatglm": + past_key_value_len = past_key_values[0][0].shape[0] + position_ids = torch.arange(verify_input_ids.shape[1], dtype=torch.long, + device=verify_input_ids.device) + position_ids = position_ids.unsqueeze(0).repeat(1, 1) + past_key_value_len + forward_args["position_ids"] = position_ids + elif self.config.model_type == "gptj": + past_length = past_key_values[0][0].size(2) + input_len = verify_input_ids.shape[1] + position_ids = torch.arange(past_length, input_len + past_length, + dtype=torch.long, device=verify_input_ids.device) + position_ids = position_ids.unsqueeze(0).view(-1, input_len) + forward_args["position_ids"] = position_ids + + return self(**forward_args) + + +@torch.no_grad() +def speculative_generate(self, + inputs: Optional[torch.Tensor] = None, + draft_model=None, + max_new_tokens=10, + max_step_draft=8, + th_stop_draft=0.8, + auto_th_stop_draft=True, + auto_parameters=[1, 0.5, 0.9, 1e-2, 0.9], + hf_adjust=False, + min_step_draft=3, + generation_config: Optional[GenerationConfig] = None, + attention_mask=None, + **sampling_kwargs): + invalidInputError(draft_model is not None, + "Draft model should be provided.") + # min_step_draft >= 1. Since the max_step_draft may adjust, + # min_step_draft can > max_step_draft + min_step_draft = min_step_draft if min_step_draft >= 1 else 1 + + input_ids, generation_config, logits_processor, stopping_criteria, \ + model_kwargs = _prepare_generate_args(self, inputs, generation_config, + **sampling_kwargs) + step = 0 step_draft = 0 step_verify = 0 @@ -851,27 +946,8 @@ def speculative_generate(self, logits = output[0] past_key_values = output[1] else: - forward_args = { - "input_ids": drafted_input_ids, - "past_key_values": past_key_values, - "attention_mask": cur_attention_mask, - "return_dict": True, - "use_cache": True, - } - if self.config.model_type == "chatglm": - past_key_value_len = past_key_values[0][0].shape[0] - position_ids = torch.arange(drafted_input_ids.shape[1], dtype=torch.long, - device=drafted_input_ids.device) - position_ids = position_ids.unsqueeze(0).repeat(1, 1) + past_key_value_len - forward_args["position_ids"] = position_ids - elif self.config.model_type == "gptj": - past_length = past_key_values[0][0].size(2) - input_len = drafted_input_ids.shape[1] - position_ids = torch.arange(past_length, input_len + past_length, - dtype=torch.long, device=drafted_input_ids.device) - position_ids = position_ids.unsqueeze(0).view(-1, input_len) - forward_args["position_ids"] = position_ids - output = self(**forward_args) + output = _non_cpu_ipex_verify(self, drafted_input_ids, past_key_values, + cur_attention_mask, return_dict=True, use_cache=True) if isinstance(output, dict): logits = output['logits'] past_key_values = output['past_key_values'] @@ -939,45 +1015,10 @@ def speculative_generate(self, # Clean up target model KV cache if max_of_max_matched != max_matched: output_ids = output_ids[:, :max_matched] - if _enable_ipex: - cur_len = past_key_values[0][0].size(1) - delta = max_of_max_matched - max_matched - tmp = torch.empty(1, (cur_len - delta), (cur_len - delta), 1, - dtype=torch.long, - ).contiguous() - past_key_values = [[tmp, key_cache, value_cache, beam_idx] - for _, key_cache, value_cache, beam_idx in past_key_values] - else: - if self.config.model_type in ["qwen"]: - past_key_values = [ - (k[:, :-(max_of_max_matched - max_matched), :], - v[:, :-(max_of_max_matched - max_matched), :]) - for k, v in past_key_values - ] - elif self.config.model_type == "chatglm": - # for chatglm, cache shape is [sl, bs, nh, hn] - past_key_values = [ - (k[:-(max_of_max_matched - max_matched), :, :, :], - v[:-(max_of_max_matched - max_matched), :, :, :]) - for k, v in past_key_values - ] - elif self.config.model_type in ["baichuan", "gptj"]: - past_key_values = [ - (k[:, :, :-(max_of_max_matched - max_matched), :], - v[:, :, :-(max_of_max_matched - max_matched), :]) - for k, v in past_key_values - ] - elif self.config.model_type == "gpt_bigcode": - past_key_values = [ - kv[:, :-(max_of_max_matched - max_matched)] - for kv in past_key_values - ] - else: - past_key_values = [ - (k[:, :, :-(max_of_max_matched - max_matched)], - v[:, :, :-(max_of_max_matched - max_matched)]) - for k, v in past_key_values - ] + new_cache_size = max_of_max_matched - max_matched + past_key_values = self._crop_past_key_values(past_key_values, + new_cache_size, + _enable_ipex) # Each iter assign new_matched kv_cache to past_key_values1 if self.device.type == 'cpu' and (not _enable_ipex):