Fix lookahead sample error & add update strategy (#10894)
* Fix sample error & add update strategy * add mtl config * fix style * remove print
This commit is contained in:
parent
94b4e96fa6
commit
015d07a58f
1 changed files with 24 additions and 16 deletions
|
|
@ -29,6 +29,7 @@ from transformers import GenerationConfig, LogitsProcessorList, StoppingCriteria
|
||||||
from ipex_llm.transformers.speculative import greedy, deepmind_sample, logits_to_probs,\
|
from ipex_llm.transformers.speculative import greedy, deepmind_sample, logits_to_probs,\
|
||||||
_crop_past_key_values, _prepare_generate_args, _non_cpu_ipex_verify, clear_benchmarks
|
_crop_past_key_values, _prepare_generate_args, _non_cpu_ipex_verify, clear_benchmarks
|
||||||
from ipex_llm.utils.common import invalidInputError
|
from ipex_llm.utils.common import invalidInputError
|
||||||
|
from ipex_llm.transformers.utils import get_xpu_device_type
|
||||||
|
|
||||||
logger = logging.getLogger("ipex_llm.lookup")
|
logger = logging.getLogger("ipex_llm.lookup")
|
||||||
|
|
||||||
|
|
@ -119,10 +120,16 @@ class PromptLookupCandidateGenerator():
|
||||||
self,
|
self,
|
||||||
num_output_tokens: int = 10,
|
num_output_tokens: int = 10,
|
||||||
max_matching_ngram_size: int = None,
|
max_matching_ngram_size: int = None,
|
||||||
|
device: str = "arc",
|
||||||
):
|
):
|
||||||
self.num_output_tokens = num_output_tokens
|
self.num_output_tokens = num_output_tokens
|
||||||
self.max_matching_ngram_size = max_matching_ngram_size if max_matching_ngram_size else 2
|
self.max_matching_ngram_size = max_matching_ngram_size if max_matching_ngram_size else 2
|
||||||
|
|
||||||
|
if device == "mtl":
|
||||||
|
self.max_candidates = 3
|
||||||
|
else:
|
||||||
|
self.max_candidates = 9
|
||||||
|
|
||||||
invalidInputError(self.max_matching_ngram_size > 0 and self.num_output_tokens > 0,
|
invalidInputError(self.max_matching_ngram_size > 0 and self.num_output_tokens > 0,
|
||||||
"Invalid max_matching_ngram_size or num_output_tokens")
|
"Invalid max_matching_ngram_size or num_output_tokens")
|
||||||
|
|
||||||
|
|
@ -183,25 +190,18 @@ class PromptLookupCandidateGenerator():
|
||||||
# so returning None
|
# so returning None
|
||||||
return candidate_input_ids, None
|
return candidate_input_ids, None
|
||||||
|
|
||||||
def update_candidate_strategy(self, input_ids: torch.LongTensor,
|
def update_candidate_strategy(self, candidate_num: int, num_matches: int):
|
||||||
scores: torch.FloatTensor, num_matches: int):
|
|
||||||
"""
|
"""
|
||||||
Updates the candidate generation strategy based on the outcomes.
|
Updates the candidate generation strategy based on the outcomes.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
|
||||||
Indices of input sequence tokens in the vocabulary.
|
|
||||||
[What are input IDs?](../glossary#input-ids)
|
|
||||||
scores (`torch.FloatTensor` of shape `(batch_size, candidate_length,
|
|
||||||
config.vocab_size)`):
|
|
||||||
Prediction scores of a language modeling head. These can be logits for each
|
|
||||||
vocabulary when not using beam search or log softmax for each vocabulary
|
|
||||||
token when using beam search
|
|
||||||
num_matches (`int`):
|
num_matches (`int`):
|
||||||
The number of matches between the candidate sequences and the model predictions.
|
The number of matches between the candidate sequences and the model predictions.
|
||||||
"""
|
"""
|
||||||
# Currently does nothing
|
if num_matches == self.num_output_tokens:
|
||||||
return
|
self.num_output_tokens = min(self.num_output_tokens + 1, self.max_candidates)
|
||||||
|
elif candidate_num > num_matches:
|
||||||
|
self.num_output_tokens = max(self.num_output_tokens - 1, 1)
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
|
|
@ -217,9 +217,12 @@ def lookup_generate(self,
|
||||||
model_kwargs = _prepare_generate_args(self, inputs, generation_config,
|
model_kwargs = _prepare_generate_args(self, inputs, generation_config,
|
||||||
**sampling_kwargs)
|
**sampling_kwargs)
|
||||||
|
|
||||||
|
device_name = get_xpu_device_type(input_ids)
|
||||||
|
|
||||||
candidates_generator = PromptLookupCandidateGenerator(
|
candidates_generator = PromptLookupCandidateGenerator(
|
||||||
num_output_tokens=num_output_tokens,
|
num_output_tokens=num_output_tokens,
|
||||||
max_matching_ngram_size=max_matching_ngram_size)
|
max_matching_ngram_size=max_matching_ngram_size,
|
||||||
|
device=device_name)
|
||||||
|
|
||||||
step = 0
|
step = 0
|
||||||
step_verify = 0
|
step_verify = 0
|
||||||
|
|
@ -291,6 +294,7 @@ def lookup_generate(self,
|
||||||
top_k=generation_config.top_k,
|
top_k=generation_config.top_k,
|
||||||
top_p=generation_config.top_p,
|
top_p=generation_config.top_p,
|
||||||
temperature=generation_config.temperature)
|
temperature=generation_config.temperature)
|
||||||
|
output_ids = output_ids.transpose(0, 1)
|
||||||
else:
|
else:
|
||||||
output_ids = greedy(logits)
|
output_ids = greedy(logits)
|
||||||
|
|
||||||
|
|
@ -303,13 +307,14 @@ def lookup_generate(self,
|
||||||
# Drafts start from [1, k]
|
# Drafts start from [1, k]
|
||||||
# Verified output start from [0, k - 1]
|
# Verified output start from [0, k - 1]
|
||||||
# including the one generated by the base model
|
# including the one generated by the base model
|
||||||
max_matched = ((output_ids[:, :-1] != verify_input_ids[:, 1:]).cumsum(-1) == 0)
|
n_matches = ((output_ids[:, :-1] != verify_input_ids[:, 1:])
|
||||||
max_matched = max_matched.sum(-1).item() + 1
|
.cumsum(-1) == 0).sum(-1).item()
|
||||||
|
max_matched = n_matches + 1
|
||||||
|
|
||||||
max_of_max_matched = output_ids.size(1)
|
max_of_max_matched = output_ids.size(1)
|
||||||
# Accept number is max_matched, min is 1
|
# Accept number is max_matched, min is 1
|
||||||
self.accept_num.append(max_matched)
|
self.accept_num.append(max_matched)
|
||||||
self.n_matched += max_matched - 1
|
self.n_matched += n_matches
|
||||||
self.n_drafted += candidate_length
|
self.n_drafted += candidate_length
|
||||||
|
|
||||||
# Clean up target model KV cache
|
# Clean up target model KV cache
|
||||||
|
|
@ -319,6 +324,9 @@ def lookup_generate(self,
|
||||||
past_key_values = _crop_past_key_values(self, past_key_values,
|
past_key_values = _crop_past_key_values(self, past_key_values,
|
||||||
new_cache_size)
|
new_cache_size)
|
||||||
|
|
||||||
|
# Update the candidate generation strategy if needed
|
||||||
|
candidates_generator.update_candidate_strategy(candidate_length, n_matches)
|
||||||
|
|
||||||
input_ids = torch.cat((input_ids, output_ids), dim=-1)
|
input_ids = torch.cat((input_ids, output_ids), dim=-1)
|
||||||
|
|
||||||
step += output_ids.size(1)
|
step += output_ids.size(1)
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue