[LLM] Avoid KV Cache OOM when seq len is larger than 1 (#10006)
* Avoid OOM during muti-round streaming chat with kv cache * For llama like kv cache, i.e., [bs, n_head, seq_len, head_dim], use is_enough_kv_cache_room_4_31. * Other models need to compare kv cache size with kv_len.
This commit is contained in:
parent
e5ae6f2c13
commit
9e18ea187f
15 changed files with 52 additions and 30 deletions
|
|
@ -42,7 +42,8 @@ import torch
|
|||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
|
||||
from bigdl.llm.transformers.models.utils import extend_kv_cache, init_kv_cache, append_kv_cache
|
||||
from bigdl.llm.transformers.models.utils import extend_kv_cache, init_kv_cache, \
|
||||
append_kv_cache, is_enough_kv_cache_room_4_31
|
||||
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb
|
||||
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
|
||||
from bigdl.llm.utils.common import log4Error
|
||||
|
|
@ -72,7 +73,9 @@ def aquila_attention_forward(
|
|||
.transpose(1, 2)
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
enough_kv_room = True
|
||||
if past_key_value is not None:
|
||||
enough_kv_room = is_enough_kv_cache_room_4_31(past_key_value, seq_len=kv_seq_len)
|
||||
kv_seq_len += past_key_value[0].shape[-2]
|
||||
if query_states.device.type == "xpu" and not (self.training and query_states.requires_grad):
|
||||
query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states,
|
||||
|
|
@ -89,7 +92,7 @@ def aquila_attention_forward(
|
|||
# reuse k, v, self_attention
|
||||
cache_k = past_key_value[0]
|
||||
cache_v = past_key_value[1]
|
||||
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
|
||||
if not enough_kv_room:
|
||||
# allocate new
|
||||
new_cache_k, new_cache_v = extend_kv_cache(bsz,
|
||||
self.num_heads, # Support GQA
|
||||
|
|
|
|||
|
|
@ -26,7 +26,8 @@ import torch.utils.checkpoint
|
|||
from torch import nn
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
from bigdl.llm.utils.common import invalidInputError
|
||||
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
|
||||
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, \
|
||||
append_kv_cache, is_enough_kv_cache_room_4_31
|
||||
from bigdl.llm.transformers.models.utils import rotate_half, apply_rotary_pos_emb
|
||||
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
|
||||
|
||||
|
|
@ -55,7 +56,9 @@ def baichuan_attention_forward_7b(
|
|||
value_states = proj[2].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
enough_kv_room = True
|
||||
if past_key_value is not None:
|
||||
enough_kv_room = is_enough_kv_cache_room_4_31(past_key_value, seq_len=kv_seq_len)
|
||||
kv_seq_len += past_key_value[0].shape[-2]
|
||||
if query_states.device.type == "xpu" and not (self.training and query_states.requires_grad):
|
||||
query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states,
|
||||
|
|
@ -76,7 +79,7 @@ def baichuan_attention_forward_7b(
|
|||
# reuse k, v, self_attention
|
||||
cache_k = past_key_value[0]
|
||||
cache_v = past_key_value[1]
|
||||
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
|
||||
if not enough_kv_room:
|
||||
# allocate new
|
||||
new_cache_k, new_cache_v = extend_kv_cache(bsz,
|
||||
self.num_heads,
|
||||
|
|
@ -163,7 +166,9 @@ def baichuan_attention_forward_13b(
|
|||
value_states = proj[2].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
enough_kv_room = True
|
||||
if past_key_value is not None:
|
||||
enough_kv_room = is_enough_kv_cache_room_4_31(past_key_value, seq_len=kv_seq_len)
|
||||
kv_seq_len += past_key_value[0].shape[-2]
|
||||
|
||||
# if past_key_value is not None:
|
||||
|
|
@ -174,7 +179,7 @@ def baichuan_attention_forward_13b(
|
|||
# reuse k, v, self_attention
|
||||
cache_k = past_key_value[0]
|
||||
cache_v = past_key_value[1]
|
||||
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
|
||||
if not enough_kv_room:
|
||||
# allocate new
|
||||
new_cache_k, new_cache_v = extend_kv_cache(bsz,
|
||||
self.num_heads,
|
||||
|
|
|
|||
|
|
@ -24,7 +24,8 @@ import torch
|
|||
import torch.utils.checkpoint
|
||||
from torch.nn import functional as F
|
||||
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
|
||||
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
|
||||
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, \
|
||||
append_kv_cache, is_enough_kv_cache_room_4_31
|
||||
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb
|
||||
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
|
||||
from bigdl.llm.transformers.models.utils import mlp_fusion_check
|
||||
|
|
@ -104,7 +105,9 @@ def baichuan_attention_forward_7b(
|
|||
value_states = proj[2].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
enough_kv_room = True
|
||||
if past_key_value is not None:
|
||||
enough_kv_room = is_enough_kv_cache_room_4_31(past_key_value, seq_len=kv_seq_len)
|
||||
kv_seq_len += past_key_value[0].shape[-2]
|
||||
if query_states.device.type == "xpu" and not (self.training and query_states.requires_grad):
|
||||
query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states,
|
||||
|
|
@ -125,7 +128,7 @@ def baichuan_attention_forward_7b(
|
|||
# reuse k, v, self_attention
|
||||
cache_k = past_key_value[0]
|
||||
cache_v = past_key_value[1]
|
||||
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
|
||||
if not enough_kv_room:
|
||||
# allocate new
|
||||
new_cache_k, new_cache_v = extend_kv_cache(bsz,
|
||||
self.num_heads,
|
||||
|
|
@ -216,7 +219,9 @@ def baichuan_attention_forward_13b(
|
|||
)
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
enough_kv_room = True
|
||||
if past_key_value is not None:
|
||||
enough_kv_room = is_enough_kv_cache_room_4_31(past_key_value, seq_len=kv_seq_len)
|
||||
kv_seq_len += past_key_value[0].shape[-2]
|
||||
|
||||
# if past_key_value is not None:
|
||||
|
|
@ -227,7 +232,7 @@ def baichuan_attention_forward_13b(
|
|||
# reuse k, v, self_attention
|
||||
cache_k = past_key_value[0]
|
||||
cache_v = past_key_value[1]
|
||||
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
|
||||
if not enough_kv_room:
|
||||
if device.type == 'xpu':
|
||||
torch.xpu.empty_cache()
|
||||
# allocate new
|
||||
|
|
|
|||
|
|
@ -121,7 +121,7 @@ def bloom_attention_forward(
|
|||
# reuse k, v, self_attention
|
||||
cache_k = layer_past[0].transpose(1, 2).view(batch_size, self.num_heads, -1, self.head_dim)
|
||||
cache_v = layer_past[1].view(batch_size, self.num_heads, -1, self.head_dim)
|
||||
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
|
||||
if cache_k.stride()[1] < kv_length * cache_k.size(3):
|
||||
# allocate new
|
||||
new_cache_k, new_cache_v = extend_kv_cache(
|
||||
batch_size,
|
||||
|
|
|
|||
|
|
@ -66,7 +66,7 @@ def attention_fn(
|
|||
cache_k = cache_k.permute(1, 2, 0, 3)
|
||||
cache_v = cache_v.permute(1, 2, 0, 3)
|
||||
past_length = cache_k.size(2)
|
||||
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
|
||||
if cache_k.stride()[1] < (past_length + cur_length) * cache_k.size(3):
|
||||
max_cache_length = past_length + cur_length + KV_CACHE_ALLOC_BLOCK_LENGTH
|
||||
new_cache_k, new_cache_v = extend_kv_cache(batch_size,
|
||||
self.num_attention_heads_per_partition,
|
||||
|
|
|
|||
|
|
@ -419,7 +419,7 @@ def chatglm2_attention_forward_8eb45c(
|
|||
cache_v = cache_v.permute(1, 2, 0, 3)
|
||||
past_length = cache_k.size(2)
|
||||
|
||||
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
|
||||
if cache_k.stride()[1] < (past_length + cur_length) * cache_k.size(3):
|
||||
max_cache_length = past_length + cur_length + KV_CACHE_ALLOC_BLOCK_LENGTH
|
||||
if device.type == "xpu" and batch_size > 1: # use beam_search for generation.
|
||||
# If batch_size > 1 on gpu, use init_kv_cache to avoid empty cache for ensuring
|
||||
|
|
|
|||
|
|
@ -150,7 +150,7 @@ def chatglm2_32k_attention_forward(
|
|||
cache_v = cache_v.permute(1, 2, 0, 3)
|
||||
past_length = cache_k.size(2)
|
||||
|
||||
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
|
||||
if cache_k.stride()[1] < (past_length + cur_length) * cache_k.size(3):
|
||||
max_cache_length = past_length + cur_length + KV_CACHE_ALLOC_BLOCK_LENGTH
|
||||
new_cache_k, new_cache_v = extend_kv_cache(batch_size,
|
||||
self.num_attention_heads_per_partition,
|
||||
|
|
|
|||
|
|
@ -97,7 +97,7 @@ def rw_attention_forward_7b(
|
|||
# reuse k, v, self_attention
|
||||
cache_k = layer_past[0].view(batch_size, self.num_kv, -1, self.head_dim)
|
||||
cache_v = layer_past[1].view(batch_size, self.num_kv, -1, self.head_dim)
|
||||
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
|
||||
if cache_k.stride()[1] < kv_length * cache_k.size(3):
|
||||
# allocate new
|
||||
new_cache_k, new_cache_v = extend_kv_cache(
|
||||
batch_size,
|
||||
|
|
@ -276,7 +276,7 @@ def rw_attention_forward_40b(
|
|||
# reuse k, v, self_attention
|
||||
cache_k = layer_past[0].view(batch_size, self.num_heads, -1, self.head_dim)
|
||||
cache_v = layer_past[1].view(batch_size, self.num_heads, -1, self.head_dim)
|
||||
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
|
||||
if cache_k.stride()[1] < kv_length * cache_k.size(3):
|
||||
# allocate new
|
||||
new_cache_k, new_cache_v = extend_kv_cache(
|
||||
batch_size,
|
||||
|
|
@ -450,7 +450,7 @@ def falcon_attention_forward(
|
|||
# reuse k, v, self_attention
|
||||
cache_k = layer_past[0].view(batch_size, num_kv_heads, -1, self.head_dim)
|
||||
cache_v = layer_past[1].view(batch_size, num_kv_heads, -1, self.head_dim)
|
||||
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
|
||||
if cache_k.stride()[1] < kv_length * cache_k.size(3):
|
||||
# allocate new
|
||||
new_cache_k, new_cache_v = extend_kv_cache(
|
||||
batch_size,
|
||||
|
|
|
|||
|
|
@ -142,8 +142,7 @@ def gptj_attention_forward(
|
|||
cache_k = cache_k.permute(0, 2, 1, 3)
|
||||
cache_v = cache_v.permute(0, 2, 1, 3)
|
||||
past_length = cache_k.size(2)
|
||||
|
||||
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
|
||||
if cache_k.stride()[1] < kv_seq_len * cache_k.size(3):
|
||||
new_cache_k, new_cache_v = extend_kv_cache(batch_size,
|
||||
self.num_attention_heads,
|
||||
self.head_dim,
|
||||
|
|
|
|||
|
|
@ -34,7 +34,8 @@
|
|||
import torch
|
||||
from typing import Optional, Tuple
|
||||
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb
|
||||
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
|
||||
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, \
|
||||
append_kv_cache, is_enough_kv_cache_room_4_31
|
||||
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
|
||||
|
||||
|
||||
|
|
@ -79,7 +80,9 @@ def gptneox_attention_forward(
|
|||
|
||||
# Compute token offset for rotary embeddings (when decoding)
|
||||
seq_len = key.shape[-2]
|
||||
enough_kv_room = True
|
||||
if has_layer_past:
|
||||
enough_kv_room = is_enough_kv_cache_room_4_31(layer_past, seq_len=seq_len)
|
||||
seq_len += layer_past[0].shape[-2]
|
||||
|
||||
use_fuse_rope = query.device.type == "xpu"
|
||||
|
|
@ -101,7 +104,7 @@ def gptneox_attention_forward(
|
|||
if has_layer_past:
|
||||
past_key = layer_past[0]
|
||||
past_value = layer_past[1]
|
||||
if past_key.stride()[1] <= past_key.size(2) * past_key.size(3):
|
||||
if not enough_kv_room:
|
||||
# allocate new
|
||||
new_past_key, new_past_value = extend_kv_cache(bsz,
|
||||
self.num_attention_heads,
|
||||
|
|
|
|||
|
|
@ -43,7 +43,8 @@ import torch
|
|||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
from bigdl.llm.utils.common import invalidInputError
|
||||
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
|
||||
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, \
|
||||
append_kv_cache, is_enough_kv_cache_room_4_31
|
||||
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb
|
||||
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
|
||||
|
||||
|
|
@ -73,7 +74,9 @@ def internlm_attention_forward(
|
|||
.transpose(1, 2)
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
enough_kv_room = True
|
||||
if past_key_value is not None:
|
||||
enough_kv_room = is_enough_kv_cache_room_4_31(past_key_value, seq_len=kv_seq_len)
|
||||
kv_seq_len += past_key_value[0].shape[-2]
|
||||
if query_states.device.type == "xpu" and not (self.training and query_states.requires_grad):
|
||||
query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states,
|
||||
|
|
@ -95,7 +98,7 @@ def internlm_attention_forward(
|
|||
# reuse k, v, self_attention
|
||||
cache_k = past_key_value[0]
|
||||
cache_v = past_key_value[1]
|
||||
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
|
||||
if not enough_kv_room:
|
||||
# allocate new
|
||||
new_cache_k, new_cache_v = extend_kv_cache(
|
||||
bsz,
|
||||
|
|
|
|||
|
|
@ -77,7 +77,7 @@ def mpt_scaled_multihead_dot_product_attention(query, key, value, n_heads,
|
|||
cache_k = past_key_value[0].transpose(2, 3)
|
||||
cache_v = past_key_value[1]
|
||||
kv_seq_len += cache_k.shape[-2]
|
||||
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
|
||||
if cache_k.stride()[1] < kv_seq_len * cache_k.size(3):
|
||||
# allocate new
|
||||
new_cache_k, new_cache_v = extend_kv_cache(bsz,
|
||||
kv_n_heads, # Support GQA
|
||||
|
|
|
|||
|
|
@ -89,7 +89,7 @@ def qwen_attention_forward_vl(
|
|||
# value = torch.cat((past_value, value), dim=1)
|
||||
cache_k = layer_past[0].transpose(1, 2)
|
||||
cache_v = layer_past[1].transpose(1, 2)
|
||||
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
|
||||
if cache_k.stride()[1] < kv_seq_len * cache_k.size(3):
|
||||
# allocate new
|
||||
new_cache_k, new_cache_v = extend_kv_cache(bsz,
|
||||
self.num_heads,
|
||||
|
|
|
|||
|
|
@ -60,7 +60,7 @@ def append_kv_cache(cache_k, cache_v, key_states, value_states):
|
|||
new_cache_k = cache_k.as_strided(new_size, cache_k.stride(), storage_offset=0)
|
||||
new_cache_k[:, :, cache_k.size(2):cache_k.size(2) + key_states.size(2), :] = key_states
|
||||
new_cache_v = cache_v.as_strided(new_size, cache_v.stride(), storage_offset=0)
|
||||
new_cache_v[:, :, cache_v.size(2):cache_k.size(2) + key_states.size(2), :] = value_states
|
||||
new_cache_v[:, :, cache_v.size(2):cache_v.size(2) + key_states.size(2), :] = value_states
|
||||
return new_cache_k, new_cache_v
|
||||
|
||||
|
||||
|
|
@ -194,17 +194,21 @@ def apply_rotary_pos_emb_cache_freq_xpu(q, k, sin, cos, model_family):
|
|||
|
||||
def is_enough_kv_cache_room_4_36(past_key_value, idx, seq_len=1):
|
||||
# to determinate if is enough kv cache room in transformers==4.36
|
||||
# seq_len for current seq len
|
||||
# For llama like kv cache, i.e., [bs, n_head, seq_len, head_dim]
|
||||
return past_key_value is not None and len(past_key_value.key_cache) > idx and \
|
||||
past_key_value.key_cache[idx].stride()[1] > \
|
||||
(past_key_value.key_cache[idx].size(2) + seq_len - 1) * \
|
||||
past_key_value.key_cache[idx].stride()[1] >= \
|
||||
(past_key_value.key_cache[idx].size(2) + seq_len) * \
|
||||
past_key_value.key_cache[idx].size(3)
|
||||
|
||||
|
||||
def is_enough_kv_cache_room_4_31(past_key_value, seq_len=1):
|
||||
# to determinate if is enough kv cache room in transformers between 4.31 and 4.35
|
||||
# seq_len for current seq len
|
||||
# For llama like kv cache, i.e., [bs, n_head, seq_len, head_dim]
|
||||
return past_key_value is not None and \
|
||||
past_key_value[0].stride()[1] > \
|
||||
(past_key_value[0].size(2) + seq_len - 1) * past_key_value[0].size(3)
|
||||
past_key_value[0].stride()[1] >= \
|
||||
(past_key_value[0].size(2) + seq_len) * past_key_value[0].size(3)
|
||||
|
||||
|
||||
def use_flash_attention(query, key):
|
||||
|
|
|
|||
Loading…
Reference in a new issue