# # Copyright 2016 The BigDL Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # Some parts of this file is adapted from # https://github.com/huggingface/transformers/blob/main/src/transformers/generation # /candidate_generator.py and # https://github.com/huggingface/transformers/blob/main/src/transformers/generation # /utils.py # from typing import Callable, List, Optional, Tuple import torch import time import copy import logging from transformers import GenerationConfig, LogitsProcessorList, StoppingCriteriaList from ipex_llm.transformers.speculative import greedy, deepmind_sample, logits_to_probs,\ _crop_past_key_values, _prepare_generate_args, _non_cpu_ipex_verify, clear_benchmarks from ipex_llm.utils.common import invalidInputError from ipex_llm.transformers.utils import get_xpu_device_type logger = logging.getLogger("ipex_llm.lookup") # patch GenerationMixin.generate from transformers import GenerationMixin original_generate = GenerationMixin.generate query_group_size = 16 @torch.no_grad() def generate( self, inputs: Optional[torch.Tensor] = None, generation_config: Optional[GenerationConfig] = None, logits_processor: Optional[LogitsProcessorList] = None, stopping_criteria: Optional[StoppingCriteriaList] = None, prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]]=None, synced_gpus: Optional[bool] = None, assistant_model: Optional["PreTrainedModel"] = None, streamer: Optional["BaseStreamer"] = None, **kwargs, ): lookahead = kwargs.pop("lookahead", None) if lookahead: from ipex_llm.transformers.convert import get_enable_ipex _enable_ipex = get_enable_ipex() if self.device.type == "cpu" and _enable_ipex: logger.warning("Prompt lookup is currently not supported on CPU with IPEX, " "fallback to original generate.") kwargs.pop("max_matching_ngram_size") else: # Do prompt lookup generation # If lookahead is provided, we will use lookup_generate instead of # spec_generate, remove vars for spec_generate and warn the user spec_params = [] for var in ['max_step_draft', 'th_stop_draft', 'hf_adjust', 'auto_th_stop_draft', 'auto_parameters', 'min_step_draft', 'th_batch_num']: value = kwargs.pop(var, None) if value is not None: spec_params.append(var) if len(spec_params) > 0: logger.warning("Since you call the generate with lookahead parameter, " f"Speculative decoding parameters {spec_params} are " "removed in the generation.") return self.lookup_generate(inputs=inputs, num_output_tokens=lookahead, generation_config=generation_config, logits_processor=logits_processor, stopping_criteria=stopping_criteria, prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, **kwargs) return original_generate(self, inputs=inputs, generation_config=generation_config, logits_processor=logits_processor, stopping_criteria=stopping_criteria, prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, synced_gpus=synced_gpus, assistant_model=assistant_model, streamer=streamer, **kwargs) GenerationMixin.generate = generate # This class is copied from https://github.com/huggingface/transformers/blob/main/src # /transformers/generation/candidate_generator.py class PromptLookupCandidateGenerator(): """ `CandidateGenerator` class to be used for prompt lookup generation. This class generates candidates by looking up likely continuations in the provided prompt (input_ids) itself. Read the following blog post for more information: https://github.com/apoorvumang/prompt-lookup-decoding Args: max_matching_ngram_size (`int`): The maximum ngram size to be considered for matching in the prompt num_output_tokens (`int`): The number of tokens to be output as candidate tokens. """ def __init__( self, num_output_tokens: int = 10, max_matching_ngram_size: int = None, device: str = "arc", ): self.num_output_tokens = num_output_tokens self.max_matching_ngram_size = max_matching_ngram_size if max_matching_ngram_size else 2 if device == "mtl": self.max_candidates = 3 else: self.max_candidates = 9 invalidInputError(self.max_matching_ngram_size > 0 and self.num_output_tokens > 0, "Invalid max_matching_ngram_size or num_output_tokens") def get_candidates(self, input_ids: torch.LongTensor)-> Tuple[torch.LongTensor, Optional[torch.FloatTensor]]: """ Fetches the candidates to be tried for the current input. Args: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids) Return: `torch.LongTensor` of shape `(num_candidates, candidate_length)`: The candidate sequences to be tried. """ input_length = input_ids.size(1) chosen_ids = None match_found = False for ngram_size in range(min(self.max_matching_ngram_size, input_length - 1), 0, -1): # Create sliding windows of size ngram_size windows = input_ids.unfold(dimension=1, size=ngram_size, step=1) # Convert ngram to a tensor for comparison ngram_tensor = input_ids[0, -ngram_size:] # Find where the windows match the ngram matches = (windows == ngram_tensor).all(dim=2) # Get the indices of matches match_indices = matches.nonzero(as_tuple=True)[1] # Iterate through match indices to find a valid continuation for idx in match_indices: start_idx = idx + ngram_size end_idx = start_idx + self.num_output_tokens end_idx = min(end_idx, input_length) if start_idx < end_idx: chosen_ids = input_ids[0, start_idx:end_idx] match_found = True break if match_found: break if chosen_ids is None or len(chosen_ids) == 0: # In case we didn't find a match return the input sequence unchanged, # reverts back to autoregressive decoding return input_ids, None # Now need extend input_ids with chosen_ids chosen_ids = chosen_ids.unsqueeze(0) candidate_input_ids = torch.cat((input_ids, chosen_ids), dim=1) # assisted_generation expects logits as well, but we don't have those here, # so returning None return candidate_input_ids, None def update_candidate_strategy(self, candidate_num: int, num_matches: int): """ Updates the candidate generation strategy based on the outcomes. Args: num_matches (`int`): The number of matches between the candidate sequences and the model predictions. """ if num_matches == self.num_output_tokens: self.num_output_tokens = min(self.num_output_tokens + 1, self.max_candidates) elif candidate_num > num_matches: self.num_output_tokens = max(self.num_output_tokens - 1, 1) @torch.no_grad() def lookup_generate(self, inputs: Optional[torch.Tensor] = None, max_new_tokens: int = 10, num_output_tokens: int = 10, max_matching_ngram_size: int = None, generation_config: Optional[GenerationConfig] = None, attention_mask=None, **sampling_kwargs): input_ids, generation_config, logits_processor, stopping_criteria, \ model_kwargs = _prepare_generate_args(self, inputs, generation_config, **sampling_kwargs) device_name = get_xpu_device_type(input_ids) candidates_generator = PromptLookupCandidateGenerator( num_output_tokens=num_output_tokens, max_matching_ngram_size=max_matching_ngram_size, device=device_name) step = 0 step_verify = 0 clear_benchmarks(self) past_key_values = None input_len = input_ids.shape[1] while True: if step >= max_new_tokens: break if step == 0: # first token use full model tic = time.time() output = self(input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask, return_dict=True, use_cache=True) logits = output['logits'] logits = logits[:, -1:] logits[:, -1, :] = logits_processor(input_ids, logits[:, -1, :]) if generation_config.do_sample: output_ids, prob_list = deepmind_sample(logits, top_k=generation_config.top_k, top_p=generation_config.top_p, temperature=generation_config.temperature) else: output_ids = greedy(logits) input_ids = torch.cat((input_ids, output_ids), dim=-1) past_key_values = output['past_key_values'] step += 1 if self.device.type == 'xpu': torch.xpu.synchronize() toc = time.time() self.first_token_time = toc - tic e2e_tic = time.time() else: cur_len = input_ids.shape[-1] toc = time.time() candidate_input_ids, _ = candidates_generator.get_candidates(input_ids=input_ids) candidate_length = candidate_input_ids.shape[1] - input_ids.shape[1] verify_input_ids = candidate_input_ids[:, -candidate_length - 1:] self.draft_num.append(candidate_length) tic = time.time() self.draft_time.append(tic - toc) if attention_mask is None: cur_attention_mask = None else: appended_len = verify_input_ids.size(1) + step - 1 ones_to_append = torch.ones(attention_mask.size(0), appended_len, device=self.device) cur_attention_mask = torch.cat((attention_mask, ones_to_append), dim=1) output = _non_cpu_ipex_verify(self, verify_input_ids, past_key_values, cur_attention_mask, return_dict=True, use_cache=True) if isinstance(output, dict): logits = output['logits'] past_key_values = output['past_key_values'] if len(logits_processor) > 0: for i in range(candidate_length + 1): logits[:, i, :] = logits_processor(candidate_input_ids[:, : cur_len + i], logits[:, i, :]) if generation_config.do_sample: output_ids, prob_list = deepmind_sample(logits, top_k=generation_config.top_k, top_p=generation_config.top_p, temperature=generation_config.temperature) output_ids = output_ids.transpose(0, 1) else: output_ids = greedy(logits) if self.device.type == 'xpu': torch.xpu.synchronize() toc = time.time() self.verify_time.append(toc - tic) # Compare drafts with target verified outputs # Drafts start from [1, k] # Verified output start from [0, k - 1] # including the one generated by the base model n_matches = ((output_ids[:, :-1] != verify_input_ids[:, 1:]) .cumsum(-1) == 0).sum(-1).item() max_matched = n_matches + 1 max_of_max_matched = output_ids.size(1) # Accept number is max_matched, min is 1 self.accept_num.append(max_matched) self.n_matched += n_matches self.n_drafted += candidate_length # Clean up target model KV cache if max_of_max_matched != max_matched: output_ids = output_ids[:, :max_matched] new_cache_size = max_of_max_matched - max_matched past_key_values = _crop_past_key_values(self, past_key_values, new_cache_size) # Update the candidate generation strategy if needed candidates_generator.update_candidate_strategy(candidate_length, n_matches) input_ids = torch.cat((input_ids, output_ids), dim=-1) step += output_ids.size(1) step_verify += 1 # Stop on eos and remove content after eos output_ids_list = output_ids[0].tolist() if generation_config.eos_token_id in output_ids_list: idx = output_ids_list.index(generation_config.eos_token_id) step -= (len(output_ids_list) - idx - 1) break step = min(step, max_new_tokens) e2e_toc = time.time() self.n_token_generated = step self.e2e_time_without_first = e2e_toc - e2e_tic return input_ids[:, : input_len + step]