[LLM] Avoid KV Cache OOM when seq len is larger than 1 (#10006)
* Avoid OOM during muti-round streaming chat with kv cache * For llama like kv cache, i.e., [bs, n_head, seq_len, head_dim], use is_enough_kv_cache_room_4_31. * Other models need to compare kv cache size with kv_len.
This commit is contained in:
		
							parent
							
								
									e5ae6f2c13
								
							
						
					
					
						commit
						9e18ea187f
					
				
					 15 changed files with 52 additions and 30 deletions
				
			
		| 
						 | 
				
			
			@ -42,7 +42,8 @@ import torch
 | 
			
		|||
import torch.utils.checkpoint
 | 
			
		||||
from torch import nn
 | 
			
		||||
 | 
			
		||||
from bigdl.llm.transformers.models.utils import extend_kv_cache, init_kv_cache, append_kv_cache
 | 
			
		||||
from bigdl.llm.transformers.models.utils import extend_kv_cache, init_kv_cache, \
 | 
			
		||||
    append_kv_cache, is_enough_kv_cache_room_4_31
 | 
			
		||||
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb
 | 
			
		||||
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
 | 
			
		||||
from bigdl.llm.utils.common import log4Error
 | 
			
		||||
| 
						 | 
				
			
			@ -72,7 +73,9 @@ def aquila_attention_forward(
 | 
			
		|||
        .transpose(1, 2)
 | 
			
		||||
 | 
			
		||||
    kv_seq_len = key_states.shape[-2]
 | 
			
		||||
    enough_kv_room = True
 | 
			
		||||
    if past_key_value is not None:
 | 
			
		||||
        enough_kv_room = is_enough_kv_cache_room_4_31(past_key_value, seq_len=kv_seq_len)
 | 
			
		||||
        kv_seq_len += past_key_value[0].shape[-2]
 | 
			
		||||
    if query_states.device.type == "xpu" and not (self.training and query_states.requires_grad):
 | 
			
		||||
        query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states,
 | 
			
		||||
| 
						 | 
				
			
			@ -89,7 +92,7 @@ def aquila_attention_forward(
 | 
			
		|||
        # reuse k, v, self_attention
 | 
			
		||||
        cache_k = past_key_value[0]
 | 
			
		||||
        cache_v = past_key_value[1]
 | 
			
		||||
        if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
 | 
			
		||||
        if not enough_kv_room:
 | 
			
		||||
            # allocate new
 | 
			
		||||
            new_cache_k, new_cache_v = extend_kv_cache(bsz,
 | 
			
		||||
                                                       self.num_heads,  # Support GQA
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -26,7 +26,8 @@ import torch.utils.checkpoint
 | 
			
		|||
from torch import nn
 | 
			
		||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
 | 
			
		||||
from bigdl.llm.utils.common import invalidInputError
 | 
			
		||||
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
 | 
			
		||||
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, \
 | 
			
		||||
    append_kv_cache, is_enough_kv_cache_room_4_31
 | 
			
		||||
from bigdl.llm.transformers.models.utils import rotate_half, apply_rotary_pos_emb
 | 
			
		||||
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -55,7 +56,9 @@ def baichuan_attention_forward_7b(
 | 
			
		|||
    value_states = proj[2].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
 | 
			
		||||
 | 
			
		||||
    kv_seq_len = key_states.shape[-2]
 | 
			
		||||
    enough_kv_room = True
 | 
			
		||||
    if past_key_value is not None:
 | 
			
		||||
        enough_kv_room = is_enough_kv_cache_room_4_31(past_key_value, seq_len=kv_seq_len)
 | 
			
		||||
        kv_seq_len += past_key_value[0].shape[-2]
 | 
			
		||||
    if query_states.device.type == "xpu" and not (self.training and query_states.requires_grad):
 | 
			
		||||
        query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states,
 | 
			
		||||
| 
						 | 
				
			
			@ -76,7 +79,7 @@ def baichuan_attention_forward_7b(
 | 
			
		|||
        # reuse k, v, self_attention
 | 
			
		||||
        cache_k = past_key_value[0]
 | 
			
		||||
        cache_v = past_key_value[1]
 | 
			
		||||
        if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
 | 
			
		||||
        if not enough_kv_room:
 | 
			
		||||
            # allocate new
 | 
			
		||||
            new_cache_k, new_cache_v = extend_kv_cache(bsz,
 | 
			
		||||
                                                       self.num_heads,
 | 
			
		||||
| 
						 | 
				
			
			@ -163,7 +166,9 @@ def baichuan_attention_forward_13b(
 | 
			
		|||
    value_states = proj[2].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
 | 
			
		||||
 | 
			
		||||
    kv_seq_len = key_states.shape[-2]
 | 
			
		||||
    enough_kv_room = True
 | 
			
		||||
    if past_key_value is not None:
 | 
			
		||||
        enough_kv_room = is_enough_kv_cache_room_4_31(past_key_value, seq_len=kv_seq_len)
 | 
			
		||||
        kv_seq_len += past_key_value[0].shape[-2]
 | 
			
		||||
 | 
			
		||||
    # if past_key_value is not None:
 | 
			
		||||
| 
						 | 
				
			
			@ -174,7 +179,7 @@ def baichuan_attention_forward_13b(
 | 
			
		|||
        # reuse k, v, self_attention
 | 
			
		||||
        cache_k = past_key_value[0]
 | 
			
		||||
        cache_v = past_key_value[1]
 | 
			
		||||
        if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
 | 
			
		||||
        if not enough_kv_room:
 | 
			
		||||
            # allocate new
 | 
			
		||||
            new_cache_k, new_cache_v = extend_kv_cache(bsz,
 | 
			
		||||
                                                       self.num_heads,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -24,7 +24,8 @@ import torch
 | 
			
		|||
import torch.utils.checkpoint
 | 
			
		||||
from torch.nn import functional as F
 | 
			
		||||
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
 | 
			
		||||
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
 | 
			
		||||
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, \
 | 
			
		||||
    append_kv_cache, is_enough_kv_cache_room_4_31
 | 
			
		||||
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb
 | 
			
		||||
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
 | 
			
		||||
from bigdl.llm.transformers.models.utils import mlp_fusion_check
 | 
			
		||||
| 
						 | 
				
			
			@ -104,7 +105,9 @@ def baichuan_attention_forward_7b(
 | 
			
		|||
    value_states = proj[2].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
 | 
			
		||||
 | 
			
		||||
    kv_seq_len = key_states.shape[-2]
 | 
			
		||||
    enough_kv_room = True
 | 
			
		||||
    if past_key_value is not None:
 | 
			
		||||
        enough_kv_room = is_enough_kv_cache_room_4_31(past_key_value, seq_len=kv_seq_len)
 | 
			
		||||
        kv_seq_len += past_key_value[0].shape[-2]
 | 
			
		||||
    if query_states.device.type == "xpu" and not (self.training and query_states.requires_grad):
 | 
			
		||||
        query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states,
 | 
			
		||||
| 
						 | 
				
			
			@ -125,7 +128,7 @@ def baichuan_attention_forward_7b(
 | 
			
		|||
        # reuse k, v, self_attention
 | 
			
		||||
        cache_k = past_key_value[0]
 | 
			
		||||
        cache_v = past_key_value[1]
 | 
			
		||||
        if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
 | 
			
		||||
        if not enough_kv_room:
 | 
			
		||||
            # allocate new
 | 
			
		||||
            new_cache_k, new_cache_v = extend_kv_cache(bsz,
 | 
			
		||||
                                                       self.num_heads,
 | 
			
		||||
| 
						 | 
				
			
			@ -216,7 +219,9 @@ def baichuan_attention_forward_13b(
 | 
			
		|||
    )
 | 
			
		||||
 | 
			
		||||
    kv_seq_len = key_states.shape[-2]
 | 
			
		||||
    enough_kv_room = True
 | 
			
		||||
    if past_key_value is not None:
 | 
			
		||||
        enough_kv_room = is_enough_kv_cache_room_4_31(past_key_value, seq_len=kv_seq_len)
 | 
			
		||||
        kv_seq_len += past_key_value[0].shape[-2]
 | 
			
		||||
 | 
			
		||||
    # if past_key_value is not None:
 | 
			
		||||
| 
						 | 
				
			
			@ -227,7 +232,7 @@ def baichuan_attention_forward_13b(
 | 
			
		|||
        # reuse k, v, self_attention
 | 
			
		||||
        cache_k = past_key_value[0]
 | 
			
		||||
        cache_v = past_key_value[1]
 | 
			
		||||
        if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
 | 
			
		||||
        if not enough_kv_room:
 | 
			
		||||
            if device.type == 'xpu':
 | 
			
		||||
                torch.xpu.empty_cache()
 | 
			
		||||
            # allocate new
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -121,7 +121,7 @@ def bloom_attention_forward(
 | 
			
		|||
        # reuse k, v, self_attention
 | 
			
		||||
        cache_k = layer_past[0].transpose(1, 2).view(batch_size, self.num_heads, -1, self.head_dim)
 | 
			
		||||
        cache_v = layer_past[1].view(batch_size, self.num_heads, -1, self.head_dim)
 | 
			
		||||
        if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
 | 
			
		||||
        if cache_k.stride()[1] < kv_length * cache_k.size(3):
 | 
			
		||||
            # allocate new
 | 
			
		||||
            new_cache_k, new_cache_v = extend_kv_cache(
 | 
			
		||||
                batch_size,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -66,7 +66,7 @@ def attention_fn(
 | 
			
		|||
        cache_k = cache_k.permute(1, 2, 0, 3)
 | 
			
		||||
        cache_v = cache_v.permute(1, 2, 0, 3)
 | 
			
		||||
        past_length = cache_k.size(2)
 | 
			
		||||
        if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
 | 
			
		||||
        if cache_k.stride()[1] < (past_length + cur_length) * cache_k.size(3):
 | 
			
		||||
            max_cache_length = past_length + cur_length + KV_CACHE_ALLOC_BLOCK_LENGTH
 | 
			
		||||
            new_cache_k, new_cache_v = extend_kv_cache(batch_size,
 | 
			
		||||
                                                       self.num_attention_heads_per_partition,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -419,7 +419,7 @@ def chatglm2_attention_forward_8eb45c(
 | 
			
		|||
        cache_v = cache_v.permute(1, 2, 0, 3)
 | 
			
		||||
        past_length = cache_k.size(2)
 | 
			
		||||
 | 
			
		||||
        if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
 | 
			
		||||
        if cache_k.stride()[1] < (past_length + cur_length) * cache_k.size(3):
 | 
			
		||||
            max_cache_length = past_length + cur_length + KV_CACHE_ALLOC_BLOCK_LENGTH
 | 
			
		||||
            if device.type == "xpu" and batch_size > 1:  # use beam_search for generation.
 | 
			
		||||
                # If batch_size > 1 on gpu, use init_kv_cache to avoid empty cache for ensuring
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -150,7 +150,7 @@ def chatglm2_32k_attention_forward(
 | 
			
		|||
        cache_v = cache_v.permute(1, 2, 0, 3)
 | 
			
		||||
        past_length = cache_k.size(2)
 | 
			
		||||
 | 
			
		||||
        if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
 | 
			
		||||
        if cache_k.stride()[1] < (past_length + cur_length) * cache_k.size(3):
 | 
			
		||||
            max_cache_length = past_length + cur_length + KV_CACHE_ALLOC_BLOCK_LENGTH
 | 
			
		||||
            new_cache_k, new_cache_v = extend_kv_cache(batch_size,
 | 
			
		||||
                                                       self.num_attention_heads_per_partition,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -97,7 +97,7 @@ def rw_attention_forward_7b(
 | 
			
		|||
        # reuse k, v, self_attention
 | 
			
		||||
        cache_k = layer_past[0].view(batch_size, self.num_kv, -1, self.head_dim)
 | 
			
		||||
        cache_v = layer_past[1].view(batch_size, self.num_kv, -1, self.head_dim)
 | 
			
		||||
        if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
 | 
			
		||||
        if cache_k.stride()[1] < kv_length * cache_k.size(3):
 | 
			
		||||
            # allocate new
 | 
			
		||||
            new_cache_k, new_cache_v = extend_kv_cache(
 | 
			
		||||
                batch_size,
 | 
			
		||||
| 
						 | 
				
			
			@ -276,7 +276,7 @@ def rw_attention_forward_40b(
 | 
			
		|||
        # reuse k, v, self_attention
 | 
			
		||||
        cache_k = layer_past[0].view(batch_size, self.num_heads, -1, self.head_dim)
 | 
			
		||||
        cache_v = layer_past[1].view(batch_size, self.num_heads, -1, self.head_dim)
 | 
			
		||||
        if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
 | 
			
		||||
        if cache_k.stride()[1] < kv_length * cache_k.size(3):
 | 
			
		||||
            # allocate new
 | 
			
		||||
            new_cache_k, new_cache_v = extend_kv_cache(
 | 
			
		||||
                batch_size,
 | 
			
		||||
| 
						 | 
				
			
			@ -450,7 +450,7 @@ def falcon_attention_forward(
 | 
			
		|||
        # reuse k, v, self_attention
 | 
			
		||||
        cache_k = layer_past[0].view(batch_size, num_kv_heads, -1, self.head_dim)
 | 
			
		||||
        cache_v = layer_past[1].view(batch_size, num_kv_heads, -1, self.head_dim)
 | 
			
		||||
        if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
 | 
			
		||||
        if cache_k.stride()[1] < kv_length * cache_k.size(3):
 | 
			
		||||
            # allocate new
 | 
			
		||||
            new_cache_k, new_cache_v = extend_kv_cache(
 | 
			
		||||
                batch_size,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -142,8 +142,7 @@ def gptj_attention_forward(
 | 
			
		|||
        cache_k = cache_k.permute(0, 2, 1, 3)
 | 
			
		||||
        cache_v = cache_v.permute(0, 2, 1, 3)
 | 
			
		||||
        past_length = cache_k.size(2)
 | 
			
		||||
 | 
			
		||||
        if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
 | 
			
		||||
        if cache_k.stride()[1] < kv_seq_len * cache_k.size(3):
 | 
			
		||||
            new_cache_k, new_cache_v = extend_kv_cache(batch_size,
 | 
			
		||||
                                                       self.num_attention_heads,
 | 
			
		||||
                                                       self.head_dim,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -34,7 +34,8 @@
 | 
			
		|||
import torch
 | 
			
		||||
from typing import Optional, Tuple
 | 
			
		||||
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb
 | 
			
		||||
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
 | 
			
		||||
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, \
 | 
			
		||||
    append_kv_cache, is_enough_kv_cache_room_4_31
 | 
			
		||||
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -79,7 +80,9 @@ def gptneox_attention_forward(
 | 
			
		|||
 | 
			
		||||
    # Compute token offset for rotary embeddings (when decoding)
 | 
			
		||||
    seq_len = key.shape[-2]
 | 
			
		||||
    enough_kv_room = True
 | 
			
		||||
    if has_layer_past:
 | 
			
		||||
        enough_kv_room = is_enough_kv_cache_room_4_31(layer_past, seq_len=seq_len)
 | 
			
		||||
        seq_len += layer_past[0].shape[-2]
 | 
			
		||||
 | 
			
		||||
    use_fuse_rope = query.device.type == "xpu"
 | 
			
		||||
| 
						 | 
				
			
			@ -101,7 +104,7 @@ def gptneox_attention_forward(
 | 
			
		|||
    if has_layer_past:
 | 
			
		||||
        past_key = layer_past[0]
 | 
			
		||||
        past_value = layer_past[1]
 | 
			
		||||
        if past_key.stride()[1] <= past_key.size(2) * past_key.size(3):
 | 
			
		||||
        if not enough_kv_room:
 | 
			
		||||
            # allocate new
 | 
			
		||||
            new_past_key, new_past_value = extend_kv_cache(bsz,
 | 
			
		||||
                                                           self.num_attention_heads,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -43,7 +43,8 @@ import torch
 | 
			
		|||
import torch.utils.checkpoint
 | 
			
		||||
from torch import nn
 | 
			
		||||
from bigdl.llm.utils.common import invalidInputError
 | 
			
		||||
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
 | 
			
		||||
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, \
 | 
			
		||||
    append_kv_cache, is_enough_kv_cache_room_4_31
 | 
			
		||||
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb
 | 
			
		||||
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -73,7 +74,9 @@ def internlm_attention_forward(
 | 
			
		|||
        .transpose(1, 2)
 | 
			
		||||
 | 
			
		||||
    kv_seq_len = key_states.shape[-2]
 | 
			
		||||
    enough_kv_room = True
 | 
			
		||||
    if past_key_value is not None:
 | 
			
		||||
        enough_kv_room = is_enough_kv_cache_room_4_31(past_key_value, seq_len=kv_seq_len)
 | 
			
		||||
        kv_seq_len += past_key_value[0].shape[-2]
 | 
			
		||||
    if query_states.device.type == "xpu" and not (self.training and query_states.requires_grad):
 | 
			
		||||
        query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states,
 | 
			
		||||
| 
						 | 
				
			
			@ -95,7 +98,7 @@ def internlm_attention_forward(
 | 
			
		|||
        # reuse k, v, self_attention
 | 
			
		||||
        cache_k = past_key_value[0]
 | 
			
		||||
        cache_v = past_key_value[1]
 | 
			
		||||
        if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
 | 
			
		||||
        if not enough_kv_room:
 | 
			
		||||
            # allocate new
 | 
			
		||||
            new_cache_k, new_cache_v = extend_kv_cache(
 | 
			
		||||
                bsz,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -43,9 +43,9 @@ from torch import nn
 | 
			
		|||
import torch.nn.functional as F
 | 
			
		||||
from bigdl.llm.utils.common import invalidInputError
 | 
			
		||||
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
 | 
			
		||||
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb,\
 | 
			
		||||
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb, \
 | 
			
		||||
    apply_rotary_pos_emb_no_cache_xpu
 | 
			
		||||
from bigdl.llm.transformers.models.utils import is_enough_kv_cache_room_4_31,\
 | 
			
		||||
from bigdl.llm.transformers.models.utils import is_enough_kv_cache_room_4_31, \
 | 
			
		||||
    is_enough_kv_cache_room_4_36
 | 
			
		||||
from bigdl.llm.transformers.low_bit_linear import SYM_INT4, FP8E5
 | 
			
		||||
from bigdl.llm.transformers.models.utils import use_flash_attention
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -77,7 +77,7 @@ def mpt_scaled_multihead_dot_product_attention(query, key, value, n_heads,
 | 
			
		|||
            cache_k = past_key_value[0].transpose(2, 3)
 | 
			
		||||
            cache_v = past_key_value[1]
 | 
			
		||||
            kv_seq_len += cache_k.shape[-2]
 | 
			
		||||
            if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
 | 
			
		||||
            if cache_k.stride()[1] < kv_seq_len * cache_k.size(3):
 | 
			
		||||
                # allocate new
 | 
			
		||||
                new_cache_k, new_cache_v = extend_kv_cache(bsz,
 | 
			
		||||
                                                           kv_n_heads,  # Support GQA
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -89,7 +89,7 @@ def qwen_attention_forward_vl(
 | 
			
		|||
        # value = torch.cat((past_value, value), dim=1)
 | 
			
		||||
        cache_k = layer_past[0].transpose(1, 2)
 | 
			
		||||
        cache_v = layer_past[1].transpose(1, 2)
 | 
			
		||||
        if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
 | 
			
		||||
        if cache_k.stride()[1] < kv_seq_len * cache_k.size(3):
 | 
			
		||||
            # allocate new
 | 
			
		||||
            new_cache_k, new_cache_v = extend_kv_cache(bsz,
 | 
			
		||||
                                                       self.num_heads,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -60,7 +60,7 @@ def append_kv_cache(cache_k, cache_v, key_states, value_states):
 | 
			
		|||
    new_cache_k = cache_k.as_strided(new_size, cache_k.stride(), storage_offset=0)
 | 
			
		||||
    new_cache_k[:, :, cache_k.size(2):cache_k.size(2) + key_states.size(2), :] = key_states
 | 
			
		||||
    new_cache_v = cache_v.as_strided(new_size, cache_v.stride(), storage_offset=0)
 | 
			
		||||
    new_cache_v[:, :, cache_v.size(2):cache_k.size(2) + key_states.size(2), :] = value_states
 | 
			
		||||
    new_cache_v[:, :, cache_v.size(2):cache_v.size(2) + key_states.size(2), :] = value_states
 | 
			
		||||
    return new_cache_k, new_cache_v
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -194,17 +194,21 @@ def apply_rotary_pos_emb_cache_freq_xpu(q, k, sin, cos, model_family):
 | 
			
		|||
 | 
			
		||||
def is_enough_kv_cache_room_4_36(past_key_value, idx, seq_len=1):
 | 
			
		||||
    # to determinate if is enough kv cache room in transformers==4.36
 | 
			
		||||
    # seq_len for current seq len
 | 
			
		||||
    # For llama like kv cache, i.e., [bs, n_head, seq_len, head_dim]
 | 
			
		||||
    return past_key_value is not None and len(past_key_value.key_cache) > idx and \
 | 
			
		||||
        past_key_value.key_cache[idx].stride()[1] > \
 | 
			
		||||
        (past_key_value.key_cache[idx].size(2) + seq_len - 1) * \
 | 
			
		||||
        past_key_value.key_cache[idx].stride()[1] >= \
 | 
			
		||||
        (past_key_value.key_cache[idx].size(2) + seq_len) * \
 | 
			
		||||
        past_key_value.key_cache[idx].size(3)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def is_enough_kv_cache_room_4_31(past_key_value, seq_len=1):
 | 
			
		||||
    # to determinate if is enough kv cache room in transformers between 4.31 and 4.35
 | 
			
		||||
    # seq_len for current seq len
 | 
			
		||||
    # For llama like kv cache, i.e., [bs, n_head, seq_len, head_dim]
 | 
			
		||||
    return past_key_value is not None and \
 | 
			
		||||
        past_key_value[0].stride()[1] > \
 | 
			
		||||
        (past_key_value[0].size(2) + seq_len - 1) * past_key_value[0].size(3)
 | 
			
		||||
        past_key_value[0].stride()[1] >= \
 | 
			
		||||
        (past_key_value[0].size(2) + seq_len) * past_key_value[0].size(3)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def use_flash_attention(query, key):
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue