Update lookahead strategy (#11021)

* update lookahead strategy

* remove lines

* fix python style check
This commit is contained in:
hxsz1997 2024-05-15 14:48:05 +08:00 committed by GitHub
parent 1d73fc8106
commit 93d40ab127
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -24,6 +24,7 @@ from typing import Callable, List, Optional, Tuple
import torch
import time
import copy
import random
import logging
from transformers import GenerationConfig, LogitsProcessorList, StoppingCriteriaList
from ipex_llm.transformers.speculative import greedy, deepmind_sample, logits_to_probs,\
@ -127,8 +128,10 @@ class PromptLookupCandidateGenerator():
if device == "mtl":
self.max_candidates = 3
self.min_candidates = 0
else:
self.max_candidates = 9
self.min_candidates = 0
invalidInputError(self.max_matching_ngram_size > 0 and self.num_output_tokens > 0,
"Invalid max_matching_ngram_size or num_output_tokens")
@ -148,6 +151,8 @@ class PromptLookupCandidateGenerator():
`torch.LongTensor` of shape `(num_candidates, candidate_length)`:
The candidate sequences to be tried.
"""
if self.num_output_tokens == 0:
return input_ids, None
input_length = input_ids.size(1)
chosen_ids = None
@ -190,7 +195,7 @@ class PromptLookupCandidateGenerator():
# so returning None
return candidate_input_ids, None
def update_candidate_strategy(self, candidate_num: int, num_matches: int):
def update_candidate_strategy(self, candidate_num: int, num_matches: int, accept_rate: float):
"""
Updates the candidate generation strategy based on the outcomes.
@ -198,10 +203,16 @@ class PromptLookupCandidateGenerator():
num_matches (`int`):
The number of matches between the candidate sequences and the model predictions.
"""
if num_matches == self.num_output_tokens:
if self.num_output_tokens == 0:
ran = random.random() - 0.15
if ran <= accept_rate:
self.num_output_tokens = 1
elif num_matches == self.num_output_tokens:
self.num_output_tokens = min(self.num_output_tokens + 1, self.max_candidates)
elif candidate_num > num_matches:
self.num_output_tokens = max(self.num_output_tokens - 1, 1)
ran = random.random() + 0.1 * (candidate_num - num_matches)
if ran > accept_rate:
self.num_output_tokens = max(self.num_output_tokens - 1, self.min_candidates)
@torch.no_grad()
@ -228,6 +239,7 @@ def lookup_generate(self,
step_verify = 0
clear_benchmarks(self)
self.accept_rate = []
past_key_values = None
input_len = input_ids.shape[1]
@ -324,8 +336,11 @@ def lookup_generate(self,
past_key_values = _crop_past_key_values(self, past_key_values,
new_cache_size)
accept_rate = self.n_matched/self.n_drafted if self.n_drafted > 0 else 1
self.accept_rate.append(accept_rate)
# Update the candidate generation strategy if needed
candidates_generator.update_candidate_strategy(candidate_length, n_matches)
candidates_generator.update_candidate_strategy(candidate_length, n_matches,
accept_rate)
input_ids = torch.cat((input_ids, output_ids), dim=-1)