Support prompt lookup in ipex-llm (#10768)
* lookup init * add lookup * fix style * remove redundant code * change param name * fix style
This commit is contained in:
		
							parent
							
								
									d30b22a81b
								
							
						
					
					
						commit
						899d392e2f
					
				
					 3 changed files with 449 additions and 82 deletions
				
			
		
							
								
								
									
										319
									
								
								python/llm/src/ipex_llm/transformers/lookup.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										319
									
								
								python/llm/src/ipex_llm/transformers/lookup.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,319 @@
 | 
			
		|||
#
 | 
			
		||||
# Copyright 2016 The BigDL Authors.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
#
 | 
			
		||||
# Some parts of this file is adapted from
 | 
			
		||||
# https://github.com/huggingface/transformers/blob/main/src/transformers/generation
 | 
			
		||||
# /candidate_generator.py and
 | 
			
		||||
# https://github.com/huggingface/transformers/blob/main/src/transformers/generation
 | 
			
		||||
# /utils.py
 | 
			
		||||
#
 | 
			
		||||
 | 
			
		||||
from typing import Callable, List, Optional, Tuple
 | 
			
		||||
import torch
 | 
			
		||||
import time
 | 
			
		||||
import copy
 | 
			
		||||
import logging
 | 
			
		||||
from transformers import GenerationConfig, LogitsProcessorList, StoppingCriteriaList
 | 
			
		||||
from ipex_llm.transformers.speculative import greedy, deepmind_sample, logits_to_probs,\
 | 
			
		||||
    _crop_past_key_values, _prepare_generate_args, _non_cpu_ipex_verify, clear_benchmarks
 | 
			
		||||
from ipex_llm.utils.common import invalidInputError
 | 
			
		||||
 | 
			
		||||
logger = logging.getLogger("ipex_llm.lookup")
 | 
			
		||||
 | 
			
		||||
# patch GenerationMixin.generate
 | 
			
		||||
from transformers import GenerationMixin
 | 
			
		||||
original_generate = GenerationMixin.generate
 | 
			
		||||
query_group_size = 16
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@torch.no_grad()
 | 
			
		||||
def generate(
 | 
			
		||||
    self,
 | 
			
		||||
    inputs: Optional[torch.Tensor] = None,
 | 
			
		||||
    generation_config: Optional[GenerationConfig] = None,
 | 
			
		||||
    logits_processor: Optional[LogitsProcessorList] = None,
 | 
			
		||||
    stopping_criteria: Optional[StoppingCriteriaList] = None,
 | 
			
		||||
    prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]]=None,
 | 
			
		||||
    synced_gpus: Optional[bool] = None,
 | 
			
		||||
    assistant_model: Optional["PreTrainedModel"] = None,
 | 
			
		||||
    streamer: Optional["BaseStreamer"] = None,
 | 
			
		||||
    **kwargs,
 | 
			
		||||
):
 | 
			
		||||
    lookahead = kwargs.pop("lookahead", None)
 | 
			
		||||
    if lookahead:
 | 
			
		||||
        from ipex_llm.transformers.convert import get_enable_ipex
 | 
			
		||||
        _enable_ipex = get_enable_ipex()
 | 
			
		||||
 | 
			
		||||
        if self.device.type == "cpu" and _enable_ipex:
 | 
			
		||||
 | 
			
		||||
            logger.warning("Prompt lookup is currently not supported on CPU with IPEX, "
 | 
			
		||||
                           "fallback to original generate.")
 | 
			
		||||
            kwargs.pop("max_matching_ngram_size")
 | 
			
		||||
        else:
 | 
			
		||||
            # Do prompt lookup generation
 | 
			
		||||
            return self.lookup_generate(inputs=inputs,
 | 
			
		||||
                                        num_output_tokens=lookahead,
 | 
			
		||||
                                        generation_config=generation_config,
 | 
			
		||||
                                        logits_processor=logits_processor,
 | 
			
		||||
                                        stopping_criteria=stopping_criteria,
 | 
			
		||||
                                        prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
 | 
			
		||||
                                        **kwargs)
 | 
			
		||||
 | 
			
		||||
    return original_generate(self,
 | 
			
		||||
                             inputs=inputs,
 | 
			
		||||
                             generation_config=generation_config,
 | 
			
		||||
                             logits_processor=logits_processor,
 | 
			
		||||
                             stopping_criteria=stopping_criteria,
 | 
			
		||||
                             prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
 | 
			
		||||
                             synced_gpus=synced_gpus,
 | 
			
		||||
                             assistant_model=assistant_model,
 | 
			
		||||
                             streamer=streamer,
 | 
			
		||||
                             **kwargs)
 | 
			
		||||
 | 
			
		||||
