Add deepmind sample into bigdl-llm speculative decoding (#10041)
* migrate deepmind sample * update * meet comments * fix style * fix style
This commit is contained in:
		
							parent
							
								
									72e67eedbb
								
							
						
					
					
						commit
						3ca03d4e97
					
				
					 1 changed files with 99 additions and 38 deletions
				
			
		| 
						 | 
				
			
			@ -80,31 +80,43 @@ def generate(
 | 
			
		|||
GenerationMixin.generate = generate
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def sample(logits, return_probs: bool=False, do_sample: bool=False, top_k: int=50,
 | 
			
		||||
           top_p: float=0.7, temperature: float=0.7):
 | 
			
		||||
 | 
			
		||||
def greedy(logits, return_probs: bool=False):
 | 
			
		||||
    if return_probs:
 | 
			
		||||
        all_probs = logits.softmax(-1)
 | 
			
		||||
        if do_sample and top_k != 1 and top_p != 0.0 and temperature != 0.0:
 | 
			
		||||
            _logits = top_k_top_p_filtering(logits.view(-1, logits.size(-1)) / temperature,
 | 
			
		||||
                                            top_k=top_k, top_p=top_p)
 | 
			
		||||
            output_ids = torch.multinomial(_logits.softmax(-1),
 | 
			
		||||
                                           num_samples=1).view(logits.shape[:-1])
 | 
			
		||||
            probs = torch.gather(all_probs, -1, output_ids.unsqueeze(-1)).squeeze(-1)
 | 
			
		||||
        else:
 | 
			
		||||
            probs, output_ids = torch.max(all_probs, dim=-1)
 | 
			
		||||
        probs, output_ids = torch.max(all_probs, dim=-1)
 | 
			
		||||
        return output_ids, probs
 | 
			
		||||
    else:
 | 
			
		||||
        if do_sample and top_k != 1 and top_p != 0.0 and temperature != 0.0:
 | 
			
		||||
            _logits = top_k_top_p_filtering(logits.view(-1, logits.size(-1)) / temperature,
 | 
			
		||||
                                            top_k=top_k, top_p=top_p)
 | 
			
		||||
            output_ids = torch.multinomial(_logits.softmax(-1),
 | 
			
		||||
                                           num_samples=1).view(logits.shape[:-1])
 | 
			
		||||
        else:
 | 
			
		||||
            output_ids = torch.argmax(logits, dim=-1)
 | 
			
		||||
        output_ids = torch.argmax(logits, dim=-1)
 | 
			
		||||
        return output_ids
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def deepmind_sample(logits, return_probs: bool=False, top_k: int=50,
 | 
			
		||||
                    top_p: float=0.7, temperature: float=0.7):
 | 
			
		||||
    prob_list = logits_to_probs(logits, top_k=top_k, top_p=top_p, temperature=temperature)
 | 
			
		||||
    output_ids = multinomial_sample_one_no_sync(prob_list)
 | 
			
		||||
    if return_probs:
 | 
			
		||||
        all_probs = logits.softmax(-1)
 | 
			
		||||
        probs = torch.gather(all_probs, -1, output_ids.unsqueeze(-1)).squeeze(-1)
 | 
			
		||||
        return output_ids, prob_list, probs
 | 
			
		||||
    else:
 | 
			
		||||
        return output_ids, prob_list
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def logits_to_probs(logits, top_k: int=50, top_p: float=0.7, temperature: float=0.7):
 | 
			
		||||
    invalidInputError(top_k != 1 and top_p != 0.0 and temperature != 0.0,
 | 
			
		||||
                      "top_k != 1 and top_p != 0.0 and temperature != 0.0 if do_sample=True")
 | 
			
		||||
    _logits = top_k_top_p_filtering(logits.view(-1, logits.size(-1)) / temperature,
 | 
			
		||||
                                    top_k=top_k, top_p=top_p)
 | 
			
		||||
    prob_list = _logits.softmax(-1)
 | 
			
		||||
 | 
			
		||||
    return prob_list
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def multinomial_sample_one_no_sync(probs_sort):
 | 
			
		||||
    q = torch.empty_like(probs_sort).exponential_(1)
 | 
			
		||||
    return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int64)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def clear_benchmarks(self):
 | 
			
		||||
    self.first_token_time = 0
 | 
			
		||||
    self.generate_time = []
 | 
			
		||||
| 
						 | 
				
			
			@ -395,9 +407,13 @@ def speculative_generate(self,
 | 
			
		|||
            logits = output['logits']
 | 
			
		||||
            logits = logits[:, -1:]
 | 
			
		||||
            logits[:, -1, :] = logits_processor(current_input_ids, logits[:, -1, :])
 | 
			
		||||
            output_ids = sample(logits, do_sample=generation_config.do_sample,
 | 
			
		||||
                                top_k=generation_config.top_k, top_p=generation_config.top_p,
 | 
			
		||||
                                temperature=generation_config.temperature)
 | 
			
		||||
            if generation_config.do_sample:
 | 
			
		||||
                output_ids, prob_list = deepmind_sample(logits,
 | 
			
		||||
                                                        top_k=generation_config.top_k,
 | 
			
		||||
                                                        top_p=generation_config.top_p,
 | 
			
		||||
                                                        temperature=generation_config.temperature)
 | 
			
		||||
            else:
 | 
			
		||||
                output_ids = greedy(logits)
 | 
			
		||||
            generate_ids[:, step] = output_ids
 | 
			
		||||
            current_input_ids = output_ids
 | 
			
		||||
            past_key_values = output['past_key_values']
 | 
			
		||||
| 
						 | 
				
			
			@ -425,7 +441,11 @@ def speculative_generate(self,
 | 
			
		|||
            else:
 | 
			
		||||
                draft_past_key_values = past_key_values
 | 
			
		||||
            draft_generate_ids[:, 0] = current_input_ids
 | 
			
		||||
            draft_prob_list = []
 | 
			
		||||
            tic = time.time()
 | 
			
		||||
            random_probs = None
 | 
			
		||||
            if generation_config.do_sample:
 | 
			
		||||
                random_probs = torch.rand(max_step_draft, device=self.device, dtype=self.dtype)
 | 
			
		||||
            # Draft model auto-regressively generate k tokens
 | 
			
		||||
            # Early stop when prob less then th_stop_draft
 | 
			
		||||
            for step_draft in range(max_step_draft):
 | 
			
		||||
| 
						 | 
				
			
			@ -455,19 +475,25 @@ def speculative_generate(self,
 | 
			
		|||
                logits = draft_output['logits']
 | 
			
		||||
                logits[:, -1, :] = logits_processor(temp_input_ids,
 | 
			
		||||
                                                    draft_output['logits'][:, -1, :])
 | 
			
		||||
                draft_output_ids, draft_output_probs = sample(
 | 
			
		||||
                    logits,
 | 
			
		||||
                    return_probs=True,
 | 
			
		||||
                    do_sample=generation_config.do_sample,
 | 
			
		||||
                    top_k=generation_config.top_k,
 | 
			
		||||
                    top_p=generation_config.top_p,
 | 
			
		||||
                    temperature=generation_config.temperature)
 | 
			
		||||
                if generation_config.do_sample:
 | 
			
		||||
                    draft_output_ids, draft_probs, draft_output_probs = deepmind_sample(
 | 
			
		||||
                        logits,
 | 
			
		||||
                        return_probs=True,
 | 
			
		||||
                        top_k=generation_config.top_k,
 | 
			
		||||
                        top_p=generation_config.top_p,
 | 
			
		||||
                        temperature=generation_config.temperature)
 | 
			
		||||
                    draft_prob_list.append(draft_probs)
 | 
			
		||||
                else:
 | 
			
		||||
                    draft_output_ids, draft_output_probs = greedy(
 | 
			
		||||
                        logits,
 | 
			
		||||
                        return_probs=True)
 | 
			
		||||
                draft_generate_ids[:, step_draft+1] = draft_output_ids
 | 
			
		||||
                draft_current_input_ids = draft_output_ids
 | 
			
		||||
                draft_past_key_values = draft_output['past_key_values']
 | 
			
		||||
                # check if draft prob is less then th_stop_draft
 | 
			
		||||
                # Draft number + step >= max output token number
 | 
			
		||||
                if draft_output_probs.item() < th_stop_draft or \
 | 
			
		||||
                th_random = 1 if random_probs is None else random_probs[step_draft]
 | 
			
		||||
                if (draft_output_probs.item() < th_stop_draft and th_random > 0.3) or \
 | 
			
		||||
                        step + step_draft + 2 >= max_new_tokens:
 | 
			
		||||
                    break
 | 
			
		||||
            if self.device.type == 'xpu':
 | 
			
		||||
| 
						 | 
				
			
			@ -533,21 +559,56 @@ def speculative_generate(self,
 | 
			
		|||
            for i in range(logits.size(1)):
 | 
			
		||||
                logits[:, i, :] = logits_processor(temp_input_ids[:, :input_ids.size(1)+step+i],
 | 
			
		||||
                                                   logits[:, i, :])
 | 
			
		||||
            output_ids = sample(logits, do_sample=generation_config.do_sample,
 | 
			
		||||
                                top_k=generation_config.top_k, top_p=generation_config.top_p,
 | 
			
		||||
                                temperature=generation_config.temperature)
 | 
			
		||||
            if generation_config.do_sample:
 | 
			
		||||
                target_probs = logits_to_probs(logits,
 | 
			
		||||
                                               top_k=generation_config.top_k,
 | 
			
		||||
                                               top_p=generation_config.top_p,
 | 
			
		||||
                                               temperature=generation_config.temperature)
 | 
			
		||||
            else:
 | 
			
		||||
                output_ids = greedy(logits)
 | 
			
		||||
            if self.device.type == 'xpu':
 | 
			
		||||
                torch.xpu.synchronize()
 | 
			
		||||
            toc = time.time()
 | 
			
		||||
            self.verify_time.append(toc - tic)
 | 
			
		||||
            self.generate_time.append(self.draft_time[-1] + self.verify_time[-1])
 | 
			
		||||
 | 
			
		||||
            # Compare drafts with target verified outputs
 | 
			
		||||
            # Drafts start from [1, k]
 | 
			
		||||
            # Verified output start from [0, k - 1]
 | 
			
		||||
            # including the one generated by the base model
 | 
			
		||||
            max_matched = ((output_ids[:, :-1] != drafted_input_ids[:, 1:]).cumsum(-1) == 0)
 | 
			
		||||
            max_matched = max_matched.sum(-1).item() + 1
 | 
			
		||||
            past_key_values = output['past_key_values']
 | 
			
		||||
 | 
			
		||||
            if generation_config.do_sample:
 | 
			
		||||
                draft_tokens = drafted_input_ids[:, 1:].squeeze(0)
 | 
			
		||||
                draft_probs = torch.stack(draft_prob_list).squeeze((1, 2))
 | 
			
		||||
 | 
			
		||||
                # q: target prob, p: draft prob
 | 
			
		||||
                # q >= p: always accept draft token
 | 
			
		||||
                # q < p: q/p prob to accept draft token
 | 
			
		||||
                p = draft_probs[torch.arange(0, drafted_n_tokens), draft_tokens]
 | 
			
		||||
                q = target_probs[torch.arange(0, drafted_n_tokens), draft_tokens]
 | 
			
		||||
                accept_draft_prob = torch.minimum(torch.ones(()), q[:drafted_n_tokens] / p)
 | 
			
		||||
                rejected_locations = (random_probs[:drafted_n_tokens] > accept_draft_prob).nonzero()
 | 
			
		||||
 | 
			
		||||
                if rejected_locations.shape[0] == 0:    # All draft tokens have been accepted
 | 
			
		||||
                    max_matched = drafted_n_tokens + 1
 | 
			
		||||
                    last_token = multinomial_sample_one_no_sync(target_probs[-1])
 | 
			
		||||
                    output_ids = torch.cat([draft_tokens, last_token])
 | 
			
		||||
                else:
 | 
			
		||||
                    max_matched = rejected_locations[0].item()
 | 
			
		||||
                    p = draft_probs[max_matched]
 | 
			
		||||
                    q = target_probs[max_matched]
 | 
			
		||||
                    resample_prob = q - p
 | 
			
		||||
                    resample_prob = torch.where(resample_prob > 0, resample_prob, 0.0)
 | 
			
		||||
                    resample_prob = resample_prob / resample_prob.sum()
 | 
			
		||||
                    next_token = multinomial_sample_one_no_sync(resample_prob)
 | 
			
		||||
                    output_ids = torch.cat([draft_tokens[:max_matched], next_token])
 | 
			
		||||
                    max_matched += 1
 | 
			
		||||
                output_ids = output_ids.unsqueeze(0)
 | 
			
		||||
            else:
 | 
			
		||||
                # Compare drafts with target verified outputs
 | 
			
		||||
                # Drafts start from [1, k]
 | 
			
		||||
                # Verified output start from [0, k - 1]
 | 
			
		||||
                # including the one generated by the base model
 | 
			
		||||
                max_matched = ((output_ids[:, :-1] != drafted_input_ids[:, 1:]).cumsum(-1) == 0)
 | 
			
		||||
                max_matched = max_matched.sum(-1).item() + 1
 | 
			
		||||
 | 
			
		||||
            max_of_max_matched = output_ids.size(1)
 | 
			
		||||
            # Accept number is max_matched, min is 1
 | 
			
		||||
            self.accept_num.append(max_matched)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue