From 9e18ea187f6c9a6e11236941f2c4046f42931724 Mon Sep 17 00:00:00 2001 From: Qiyuan Gong Date: Fri, 26 Jan 2024 17:30:08 +0800 Subject: [PATCH] [LLM] Avoid KV Cache OOM when seq len is larger than 1 (#10006) * Avoid OOM during muti-round streaming chat with kv cache * For llama like kv cache, i.e., [bs, n_head, seq_len, head_dim], use is_enough_kv_cache_room_4_31. * Other models need to compare kv cache size with kv_len. --- .../src/bigdl/llm/transformers/models/aquila.py | 7 +++++-- .../src/bigdl/llm/transformers/models/baichuan.py | 11 ++++++++--- .../src/bigdl/llm/transformers/models/baichuan2.py | 11 ++++++++--- .../llm/src/bigdl/llm/transformers/models/bloom.py | 2 +- .../src/bigdl/llm/transformers/models/chatglm.py | 2 +- .../src/bigdl/llm/transformers/models/chatglm2.py | 2 +- .../bigdl/llm/transformers/models/chatglm2_32k.py | 2 +- .../src/bigdl/llm/transformers/models/falcon.py | 6 +++--- .../llm/src/bigdl/llm/transformers/models/gptj.py | 3 +-- .../src/bigdl/llm/transformers/models/gptneox.py | 7 +++++-- .../src/bigdl/llm/transformers/models/internlm.py | 7 +++++-- .../src/bigdl/llm/transformers/models/mistral.py | 4 ++-- .../llm/src/bigdl/llm/transformers/models/mpt.py | 2 +- .../src/bigdl/llm/transformers/models/qwen_vl.py | 2 +- .../llm/src/bigdl/llm/transformers/models/utils.py | 14 +++++++++----- 15 files changed, 52 insertions(+), 30 deletions(-) diff --git a/python/llm/src/bigdl/llm/transformers/models/aquila.py b/python/llm/src/bigdl/llm/transformers/models/aquila.py index 417b20e7..68ca7a01 100644 --- a/python/llm/src/bigdl/llm/transformers/models/aquila.py +++ b/python/llm/src/bigdl/llm/transformers/models/aquila.py @@ -42,7 +42,8 @@ import torch import torch.utils.checkpoint from torch import nn -from bigdl.llm.transformers.models.utils import extend_kv_cache, init_kv_cache, append_kv_cache +from bigdl.llm.transformers.models.utils import extend_kv_cache, init_kv_cache, \ + append_kv_cache, is_enough_kv_cache_room_4_31 from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu from bigdl.llm.utils.common import log4Error @@ -72,7 +73,9 @@ def aquila_attention_forward( .transpose(1, 2) kv_seq_len = key_states.shape[-2] + enough_kv_room = True if past_key_value is not None: + enough_kv_room = is_enough_kv_cache_room_4_31(past_key_value, seq_len=kv_seq_len) kv_seq_len += past_key_value[0].shape[-2] if query_states.device.type == "xpu" and not (self.training and query_states.requires_grad): query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states, @@ -89,7 +92,7 @@ def aquila_attention_forward( # reuse k, v, self_attention cache_k = past_key_value[0] cache_v = past_key_value[1] - if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3): + if not enough_kv_room: # allocate new new_cache_k, new_cache_v = extend_kv_cache(bsz, self.num_heads, # Support GQA diff --git a/python/llm/src/bigdl/llm/transformers/models/baichuan.py b/python/llm/src/bigdl/llm/transformers/models/baichuan.py index 2bd8f550..c20ed3d2 100644 --- a/python/llm/src/bigdl/llm/transformers/models/baichuan.py +++ b/python/llm/src/bigdl/llm/transformers/models/baichuan.py @@ -26,7 +26,8 @@ import torch.utils.checkpoint from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from bigdl.llm.utils.common import invalidInputError -from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache +from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, \ + append_kv_cache, is_enough_kv_cache_room_4_31 from bigdl.llm.transformers.models.utils import rotate_half, apply_rotary_pos_emb from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu @@ -55,7 +56,9 @@ def baichuan_attention_forward_7b( value_states = proj[2].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) kv_seq_len = key_states.shape[-2] + enough_kv_room = True if past_key_value is not None: + enough_kv_room = is_enough_kv_cache_room_4_31(past_key_value, seq_len=kv_seq_len) kv_seq_len += past_key_value[0].shape[-2] if query_states.device.type == "xpu" and not (self.training and query_states.requires_grad): query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states, @@ -76,7 +79,7 @@ def baichuan_attention_forward_7b( # reuse k, v, self_attention cache_k = past_key_value[0] cache_v = past_key_value[1] - if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3): + if not enough_kv_room: # allocate new new_cache_k, new_cache_v = extend_kv_cache(bsz, self.num_heads, @@ -163,7 +166,9 @@ def baichuan_attention_forward_13b( value_states = proj[2].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) kv_seq_len = key_states.shape[-2] + enough_kv_room = True if past_key_value is not None: + enough_kv_room = is_enough_kv_cache_room_4_31(past_key_value, seq_len=kv_seq_len) kv_seq_len += past_key_value[0].shape[-2] # if past_key_value is not None: @@ -174,7 +179,7 @@ def baichuan_attention_forward_13b( # reuse k, v, self_attention cache_k = past_key_value[0] cache_v = past_key_value[1] - if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3): + if not enough_kv_room: # allocate new new_cache_k, new_cache_v = extend_kv_cache(bsz, self.num_heads, diff --git a/python/llm/src/bigdl/llm/transformers/models/baichuan2.py b/python/llm/src/bigdl/llm/transformers/models/baichuan2.py index 21bf40db..0bbf0038 100644 --- a/python/llm/src/bigdl/llm/transformers/models/baichuan2.py +++ b/python/llm/src/bigdl/llm/transformers/models/baichuan2.py @@ -24,7 +24,8 @@ import torch import torch.utils.checkpoint from torch.nn import functional as F from bigdl.llm.ggml.quantize import ggml_tensor_qtype -from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache +from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, \ + append_kv_cache, is_enough_kv_cache_room_4_31 from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu from bigdl.llm.transformers.models.utils import mlp_fusion_check @@ -104,7 +105,9 @@ def baichuan_attention_forward_7b( value_states = proj[2].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) kv_seq_len = key_states.shape[-2] + enough_kv_room = True if past_key_value is not None: + enough_kv_room = is_enough_kv_cache_room_4_31(past_key_value, seq_len=kv_seq_len) kv_seq_len += past_key_value[0].shape[-2] if query_states.device.type == "xpu" and not (self.training and query_states.requires_grad): query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states, @@ -125,7 +128,7 @@ def baichuan_attention_forward_7b( # reuse k, v, self_attention cache_k = past_key_value[0] cache_v = past_key_value[1] - if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3): + if not enough_kv_room: # allocate new new_cache_k, new_cache_v = extend_kv_cache(bsz, self.num_heads, @@ -216,7 +219,9 @@ def baichuan_attention_forward_13b( ) kv_seq_len = key_states.shape[-2] + enough_kv_room = True if past_key_value is not None: + enough_kv_room = is_enough_kv_cache_room_4_31(past_key_value, seq_len=kv_seq_len) kv_seq_len += past_key_value[0].shape[-2] # if past_key_value is not None: @@ -227,7 +232,7 @@ def baichuan_attention_forward_13b( # reuse k, v, self_attention cache_k = past_key_value[0] cache_v = past_key_value[1] - if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3): + if not enough_kv_room: if device.type == 'xpu': torch.xpu.empty_cache() # allocate new diff --git a/python/llm/src/bigdl/llm/transformers/models/bloom.py b/python/llm/src/bigdl/llm/transformers/models/bloom.py index ff9d4b6a..d8d5aab0 100644 --- a/python/llm/src/bigdl/llm/transformers/models/bloom.py +++ b/python/llm/src/bigdl/llm/transformers/models/bloom.py @@ -121,7 +121,7 @@ def bloom_attention_forward( # reuse k, v, self_attention cache_k = layer_past[0].transpose(1, 2).view(batch_size, self.num_heads, -1, self.head_dim) cache_v = layer_past[1].view(batch_size, self.num_heads, -1, self.head_dim) - if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3): + if cache_k.stride()[1] < kv_length * cache_k.size(3): # allocate new new_cache_k, new_cache_v = extend_kv_cache( batch_size, diff --git a/python/llm/src/bigdl/llm/transformers/models/chatglm.py b/python/llm/src/bigdl/llm/transformers/models/chatglm.py index 9285df44..4adcc722 100644 --- a/python/llm/src/bigdl/llm/transformers/models/chatglm.py +++ b/python/llm/src/bigdl/llm/transformers/models/chatglm.py @@ -66,7 +66,7 @@ def attention_fn( cache_k = cache_k.permute(1, 2, 0, 3) cache_v = cache_v.permute(1, 2, 0, 3) past_length = cache_k.size(2) - if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3): + if cache_k.stride()[1] < (past_length + cur_length) * cache_k.size(3): max_cache_length = past_length + cur_length + KV_CACHE_ALLOC_BLOCK_LENGTH new_cache_k, new_cache_v = extend_kv_cache(batch_size, self.num_attention_heads_per_partition, diff --git a/python/llm/src/bigdl/llm/transformers/models/chatglm2.py b/python/llm/src/bigdl/llm/transformers/models/chatglm2.py index 8944d807..67719ad2 100644 --- a/python/llm/src/bigdl/llm/transformers/models/chatglm2.py +++ b/python/llm/src/bigdl/llm/transformers/models/chatglm2.py @@ -419,7 +419,7 @@ def chatglm2_attention_forward_8eb45c( cache_v = cache_v.permute(1, 2, 0, 3) past_length = cache_k.size(2) - if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3): + if cache_k.stride()[1] < (past_length + cur_length) * cache_k.size(3): max_cache_length = past_length + cur_length + KV_CACHE_ALLOC_BLOCK_LENGTH if device.type == "xpu" and batch_size > 1: # use beam_search for generation. # If batch_size > 1 on gpu, use init_kv_cache to avoid empty cache for ensuring diff --git a/python/llm/src/bigdl/llm/transformers/models/chatglm2_32k.py b/python/llm/src/bigdl/llm/transformers/models/chatglm2_32k.py index d7842d90..d85861f2 100644 --- a/python/llm/src/bigdl/llm/transformers/models/chatglm2_32k.py +++ b/python/llm/src/bigdl/llm/transformers/models/chatglm2_32k.py @@ -150,7 +150,7 @@ def chatglm2_32k_attention_forward( cache_v = cache_v.permute(1, 2, 0, 3) past_length = cache_k.size(2) - if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3): + if cache_k.stride()[1] < (past_length + cur_length) * cache_k.size(3): max_cache_length = past_length + cur_length + KV_CACHE_ALLOC_BLOCK_LENGTH new_cache_k, new_cache_v = extend_kv_cache(batch_size, self.num_attention_heads_per_partition, diff --git a/python/llm/src/bigdl/llm/transformers/models/falcon.py b/python/llm/src/bigdl/llm/transformers/models/falcon.py index d5fb455c..cdcfcefa 100644 --- a/python/llm/src/bigdl/llm/transformers/models/falcon.py +++ b/python/llm/src/bigdl/llm/transformers/models/falcon.py @@ -97,7 +97,7 @@ def rw_attention_forward_7b( # reuse k, v, self_attention cache_k = layer_past[0].view(batch_size, self.num_kv, -1, self.head_dim) cache_v = layer_past[1].view(batch_size, self.num_kv, -1, self.head_dim) - if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3): + if cache_k.stride()[1] < kv_length * cache_k.size(3): # allocate new new_cache_k, new_cache_v = extend_kv_cache( batch_size, @@ -276,7 +276,7 @@ def rw_attention_forward_40b( # reuse k, v, self_attention cache_k = layer_past[0].view(batch_size, self.num_heads, -1, self.head_dim) cache_v = layer_past[1].view(batch_size, self.num_heads, -1, self.head_dim) - if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3): + if cache_k.stride()[1] < kv_length * cache_k.size(3): # allocate new new_cache_k, new_cache_v = extend_kv_cache( batch_size, @@ -450,7 +450,7 @@ def falcon_attention_forward( # reuse k, v, self_attention cache_k = layer_past[0].view(batch_size, num_kv_heads, -1, self.head_dim) cache_v = layer_past[1].view(batch_size, num_kv_heads, -1, self.head_dim) - if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3): + if cache_k.stride()[1] < kv_length * cache_k.size(3): # allocate new new_cache_k, new_cache_v = extend_kv_cache( batch_size, diff --git a/python/llm/src/bigdl/llm/transformers/models/gptj.py b/python/llm/src/bigdl/llm/transformers/models/gptj.py index 6a4e0aff..7dfee106 100644 --- a/python/llm/src/bigdl/llm/transformers/models/gptj.py +++ b/python/llm/src/bigdl/llm/transformers/models/gptj.py @@ -142,8 +142,7 @@ def gptj_attention_forward( cache_k = cache_k.permute(0, 2, 1, 3) cache_v = cache_v.permute(0, 2, 1, 3) past_length = cache_k.size(2) - - if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3): + if cache_k.stride()[1] < kv_seq_len * cache_k.size(3): new_cache_k, new_cache_v = extend_kv_cache(batch_size, self.num_attention_heads, self.head_dim, diff --git a/python/llm/src/bigdl/llm/transformers/models/gptneox.py b/python/llm/src/bigdl/llm/transformers/models/gptneox.py index 2a47d6e9..ca29845a 100644 --- a/python/llm/src/bigdl/llm/transformers/models/gptneox.py +++ b/python/llm/src/bigdl/llm/transformers/models/gptneox.py @@ -34,7 +34,8 @@ import torch from typing import Optional, Tuple from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb -from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache +from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, \ + append_kv_cache, is_enough_kv_cache_room_4_31 from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu @@ -79,7 +80,9 @@ def gptneox_attention_forward( # Compute token offset for rotary embeddings (when decoding) seq_len = key.shape[-2] + enough_kv_room = True if has_layer_past: + enough_kv_room = is_enough_kv_cache_room_4_31(layer_past, seq_len=seq_len) seq_len += layer_past[0].shape[-2] use_fuse_rope = query.device.type == "xpu" @@ -101,7 +104,7 @@ def gptneox_attention_forward( if has_layer_past: past_key = layer_past[0] past_value = layer_past[1] - if past_key.stride()[1] <= past_key.size(2) * past_key.size(3): + if not enough_kv_room: # allocate new new_past_key, new_past_value = extend_kv_cache(bsz, self.num_attention_heads, diff --git a/python/llm/src/bigdl/llm/transformers/models/internlm.py b/python/llm/src/bigdl/llm/transformers/models/internlm.py index b52259fd..53475ed9 100644 --- a/python/llm/src/bigdl/llm/transformers/models/internlm.py +++ b/python/llm/src/bigdl/llm/transformers/models/internlm.py @@ -43,7 +43,8 @@ import torch import torch.utils.checkpoint from torch import nn from bigdl.llm.utils.common import invalidInputError -from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache +from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, \ + append_kv_cache, is_enough_kv_cache_room_4_31 from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu @@ -73,7 +74,9 @@ def internlm_attention_forward( .transpose(1, 2) kv_seq_len = key_states.shape[-2] + enough_kv_room = True if past_key_value is not None: + enough_kv_room = is_enough_kv_cache_room_4_31(past_key_value, seq_len=kv_seq_len) kv_seq_len += past_key_value[0].shape[-2] if query_states.device.type == "xpu" and not (self.training and query_states.requires_grad): query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states, @@ -95,7 +98,7 @@ def internlm_attention_forward( # reuse k, v, self_attention cache_k = past_key_value[0] cache_v = past_key_value[1] - if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3): + if not enough_kv_room: # allocate new new_cache_k, new_cache_v = extend_kv_cache( bsz, diff --git a/python/llm/src/bigdl/llm/transformers/models/mistral.py b/python/llm/src/bigdl/llm/transformers/models/mistral.py index 3c653c2a..9f0c5fb6 100644 --- a/python/llm/src/bigdl/llm/transformers/models/mistral.py +++ b/python/llm/src/bigdl/llm/transformers/models/mistral.py @@ -43,9 +43,9 @@ from torch import nn import torch.nn.functional as F from bigdl.llm.utils.common import invalidInputError from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache -from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb,\ +from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb, \ apply_rotary_pos_emb_no_cache_xpu -from bigdl.llm.transformers.models.utils import is_enough_kv_cache_room_4_31,\ +from bigdl.llm.transformers.models.utils import is_enough_kv_cache_room_4_31, \ is_enough_kv_cache_room_4_36 from bigdl.llm.transformers.low_bit_linear import SYM_INT4, FP8E5 from bigdl.llm.transformers.models.utils import use_flash_attention diff --git a/python/llm/src/bigdl/llm/transformers/models/mpt.py b/python/llm/src/bigdl/llm/transformers/models/mpt.py index fd8e28b7..7b32a4bc 100644 --- a/python/llm/src/bigdl/llm/transformers/models/mpt.py +++ b/python/llm/src/bigdl/llm/transformers/models/mpt.py @@ -77,7 +77,7 @@ def mpt_scaled_multihead_dot_product_attention(query, key, value, n_heads, cache_k = past_key_value[0].transpose(2, 3) cache_v = past_key_value[1] kv_seq_len += cache_k.shape[-2] - if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3): + if cache_k.stride()[1] < kv_seq_len * cache_k.size(3): # allocate new new_cache_k, new_cache_v = extend_kv_cache(bsz, kv_n_heads, # Support GQA diff --git a/python/llm/src/bigdl/llm/transformers/models/qwen_vl.py b/python/llm/src/bigdl/llm/transformers/models/qwen_vl.py index 83cf5870..82fcf90b 100644 --- a/python/llm/src/bigdl/llm/transformers/models/qwen_vl.py +++ b/python/llm/src/bigdl/llm/transformers/models/qwen_vl.py @@ -89,7 +89,7 @@ def qwen_attention_forward_vl( # value = torch.cat((past_value, value), dim=1) cache_k = layer_past[0].transpose(1, 2) cache_v = layer_past[1].transpose(1, 2) - if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3): + if cache_k.stride()[1] < kv_seq_len * cache_k.size(3): # allocate new new_cache_k, new_cache_v = extend_kv_cache(bsz, self.num_heads, diff --git a/python/llm/src/bigdl/llm/transformers/models/utils.py b/python/llm/src/bigdl/llm/transformers/models/utils.py index b0aefa84..67f4e6e1 100644 --- a/python/llm/src/bigdl/llm/transformers/models/utils.py +++ b/python/llm/src/bigdl/llm/transformers/models/utils.py @@ -60,7 +60,7 @@ def append_kv_cache(cache_k, cache_v, key_states, value_states): new_cache_k = cache_k.as_strided(new_size, cache_k.stride(), storage_offset=0) new_cache_k[:, :, cache_k.size(2):cache_k.size(2) + key_states.size(2), :] = key_states new_cache_v = cache_v.as_strided(new_size, cache_v.stride(), storage_offset=0) - new_cache_v[:, :, cache_v.size(2):cache_k.size(2) + key_states.size(2), :] = value_states + new_cache_v[:, :, cache_v.size(2):cache_v.size(2) + key_states.size(2), :] = value_states return new_cache_k, new_cache_v @@ -194,17 +194,21 @@ def apply_rotary_pos_emb_cache_freq_xpu(q, k, sin, cos, model_family): def is_enough_kv_cache_room_4_36(past_key_value, idx, seq_len=1): # to determinate if is enough kv cache room in transformers==4.36 + # seq_len for current seq len + # For llama like kv cache, i.e., [bs, n_head, seq_len, head_dim] return past_key_value is not None and len(past_key_value.key_cache) > idx and \ - past_key_value.key_cache[idx].stride()[1] > \ - (past_key_value.key_cache[idx].size(2) + seq_len - 1) * \ + past_key_value.key_cache[idx].stride()[1] >= \ + (past_key_value.key_cache[idx].size(2) + seq_len) * \ past_key_value.key_cache[idx].size(3) def is_enough_kv_cache_room_4_31(past_key_value, seq_len=1): # to determinate if is enough kv cache room in transformers between 4.31 and 4.35 + # seq_len for current seq len + # For llama like kv cache, i.e., [bs, n_head, seq_len, head_dim] return past_key_value is not None and \ - past_key_value[0].stride()[1] > \ - (past_key_value[0].size(2) + seq_len - 1) * past_key_value[0].size(3) + past_key_value[0].stride()[1] >= \ + (past_key_value[0].size(2) + seq_len) * past_key_value[0].size(3) def use_flash_attention(query, key):