[LLM] Avoid KV Cache OOM when seq len is larger than 1 (#10006)

* Avoid OOM during muti-round streaming chat with kv cache
* For llama like kv cache, i.e., [bs, n_head, seq_len, head_dim], use is_enough_kv_cache_room_4_31.
* Other models need to compare kv cache size with kv_len.
This commit is contained in:
Qiyuan Gong 2024-01-26 17:30:08 +08:00 committed by GitHub
parent e5ae6f2c13
commit 9e18ea187f
15 changed files with 52 additions and 30 deletions

View file

@ -42,7 +42,8 @@ import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from torch import nn from torch import nn
from bigdl.llm.transformers.models.utils import extend_kv_cache, init_kv_cache, append_kv_cache from bigdl.llm.transformers.models.utils import extend_kv_cache, init_kv_cache, \
append_kv_cache, is_enough_kv_cache_room_4_31
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
from bigdl.llm.utils.common import log4Error from bigdl.llm.utils.common import log4Error
@ -72,7 +73,9 @@ def aquila_attention_forward(
.transpose(1, 2) .transpose(1, 2)
kv_seq_len = key_states.shape[-2] kv_seq_len = key_states.shape[-2]
enough_kv_room = True
if past_key_value is not None: if past_key_value is not None:
enough_kv_room = is_enough_kv_cache_room_4_31(past_key_value, seq_len=kv_seq_len)
kv_seq_len += past_key_value[0].shape[-2] kv_seq_len += past_key_value[0].shape[-2]
if query_states.device.type == "xpu" and not (self.training and query_states.requires_grad): if query_states.device.type == "xpu" and not (self.training and query_states.requires_grad):
query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states, query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states,
@ -89,7 +92,7 @@ def aquila_attention_forward(
# reuse k, v, self_attention # reuse k, v, self_attention
cache_k = past_key_value[0] cache_k = past_key_value[0]
cache_v = past_key_value[1] cache_v = past_key_value[1]
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3): if not enough_kv_room:
# allocate new # allocate new
new_cache_k, new_cache_v = extend_kv_cache(bsz, new_cache_k, new_cache_v = extend_kv_cache(bsz,
self.num_heads, # Support GQA self.num_heads, # Support GQA

View file

@ -26,7 +26,8 @@ import torch.utils.checkpoint
from torch import nn from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from bigdl.llm.utils.common import invalidInputError from bigdl.llm.utils.common import invalidInputError
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, \
append_kv_cache, is_enough_kv_cache_room_4_31
from bigdl.llm.transformers.models.utils import rotate_half, apply_rotary_pos_emb from bigdl.llm.transformers.models.utils import rotate_half, apply_rotary_pos_emb
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
@ -55,7 +56,9 @@ def baichuan_attention_forward_7b(
value_states = proj[2].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) value_states = proj[2].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2] kv_seq_len = key_states.shape[-2]
enough_kv_room = True
if past_key_value is not None: if past_key_value is not None:
enough_kv_room = is_enough_kv_cache_room_4_31(past_key_value, seq_len=kv_seq_len)
kv_seq_len += past_key_value[0].shape[-2] kv_seq_len += past_key_value[0].shape[-2]
if query_states.device.type == "xpu" and not (self.training and query_states.requires_grad): if query_states.device.type == "xpu" and not (self.training and query_states.requires_grad):
query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states, query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states,
@ -76,7 +79,7 @@ def baichuan_attention_forward_7b(
# reuse k, v, self_attention # reuse k, v, self_attention
cache_k = past_key_value[0] cache_k = past_key_value[0]
cache_v = past_key_value[1] cache_v = past_key_value[1]
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3): if not enough_kv_room:
# allocate new # allocate new
new_cache_k, new_cache_v = extend_kv_cache(bsz, new_cache_k, new_cache_v = extend_kv_cache(bsz,
self.num_heads, self.num_heads,
@ -163,7 +166,9 @@ def baichuan_attention_forward_13b(
value_states = proj[2].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) value_states = proj[2].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2] kv_seq_len = key_states.shape[-2]
enough_kv_room = True
if past_key_value is not None: if past_key_value is not None:
enough_kv_room = is_enough_kv_cache_room_4_31(past_key_value, seq_len=kv_seq_len)
kv_seq_len += past_key_value[0].shape[-2] kv_seq_len += past_key_value[0].shape[-2]
# if past_key_value is not None: # if past_key_value is not None:
@ -174,7 +179,7 @@ def baichuan_attention_forward_13b(
# reuse k, v, self_attention # reuse k, v, self_attention
cache_k = past_key_value[0] cache_k = past_key_value[0]
cache_v = past_key_value[1] cache_v = past_key_value[1]
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3): if not enough_kv_room:
# allocate new # allocate new
new_cache_k, new_cache_v = extend_kv_cache(bsz, new_cache_k, new_cache_v = extend_kv_cache(bsz,
self.num_heads, self.num_heads,

View file

@ -24,7 +24,8 @@ import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from torch.nn import functional as F from torch.nn import functional as F
from bigdl.llm.ggml.quantize import ggml_tensor_qtype from bigdl.llm.ggml.quantize import ggml_tensor_qtype
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, \
append_kv_cache, is_enough_kv_cache_room_4_31
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
from bigdl.llm.transformers.models.utils import mlp_fusion_check from bigdl.llm.transformers.models.utils import mlp_fusion_check
@ -104,7 +105,9 @@ def baichuan_attention_forward_7b(
value_states = proj[2].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) value_states = proj[2].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2] kv_seq_len = key_states.shape[-2]
enough_kv_room = True
if past_key_value is not None: if past_key_value is not None:
enough_kv_room = is_enough_kv_cache_room_4_31(past_key_value, seq_len=kv_seq_len)
kv_seq_len += past_key_value[0].shape[-2] kv_seq_len += past_key_value[0].shape[-2]
if query_states.device.type == "xpu" and not (self.training and query_states.requires_grad): if query_states.device.type == "xpu" and not (self.training and query_states.requires_grad):
query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states, query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states,
@ -125,7 +128,7 @@ def baichuan_attention_forward_7b(
# reuse k, v, self_attention # reuse k, v, self_attention
cache_k = past_key_value[0] cache_k = past_key_value[0]
cache_v = past_key_value[1] cache_v = past_key_value[1]
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3): if not enough_kv_room:
# allocate new # allocate new
new_cache_k, new_cache_v = extend_kv_cache(bsz, new_cache_k, new_cache_v = extend_kv_cache(bsz,
self.num_heads, self.num_heads,
@ -216,7 +219,9 @@ def baichuan_attention_forward_13b(
) )
kv_seq_len = key_states.shape[-2] kv_seq_len = key_states.shape[-2]
enough_kv_room = True
if past_key_value is not None: if past_key_value is not None:
enough_kv_room = is_enough_kv_cache_room_4_31(past_key_value, seq_len=kv_seq_len)
kv_seq_len += past_key_value[0].shape[-2] kv_seq_len += past_key_value[0].shape[-2]
# if past_key_value is not None: # if past_key_value is not None:
@ -227,7 +232,7 @@ def baichuan_attention_forward_13b(
# reuse k, v, self_attention # reuse k, v, self_attention
cache_k = past_key_value[0] cache_k = past_key_value[0]
cache_v = past_key_value[1] cache_v = past_key_value[1]
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3): if not enough_kv_room:
if device.type == 'xpu': if device.type == 'xpu':
torch.xpu.empty_cache() torch.xpu.empty_cache()
# allocate new # allocate new

View file

@ -121,7 +121,7 @@ def bloom_attention_forward(
# reuse k, v, self_attention # reuse k, v, self_attention
cache_k = layer_past[0].transpose(1, 2).view(batch_size, self.num_heads, -1, self.head_dim) cache_k = layer_past[0].transpose(1, 2).view(batch_size, self.num_heads, -1, self.head_dim)
cache_v = layer_past[1].view(batch_size, self.num_heads, -1, self.head_dim) cache_v = layer_past[1].view(batch_size, self.num_heads, -1, self.head_dim)
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3): if cache_k.stride()[1] < kv_length * cache_k.size(3):
# allocate new # allocate new
new_cache_k, new_cache_v = extend_kv_cache( new_cache_k, new_cache_v = extend_kv_cache(
batch_size, batch_size,

View file

@ -66,7 +66,7 @@ def attention_fn(
cache_k = cache_k.permute(1, 2, 0, 3) cache_k = cache_k.permute(1, 2, 0, 3)
cache_v = cache_v.permute(1, 2, 0, 3) cache_v = cache_v.permute(1, 2, 0, 3)
past_length = cache_k.size(2) past_length = cache_k.size(2)
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3): if cache_k.stride()[1] < (past_length + cur_length) * cache_k.size(3):
max_cache_length = past_length + cur_length + KV_CACHE_ALLOC_BLOCK_LENGTH max_cache_length = past_length + cur_length + KV_CACHE_ALLOC_BLOCK_LENGTH
new_cache_k, new_cache_v = extend_kv_cache(batch_size, new_cache_k, new_cache_v = extend_kv_cache(batch_size,
self.num_attention_heads_per_partition, self.num_attention_heads_per_partition,

View file

@ -419,7 +419,7 @@ def chatglm2_attention_forward_8eb45c(
cache_v = cache_v.permute(1, 2, 0, 3) cache_v = cache_v.permute(1, 2, 0, 3)
past_length = cache_k.size(2) past_length = cache_k.size(2)
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3): if cache_k.stride()[1] < (past_length + cur_length) * cache_k.size(3):
max_cache_length = past_length + cur_length + KV_CACHE_ALLOC_BLOCK_LENGTH max_cache_length = past_length + cur_length + KV_CACHE_ALLOC_BLOCK_LENGTH
if device.type == "xpu" and batch_size > 1: # use beam_search for generation. if device.type == "xpu" and batch_size > 1: # use beam_search for generation.
# If batch_size > 1 on gpu, use init_kv_cache to avoid empty cache for ensuring # If batch_size > 1 on gpu, use init_kv_cache to avoid empty cache for ensuring

View file

@ -150,7 +150,7 @@ def chatglm2_32k_attention_forward(
cache_v = cache_v.permute(1, 2, 0, 3) cache_v = cache_v.permute(1, 2, 0, 3)
past_length = cache_k.size(2) past_length = cache_k.size(2)
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3): if cache_k.stride()[1] < (past_length + cur_length) * cache_k.size(3):
max_cache_length = past_length + cur_length + KV_CACHE_ALLOC_BLOCK_LENGTH max_cache_length = past_length + cur_length + KV_CACHE_ALLOC_BLOCK_LENGTH
new_cache_k, new_cache_v = extend_kv_cache(batch_size, new_cache_k, new_cache_v = extend_kv_cache(batch_size,
self.num_attention_heads_per_partition, self.num_attention_heads_per_partition,

View file

@ -97,7 +97,7 @@ def rw_attention_forward_7b(
# reuse k, v, self_attention # reuse k, v, self_attention
cache_k = layer_past[0].view(batch_size, self.num_kv, -1, self.head_dim) cache_k = layer_past[0].view(batch_size, self.num_kv, -1, self.head_dim)
cache_v = layer_past[1].view(batch_size, self.num_kv, -1, self.head_dim) cache_v = layer_past[1].view(batch_size, self.num_kv, -1, self.head_dim)
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3): if cache_k.stride()[1] < kv_length * cache_k.size(3):
# allocate new # allocate new
new_cache_k, new_cache_v = extend_kv_cache( new_cache_k, new_cache_v = extend_kv_cache(
batch_size, batch_size,
@ -276,7 +276,7 @@ def rw_attention_forward_40b(
# reuse k, v, self_attention # reuse k, v, self_attention
cache_k = layer_past[0].view(batch_size, self.num_heads, -1, self.head_dim) cache_k = layer_past[0].view(batch_size, self.num_heads, -1, self.head_dim)
cache_v = layer_past[1].view(batch_size, self.num_heads, -1, self.head_dim) cache_v = layer_past[1].view(batch_size, self.num_heads, -1, self.head_dim)
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3): if cache_k.stride()[1] < kv_length * cache_k.size(3):
# allocate new # allocate new
new_cache_k, new_cache_v = extend_kv_cache( new_cache_k, new_cache_v = extend_kv_cache(
batch_size, batch_size,
@ -450,7 +450,7 @@ def falcon_attention_forward(
# reuse k, v, self_attention # reuse k, v, self_attention
cache_k = layer_past[0].view(batch_size, num_kv_heads, -1, self.head_dim) cache_k = layer_past[0].view(batch_size, num_kv_heads, -1, self.head_dim)
cache_v = layer_past[1].view(batch_size, num_kv_heads, -1, self.head_dim) cache_v = layer_past[1].view(batch_size, num_kv_heads, -1, self.head_dim)
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3): if cache_k.stride()[1] < kv_length * cache_k.size(3):
# allocate new # allocate new
new_cache_k, new_cache_v = extend_kv_cache( new_cache_k, new_cache_v = extend_kv_cache(
batch_size, batch_size,

View file

@ -142,8 +142,7 @@ def gptj_attention_forward(
cache_k = cache_k.permute(0, 2, 1, 3) cache_k = cache_k.permute(0, 2, 1, 3)
cache_v = cache_v.permute(0, 2, 1, 3) cache_v = cache_v.permute(0, 2, 1, 3)
past_length = cache_k.size(2) past_length = cache_k.size(2)
if cache_k.stride()[1] < kv_seq_len * cache_k.size(3):
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
new_cache_k, new_cache_v = extend_kv_cache(batch_size, new_cache_k, new_cache_v = extend_kv_cache(batch_size,
self.num_attention_heads, self.num_attention_heads,
self.head_dim, self.head_dim,

View file

@ -34,7 +34,8 @@
import torch import torch
from typing import Optional, Tuple from typing import Optional, Tuple
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, \
append_kv_cache, is_enough_kv_cache_room_4_31
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
@ -79,7 +80,9 @@ def gptneox_attention_forward(
# Compute token offset for rotary embeddings (when decoding) # Compute token offset for rotary embeddings (when decoding)
seq_len = key.shape[-2] seq_len = key.shape[-2]
enough_kv_room = True
if has_layer_past: if has_layer_past:
enough_kv_room = is_enough_kv_cache_room_4_31(layer_past, seq_len=seq_len)
seq_len += layer_past[0].shape[-2] seq_len += layer_past[0].shape[-2]
use_fuse_rope = query.device.type == "xpu" use_fuse_rope = query.device.type == "xpu"
@ -101,7 +104,7 @@ def gptneox_attention_forward(
if has_layer_past: if has_layer_past:
past_key = layer_past[0] past_key = layer_past[0]
past_value = layer_past[1] past_value = layer_past[1]
if past_key.stride()[1] <= past_key.size(2) * past_key.size(3): if not enough_kv_room:
# allocate new # allocate new
new_past_key, new_past_value = extend_kv_cache(bsz, new_past_key, new_past_value = extend_kv_cache(bsz,
self.num_attention_heads, self.num_attention_heads,

View file

@ -43,7 +43,8 @@ import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from torch import nn from torch import nn
from bigdl.llm.utils.common import invalidInputError from bigdl.llm.utils.common import invalidInputError
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, \
append_kv_cache, is_enough_kv_cache_room_4_31
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
@ -73,7 +74,9 @@ def internlm_attention_forward(
.transpose(1, 2) .transpose(1, 2)
kv_seq_len = key_states.shape[-2] kv_seq_len = key_states.shape[-2]
enough_kv_room = True
if past_key_value is not None: if past_key_value is not None:
enough_kv_room = is_enough_kv_cache_room_4_31(past_key_value, seq_len=kv_seq_len)
kv_seq_len += past_key_value[0].shape[-2] kv_seq_len += past_key_value[0].shape[-2]
if query_states.device.type == "xpu" and not (self.training and query_states.requires_grad): if query_states.device.type == "xpu" and not (self.training and query_states.requires_grad):
query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states, query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states,
@ -95,7 +98,7 @@ def internlm_attention_forward(
# reuse k, v, self_attention # reuse k, v, self_attention
cache_k = past_key_value[0] cache_k = past_key_value[0]
cache_v = past_key_value[1] cache_v = past_key_value[1]
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3): if not enough_kv_room:
# allocate new # allocate new
new_cache_k, new_cache_v = extend_kv_cache( new_cache_k, new_cache_v = extend_kv_cache(
bsz, bsz,

View file

@ -77,7 +77,7 @@ def mpt_scaled_multihead_dot_product_attention(query, key, value, n_heads,
cache_k = past_key_value[0].transpose(2, 3) cache_k = past_key_value[0].transpose(2, 3)
cache_v = past_key_value[1] cache_v = past_key_value[1]
kv_seq_len += cache_k.shape[-2] kv_seq_len += cache_k.shape[-2]
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3): if cache_k.stride()[1] < kv_seq_len * cache_k.size(3):
# allocate new # allocate new
new_cache_k, new_cache_v = extend_kv_cache(bsz, new_cache_k, new_cache_v = extend_kv_cache(bsz,
kv_n_heads, # Support GQA kv_n_heads, # Support GQA

View file

@ -89,7 +89,7 @@ def qwen_attention_forward_vl(
# value = torch.cat((past_value, value), dim=1) # value = torch.cat((past_value, value), dim=1)
cache_k = layer_past[0].transpose(1, 2) cache_k = layer_past[0].transpose(1, 2)
cache_v = layer_past[1].transpose(1, 2) cache_v = layer_past[1].transpose(1, 2)
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3): if cache_k.stride()[1] < kv_seq_len * cache_k.size(3):
# allocate new # allocate new
new_cache_k, new_cache_v = extend_kv_cache(bsz, new_cache_k, new_cache_v = extend_kv_cache(bsz,
self.num_heads, self.num_heads,

View file

@ -60,7 +60,7 @@ def append_kv_cache(cache_k, cache_v, key_states, value_states):
new_cache_k = cache_k.as_strided(new_size, cache_k.stride(), storage_offset=0) new_cache_k = cache_k.as_strided(new_size, cache_k.stride(), storage_offset=0)
new_cache_k[:, :, cache_k.size(2):cache_k.size(2) + key_states.size(2), :] = key_states new_cache_k[:, :, cache_k.size(2):cache_k.size(2) + key_states.size(2), :] = key_states
new_cache_v = cache_v.as_strided(new_size, cache_v.stride(), storage_offset=0) new_cache_v = cache_v.as_strided(new_size, cache_v.stride(), storage_offset=0)
new_cache_v[:, :, cache_v.size(2):cache_k.size(2) + key_states.size(2), :] = value_states new_cache_v[:, :, cache_v.size(2):cache_v.size(2) + key_states.size(2), :] = value_states
return new_cache_k, new_cache_v return new_cache_k, new_cache_v
@ -194,17 +194,21 @@ def apply_rotary_pos_emb_cache_freq_xpu(q, k, sin, cos, model_family):
def is_enough_kv_cache_room_4_36(past_key_value, idx, seq_len=1): def is_enough_kv_cache_room_4_36(past_key_value, idx, seq_len=1):
# to determinate if is enough kv cache room in transformers==4.36 # to determinate if is enough kv cache room in transformers==4.36
# seq_len for current seq len
# For llama like kv cache, i.e., [bs, n_head, seq_len, head_dim]
return past_key_value is not None and len(past_key_value.key_cache) > idx and \ return past_key_value is not None and len(past_key_value.key_cache) > idx and \
past_key_value.key_cache[idx].stride()[1] > \ past_key_value.key_cache[idx].stride()[1] >= \
(past_key_value.key_cache[idx].size(2) + seq_len - 1) * \ (past_key_value.key_cache[idx].size(2) + seq_len) * \
past_key_value.key_cache[idx].size(3) past_key_value.key_cache[idx].size(3)
def is_enough_kv_cache_room_4_31(past_key_value, seq_len=1): def is_enough_kv_cache_room_4_31(past_key_value, seq_len=1):
# to determinate if is enough kv cache room in transformers between 4.31 and 4.35 # to determinate if is enough kv cache room in transformers between 4.31 and 4.35
# seq_len for current seq len
# For llama like kv cache, i.e., [bs, n_head, seq_len, head_dim]
return past_key_value is not None and \ return past_key_value is not None and \
past_key_value[0].stride()[1] > \ past_key_value[0].stride()[1] >= \
(past_key_value[0].size(2) + seq_len - 1) * past_key_value[0].size(3) (past_key_value[0].size(2) + seq_len) * past_key_value[0].size(3)
def use_flash_attention(query, key): def use_flash_attention(query, key):