Add deepmind sample into bigdl-llm speculative decoding (#10041)
* migrate deepmind sample * update * meet comments * fix style * fix style
This commit is contained in:
parent
72e67eedbb
commit
3ca03d4e97
1 changed files with 99 additions and 38 deletions
|
|
@ -80,31 +80,43 @@ def generate(
|
|||
GenerationMixin.generate = generate
|
||||
|
||||
|
||||
def sample(logits, return_probs: bool=False, do_sample: bool=False, top_k: int=50,
|
||||
top_p: float=0.7, temperature: float=0.7):
|
||||
|
||||
def greedy(logits, return_probs: bool=False):
|
||||
if return_probs:
|
||||
all_probs = logits.softmax(-1)
|
||||
if do_sample and top_k != 1 and top_p != 0.0 and temperature != 0.0:
|
||||
_logits = top_k_top_p_filtering(logits.view(-1, logits.size(-1)) / temperature,
|
||||
top_k=top_k, top_p=top_p)
|
||||
output_ids = torch.multinomial(_logits.softmax(-1),
|
||||
num_samples=1).view(logits.shape[:-1])
|
||||
probs = torch.gather(all_probs, -1, output_ids.unsqueeze(-1)).squeeze(-1)
|
||||
else:
|
||||
probs, output_ids = torch.max(all_probs, dim=-1)
|
||||
return output_ids, probs
|
||||
else:
|
||||
if do_sample and top_k != 1 and top_p != 0.0 and temperature != 0.0:
|
||||
_logits = top_k_top_p_filtering(logits.view(-1, logits.size(-1)) / temperature,
|
||||
top_k=top_k, top_p=top_p)
|
||||
output_ids = torch.multinomial(_logits.softmax(-1),
|
||||
num_samples=1).view(logits.shape[:-1])
|
||||
else:
|
||||
output_ids = torch.argmax(logits, dim=-1)
|
||||
return output_ids
|
||||
|
||||
|
||||
def deepmind_sample(logits, return_probs: bool=False, top_k: int=50,
|
||||
top_p: float=0.7, temperature: float=0.7):
|
||||
prob_list = logits_to_probs(logits, top_k=top_k, top_p=top_p, temperature=temperature)
|
||||
output_ids = multinomial_sample_one_no_sync(prob_list)
|
||||
if return_probs:
|
||||
all_probs = logits.softmax(-1)
|
||||
probs = torch.gather(all_probs, -1, output_ids.unsqueeze(-1)).squeeze(-1)
|
||||
return output_ids, prob_list, probs
|
||||
else:
|
||||
return output_ids, prob_list
|
||||
|
||||
|
||||
def logits_to_probs(logits, top_k: int=50, top_p: float=0.7, temperature: float=0.7):
|
||||
invalidInputError(top_k != 1 and top_p != 0.0 and temperature != 0.0,
|
||||
"top_k != 1 and top_p != 0.0 and temperature != 0.0 if do_sample=True")
|
||||
_logits = top_k_top_p_filtering(logits.view(-1, logits.size(-1)) / temperature,
|
||||
top_k=top_k, top_p=top_p)
|
||||
prob_list = _logits.softmax(-1)
|
||||
|
||||
return prob_list
|
||||
|
||||
|
||||
def multinomial_sample_one_no_sync(probs_sort):
|
||||
q = torch.empty_like(probs_sort).exponential_(1)
|
||||
return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int64)
|
||||
|
||||
|
||||
def clear_benchmarks(self):
|
||||
self.first_token_time = 0
|
||||
self.generate_time = []
|
||||
|
|
@ -395,9 +407,13 @@ def speculative_generate(self,
|
|||
logits = output['logits']
|
||||
logits = logits[:, -1:]
|
||||
logits[:, -1, :] = logits_processor(current_input_ids, logits[:, -1, :])
|
||||
output_ids = sample(logits, do_sample=generation_config.do_sample,
|
||||
top_k=generation_config.top_k, top_p=generation_config.top_p,
|
||||
if generation_config.do_sample:
|
||||
output_ids, prob_list = deepmind_sample(logits,
|
||||
top_k=generation_config.top_k,
|
||||
top_p=generation_config.top_p,
|
||||
temperature=generation_config.temperature)
|
||||
else:
|
||||
output_ids = greedy(logits)
|
||||
generate_ids[:, step] = output_ids
|
||||
current_input_ids = output_ids
|
||||
past_key_values = output['past_key_values']
|
||||
|
|
@ -425,7 +441,11 @@ def speculative_generate(self,
|
|||
else:
|
||||
draft_past_key_values = past_key_values
|
||||
draft_generate_ids[:, 0] = current_input_ids
|
||||
draft_prob_list = []
|
||||
tic = time.time()
|
||||
random_probs = None
|
||||
if generation_config.do_sample:
|
||||
random_probs = torch.rand(max_step_draft, device=self.device, dtype=self.dtype)
|
||||
# Draft model auto-regressively generate k tokens
|
||||
# Early stop when prob less then th_stop_draft
|
||||
for step_draft in range(max_step_draft):
|
||||
|
|
@ -455,19 +475,25 @@ def speculative_generate(self,
|
|||
logits = draft_output['logits']
|
||||
logits[:, -1, :] = logits_processor(temp_input_ids,
|
||||
draft_output['logits'][:, -1, :])
|
||||
draft_output_ids, draft_output_probs = sample(
|
||||
if generation_config.do_sample:
|
||||
draft_output_ids, draft_probs, draft_output_probs = deepmind_sample(
|
||||
logits,
|
||||
return_probs=True,
|
||||
do_sample=generation_config.do_sample,
|
||||
top_k=generation_config.top_k,
|
||||
top_p=generation_config.top_p,
|
||||
temperature=generation_config.temperature)
|
||||
draft_prob_list.append(draft_probs)
|
||||
else:
|
||||
draft_output_ids, draft_output_probs = greedy(
|
||||
logits,
|
||||
return_probs=True)
|
||||
draft_generate_ids[:, step_draft+1] = draft_output_ids
|
||||
draft_current_input_ids = draft_output_ids
|
||||
draft_past_key_values = draft_output['past_key_values']
|
||||
# check if draft prob is less then th_stop_draft
|
||||
# Draft number + step >= max output token number
|
||||
if draft_output_probs.item() < th_stop_draft or \
|
||||
th_random = 1 if random_probs is None else random_probs[step_draft]
|
||||
if (draft_output_probs.item() < th_stop_draft and th_random > 0.3) or \
|
||||
step + step_draft + 2 >= max_new_tokens:
|
||||
break
|
||||
if self.device.type == 'xpu':
|
||||
|
|
@ -533,21 +559,56 @@ def speculative_generate(self,
|
|||
for i in range(logits.size(1)):
|
||||
logits[:, i, :] = logits_processor(temp_input_ids[:, :input_ids.size(1)+step+i],
|
||||
logits[:, i, :])
|
||||
output_ids = sample(logits, do_sample=generation_config.do_sample,
|
||||
top_k=generation_config.top_k, top_p=generation_config.top_p,
|
||||
if generation_config.do_sample:
|
||||
target_probs = logits_to_probs(logits,
|
||||
top_k=generation_config.top_k,
|
||||
top_p=generation_config.top_p,
|
||||
temperature=generation_config.temperature)
|
||||
else:
|
||||
output_ids = greedy(logits)
|
||||
if self.device.type == 'xpu':
|
||||
torch.xpu.synchronize()
|
||||
toc = time.time()
|
||||
self.verify_time.append(toc - tic)
|
||||
self.generate_time.append(self.draft_time[-1] + self.verify_time[-1])
|
||||
|
||||
past_key_values = output['past_key_values']
|
||||
|
||||
if generation_config.do_sample:
|
||||
draft_tokens = drafted_input_ids[:, 1:].squeeze(0)
|
||||
draft_probs = torch.stack(draft_prob_list).squeeze((1, 2))
|
||||
|
||||
# q: target prob, p: draft prob
|
||||
# q >= p: always accept draft token
|
||||
# q < p: q/p prob to accept draft token
|
||||
p = draft_probs[torch.arange(0, drafted_n_tokens), draft_tokens]
|
||||
q = target_probs[torch.arange(0, drafted_n_tokens), draft_tokens]
|
||||
accept_draft_prob = torch.minimum(torch.ones(()), q[:drafted_n_tokens] / p)
|
||||
rejected_locations = (random_probs[:drafted_n_tokens] > accept_draft_prob).nonzero()
|
||||
|
||||
if rejected_locations.shape[0] == 0: # All draft tokens have been accepted
|
||||
max_matched = drafted_n_tokens + 1
|
||||
last_token = multinomial_sample_one_no_sync(target_probs[-1])
|
||||
output_ids = torch.cat([draft_tokens, last_token])
|
||||
else:
|
||||
max_matched = rejected_locations[0].item()
|
||||
p = draft_probs[max_matched]
|
||||
q = target_probs[max_matched]
|
||||
resample_prob = q - p
|
||||
resample_prob = torch.where(resample_prob > 0, resample_prob, 0.0)
|
||||
resample_prob = resample_prob / resample_prob.sum()
|
||||
next_token = multinomial_sample_one_no_sync(resample_prob)
|
||||
output_ids = torch.cat([draft_tokens[:max_matched], next_token])
|
||||
max_matched += 1
|
||||
output_ids = output_ids.unsqueeze(0)
|
||||
else:
|
||||
# Compare drafts with target verified outputs
|
||||
# Drafts start from [1, k]
|
||||
# Verified output start from [0, k - 1]
|
||||
# including the one generated by the base model
|
||||
max_matched = ((output_ids[:, :-1] != drafted_input_ids[:, 1:]).cumsum(-1) == 0)
|
||||
max_matched = max_matched.sum(-1).item() + 1
|
||||
|
||||
max_of_max_matched = output_ids.size(1)
|
||||
# Accept number is max_matched, min is 1
|
||||
self.accept_num.append(max_matched)
|
||||
|
|
|
|||
Loading…
Reference in a new issue