GenerationMixin.generate = generate
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# This class is copied from https://github.com/huggingface/transformers/blob/main/src
 | 
			
		||||
# /transformers/generation/candidate_generator.py
 | 
			
		||||
class PromptLookupCandidateGenerator():
 | 
			
		||||
    """
 | 
			
		||||
    `CandidateGenerator` class to be used for prompt lookup generation.
 | 
			
		||||
    This class generates candidates
 | 
			
		||||
    by looking up
 | 
			
		||||
    likely continuations in the provided prompt (input_ids) itself.
 | 
			
		||||
    Read the following blog post for more information:
 | 
			
		||||
    https://github.com/apoorvumang/prompt-lookup-decoding
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        max_matching_ngram_size (`int`):
 | 
			
		||||
            The maximum ngram size to be considered for matching in the prompt
 | 
			
		||||
        num_output_tokens (`int`):
 | 
			
		||||
            The number of tokens to be output as candidate tokens.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        num_output_tokens: int = 10,
 | 
			
		||||
        max_matching_ngram_size: int = None,
 | 
			
		||||
    ):
 | 
			
		||||
        self.num_output_tokens = num_output_tokens
 | 
			
		||||
        self.max_matching_ngram_size = max_matching_ngram_size if max_matching_ngram_size else 2
 | 
			
		||||
 | 
			
		||||
        invalidInputError(self.max_matching_ngram_size > 0 and self.num_output_tokens > 0,
 | 
			
		||||
                          "Invalid max_matching_ngram_size or num_output_tokens")
 | 
			
		||||
 | 
			
		||||
    def get_candidates(self,
 | 
			
		||||
                       input_ids: torch.LongTensor)-> Tuple[torch.LongTensor,
 | 
			
		||||
                                                            Optional[torch.FloatTensor]]:
 | 
			
		||||
        """
 | 
			
		||||
        Fetches the candidates to be tried for the current input.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
 | 
			
		||||
                Indices of input sequence tokens in the vocabulary.
 | 
			
		||||
                [What are input IDs?](../glossary#input-ids)
 | 
			
		||||
 | 
			
		||||
        Return:
 | 
			
		||||
            `torch.LongTensor` of shape `(num_candidates, candidate_length)`:
 | 
			
		||||
            The candidate sequences to be tried.
 | 
			
		||||
        """
 | 
			
		||||
        input_length = input_ids.size(1)
 | 
			
		||||
 | 
			
		||||
        chosen_ids = None
 | 
			
		||||
        match_found = False
 | 
			
		||||
        for ngram_size in range(min(self.max_matching_ngram_size, input_length - 1), 0, -1):
 | 
			
		||||
            # Create sliding windows of size ngram_size
 | 
			
		||||
            windows = input_ids.unfold(dimension=1, size=ngram_size, step=1)
 | 
			
		||||
 | 
			
		||||
            # Convert ngram to a tensor for comparison
 | 
			
		||||
            ngram_tensor = input_ids[0, -ngram_size:]
 | 
			
		||||
 | 
			
		||||
            # Find where the windows match the ngram
 | 
			
		||||
            matches = (windows == ngram_tensor).all(dim=2)
 | 
			
		||||
 | 
			
		||||
            # Get the indices of matches
 | 
			
		||||
            match_indices = matches.nonzero(as_tuple=True)[1]
 | 
			
		||||
 | 
			
		||||
            # Iterate through match indices to find a valid continuation
 | 
			
		||||
            for idx in match_indices:
 | 
			
		||||
                start_idx = idx + ngram_size
 | 
			
		||||
                end_idx = start_idx + self.num_output_tokens
 | 
			
		||||
                end_idx = min(end_idx, input_length)
 | 
			
		||||
 | 
			
		||||
                if start_idx < end_idx:
 | 
			
		||||
                    chosen_ids = input_ids[0, start_idx:end_idx]
 | 
			
		||||
                    match_found = True
 | 
			
		||||
                    break
 | 
			
		||||
            if match_found:
 | 
			
		||||
                break
 | 
			
		||||
 | 
			
		||||
        if chosen_ids is None or len(chosen_ids) == 0:
 | 
			
		||||
            # In case we didn't find a match return the input sequence unchanged,
 | 
			
		||||
            # reverts back to autoregressive decoding
 | 
			
		||||
            return input_ids, None
 | 
			
		||||
 | 
			
		||||
        # Now need extend input_ids with chosen_ids
 | 
			
		||||
        chosen_ids = chosen_ids.unsqueeze(0)
 | 
			
		||||
        candidate_input_ids = torch.cat((input_ids, chosen_ids), dim=1)
 | 
			
		||||
        # assisted_generation expects logits as well, but we don't have those here,
 | 
			
		||||
        # so returning None
 | 
			
		||||
        return candidate_input_ids, None
 | 
			
		||||
 | 
			
		||||
    def update_candidate_strategy(self, input_ids: torch.LongTensor,
 | 
			
		||||
                                  scores: torch.FloatTensor, num_matches: int):
 | 
			
		||||
        """
 | 
			
		||||
        Updates the candidate generation strategy based on the outcomes.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
 | 
			
		||||
                Indices of input sequence tokens in the vocabulary.
 | 
			
		||||
                [What are input IDs?](../glossary#input-ids)
 | 
			
		||||
            scores (`torch.FloatTensor` of shape `(batch_size, candidate_length,
 | 
			
		||||
                config.vocab_size)`):
 | 
			
		||||
                Prediction scores of a language modeling head. These can be logits for each
 | 
			
		||||
                vocabulary when not using beam search or log softmax for each vocabulary
 | 
			
		||||
                token when using beam search
 | 
			
		||||
            num_matches (`int`):
 | 
			
		||||
                The number of matches between the candidate sequences and the model predictions.
 | 
			
		||||
        """
 | 
			
		||||
        # Currently does nothing
 | 
			
		||||
        return
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@torch.no_grad()
 | 
			
		||||
def lookup_generate(self,
 | 
			
		||||
                    inputs: Optional[torch.Tensor] = None,
 | 
			
		||||
                    max_new_tokens: int = 10,
 | 
			
		||||
                    num_output_tokens: int = 10,
 | 
			
		||||
                    max_matching_ngram_size: int = None,
 | 
			
		||||
                    generation_config: Optional[GenerationConfig] = None,
 | 
			
		||||
                    attention_mask=None,
 | 
			
		||||
                    **sampling_kwargs):
 | 
			
		||||
    input_ids, generation_config, logits_processor, stopping_criteria, \
 | 
			
		||||
        model_kwargs = _prepare_generate_args(self, inputs, generation_config,
 | 
			
		||||
                                              **sampling_kwargs)
 | 
			
		||||
 | 
			
		||||
    candidates_generator = PromptLookupCandidateGenerator(
 | 
			
		||||
        num_output_tokens=num_output_tokens,
 | 
			
		||||
        max_matching_ngram_size=max_matching_ngram_size)
 | 
			
		||||
 | 
			
		||||
    step = 0
 | 
			
		||||
    step_verify = 0
 | 
			
		||||
 | 
			
		||||
    clear_benchmarks(self)
 | 
			
		||||
 | 
			
		||||
    past_key_values = None
 | 
			
		||||
    input_len = input_ids.shape[1]
 | 
			
		||||
 | 
			
		||||
    while True:
 | 
			
		||||
        if step >= max_new_tokens:
 | 
			
		||||
            break
 | 
			
		||||
 | 
			
		||||
        if step == 0:
 | 
			
		||||
            # first token use full model
 | 
			
		||||
            tic = time.time()
 | 
			
		||||
            output = self(input_ids=input_ids,
 | 
			
		||||
                          past_key_values=past_key_values,
 | 
			
		||||
                          attention_mask=attention_mask,
 | 
			
		||||
                          return_dict=True,
 | 
			
		||||
                          use_cache=True)
 | 
			
		||||
            logits = output['logits']
 | 
			
		||||
            logits = logits[:, -1:]
 | 
			
		||||
            logits[:, -1, :] = logits_processor(input_ids, logits[:, -1, :])
 | 
			
		||||
            if generation_config.do_sample:
 | 
			
		||||
                output_ids, prob_list = deepmind_sample(logits,
 | 
			
		||||
                                                        top_k=generation_config.top_k,
 | 
			
		||||
                                                        top_p=generation_config.top_p,
 | 
			
		||||
                                                        temperature=generation_config.temperature)
 | 
			
		||||
            else:
 | 
			
		||||
                output_ids = greedy(logits)
 | 
			
		||||
            input_ids = torch.cat((input_ids, output_ids), dim=-1)
 | 
			
		||||
            past_key_values = output['past_key_values']
 | 
			
		||||
            step += 1
 | 
			
		||||
            if self.device.type == 'xpu':
 | 
			
		||||
                torch.xpu.synchronize()
 | 
			
		||||
            toc = time.time()
 | 
			
		||||
            self.first_token_time = toc - tic
 | 
			
		||||
            e2e_tic = time.time()
 | 
			
		||||
        else:
 | 
			
		||||
            cur_len = input_ids.shape[-1]
 | 
			
		||||
            toc = time.time()
 | 
			
		||||
            candidate_input_ids, _ = candidates_generator.get_candidates(input_ids=input_ids)
 | 
			
		||||
            candidate_length = candidate_input_ids.shape[1] - input_ids.shape[1]
 | 
			
		||||
            verify_input_ids = candidate_input_ids[:, -candidate_length - 1:]
 | 
			
		||||
            self.draft_num.append(candidate_length)
 | 
			
		||||
            tic = time.time()
 | 
			
		||||
            self.draft_time.append(tic - toc)
 | 
			
		||||
            output = _non_cpu_ipex_verify(self, verify_input_ids, past_key_values,
 | 
			
		||||
                                          attention_mask, return_dict=True, use_cache=True)
 | 
			
		||||
            if isinstance(output, dict):
 | 
			
		||||
                logits = output['logits']
 | 
			
		||||
                past_key_values = output['past_key_values']
 | 
			
		||||
 | 
			
		||||
            if len(logits_processor) > 0:
 | 
			
		||||
                for i in range(candidate_length + 1):
 | 
			
		||||
                    logits[:, i, :] = logits_processor(candidate_input_ids[:, : cur_len + i],
 | 
			
		||||
                                                       logits[:, i, :])
 | 
			
		||||
 | 
			
		||||
            if generation_config.do_sample:
 | 
			
		||||
                output_ids, prob_list = deepmind_sample(logits,
 | 
			
		||||
                                                        top_k=generation_config.top_k,
 | 
			
		||||
                                                        top_p=generation_config.top_p,
 | 
			
		||||
                                                        temperature=generation_config.temperature)
 | 
			
		||||
            else:
 | 
			
		||||
                output_ids = greedy(logits)
 | 
			
		||||
 | 
			
		||||
            if self.device.type == 'xpu':
 | 
			
		||||
                torch.xpu.synchronize()
 | 
			
		||||
            toc = time.time()
 | 
			
		||||
            self.verify_time.append(toc - tic)
 | 
			
		||||
 | 
			
		||||
            # Compare drafts with target verified outputs
 | 
			
		||||
            # Drafts start from [1, k]
 | 
			
		||||
            # Verified output start from [0, k - 1]
 | 
			
		||||
            # including the one generated by the base model
 | 
			
		||||
            max_matched = ((output_ids[:, :-1] != verify_input_ids[:, 1:]).cumsum(-1) == 0)
 | 
			
		||||
            max_matched = max_matched.sum(-1).item() + 1
 | 
			
		||||
 | 
			
		||||
            max_of_max_matched = output_ids.size(1)
 | 
			
		||||
            # Accept number is max_matched, min is 1
 | 
			
		||||
            self.accept_num.append(max_matched)
 | 
			
		||||
            self.n_matched += max_matched - 1
 | 
			
		||||
            self.n_drafted += candidate_length
 | 
			
		||||
 | 
			
		||||
            # Clean up target model KV cache
 | 
			
		||||
            if max_of_max_matched != max_matched:
 | 
			
		||||
                output_ids = output_ids[:, :max_matched]
 | 
			
		||||
                new_cache_size = max_of_max_matched - max_matched
 | 
			
		||||
                past_key_values = _crop_past_key_values(self, past_key_values, new_cache_size)
 | 
			
		||||
 | 
			
		||||
            input_ids = torch.cat((input_ids, output_ids), dim=-1)
 | 
			
		||||
 | 
			
		||||
            step += output_ids.size(1)
 | 
			
		||||
            step_verify += 1
 | 
			
		||||
 | 
			
		||||
        # Stop on eos and remove content after eos
 | 
			
		||||
        output_ids_list = output_ids[0].tolist()
 | 
			
		||||
        if generation_config.eos_token_id in output_ids_list:
 | 
			
		||||
            idx = output_ids_list.index(generation_config.eos_token_id)
 | 
			
		||||
            step -= (len(output_ids_list) - idx - 1)
 | 
			
		||||
            break
 | 
			
		||||
 | 
			
		||||
    step = min(step, max_new_tokens)
 | 
			
		||||
    e2e_toc = time.time()
 | 
			
		||||
    self.n_token_generated = step
 | 
			
		||||
    self.e2e_time_without_first = e2e_toc - e2e_tic
 | 
			
		||||
 | 
			
		||||
    return input_ids[:, : input_len + step]
 | 
			
		||||
| 
						 | 
				
			
			@ -330,7 +330,8 @@ class _BaseAutoModelClass:
 | 
			
		|||
            model = cls.load_convert(q_k, optimize_model, *args, **kwargs)
 | 
			
		||||
 | 
			
		||||
            if speculative:
 | 
			
		||||
                from .speculative import speculative_generate, clear_benchmarks
 | 
			
		||||
                from .speculative import speculative_generate, clear_benchmarks,\
 | 
			
		||||
                    _crop_past_key_values
 | 
			
		||||
                # load a sym_int4 model as draft model
 | 
			
		||||
                draft_model = cls.load_convert('sym_int4', optimize_model, *args, **kwargs)
 | 
			
		||||
                model.draft_model = draft_model
 | 
			
		||||
| 
						 | 
				
			
			@ -338,6 +339,12 @@ class _BaseAutoModelClass:
 | 
			
		|||
                # add speculative_generate to pretrained model dynamically
 | 
			
		||||
                model.clear_benchmarks = types.MethodType(clear_benchmarks, model)
 | 
			
		||||
                model.speculative_generate = types.MethodType(speculative_generate, model)
 | 
			
		||||
                model._crop_past_key_values = types.MethodType(_crop_past_key_values, model)
 | 
			
		||||
 | 
			
		||||
            # add lookup_generate to pretrained model
 | 
			
		||||
            from .lookup import lookup_generate
 | 
			
		||||
            import types
 | 
			
		||||
            model.lookup_generate = types.MethodType(lookup_generate, model)
 | 
			
		||||
        else:
 | 
			
		||||
            # load default
 | 
			
		||||
            model = cls.HF_Model.from_pretrained(*args, **kwargs)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -439,26 +439,49 @@ def _check_and_extend_kv_cache(past_key_values, max_step_draft, kv_alloc_block_l
 | 
			
		|||
    return past_key_values, not enough_kv_room
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@torch.no_grad()
 | 
			
		||||
def speculative_generate(self,
 | 
			
		||||
                         inputs: Optional[torch.Tensor] = None,
 | 
			
		||||
                         draft_model=None,
 | 
			
		||||
                         max_new_tokens=10,
 | 
			
		||||
                         max_step_draft=8,
 | 
			
		||||
                         th_stop_draft=0.8,
 | 
			
		||||
                         auto_th_stop_draft=True,
 | 
			
		||||
                         auto_parameters=[1, 0.5, 0.9, 1e-2, 0.9],
 | 
			
		||||
                         hf_adjust=False,
 | 
			
		||||
                         min_step_draft=3,
 | 
			
		||||
                         generation_config: Optional[GenerationConfig] = None,
 | 
			
		||||
                         attention_mask=None,
 | 
			
		||||
                         **sampling_kwargs):
 | 
			
		||||
    invalidInputError(draft_model is not None,
 | 
			
		||||
                      "Draft model should be provided.")
 | 
			
		||||
    # min_step_draft >= 1. Since the max_step_draft may adjust,
 | 
			
		||||
    # min_step_draft can > max_step_draft
 | 
			
		||||
    min_step_draft = min_step_draft if min_step_draft >= 1 else 1
 | 
			
		||||
def _crop_past_key_values(self, past_key_values, new_cache_size, _enable_ipex=False):
 | 
			
		||||
    if _enable_ipex:
 | 
			
		||||
        cur_len = past_key_values[0][0].size(1)
 | 
			
		||||
        delta = new_cache_size
 | 
			
		||||
        tmp = torch.empty(1, (cur_len - delta), (cur_len - delta), 1,
 | 
			
		||||
                          dtype=torch.long).contiguous()
 | 
			
		||||
        past_key_values = [[tmp, key_cache, value_cache, beam_idx]
 | 
			
		||||
                           for _, key_cache, value_cache, beam_idx in past_key_values]
 | 
			
		||||
    else:
 | 
			
		||||
        if self.config.model_type in ["qwen"]:
 | 
			
		||||
            past_key_values = [
 | 
			
		||||
                (k[:, :-(new_cache_size), :],
 | 
			
		||||
                    v[:, :-(new_cache_size), :])
 | 
			
		||||
                for k, v in past_key_values
 | 
			
		||||
            ]
 | 
			
		||||
        elif self.config.model_type == "chatglm":
 | 
			
		||||
            # for chatglm, cache shape is [sl, bs, nh, hn]
 | 
			
		||||
            past_key_values = [
 | 
			
		||||
                (k[:-(new_cache_size), :, :, :],
 | 
			
		||||
                    v[:-(new_cache_size), :, :, :])
 | 
			
		||||
                for k, v in past_key_values
 | 
			
		||||
            ]
 | 
			
		||||
        elif self.config.model_type in ["baichuan", "gptj"]:
 | 
			
		||||
            past_key_values = [
 | 
			
		||||
                (k[:, :, :-(new_cache_size), :],
 | 
			
		||||
                    v[:, :, :-(new_cache_size), :])
 | 
			
		||||
                for k, v in past_key_values
 | 
			
		||||
            ]
 | 
			
		||||
        elif self.config.model_type == "gpt_bigcode":
 | 
			
		||||
            past_key_values = [
 | 
			
		||||
                kv[:, :-(new_cache_size)]
 | 
			
		||||
                for kv in past_key_values
 | 
			
		||||
            ]
 | 
			
		||||
        else:
 | 
			
		||||
            past_key_values = [
 | 
			
		||||
                (k[:, :, :-(new_cache_size)],
 | 
			
		||||
                    v[:, :, :-(new_cache_size)])
 | 
			
		||||
                for k, v in past_key_values
 | 
			
		||||
            ]
 | 
			
		||||
    return past_key_values
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _prepare_generate_args(self, inputs, generation_config, **sampling_kwargs):
 | 
			
		||||
    if generation_config is None:
 | 
			
		||||
        generation_config = self.generation_config
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -494,10 +517,27 @@ def speculative_generate(self,
 | 
			
		|||
    inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(
 | 
			
		||||
        inputs, generation_config.bos_token_id, model_kwargs
 | 
			
		||||
    )
 | 
			
		||||
    batch_size = inputs_tensor.shape[0]
 | 
			
		||||
 | 
			
		||||
    # 4. Define other model kwargs
 | 
			
		||||
    # Removed not used
 | 
			
		||||
    # model_kwargs["output_attentions"] = generation_config.output_attentions
 | 
			
		||||
    # model_kwargs["output_hidden_states"] = generation_config.output_hidden_states
 | 
			
		||||
    # # decoder-only models with inputs_embeds forwarding must use caching
 | 
			
		||||
    # # (otherwise we can't detect whether we are generating the first new token or not,
 | 
			
		||||
    # # and we only want to use the embeddings for the first new token)
 | 
			
		||||
    # if not self.config.is_encoder_decoder and model_input_name == "inputs_embeds":
 | 
			
		||||
    #     model_kwargs["use_cache"] = True
 | 
			
		||||
    # else:
 | 
			
		||||
    #     model_kwargs["use_cache"] = generation_config.use_cache
 | 
			
		||||
 | 
			
		||||
    # accepts_attention_mask = "attention_mask" in set(
 | 
			
		||||
    #     inspect.signature(self.forward).parameters.keys())
 | 
			
		||||
    # requires_attention_mask = "encoder_outputs" not in model_kwargs
 | 
			
		||||
 | 
			
		||||
    # if model_kwargs.get("attention_mask", None) is None and \
 | 
			
		||||
    #         requires_attention_mask and accepts_attention_mask:
 | 
			
		||||
    #     model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
 | 
			
		||||
    #         inputs_tensor, generation_config.pad_token_id, generation_config.eos_token_id
 | 
			
		||||
    #     )
 | 
			
		||||
 | 
			
		||||
    # decoder-only models should use left-padding for generation
 | 
			
		||||
    if not self.config.is_encoder_decoder:
 | 
			
		||||
| 
						 | 
				
			
			@ -543,6 +583,61 @@ def speculative_generate(self,
 | 
			
		|||
        **model_kwargs,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    return input_ids, generation_config, logits_processor, stopping_criteria, model_kwargs
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _non_cpu_ipex_verify(self, verify_input_ids, past_key_values, cur_attention_mask=None,
 | 
			
		||||
                         return_dict=True, use_cache=True):
 | 
			
		||||
    forward_args = {
 | 
			
		||||
        "input_ids": verify_input_ids,
 | 
			
		||||
        "past_key_values": past_key_values,
 | 
			
		||||
        "return_dict": return_dict,
 | 
			
		||||
        "use_cache": use_cache,
 | 
			
		||||
    }
 | 
			
		||||
    if cur_attention_mask:
 | 
			
		||||
        forward_args["attention_mask"] = cur_attention_mask
 | 
			
		||||
 | 
			
		||||
    if self.config.model_type == "chatglm":
 | 
			
		||||
        past_key_value_len = past_key_values[0][0].shape[0]
 | 
			
		||||
        position_ids = torch.arange(verify_input_ids.shape[1], dtype=torch.long,
 | 
			
		||||
                                    device=verify_input_ids.device)
 | 
			
		||||
        position_ids = position_ids.unsqueeze(0).repeat(1, 1) + past_key_value_len
 | 
			
		||||
        forward_args["position_ids"] = position_ids
 | 
			
		||||
    elif self.config.model_type == "gptj":
 | 
			
		||||
        past_length = past_key_values[0][0].size(2)
 | 
			
		||||
        input_len = verify_input_ids.shape[1]
 | 
			
		||||
        position_ids = torch.arange(past_length, input_len + past_length,
 | 
			
		||||
                                    dtype=torch.long, device=verify_input_ids.device)
 | 
			
		||||
        position_ids = position_ids.unsqueeze(0).view(-1, input_len)
 | 
			
		||||
        forward_args["position_ids"] = position_ids
 | 
			
		||||
 | 
			
		||||
    return self(**forward_args)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@torch.no_grad()
 | 
			
		||||
def speculative_generate(self,
 | 
			
		||||
                         inputs: Optional[torch.Tensor] = None,
 | 
			
		||||
                         draft_model=None,
 | 
			
		||||
                         max_new_tokens=10,
 | 
			
		||||
                         max_step_draft=8,
 | 
			
		||||
                         th_stop_draft=0.8,
 | 
			
		||||
                         auto_th_stop_draft=True,
 | 
			
		||||
                         auto_parameters=[1, 0.5, 0.9, 1e-2, 0.9],
 | 
			
		||||
                         hf_adjust=False,
 | 
			
		||||
                         min_step_draft=3,
 | 
			
		||||
                         generation_config: Optional[GenerationConfig] = None,
 | 
			
		||||
                         attention_mask=None,
 | 
			
		||||
                         **sampling_kwargs):
 | 
			
		||||
    invalidInputError(draft_model is not None,
 | 
			
		||||
                      "Draft model should be provided.")
 | 
			
		||||
    # min_step_draft >= 1. Since the max_step_draft may adjust,
 | 
			
		||||
    # min_step_draft can > max_step_draft
 | 
			
		||||
    min_step_draft = min_step_draft if min_step_draft >= 1 else 1
 | 
			
		||||
 | 
			
		||||
    input_ids, generation_config, logits_processor, stopping_criteria, \
 | 
			
		||||
        model_kwargs = _prepare_generate_args(self, inputs, generation_config,
 | 
			
		||||
                                              **sampling_kwargs)
 | 
			
		||||
 | 
			
		||||
    step = 0
 | 
			
		||||
    step_draft = 0
 | 
			
		||||
    step_verify = 0
 | 
			
		||||
| 
						 | 
				
			
			@ -851,27 +946,8 @@ def speculative_generate(self,
 | 
			
		|||
                logits = output[0]
 | 
			
		||||
                past_key_values = output[1]
 | 
			
		||||
            else:
 | 
			
		||||
                forward_args = {
 | 
			
		||||
                    "input_ids": drafted_input_ids,
 | 
			
		||||
                    "past_key_values": past_key_values,
 | 
			
		||||
                    "attention_mask": cur_attention_mask,
 | 
			
		||||
                    "return_dict": True,
 | 
			
		||||
                    "use_cache": True,
 | 
			
		||||
                }
 | 
			
		||||
                if self.config.model_type == "chatglm":
 | 
			
		||||
                    past_key_value_len = past_key_values[0][0].shape[0]
 | 
			
		||||
                    position_ids = torch.arange(drafted_input_ids.shape[1], dtype=torch.long,
 | 
			
		||||
                                                device=drafted_input_ids.device)
 | 
			
		||||
                    position_ids = position_ids.unsqueeze(0).repeat(1, 1) + past_key_value_len
 | 
			
		||||
                    forward_args["position_ids"] = position_ids
 | 
			
		||||
                elif self.config.model_type == "gptj":
 | 
			
		||||
                    past_length = past_key_values[0][0].size(2)
 | 
			
		||||
                    input_len = drafted_input_ids.shape[1]
 | 
			
		||||
                    position_ids = torch.arange(past_length, input_len + past_length,
 | 
			
		||||
                                                dtype=torch.long, device=drafted_input_ids.device)
 | 
			
		||||
                    position_ids = position_ids.unsqueeze(0).view(-1, input_len)
 | 
			
		||||
                    forward_args["position_ids"] = position_ids
 | 
			
		||||
                output = self(**forward_args)
 | 
			
		||||
                output = _non_cpu_ipex_verify(self, drafted_input_ids, past_key_values,
 | 
			
		||||
                                              cur_attention_mask, return_dict=True, use_cache=True)
 | 
			
		||||
            if isinstance(output, dict):
 | 
			
		||||
                logits = output['logits']
 | 
			
		||||
                past_key_values = output['past_key_values']
 | 
			
		||||
| 
						 | 
				
			
			@ -939,45 +1015,10 @@ def speculative_generate(self,
 | 
			
		|||
            # Clean up target model KV cache
 | 
			
		||||
            if max_of_max_matched != max_matched:
 | 
			
		||||
                output_ids = output_ids[:, :max_matched]
 | 
			
		||||
                if _enable_ipex:
 | 
			
		||||
                    cur_len = past_key_values[0][0].size(1)
 | 
			
		||||
                    delta = max_of_max_matched - max_matched
 | 
			
		||||
                    tmp = torch.empty(1, (cur_len - delta), (cur_len - delta), 1,
 | 
			
		||||
                                      dtype=torch.long,
 | 
			
		||||
                                      ).contiguous()
 | 
			
		||||
                    past_key_values = [[tmp, key_cache, value_cache, beam_idx]
 | 
			
		||||
                                       for _, key_cache, value_cache, beam_idx in past_key_values]
 | 
			
		||||
                else:
 | 
			
		||||
                    if self.config.model_type in ["qwen"]:
 | 
			
		||||
                        past_key_values = [
 | 
			
		||||
                            (k[:, :-(max_of_max_matched - max_matched), :],
 | 
			
		||||
                             v[:, :-(max_of_max_matched - max_matched), :])
 | 
			
		||||
                            for k, v in past_key_values
 | 
			
		||||
                        ]
 | 
			
		||||
                    elif self.config.model_type == "chatglm":
 | 
			
		||||
                        # for chatglm, cache shape is [sl, bs, nh, hn]
 | 
			
		||||
                        past_key_values = [
 | 
			
		||||
                            (k[:-(max_of_max_matched - max_matched), :, :, :],
 | 
			
		||||
                             v[:-(max_of_max_matched - max_matched), :, :, :])
 | 
			
		||||
                            for k, v in past_key_values
 | 
			
		||||
                        ]
 | 
			
		||||
                    elif self.config.model_type in ["baichuan", "gptj"]:
 | 
			
		||||
                        past_key_values = [
 | 
			
		||||
                            (k[:, :, :-(max_of_max_matched - max_matched), :],
 | 
			
		||||
                             v[:, :, :-(max_of_max_matched - max_matched), :])
 | 
			
		||||
                            for k, v in past_key_values
 | 
			
		||||
                        ]
 | 
			
		||||
                    elif self.config.model_type == "gpt_bigcode":
 | 
			
		||||
                        past_key_values = [
 | 
			
		||||
                            kv[:, :-(max_of_max_matched - max_matched)]
 | 
			
		||||
                            for kv in past_key_values
 | 
			
		||||
                        ]
 | 
			
		||||
                    else:
 | 
			
		||||
                        past_key_values = [
 | 
			
		||||
                            (k[:, :, :-(max_of_max_matched - max_matched)],
 | 
			
		||||
                             v[:, :, :-(max_of_max_matched - max_matched)])
 | 
			
		||||
                            for k, v in past_key_values
 | 
			
		||||
                        ]
 | 
			
		||||
                new_cache_size = max_of_max_matched - max_matched
 | 
			
		||||
                past_key_values = self._crop_past_key_values(past_key_values,
 | 
			
		||||
                                                             new_cache_size,
 | 
			
		||||
                                                             _enable_ipex)
 | 
			
		||||
 | 
			
		||||
            # Each iter assign new_matched kv_cache to past_key_values1
 | 
			
		||||
            if self.device.type == 'cpu' and (not _enable_ipex):
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue