[LLM] Avoid KV Cache OOM when seq len is larger than 1 (#10006)

* Avoid OOM during muti-round streaming chat with kv cache
* For llama like kv cache, i.e., [bs, n_head, seq_len, head_dim], use is_enough_kv_cache_room_4_31.
* Other models need to compare kv cache size with kv_len.
This commit is contained in:
Qiyuan Gong 2024-01-26 17:30:08 +08:00 committed by GitHub
parent e5ae6f2c13
commit 9e18ea187f
15 changed files with 52 additions and 30 deletions

View file

@ -42,7 +42,8 @@ import torch
import torch.utils.checkpoint
from torch import nn
from bigdl.llm.transformers.models.utils import extend_kv_cache, init_kv_cache, append_kv_cache
from bigdl.llm.transformers.models.utils import extend_kv_cache, init_kv_cache, \
append_kv_cache, is_enough_kv_cache_room_4_31
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
from bigdl.llm.utils.common import log4Error
@ -72,7 +73,9 @@ def aquila_attention_forward(
.transpose(1, 2)
kv_seq_len = key_states.shape[-2]
enough_kv_room = True
if past_key_value is not None:
enough_kv_room = is_enough_kv_cache_room_4_31(past_key_value, seq_len=kv_seq_len)
kv_seq_len += past_key_value[0].shape[-2]
if query_states.device.type == "xpu" and not (self.training and query_states.requires_grad):
query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states,
@ -89,7 +92,7 @@ def aquila_attention_forward(
# reuse k, v, self_attention
cache_k = past_key_value[0]
cache_v = past_key_value[1]
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
if not enough_kv_room:
# allocate new
new_cache_k, new_cache_v = extend_kv_cache(bsz,
self.num_heads, # Support GQA

View file

@ -26,7 +26,8 @@ import torch.utils.checkpoint
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from bigdl.llm.utils.common import invalidInputError
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, \
append_kv_cache, is_enough_kv_cache_room_4_31
from bigdl.llm.transformers.models.utils import rotate_half, apply_rotary_pos_emb
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
@ -55,7 +56,9 @@ def baichuan_attention_forward_7b(
value_states = proj[2].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
enough_kv_room = True
if past_key_value is not None:
enough_kv_room = is_enough_kv_cache_room_4_31(past_key_value, seq_len=kv_seq_len)
kv_seq_len += past_key_value[0].shape[-2]
if query_states.device.type == "xpu" and not (self.training and query_states.requires_grad):
query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states,
@ -76,7 +79,7 @@ def baichuan_attention_forward_7b(
# reuse k, v, self_attention
cache_k = past_key_value[0]
cache_v = past_key_value[1]
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
if not enough_kv_room:
# allocate new
new_cache_k, new_cache_v = extend_kv_cache(bsz,
self.num_heads,
@ -163,7 +166,9 @@ def baichuan_attention_forward_13b(
value_states = proj[2].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
enough_kv_room = True
if past_key_value is not None:
enough_kv_room = is_enough_kv_cache_room_4_31(past_key_value, seq_len=kv_seq_len)
kv_seq_len += past_key_value[0].shape[-2]
# if past_key_value is not None:
@ -174,7 +179,7 @@ def baichuan_attention_forward_13b(
# reuse k, v, self_attention
cache_k = past_key_value[0]
cache_v = past_key_value[1]
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
if not enough_kv_room:
# allocate new
new_cache_k, new_cache_v = extend_kv_cache(bsz,
self.num_heads,

View file

@ -24,7 +24,8 @@ import torch
import torch.utils.checkpoint
from torch.nn import functional as F
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, \
append_kv_cache, is_enough_kv_cache_room_4_31
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
from bigdl.llm.transformers.models.utils import mlp_fusion_check
@ -104,7 +105,9 @@ def baichuan_attention_forward_7b(
value_states = proj[2].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
enough_kv_room = True
if past_key_value is not None:
enough_kv_room = is_enough_kv_cache_room_4_31(past_key_value, seq_len=kv_seq_len)
kv_seq_len += past_key_value[0].shape[-2]
if query_states.device.type == "xpu" and not (self.training and query_states.requires_grad):
query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states,
@ -125,7 +128,7 @@ def baichuan_attention_forward_7b(
# reuse k, v, self_attention
cache_k = past_key_value[0]
cache_v = past_key_value[1]
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
if not enough_kv_room:
# allocate new
new_cache_k, new_cache_v = extend_kv_cache(bsz,
self.num_heads,
@ -216,7 +219,9 @@ def baichuan_attention_forward_13b(
)
kv_seq_len = key_states.shape[-2]
enough_kv_room = True
if past_key_value is not None:
enough_kv_room = is_enough_kv_cache_room_4_31(past_key_value, seq_len=kv_seq_len)
kv_seq_len += past_key_value[0].shape[-2]
# if past_key_value is not None:
@ -227,7 +232,7 @@ def baichuan_attention_forward_13b(
# reuse k, v, self_attention
cache_k = past_key_value[0]
cache_v = past_key_value[1]
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
if not enough_kv_room:
if device.type == 'xpu':
torch.xpu.empty_cache()
# allocate new

View file

@ -121,7 +121,7 @@ def bloom_attention_forward(
# reuse k, v, self_attention
cache_k = layer_past[0].transpose(1, 2).view(batch_size, self.num_heads, -1, self.head_dim)
cache_v = layer_past[1].view(batch_size, self.num_heads, -1, self.head_dim)
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
if cache_k.stride()[1] < kv_length * cache_k.size(3):
# allocate new
new_cache_k, new_cache_v = extend_kv_cache(
batch_size,

View file

@ -66,7 +66,7 @@ def attention_fn(
cache_k = cache_k.permute(1, 2, 0, 3)
cache_v = cache_v.permute(1, 2, 0, 3)
past_length = cache_k.size(2)
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
if cache_k.stride()[1] < (past_length + cur_length) * cache_k.size(3):
max_cache_length = past_length + cur_length + KV_CACHE_ALLOC_BLOCK_LENGTH
new_cache_k, new_cache_v = extend_kv_cache(batch_size,
self.num_attention_heads_per_partition,

View file

@ -419,7 +419,7 @@ def chatglm2_attention_forward_8eb45c(
cache_v = cache_v.permute(1, 2, 0, 3)
past_length = cache_k.size(2)
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
if cache_k.stride()[1] < (past_length + cur_length) * cache_k.size(3):
max_cache_length = past_length + cur_length + KV_CACHE_ALLOC_BLOCK_LENGTH
if device.type == "xpu" and batch_size > 1: # use beam_search for generation.
# If batch_size > 1 on gpu, use init_kv_cache to avoid empty cache for ensuring

View file

@ -150,7 +150,7 @@ def chatglm2_32k_attention_forward(
cache_v = cache_v.permute(1, 2, 0, 3)
past_length = cache_k.size(2)
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
if cache_k.stride()[1] < (past_length + cur_length) * cache_k.size(3):
max_cache_length = past_length + cur_length + KV_CACHE_ALLOC_BLOCK_LENGTH
new_cache_k, new_cache_v = extend_kv_cache(batch_size,
self.num_attention_heads_per_partition,

View file

@ -97,7 +97,7 @@ def rw_attention_forward_7b(
# reuse k, v, self_attention
cache_k = layer_past[0].view(batch_size, self.num_kv, -1, self.head_dim)
cache_v = layer_past[1].view(batch_size, self.num_kv, -1, self.head_dim)
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
if cache_k.stride()[1] < kv_length * cache_k.size(3):
# allocate new
new_cache_k, new_cache_v = extend_kv_cache(
batch_size,
@ -276,7 +276,7 @@ def rw_attention_forward_40b(
# reuse k, v, self_attention
cache_k = layer_past[0].view(batch_size, self.num_heads, -1, self.head_dim)
cache_v = layer_past[1].view(batch_size, self.num_heads, -1, self.head_dim)
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
if cache_k.stride()[1] < kv_length * cache_k.size(3):
# allocate new
new_cache_k, new_cache_v = extend_kv_cache(
batch_size,
@ -450,7 +450,7 @@ def falcon_attention_forward(
# reuse k, v, self_attention
cache_k = layer_past[0].view(batch_size, num_kv_heads, -1, self.head_dim)
cache_v = layer_past[1].view(batch_size, num_kv_heads, -1, self.head_dim)
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
if cache_k.stride()[1] < kv_length * cache_k.size(3):
# allocate new
new_cache_k, new_cache_v = extend_kv_cache(
batch_size,

View file

@ -142,8 +142,7 @@ def gptj_attention_forward(
cache_k = cache_k.permute(0, 2, 1, 3)
cache_v = cache_v.permute(0, 2, 1, 3)
past_length = cache_k.size(2)
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
if cache_k.stride()[1] < kv_seq_len * cache_k.size(3):
new_cache_k, new_cache_v = extend_kv_cache(batch_size,
self.num_attention_heads,
self.head_dim,

View file

@ -34,7 +34,8 @@
import torch
from typing import Optional, Tuple
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, \
append_kv_cache, is_enough_kv_cache_room_4_31
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
@ -79,7 +80,9 @@ def gptneox_attention_forward(
# Compute token offset for rotary embeddings (when decoding)
seq_len = key.shape[-2]
enough_kv_room = True
if has_layer_past:
enough_kv_room = is_enough_kv_cache_room_4_31(layer_past, seq_len=seq_len)
seq_len += layer_past[0].shape[-2]
use_fuse_rope = query.device.type == "xpu"
@ -101,7 +104,7 @@ def gptneox_attention_forward(
if has_layer_past:
past_key = layer_past[0]
past_value = layer_past[1]
if past_key.stride()[1] <= past_key.size(2) * past_key.size(3):
if not enough_kv_room:
# allocate new
new_past_key, new_past_value = extend_kv_cache(bsz,
self.num_attention_heads,

View file

@ -43,7 +43,8 @@ import torch
import torch.utils.checkpoint
from torch import nn
from bigdl.llm.utils.common import invalidInputError
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, \
append_kv_cache, is_enough_kv_cache_room_4_31
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
@ -73,7 +74,9 @@ def internlm_attention_forward(
.transpose(1, 2)
kv_seq_len = key_states.shape[-2]
enough_kv_room = True
if past_key_value is not None:
enough_kv_room = is_enough_kv_cache_room_4_31(past_key_value, seq_len=kv_seq_len)
kv_seq_len += past_key_value[0].shape[-2]
if query_states.device.type == "xpu" and not (self.training and query_states.requires_grad):
query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states,
@ -95,7 +98,7 @@ def internlm_attention_forward(
# reuse k, v, self_attention
cache_k = past_key_value[0]
cache_v = past_key_value[1]
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
if not enough_kv_room:
# allocate new
new_cache_k, new_cache_v = extend_kv_cache(
bsz,

View file

@ -43,9 +43,9 @@ from torch import nn
import torch.nn.functional as F
from bigdl.llm.utils.common import invalidInputError
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb,\
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb, \
apply_rotary_pos_emb_no_cache_xpu
from bigdl.llm.transformers.models.utils import is_enough_kv_cache_room_4_31,\
from bigdl.llm.transformers.models.utils import is_enough_kv_cache_room_4_31, \
is_enough_kv_cache_room_4_36
from bigdl.llm.transformers.low_bit_linear import SYM_INT4, FP8E5
from bigdl.llm.transformers.models.utils import use_flash_attention

View file

@ -77,7 +77,7 @@ def mpt_scaled_multihead_dot_product_attention(query, key, value, n_heads,
cache_k = past_key_value[0].transpose(2, 3)
cache_v = past_key_value[1]
kv_seq_len += cache_k.shape[-2]
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
if cache_k.stride()[1] < kv_seq_len * cache_k.size(3):
# allocate new
new_cache_k, new_cache_v = extend_kv_cache(bsz,
kv_n_heads, # Support GQA

View file

@ -89,7 +89,7 @@ def qwen_attention_forward_vl(
# value = torch.cat((past_value, value), dim=1)
cache_k = layer_past[0].transpose(1, 2)
cache_v = layer_past[1].transpose(1, 2)
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
if cache_k.stride()[1] < kv_seq_len * cache_k.size(3):
# allocate new
new_cache_k, new_cache_v = extend_kv_cache(bsz,
self.num_heads,

View file

@ -60,7 +60,7 @@ def append_kv_cache(cache_k, cache_v, key_states, value_states):
new_cache_k = cache_k.as_strided(new_size, cache_k.stride(), storage_offset=0)
new_cache_k[:, :, cache_k.size(2):cache_k.size(2) + key_states.size(2), :] = key_states
new_cache_v = cache_v.as_strided(new_size, cache_v.stride(), storage_offset=0)
new_cache_v[:, :, cache_v.size(2):cache_k.size(2) + key_states.size(2), :] = value_states
new_cache_v[:, :, cache_v.size(2):cache_v.size(2) + key_states.size(2), :] = value_states
return new_cache_k, new_cache_v
@ -194,17 +194,21 @@ def apply_rotary_pos_emb_cache_freq_xpu(q, k, sin, cos, model_family):
def is_enough_kv_cache_room_4_36(past_key_value, idx, seq_len=1):
# to determinate if is enough kv cache room in transformers==4.36
# seq_len for current seq len
# For llama like kv cache, i.e., [bs, n_head, seq_len, head_dim]
return past_key_value is not None and len(past_key_value.key_cache) > idx and \
past_key_value.key_cache[idx].stride()[1] > \
(past_key_value.key_cache[idx].size(2) + seq_len - 1) * \
past_key_value.key_cache[idx].stride()[1] >= \
(past_key_value.key_cache[idx].size(2) + seq_len) * \
past_key_value.key_cache[idx].size(3)
def is_enough_kv_cache_room_4_31(past_key_value, seq_len=1):
# to determinate if is enough kv cache room in transformers between 4.31 and 4.35
# seq_len for current seq len
# For llama like kv cache, i.e., [bs, n_head, seq_len, head_dim]
return past_key_value is not None and \
past_key_value[0].stride()[1] > \
(past_key_value[0].size(2) + seq_len - 1) * past_key_value[0].size(3)
past_key_value[0].stride()[1] >= \
(past_key_value[0].size(2) + seq_len) * past_key_value[0].size(3)
def use_flash_attention(query, key):