LLM: refactor kv cache (#9030)

* refactor utils

* meet code review; update all models

* small fix
This commit is contained in:
Ruonan Wang 2023-09-21 21:28:03 +08:00 committed by GitHub
parent 868511cf02
commit b943d73844
12 changed files with 93 additions and 108 deletions

View file

@ -44,7 +44,6 @@ if __name__ == '__main__':
# which convert the relevant layers in the model into INT4 format # which convert the relevant layers in the model into INT4 format
model = AutoModelForCausalLM.from_pretrained(model_path, model = AutoModelForCausalLM.from_pretrained(model_path,
load_in_4bit=True, load_in_4bit=True,
optimize_model=False,
trust_remote_code=True, trust_remote_code=True,
use_cache=True) use_cache=True)
model = model.to('xpu') model = model.to('xpu')

View file

@ -42,7 +42,6 @@ if __name__ == '__main__':
# which convert the relevant layers in the model into INT4 format # which convert the relevant layers in the model into INT4 format
model = AutoModelForCausalLM.from_pretrained(model_path, model = AutoModelForCausalLM.from_pretrained(model_path,
load_in_4bit=True, load_in_4bit=True,
optimize_model=False,
trust_remote_code=True, trust_remote_code=True,
use_cache=True) use_cache=True)
model = model.to('xpu') model = model.to('xpu')

View file

@ -26,7 +26,7 @@ import torch.utils.checkpoint
from torch import nn from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from bigdl.llm.utils.common import invalidInputError from bigdl.llm.utils.common import invalidInputError
from bigdl.llm.transformers.models.utils import create_kv_cache, append_kv_cache from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
from bigdl.llm.transformers.models.utils import rotate_half, apply_rotary_pos_emb from bigdl.llm.transformers.models.utils import rotate_half, apply_rotary_pos_emb
KV_CACHE_ALLOC_BLOCK_LENGTH = 256 KV_CACHE_ALLOC_BLOCK_LENGTH = 256
@ -70,10 +70,8 @@ def baichuan_attention_forward_7b(
cache_k = past_key_value[0] cache_k = past_key_value[0]
cache_v = past_key_value[1] cache_v = past_key_value[1]
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3): if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
if device.type == 'xpu':
torch.xpu.empty_cache()
# allocate new # allocate new
new_cache_k, new_cache_v = create_kv_cache(bsz, new_cache_k, new_cache_v = extend_kv_cache(bsz,
self.num_heads, self.num_heads,
self.head_dim, self.head_dim,
cache_k.size(2), cache_k.size(2),
@ -89,7 +87,7 @@ def baichuan_attention_forward_7b(
elif use_cache: elif use_cache:
max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH
new_key_states, new_value_states = create_kv_cache(bsz, new_key_states, new_value_states = init_kv_cache(bsz,
self.num_heads, self.num_heads,
self.head_dim, self.head_dim,
kv_seq_len, kv_seq_len,
@ -170,10 +168,8 @@ def baichuan_attention_forward_13b(
cache_k = past_key_value[0] cache_k = past_key_value[0]
cache_v = past_key_value[1] cache_v = past_key_value[1]
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3): if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
if device.type == 'xpu':
torch.xpu.empty_cache()
# allocate new # allocate new
new_cache_k, new_cache_v = create_kv_cache(bsz, new_cache_k, new_cache_v = extend_kv_cache(bsz,
self.num_heads, self.num_heads,
self.head_dim, self.head_dim,
cache_k.size(2), cache_k.size(2),
@ -189,7 +185,7 @@ def baichuan_attention_forward_13b(
elif use_cache: elif use_cache:
max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH
new_key_states, new_value_states = create_kv_cache(bsz, new_key_states, new_value_states = init_kv_cache(bsz,
self.num_heads, self.num_heads,
self.head_dim, self.head_dim,
kv_seq_len, kv_seq_len,

View file

@ -26,7 +26,7 @@ from torch import nn
from torch.nn import functional as F from torch.nn import functional as F
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from bigdl.llm.utils.common import invalidInputError from bigdl.llm.utils.common import invalidInputError
from bigdl.llm.transformers.models.utils import create_kv_cache, append_kv_cache from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
from bigdl.llm.transformers.models.utils import rotate_half, apply_rotary_pos_emb from bigdl.llm.transformers.models.utils import rotate_half, apply_rotary_pos_emb
from transformers.utils import logging, ContextManagers from transformers.utils import logging, ContextManagers
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@ -82,10 +82,8 @@ def baichuan_attention_forward_7b(
cache_k = past_key_value[0] cache_k = past_key_value[0]
cache_v = past_key_value[1] cache_v = past_key_value[1]
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3): if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
if device.type == 'xpu':
torch.xpu.empty_cache()
# allocate new # allocate new
new_cache_k, new_cache_v = create_kv_cache(bsz, new_cache_k, new_cache_v = extend_kv_cache(bsz,
self.num_heads, self.num_heads,
self.head_dim, self.head_dim,
cache_k.size(2), cache_k.size(2),
@ -101,7 +99,7 @@ def baichuan_attention_forward_7b(
elif use_cache: elif use_cache:
max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH
new_key_states, new_value_states = create_kv_cache(bsz, new_key_states, new_value_states = init_kv_cache(bsz,
self.num_heads, self.num_heads,
self.head_dim, self.head_dim,
kv_seq_len, kv_seq_len,
@ -182,7 +180,7 @@ def baichuan_attention_forward_13b(
if device.type == 'xpu': if device.type == 'xpu':
torch.xpu.empty_cache() torch.xpu.empty_cache()
# allocate new # allocate new
new_cache_k, new_cache_v = create_kv_cache(bsz, new_cache_k, new_cache_v = extend_kv_cache(bsz,
self.num_heads, self.num_heads,
self.head_dim, self.head_dim,
cache_k.size(2), cache_k.size(2),
@ -198,7 +196,7 @@ def baichuan_attention_forward_13b(
elif use_cache: elif use_cache:
max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH
new_key_states, new_value_states = create_kv_cache(bsz, new_key_states, new_value_states = init_kv_cache(bsz,
self.num_heads, self.num_heads,
self.head_dim, self.head_dim,
kv_seq_len, kv_seq_len,

View file

@ -37,7 +37,7 @@ from typing import Optional, Tuple
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from torch.nn import functional as F from torch.nn import functional as F
from bigdl.llm.transformers.models.utils import create_kv_cache, append_kv_cache from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
KV_CACHE_ALLOC_BLOCK_LENGTH = 256 KV_CACHE_ALLOC_BLOCK_LENGTH = 256
@ -107,10 +107,8 @@ def bloom_attention_forward(
cache_k = layer_past[0].transpose(1, 2).view(batch_size, self.num_heads, -1, self.head_dim) cache_k = layer_past[0].transpose(1, 2).view(batch_size, self.num_heads, -1, self.head_dim)
cache_v = layer_past[1].view(batch_size, self.num_heads, -1, self.head_dim) cache_v = layer_past[1].view(batch_size, self.num_heads, -1, self.head_dim)
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3): if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
if device.type == 'xpu':
torch.xpu.empty_cache()
# allocate new # allocate new
new_cache_k, new_cache_v = create_kv_cache( new_cache_k, new_cache_v = extend_kv_cache(
batch_size, batch_size,
self.num_heads, self.num_heads,
self.head_dim, self.head_dim,
@ -128,7 +126,7 @@ def bloom_attention_forward(
elif use_cache: elif use_cache:
max_cache_length = kv_length + KV_CACHE_ALLOC_BLOCK_LENGTH max_cache_length = kv_length + KV_CACHE_ALLOC_BLOCK_LENGTH
new_key_states, new_value_states = create_kv_cache( new_key_states, new_value_states = init_kv_cache(
batch_size, batch_size,
self.num_heads, self.num_heads,
self.head_dim, self.head_dim,

View file

@ -22,7 +22,7 @@ import torch
import torch.utils.checkpoint import torch.utils.checkpoint
import torch.nn.functional as F import torch.nn.functional as F
from typing import Optional, Tuple from typing import Optional, Tuple
from bigdl.llm.transformers.models.utils import create_kv_cache, append_kv_cache from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
def rotate_half(x): def rotate_half(x):
@ -67,10 +67,8 @@ def attention_fn(
cache_v = cache_v.permute(1, 2, 0, 3) cache_v = cache_v.permute(1, 2, 0, 3)
past_length = cache_k.size(2) past_length = cache_k.size(2)
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3): if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
if device.type == 'xpu':
torch.xpu.empty_cache()
max_cache_length = past_length + cur_length + KV_CACHE_ALLOC_BLOCK_LENGTH max_cache_length = past_length + cur_length + KV_CACHE_ALLOC_BLOCK_LENGTH
new_cache_k, new_cache_v = create_kv_cache(batch_size, new_cache_k, new_cache_v = extend_kv_cache(batch_size,
self.num_attention_heads_per_partition, self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head, self.hidden_size_per_attention_head,
past_length, past_length,
@ -84,7 +82,7 @@ def attention_fn(
elif use_cache: elif use_cache:
max_cache_length = max(KV_CACHE_ALLOC_MIN_LENGTH, cur_length) \ max_cache_length = max(KV_CACHE_ALLOC_MIN_LENGTH, cur_length) \
+ KV_CACHE_ALLOC_BLOCK_LENGTH + KV_CACHE_ALLOC_BLOCK_LENGTH
key_cache, value_cache = create_kv_cache(batch_size, self.num_attention_heads_per_partition, key_cache, value_cache = init_kv_cache(batch_size, self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head, cur_length, self.hidden_size_per_attention_head, cur_length,
max_cache_length, max_cache_length,
dtype=query_layer.dtype, device=device) dtype=query_layer.dtype, device=device)

View file

@ -20,7 +20,7 @@
import torch import torch
from typing import Optional, Tuple, Union, List, Callable, Dict, Any from typing import Optional, Tuple, Union, List, Callable, Dict, Any
import torch.nn.functional as F import torch.nn.functional as F
from bigdl.llm.transformers.models.utils import create_kv_cache, append_kv_cache from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
KV_CACHE_ALLOC_BLOCK_LENGTH = 256 KV_CACHE_ALLOC_BLOCK_LENGTH = 256
@ -151,10 +151,8 @@ def chatglm2_attention_forward_8eb45c(
past_length = cache_k.size(2) past_length = cache_k.size(2)
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3): if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
if device.type == 'xpu':
torch.xpu.empty_cache()
max_cache_length = past_length + cur_length + KV_CACHE_ALLOC_BLOCK_LENGTH max_cache_length = past_length + cur_length + KV_CACHE_ALLOC_BLOCK_LENGTH
new_cache_k, new_cache_v = create_kv_cache(batch_size, new_cache_k, new_cache_v = extend_kv_cache(batch_size,
self.num_attention_heads_per_partition, self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head, self.hidden_size_per_attention_head,
past_length, past_length,
@ -172,7 +170,7 @@ def chatglm2_attention_forward_8eb45c(
max_cache_length = max(KV_CACHE_ALLOC_MIN_LENGTH, cur_length) \ max_cache_length = max(KV_CACHE_ALLOC_MIN_LENGTH, cur_length) \
+ KV_CACHE_ALLOC_BLOCK_LENGTH + KV_CACHE_ALLOC_BLOCK_LENGTH
key_cache, value_cache = create_kv_cache(batch_size, self.num_attention_heads_per_partition, key_cache, value_cache = init_kv_cache(batch_size, self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head, cur_length, self.hidden_size_per_attention_head, cur_length,
max_cache_length, max_cache_length,
dtype=query_layer.dtype, device=device) dtype=query_layer.dtype, device=device)

View file

@ -38,7 +38,7 @@ from typing import Optional, Tuple
import torch import torch
from torch.nn import functional as F from torch.nn import functional as F
from bigdl.llm.utils.common import invalidInputError from bigdl.llm.utils.common import invalidInputError
from bigdl.llm.transformers.models.utils import create_kv_cache, append_kv_cache from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
KV_CACHE_ALLOC_BLOCK_LENGTH = 256 KV_CACHE_ALLOC_BLOCK_LENGTH = 256
@ -98,10 +98,8 @@ def rw_attention_forward_7b(
cache_k = layer_past[0].view(batch_size, self.num_kv, -1, self.head_dim) cache_k = layer_past[0].view(batch_size, self.num_kv, -1, self.head_dim)
cache_v = layer_past[1].view(batch_size, self.num_kv, -1, self.head_dim) cache_v = layer_past[1].view(batch_size, self.num_kv, -1, self.head_dim)
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3): if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
if device.type == 'xpu':
torch.xpu.empty_cache()
# allocate new # allocate new
new_cache_k, new_cache_v = create_kv_cache( new_cache_k, new_cache_v = extend_kv_cache(
batch_size, batch_size,
self.num_kv, self.num_kv,
self.head_dim, self.head_dim,
@ -119,7 +117,7 @@ def rw_attention_forward_7b(
elif use_cache: elif use_cache:
max_cache_length = kv_length + KV_CACHE_ALLOC_BLOCK_LENGTH max_cache_length = kv_length + KV_CACHE_ALLOC_BLOCK_LENGTH
new_key_states, new_value_states = create_kv_cache( new_key_states, new_value_states = init_kv_cache(
batch_size, batch_size,
self.num_kv, self.num_kv,
self.head_dim, self.head_dim,
@ -280,7 +278,7 @@ def rw_attention_forward_40b(
cache_v = layer_past[1].view(batch_size, self.num_heads, -1, self.head_dim) cache_v = layer_past[1].view(batch_size, self.num_heads, -1, self.head_dim)
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3): if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
# allocate new # allocate new
new_cache_k, new_cache_v = create_kv_cache( new_cache_k, new_cache_v = extend_kv_cache(
batch_size, batch_size,
self.num_heads, self.num_heads,
self.head_dim, self.head_dim,
@ -298,7 +296,7 @@ def rw_attention_forward_40b(
elif use_cache: elif use_cache:
max_cache_length = kv_length + KV_CACHE_ALLOC_BLOCK_LENGTH max_cache_length = kv_length + KV_CACHE_ALLOC_BLOCK_LENGTH
new_key_states, new_value_states = create_kv_cache( new_key_states, new_value_states = init_kv_cache(
batch_size, batch_size,
self.num_heads, self.num_heads,
self.head_dim, self.head_dim,
@ -454,7 +452,7 @@ def falcon_attention_forward(
cache_v = layer_past[1].view(batch_size, self.num_heads, -1, self.head_dim) cache_v = layer_past[1].view(batch_size, self.num_heads, -1, self.head_dim)
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3): if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
# allocate new # allocate new
new_cache_k, new_cache_v = create_kv_cache( new_cache_k, new_cache_v = extend_kv_cache(
batch_size, batch_size,
self.num_heads, self.num_heads,
self.head_dim, self.head_dim,
@ -472,7 +470,7 @@ def falcon_attention_forward(
elif use_cache: elif use_cache:
max_cache_length = kv_length + KV_CACHE_ALLOC_BLOCK_LENGTH max_cache_length = kv_length + KV_CACHE_ALLOC_BLOCK_LENGTH
new_key_states, new_value_states = create_kv_cache( new_key_states, new_value_states = init_kv_cache(
batch_size, batch_size,
self.num_heads, self.num_heads,
self.head_dim, self.head_dim,

View file

@ -19,8 +19,8 @@
import torch import torch
from typing import Optional, Tuple, Union from typing import Optional, Tuple, Union
from bigdl.llm.transformers.models.utils import create_kv_cache, append_kv_cache, \ from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, \
apply_rotary_pos_emb apply_rotary_pos_emb, append_kv_cache
from transformers.utils.import_utils import is_torch_fx_proxy from transformers.utils.import_utils import is_torch_fx_proxy
@ -144,9 +144,7 @@ def gptj_attention_forward(
past_length = cache_k.size(2) past_length = cache_k.size(2)
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3): if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
if device.type == 'xpu': new_cache_k, new_cache_v = extend_kv_cache(batch_size,
torch.xpu.empty_cache()
new_cache_k, new_cache_v = create_kv_cache(batch_size,
self.num_attention_heads, self.num_attention_heads,
self.head_dim, self.head_dim,
past_length, past_length,
@ -160,7 +158,7 @@ def gptj_attention_forward(
key, value = append_kv_cache(cache_k, cache_v, key, value) key, value = append_kv_cache(cache_k, cache_v, key, value)
elif use_cache: elif use_cache:
key_cache, value_cache = create_kv_cache(batch_size, key_cache, value_cache = init_kv_cache(batch_size,
self.num_attention_heads, self.num_attention_heads,
self.head_dim, self.head_dim,
kv_seq_len, kv_seq_len,

View file

@ -34,7 +34,7 @@
import torch import torch
from typing import Optional, Tuple from typing import Optional, Tuple
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb
from bigdl.llm.transformers.models.utils import create_kv_cache, append_kv_cache from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache
KV_CACHE_ALLOC_BLOCK_LENGTH = 256 KV_CACHE_ALLOC_BLOCK_LENGTH = 256
@ -90,10 +90,8 @@ def gptneox_attention_forward(
past_key = layer_past[0] past_key = layer_past[0]
past_value = layer_past[1] past_value = layer_past[1]
if past_key.stride()[1] <= past_key.size(2) * past_key.size(3): if past_key.stride()[1] <= past_key.size(2) * past_key.size(3):
if device.type == 'xpu':
torch.xpu.empty_cache()
# allocate new # allocate new
new_past_key, new_past_value = create_kv_cache(bsz, new_past_key, new_past_value = extend_kv_cache(bsz,
self.num_attention_heads, self.num_attention_heads,
self.head_size, self.head_size,
past_key.size(2), past_key.size(2),
@ -108,7 +106,7 @@ def gptneox_attention_forward(
key, value = append_kv_cache(past_key, past_value, key, value) key, value = append_kv_cache(past_key, past_value, key, value)
elif use_cache: elif use_cache:
max_cache_length = seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH max_cache_length = seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH
new_key, new_value = create_kv_cache(bsz, new_key, new_value = init_kv_cache(bsz,
self.num_attention_heads, self.num_attention_heads,
self.head_size, self.head_size,
seq_len, seq_len,

View file

@ -37,7 +37,7 @@ from typing import Optional, Tuple
import math import math
import torch.nn.functional as F import torch.nn.functional as F
from bigdl.llm.utils.common import invalidInputError from bigdl.llm.utils.common import invalidInputError
from bigdl.llm.transformers.models.utils import create_kv_cache, append_kv_cache from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
from bigdl.llm.transformers.models.utils import rotate_half, apply_rotary_pos_emb from bigdl.llm.transformers.models.utils import rotate_half, apply_rotary_pos_emb
@ -112,10 +112,8 @@ def llama_attention_forward_4_31(
cache_k = past_key_value[0] cache_k = past_key_value[0]
cache_v = past_key_value[1] cache_v = past_key_value[1]
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3): if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
if device.type == 'xpu':
torch.xpu.empty_cache()
# allocate new # allocate new
new_cache_k, new_cache_v = create_kv_cache(bsz, new_cache_k, new_cache_v = extend_kv_cache(bsz,
self.num_key_value_heads, # Support GQA self.num_key_value_heads, # Support GQA
self.head_dim, self.head_dim,
cache_k.size(2), cache_k.size(2),
@ -131,7 +129,7 @@ def llama_attention_forward_4_31(
elif use_cache: elif use_cache:
max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH
new_key_states, new_value_states = create_kv_cache(bsz, new_key_states, new_value_states = init_kv_cache(bsz,
self.num_key_value_heads, self.num_key_value_heads,
self.head_dim, self.head_dim,
kv_seq_len, kv_seq_len,

View file

@ -18,7 +18,7 @@ import torch
from bigdl.llm.utils.common import invalidInputError from bigdl.llm.utils.common import invalidInputError
def create_kv_cache(batch_size, num_heads, head_dim, current_length, max_length, dtype, device): def init_kv_cache(batch_size, num_heads, head_dim, current_length, max_length, dtype, device):
key_cache_storage = torch.empty(batch_size, num_heads, key_cache_storage = torch.empty(batch_size, num_heads,
max_length, head_dim, max_length, head_dim,
dtype=dtype, device=device) dtype=dtype, device=device)
@ -37,6 +37,13 @@ def create_kv_cache(batch_size, num_heads, head_dim, current_length, max_length,
return key_cache, value_cache return key_cache, value_cache
def extend_kv_cache(batch_size, num_heads, head_dim, current_length, max_length, dtype, device):
# empty cache to reduce gpu memory
if device.type == 'xpu':
torch.xpu.empty_cache()
return init_kv_cache(batch_size, num_heads, head_dim, current_length, max_length, dtype, device)
def append_kv_cache(cache_k, cache_v, key_states, value_states): def append_kv_cache(cache_k, cache_v, key_states, value_states):
new_size = (cache_k.size(0), new_size = (cache_k.size(0),
cache_k.size(1), cache_k.size(1),