Update lookahead strategy (#11021)
* update lookahead strategy * remove lines * fix python style check
This commit is contained in:
parent
1d73fc8106
commit
93d40ab127
1 changed files with 19 additions and 4 deletions
|
|
@ -24,6 +24,7 @@ from typing import Callable, List, Optional, Tuple
|
|||
import torch
|
||||
import time
|
||||
import copy
|
||||
import random
|
||||
import logging
|
||||
from transformers import GenerationConfig, LogitsProcessorList, StoppingCriteriaList
|
||||
from ipex_llm.transformers.speculative import greedy, deepmind_sample, logits_to_probs,\
|
||||
|
|
@ -127,8 +128,10 @@ class PromptLookupCandidateGenerator():
|
|||
|
||||
if device == "mtl":
|
||||
self.max_candidates = 3
|
||||
self.min_candidates = 0
|
||||
else:
|
||||
self.max_candidates = 9
|
||||
self.min_candidates = 0
|
||||
|
||||
invalidInputError(self.max_matching_ngram_size > 0 and self.num_output_tokens > 0,
|
||||
"Invalid max_matching_ngram_size or num_output_tokens")
|
||||
|
|
@ -148,6 +151,8 @@ class PromptLookupCandidateGenerator():
|
|||
`torch.LongTensor` of shape `(num_candidates, candidate_length)`:
|
||||
The candidate sequences to be tried.
|
||||
"""
|
||||
if self.num_output_tokens == 0:
|
||||
return input_ids, None
|
||||
input_length = input_ids.size(1)
|
||||
|
||||
chosen_ids = None
|
||||
|
|
@ -190,7 +195,7 @@ class PromptLookupCandidateGenerator():
|
|||
# so returning None
|
||||
return candidate_input_ids, None
|
||||
|
||||
def update_candidate_strategy(self, candidate_num: int, num_matches: int):
|
||||
def update_candidate_strategy(self, candidate_num: int, num_matches: int, accept_rate: float):
|
||||
"""
|
||||
Updates the candidate generation strategy based on the outcomes.
|
||||
|
||||
|
|
@ -198,10 +203,16 @@ class PromptLookupCandidateGenerator():
|
|||
num_matches (`int`):
|
||||
The number of matches between the candidate sequences and the model predictions.
|
||||
"""
|
||||
if num_matches == self.num_output_tokens:
|
||||
if self.num_output_tokens == 0:
|
||||
ran = random.random() - 0.15
|
||||
if ran <= accept_rate:
|
||||
self.num_output_tokens = 1
|
||||
elif num_matches == self.num_output_tokens:
|
||||
self.num_output_tokens = min(self.num_output_tokens + 1, self.max_candidates)
|
||||
elif candidate_num > num_matches:
|
||||
self.num_output_tokens = max(self.num_output_tokens - 1, 1)
|
||||
ran = random.random() + 0.1 * (candidate_num - num_matches)
|
||||
if ran > accept_rate:
|
||||
self.num_output_tokens = max(self.num_output_tokens - 1, self.min_candidates)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
|
|
@ -228,6 +239,7 @@ def lookup_generate(self,
|
|||
step_verify = 0
|
||||
|
||||
clear_benchmarks(self)
|
||||
self.accept_rate = []
|
||||
|
||||
past_key_values = None
|
||||
input_len = input_ids.shape[1]
|
||||
|
|
@ -324,8 +336,11 @@ def lookup_generate(self,
|
|||
past_key_values = _crop_past_key_values(self, past_key_values,
|
||||
new_cache_size)
|
||||
|
||||
accept_rate = self.n_matched/self.n_drafted if self.n_drafted > 0 else 1
|
||||
self.accept_rate.append(accept_rate)
|
||||
# Update the candidate generation strategy if needed
|
||||
candidates_generator.update_candidate_strategy(candidate_length, n_matches)
|
||||
candidates_generator.update_candidate_strategy(candidate_length, n_matches,
|
||||
accept_rate)
|
||||
|
||||
input_ids = torch.cat((input_ids, output_ids), dim=-1)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue