Update lookahead strategy (#11021)
* update lookahead strategy * remove lines * fix python style check
This commit is contained in:
parent
1d73fc8106
commit
93d40ab127
1 changed files with 19 additions and 4 deletions
|
|
@ -24,6 +24,7 @@ from typing import Callable, List, Optional, Tuple
|
||||||
import torch
|
import torch
|
||||||
import time
|
import time
|
||||||
import copy
|
import copy
|
||||||
|
import random
|
||||||
import logging
|
import logging
|
||||||
from transformers import GenerationConfig, LogitsProcessorList, StoppingCriteriaList
|
from transformers import GenerationConfig, LogitsProcessorList, StoppingCriteriaList
|
||||||
from ipex_llm.transformers.speculative import greedy, deepmind_sample, logits_to_probs,\
|
from ipex_llm.transformers.speculative import greedy, deepmind_sample, logits_to_probs,\
|
||||||
|
|
@ -127,8 +128,10 @@ class PromptLookupCandidateGenerator():
|
||||||
|
|
||||||
if device == "mtl":
|
if device == "mtl":
|
||||||
self.max_candidates = 3
|
self.max_candidates = 3
|
||||||
|
self.min_candidates = 0
|
||||||
else:
|
else:
|
||||||
self.max_candidates = 9
|
self.max_candidates = 9
|
||||||
|
self.min_candidates = 0
|
||||||
|
|
||||||
invalidInputError(self.max_matching_ngram_size > 0 and self.num_output_tokens > 0,
|
invalidInputError(self.max_matching_ngram_size > 0 and self.num_output_tokens > 0,
|
||||||
"Invalid max_matching_ngram_size or num_output_tokens")
|
"Invalid max_matching_ngram_size or num_output_tokens")
|
||||||
|
|
@ -148,6 +151,8 @@ class PromptLookupCandidateGenerator():
|
||||||
`torch.LongTensor` of shape `(num_candidates, candidate_length)`:
|
`torch.LongTensor` of shape `(num_candidates, candidate_length)`:
|
||||||
The candidate sequences to be tried.
|
The candidate sequences to be tried.
|
||||||
"""
|
"""
|
||||||
|
if self.num_output_tokens == 0:
|
||||||
|
return input_ids, None
|
||||||
input_length = input_ids.size(1)
|
input_length = input_ids.size(1)
|
||||||
|
|
||||||
chosen_ids = None
|
chosen_ids = None
|
||||||
|
|
@ -190,7 +195,7 @@ class PromptLookupCandidateGenerator():
|
||||||
# so returning None
|
# so returning None
|
||||||
return candidate_input_ids, None
|
return candidate_input_ids, None
|
||||||
|
|
||||||
def update_candidate_strategy(self, candidate_num: int, num_matches: int):
|
def update_candidate_strategy(self, candidate_num: int, num_matches: int, accept_rate: float):
|
||||||
"""
|
"""
|
||||||
Updates the candidate generation strategy based on the outcomes.
|
Updates the candidate generation strategy based on the outcomes.
|
||||||
|
|
||||||
|
|
@ -198,10 +203,16 @@ class PromptLookupCandidateGenerator():
|
||||||
num_matches (`int`):
|
num_matches (`int`):
|
||||||
The number of matches between the candidate sequences and the model predictions.
|
The number of matches between the candidate sequences and the model predictions.
|
||||||
"""
|
"""
|
||||||
if num_matches == self.num_output_tokens:
|
if self.num_output_tokens == 0:
|
||||||
|
ran = random.random() - 0.15
|
||||||
|
if ran <= accept_rate:
|
||||||
|
self.num_output_tokens = 1
|
||||||
|
elif num_matches == self.num_output_tokens:
|
||||||
self.num_output_tokens = min(self.num_output_tokens + 1, self.max_candidates)
|
self.num_output_tokens = min(self.num_output_tokens + 1, self.max_candidates)
|
||||||
elif candidate_num > num_matches:
|
elif candidate_num > num_matches:
|
||||||
self.num_output_tokens = max(self.num_output_tokens - 1, 1)
|
ran = random.random() + 0.1 * (candidate_num - num_matches)
|
||||||
|
if ran > accept_rate:
|
||||||
|
self.num_output_tokens = max(self.num_output_tokens - 1, self.min_candidates)
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
|
|
@ -228,6 +239,7 @@ def lookup_generate(self,
|
||||||
step_verify = 0
|
step_verify = 0
|
||||||
|
|
||||||
clear_benchmarks(self)
|
clear_benchmarks(self)
|
||||||
|
self.accept_rate = []
|
||||||
|
|
||||||
past_key_values = None
|
past_key_values = None
|
||||||
input_len = input_ids.shape[1]
|
input_len = input_ids.shape[1]
|
||||||
|
|
@ -324,8 +336,11 @@ def lookup_generate(self,
|
||||||
past_key_values = _crop_past_key_values(self, past_key_values,
|
past_key_values = _crop_past_key_values(self, past_key_values,
|
||||||
new_cache_size)
|
new_cache_size)
|
||||||
|
|
||||||
|
accept_rate = self.n_matched/self.n_drafted if self.n_drafted > 0 else 1
|
||||||
|
self.accept_rate.append(accept_rate)
|
||||||
# Update the candidate generation strategy if needed
|
# Update the candidate generation strategy if needed
|
||||||
candidates_generator.update_candidate_strategy(candidate_length, n_matches)
|
candidates_generator.update_candidate_strategy(candidate_length, n_matches,
|
||||||
|
accept_rate)
|
||||||
|
|
||||||
input_ids = torch.cat((input_ids, output_ids), dim=-1)
|
input_ids = torch.cat((input_ids, output_ids), dim=-1)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue