Fix lookahead sample error & add update strategy (#10894)
* Fix sample error & add update strategy * add mtl config * fix style * remove print
This commit is contained in:
parent
94b4e96fa6
commit
015d07a58f
1 changed files with 24 additions and 16 deletions
|
|
@ -29,6 +29,7 @@ from transformers import GenerationConfig, LogitsProcessorList, StoppingCriteria
|
|||
from ipex_llm.transformers.speculative import greedy, deepmind_sample, logits_to_probs,\
|
||||
_crop_past_key_values, _prepare_generate_args, _non_cpu_ipex_verify, clear_benchmarks
|
||||
from ipex_llm.utils.common import invalidInputError
|
||||
from ipex_llm.transformers.utils import get_xpu_device_type
|
||||
|
||||
logger = logging.getLogger("ipex_llm.lookup")
|
||||
|
||||
|
|
@ -119,10 +120,16 @@ class PromptLookupCandidateGenerator():
|
|||
self,
|
||||
num_output_tokens: int = 10,
|
||||
max_matching_ngram_size: int = None,
|
||||
device: str = "arc",
|
||||
):
|
||||
self.num_output_tokens = num_output_tokens
|
||||
self.max_matching_ngram_size = max_matching_ngram_size if max_matching_ngram_size else 2
|
||||
|
||||
if device == "mtl":
|
||||
self.max_candidates = 3
|
||||
else:
|
||||
self.max_candidates = 9
|
||||
|
||||
invalidInputError(self.max_matching_ngram_size > 0 and self.num_output_tokens > 0,
|
||||
"Invalid max_matching_ngram_size or num_output_tokens")
|
||||
|
||||
|
|
@ -183,25 +190,18 @@ class PromptLookupCandidateGenerator():
|
|||
# so returning None
|
||||
return candidate_input_ids, None
|
||||
|
||||
def update_candidate_strategy(self, input_ids: torch.LongTensor,
|
||||
scores: torch.FloatTensor, num_matches: int):
|
||||
def update_candidate_strategy(self, candidate_num: int, num_matches: int):
|
||||
"""
|
||||
Updates the candidate generation strategy based on the outcomes.
|
||||
|
||||
Args:
|
||||
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||
Indices of input sequence tokens in the vocabulary.
|
||||
[What are input IDs?](../glossary#input-ids)
|
||||
scores (`torch.FloatTensor` of shape `(batch_size, candidate_length,
|
||||
config.vocab_size)`):
|
||||
Prediction scores of a language modeling head. These can be logits for each
|
||||
vocabulary when not using beam search or log softmax for each vocabulary
|
||||
token when using beam search
|
||||
num_matches (`int`):
|
||||
The number of matches between the candidate sequences and the model predictions.
|
||||
"""
|
||||
# Currently does nothing
|
||||
return
|
||||
if num_matches == self.num_output_tokens:
|
||||
self.num_output_tokens = min(self.num_output_tokens + 1, self.max_candidates)
|
||||
elif candidate_num > num_matches:
|
||||
self.num_output_tokens = max(self.num_output_tokens - 1, 1)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
|
|
@ -217,9 +217,12 @@ def lookup_generate(self,
|
|||
model_kwargs = _prepare_generate_args(self, inputs, generation_config,
|
||||
**sampling_kwargs)
|
||||
|
||||
device_name = get_xpu_device_type(input_ids)
|
||||
|
||||
candidates_generator = PromptLookupCandidateGenerator(
|
||||
num_output_tokens=num_output_tokens,
|
||||
max_matching_ngram_size=max_matching_ngram_size)
|
||||
max_matching_ngram_size=max_matching_ngram_size,
|
||||
device=device_name)
|
||||
|
||||
step = 0
|
||||
step_verify = 0
|
||||
|
|
@ -291,6 +294,7 @@ def lookup_generate(self,
|
|||
top_k=generation_config.top_k,
|
||||
top_p=generation_config.top_p,
|
||||
temperature=generation_config.temperature)
|
||||
output_ids = output_ids.transpose(0, 1)
|
||||
else:
|
||||
output_ids = greedy(logits)
|
||||
|
||||
|
|
@ -303,13 +307,14 @@ def lookup_generate(self,
|
|||
# Drafts start from [1, k]
|
||||
# Verified output start from [0, k - 1]
|
||||
# including the one generated by the base model
|
||||
max_matched = ((output_ids[:, :-1] != verify_input_ids[:, 1:]).cumsum(-1) == 0)
|
||||
max_matched = max_matched.sum(-1).item() + 1
|
||||
n_matches = ((output_ids[:, :-1] != verify_input_ids[:, 1:])
|
||||
.cumsum(-1) == 0).sum(-1).item()
|
||||
max_matched = n_matches + 1
|
||||
|
||||
max_of_max_matched = output_ids.size(1)
|
||||
# Accept number is max_matched, min is 1
|
||||
self.accept_num.append(max_matched)
|
||||
self.n_matched += max_matched - 1
|
||||
self.n_matched += n_matches
|
||||
self.n_drafted += candidate_length
|
||||
|
||||
# Clean up target model KV cache
|
||||
|
|
@ -319,6 +324,9 @@ def lookup_generate(self,
|
|||
past_key_values = _crop_past_key_values(self, past_key_values,
|
||||
new_cache_size)
|
||||
|
||||
# Update the candidate generation strategy if needed
|
||||
candidates_generator.update_candidate_strategy(candidate_length, n_matches)
|
||||
|
||||
input_ids = torch.cat((input_ids, output_ids), dim=-1)
|
||||
|
||||
step += output_ids.size(1)
|
||||
|
|
|
|||
Loading…
Reference in a new issue