Update lookahead strategy (#11021)

* update lookahead strategy

* remove lines

* fix python style check
This commit is contained in:
hxsz1997 2024-05-15 14:48:05 +08:00 committed by GitHub
parent 1d73fc8106
commit 93d40ab127
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -24,6 +24,7 @@ from typing import Callable, List, Optional, Tuple
import torch import torch
import time import time
import copy import copy
import random
import logging import logging
from transformers import GenerationConfig, LogitsProcessorList, StoppingCriteriaList from transformers import GenerationConfig, LogitsProcessorList, StoppingCriteriaList
from ipex_llm.transformers.speculative import greedy, deepmind_sample, logits_to_probs,\ from ipex_llm.transformers.speculative import greedy, deepmind_sample, logits_to_probs,\
@ -127,8 +128,10 @@ class PromptLookupCandidateGenerator():
if device == "mtl": if device == "mtl":
self.max_candidates = 3 self.max_candidates = 3
self.min_candidates = 0
else: else:
self.max_candidates = 9 self.max_candidates = 9
self.min_candidates = 0
invalidInputError(self.max_matching_ngram_size > 0 and self.num_output_tokens > 0, invalidInputError(self.max_matching_ngram_size > 0 and self.num_output_tokens > 0,
"Invalid max_matching_ngram_size or num_output_tokens") "Invalid max_matching_ngram_size or num_output_tokens")
@ -148,6 +151,8 @@ class PromptLookupCandidateGenerator():
`torch.LongTensor` of shape `(num_candidates, candidate_length)`: `torch.LongTensor` of shape `(num_candidates, candidate_length)`:
The candidate sequences to be tried. The candidate sequences to be tried.
""" """
if self.num_output_tokens == 0:
return input_ids, None
input_length = input_ids.size(1) input_length = input_ids.size(1)
chosen_ids = None chosen_ids = None
@ -190,7 +195,7 @@ class PromptLookupCandidateGenerator():
# so returning None # so returning None
return candidate_input_ids, None return candidate_input_ids, None
def update_candidate_strategy(self, candidate_num: int, num_matches: int): def update_candidate_strategy(self, candidate_num: int, num_matches: int, accept_rate: float):
""" """
Updates the candidate generation strategy based on the outcomes. Updates the candidate generation strategy based on the outcomes.
@ -198,10 +203,16 @@ class PromptLookupCandidateGenerator():
num_matches (`int`): num_matches (`int`):
The number of matches between the candidate sequences and the model predictions. The number of matches between the candidate sequences and the model predictions.
""" """
if num_matches == self.num_output_tokens: if self.num_output_tokens == 0:
ran = random.random() - 0.15
if ran <= accept_rate:
self.num_output_tokens = 1
elif num_matches == self.num_output_tokens:
self.num_output_tokens = min(self.num_output_tokens + 1, self.max_candidates) self.num_output_tokens = min(self.num_output_tokens + 1, self.max_candidates)
elif candidate_num > num_matches: elif candidate_num > num_matches:
self.num_output_tokens = max(self.num_output_tokens - 1, 1) ran = random.random() + 0.1 * (candidate_num - num_matches)
if ran > accept_rate:
self.num_output_tokens = max(self.num_output_tokens - 1, self.min_candidates)
@torch.no_grad() @torch.no_grad()
@ -228,6 +239,7 @@ def lookup_generate(self,
step_verify = 0 step_verify = 0
clear_benchmarks(self) clear_benchmarks(self)
self.accept_rate = []
past_key_values = None past_key_values = None
input_len = input_ids.shape[1] input_len = input_ids.shape[1]
@ -324,8 +336,11 @@ def lookup_generate(self,
past_key_values = _crop_past_key_values(self, past_key_values, past_key_values = _crop_past_key_values(self, past_key_values,
new_cache_size) new_cache_size)
accept_rate = self.n_matched/self.n_drafted if self.n_drafted > 0 else 1
self.accept_rate.append(accept_rate)
# Update the candidate generation strategy if needed # Update the candidate generation strategy if needed
candidates_generator.update_candidate_strategy(candidate_length, n_matches) candidates_generator.update_candidate_strategy(candidate_length, n_matches,
accept_rate)
input_ids = torch.cat((input_ids, output_ids), dim=-1) input_ids = torch.cat((input_ids, output_ids), dim=-1)