Fix lookahead sample error & add update strategy (#10894)

* Fix sample error & add update strategy

* add mtl config

* fix style

* remove print
This commit is contained in:
Yina Chen 2024-04-28 17:21:00 +08:00 committed by GitHub
parent 94b4e96fa6
commit 015d07a58f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -29,6 +29,7 @@ from transformers import GenerationConfig, LogitsProcessorList, StoppingCriteria
from ipex_llm.transformers.speculative import greedy, deepmind_sample, logits_to_probs,\
_crop_past_key_values, _prepare_generate_args, _non_cpu_ipex_verify, clear_benchmarks
from ipex_llm.utils.common import invalidInputError
from ipex_llm.transformers.utils import get_xpu_device_type
logger = logging.getLogger("ipex_llm.lookup")
@ -119,10 +120,16 @@ class PromptLookupCandidateGenerator():
self,
num_output_tokens: int = 10,
max_matching_ngram_size: int = None,
device: str = "arc",
):
self.num_output_tokens = num_output_tokens
self.max_matching_ngram_size = max_matching_ngram_size if max_matching_ngram_size else 2
if device == "mtl":
self.max_candidates = 3
else:
self.max_candidates = 9
invalidInputError(self.max_matching_ngram_size > 0 and self.num_output_tokens > 0,
"Invalid max_matching_ngram_size or num_output_tokens")
@ -183,25 +190,18 @@ class PromptLookupCandidateGenerator():
# so returning None
return candidate_input_ids, None
def update_candidate_strategy(self, input_ids: torch.LongTensor,
scores: torch.FloatTensor, num_matches: int):
def update_candidate_strategy(self, candidate_num: int, num_matches: int):
"""
Updates the candidate generation strategy based on the outcomes.
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary.
[What are input IDs?](../glossary#input-ids)
scores (`torch.FloatTensor` of shape `(batch_size, candidate_length,
config.vocab_size)`):
Prediction scores of a language modeling head. These can be logits for each
vocabulary when not using beam search or log softmax for each vocabulary
token when using beam search
num_matches (`int`):
The number of matches between the candidate sequences and the model predictions.
"""
# Currently does nothing
return
if num_matches == self.num_output_tokens:
self.num_output_tokens = min(self.num_output_tokens + 1, self.max_candidates)
elif candidate_num > num_matches:
self.num_output_tokens = max(self.num_output_tokens - 1, 1)
@torch.no_grad()
@ -217,9 +217,12 @@ def lookup_generate(self,
model_kwargs = _prepare_generate_args(self, inputs, generation_config,
**sampling_kwargs)
device_name = get_xpu_device_type(input_ids)
candidates_generator = PromptLookupCandidateGenerator(
num_output_tokens=num_output_tokens,
max_matching_ngram_size=max_matching_ngram_size)
max_matching_ngram_size=max_matching_ngram_size,
device=device_name)
step = 0
step_verify = 0
@ -291,6 +294,7 @@ def lookup_generate(self,
top_k=generation_config.top_k,
top_p=generation_config.top_p,
temperature=generation_config.temperature)
output_ids = output_ids.transpose(0, 1)
else:
output_ids = greedy(logits)
@ -303,13 +307,14 @@ def lookup_generate(self,
# Drafts start from [1, k]
# Verified output start from [0, k - 1]
# including the one generated by the base model
max_matched = ((output_ids[:, :-1] != verify_input_ids[:, 1:]).cumsum(-1) == 0)
max_matched = max_matched.sum(-1).item() + 1
n_matches = ((output_ids[:, :-1] != verify_input_ids[:, 1:])
.cumsum(-1) == 0).sum(-1).item()
max_matched = n_matches + 1
max_of_max_matched = output_ids.size(1)
# Accept number is max_matched, min is 1
self.accept_num.append(max_matched)
self.n_matched += max_matched - 1
self.n_matched += n_matches
self.n_drafted += candidate_length
# Clean up target model KV cache
@ -319,6 +324,9 @@ def lookup_generate(self,
past_key_values = _crop_past_key_values(self, past_key_values,
new_cache_size)
# Update the candidate generation strategy if needed
candidates_generator.update_candidate_strategy(candidate_length, n_matches)
input_ids = torch.cat((input_ids, output_ids), dim=-1)
step += output_ids.size(1)