Optimize speculative decoding PVC memory usage (#10329)
* optimize memory * update * update * update * support other models * update * fix style
This commit is contained in:
		
							parent
							
								
									cc796848ea
								
							
						
					
					
						commit
						9ea499ca68
					
				
					 1 changed files with 60 additions and 2 deletions
				
			
		| 
						 | 
				
			
			@ -24,8 +24,8 @@ import time
 | 
			
		|||
import os
 | 
			
		||||
import copy
 | 
			
		||||
import logging
 | 
			
		||||
import warnings
 | 
			
		||||
import inspect
 | 
			
		||||
import transformers
 | 
			
		||||
from packaging import version
 | 
			
		||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
 | 
			
		||||
from transformers import top_k_top_p_filtering, GenerationConfig, \
 | 
			
		||||
    LogitsProcessorList, StoppingCriteriaList
 | 
			
		||||
| 
						 | 
				
			
			@ -367,6 +367,55 @@ def _update_past_key_values_storage_cpu(self, past_key_values, past_key_values_s
 | 
			
		|||
                    delta_past_value.to(torch.float32)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _check_and_extend_kv_cache(past_key_values, max_step_draft, kv_alloc_block_len=256,
 | 
			
		||||
                               model_type="llama"):
 | 
			
		||||
    from bigdl.llm.transformers.models.utils import is_enough_kv_cache_room_4_31, \
 | 
			
		||||
        extend_kv_cache
 | 
			
		||||
    enough_kv_room = True
 | 
			
		||||
    if model_type not in ["chatglm", "qwen", "baichuan", "llama", "mistral",
 | 
			
		||||
                          "gptj", "opt"]:
 | 
			
		||||
        return past_key_values, False
 | 
			
		||||
    cache_k = past_key_values[0][0]
 | 
			
		||||
    if model_type == "chatglm":
 | 
			
		||||
        cache_k = cache_k.permute(1, 2, 0, 3)
 | 
			
		||||
    elif model_type == "qwen":
 | 
			
		||||
        cache_k = cache_k.transpose(1, 2)
 | 
			
		||||
 | 
			
		||||
    enough_kv_room = is_enough_kv_cache_room_4_31(past_key_value=(cache_k, None),
 | 
			
		||||
                                                  seq_len=max_step_draft)
 | 
			
		||||
    bsz, num_heads, current_seq_len, head_dim = cache_k.shape
 | 
			
		||||
    device = past_key_values[0][0].device
 | 
			
		||||
    if not enough_kv_room:
 | 
			
		||||
        past_key_values = list(past_key_values)
 | 
			
		||||
        for i in range(len(past_key_values)):
 | 
			
		||||
            cache_k = past_key_values[i][0]
 | 
			
		||||
            cache_v = past_key_values[i][1]
 | 
			
		||||
            if model_type == "chatglm":
 | 
			
		||||
                cache_k = cache_k.permute(1, 2, 0, 3)
 | 
			
		||||
                cache_v = cache_v.permute(1, 2, 0, 3)
 | 
			
		||||
            elif model_type == "qwen":
 | 
			
		||||
                cache_k = cache_k.transpose(1, 2)
 | 
			
		||||
                cache_v = cache_v.transpose(1, 2)
 | 
			
		||||
            new_cache_k, new_cache_v = extend_kv_cache(
 | 
			
		||||
                bsz,
 | 
			
		||||
                num_heads,  # Support GQA
 | 
			
		||||
                head_dim,
 | 
			
		||||
                cache_k.size(2),
 | 
			
		||||
                current_seq_len + max_step_draft + kv_alloc_block_len,
 | 
			
		||||
                dtype=cache_v.dtype,
 | 
			
		||||
                device=device)
 | 
			
		||||
            new_cache_k[:] = cache_k
 | 
			
		||||
            new_cache_v[:] = cache_v
 | 
			
		||||
            if model_type == "chatglm":
 | 
			
		||||
                past_key_values[i] = (new_cache_k.permute(2, 0, 1, 3),
 | 
			
		||||
                                      new_cache_v.permute(2, 0, 1, 3))
 | 
			
		||||
            elif model_type == "qwen":
 | 
			
		||||
                past_key_values[i] = (new_cache_k.transpose(1, 2), new_cache_v.transpose(1, 2))
 | 
			
		||||
            else:
 | 
			
		||||
                past_key_values[i] = (new_cache_k, new_cache_v)
 | 
			
		||||
    return past_key_values, not enough_kv_room
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@torch.no_grad()
 | 
			
		||||
def speculative_generate(self,
 | 
			
		||||
                         inputs: Optional[torch.Tensor] = None,
 | 
			
		||||
| 
						 | 
				
			
			@ -504,6 +553,9 @@ def speculative_generate(self,
 | 
			
		|||
 | 
			
		||||
    self.clear_benchmarks()
 | 
			
		||||
 | 
			
		||||
    if self.device.type == 'xpu':
 | 
			
		||||
        torch.xpu.empty_cache()
 | 
			
		||||
 | 
			
		||||
    # Example:
 | 
			
		||||
    # Target model forward for the first token
 | 
			
		||||
    # Step 1. target_model(prompt) -> a
 | 
			
		||||
| 
						 | 
				
			
			@ -562,6 +614,10 @@ def speculative_generate(self,
 | 
			
		|||
                                                       past_key_values_storage,  _enable_ipex)
 | 
			
		||||
                original_draft_past_key_values = draft_past_key_values
 | 
			
		||||
            else:
 | 
			
		||||
                past_key_values, extend_kv = _check_and_extend_kv_cache(past_key_values,
 | 
			
		||||
                                                                        max_step_draft,
 | 
			
		||||
                                                                        max_new_tokens - step + 40,
 | 
			
		||||
                                                                        self.config.model_type)
 | 
			
		||||
                draft_past_key_values = past_key_values
 | 
			
		||||
            draft_generate_ids[:, 0] = current_input_ids
 | 
			
		||||
            draft_prob_list = []
 | 
			
		||||
| 
						 | 
				
			
			@ -742,6 +798,8 @@ def speculative_generate(self,
 | 
			
		|||
                output_ids = greedy(logits)
 | 
			
		||||
            if self.device.type == 'xpu':
 | 
			
		||||
                torch.xpu.synchronize()
 | 
			
		||||
                if extend_kv:
 | 
			
		||||
                    torch.xpu.empty_cache()
 | 
			
		||||
            toc = time.time()
 | 
			
		||||
            self.verify_time.append(toc - tic)
 | 
			
		||||
            self.generate_time.append(self.draft_time[-1] + self.verify_time[-1])
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue