From 93d40ab1272a270c90250dcad7e6f401fda2e3bc Mon Sep 17 00:00:00 2001 From: hxsz1997 <45651968+hxsz1997@users.noreply.github.com> Date: Wed, 15 May 2024 14:48:05 +0800 Subject: [PATCH] Update lookahead strategy (#11021) * update lookahead strategy * remove lines * fix python style check --- .../llm/src/ipex_llm/transformers/lookup.py | 23 +++++++++++++++---- 1 file changed, 19 insertions(+), 4 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/lookup.py b/python/llm/src/ipex_llm/transformers/lookup.py index d5423848..1eaaf83a 100644 --- a/python/llm/src/ipex_llm/transformers/lookup.py +++ b/python/llm/src/ipex_llm/transformers/lookup.py @@ -24,6 +24,7 @@ from typing import Callable, List, Optional, Tuple import torch import time import copy +import random import logging from transformers import GenerationConfig, LogitsProcessorList, StoppingCriteriaList from ipex_llm.transformers.speculative import greedy, deepmind_sample, logits_to_probs,\ @@ -127,8 +128,10 @@ class PromptLookupCandidateGenerator(): if device == "mtl": self.max_candidates = 3 + self.min_candidates = 0 else: self.max_candidates = 9 + self.min_candidates = 0 invalidInputError(self.max_matching_ngram_size > 0 and self.num_output_tokens > 0, "Invalid max_matching_ngram_size or num_output_tokens") @@ -148,6 +151,8 @@ class PromptLookupCandidateGenerator(): `torch.LongTensor` of shape `(num_candidates, candidate_length)`: The candidate sequences to be tried. """ + if self.num_output_tokens == 0: + return input_ids, None input_length = input_ids.size(1) chosen_ids = None @@ -190,7 +195,7 @@ class PromptLookupCandidateGenerator(): # so returning None return candidate_input_ids, None - def update_candidate_strategy(self, candidate_num: int, num_matches: int): + def update_candidate_strategy(self, candidate_num: int, num_matches: int, accept_rate: float): """ Updates the candidate generation strategy based on the outcomes. @@ -198,10 +203,16 @@ class PromptLookupCandidateGenerator(): num_matches (`int`): The number of matches between the candidate sequences and the model predictions. """ - if num_matches == self.num_output_tokens: + if self.num_output_tokens == 0: + ran = random.random() - 0.15 + if ran <= accept_rate: + self.num_output_tokens = 1 + elif num_matches == self.num_output_tokens: self.num_output_tokens = min(self.num_output_tokens + 1, self.max_candidates) elif candidate_num > num_matches: - self.num_output_tokens = max(self.num_output_tokens - 1, 1) + ran = random.random() + 0.1 * (candidate_num - num_matches) + if ran > accept_rate: + self.num_output_tokens = max(self.num_output_tokens - 1, self.min_candidates) @torch.no_grad() @@ -228,6 +239,7 @@ def lookup_generate(self, step_verify = 0 clear_benchmarks(self) + self.accept_rate = [] past_key_values = None input_len = input_ids.shape[1] @@ -324,8 +336,11 @@ def lookup_generate(self, past_key_values = _crop_past_key_values(self, past_key_values, new_cache_size) + accept_rate = self.n_matched/self.n_drafted if self.n_drafted > 0 else 1 + self.accept_rate.append(accept_rate) # Update the candidate generation strategy if needed - candidates_generator.update_candidate_strategy(candidate_length, n_matches) + candidates_generator.update_candidate_strategy(candidate_length, n_matches, + accept_rate) input_ids = torch.cat((input_ids, output_ids), dim=-1)