LLM: refactor kv cache (#9030)

* refactor utils

* meet code review; update all models

* small fix
This commit is contained in:
Ruonan Wang 2023-09-21 21:28:03 +08:00 committed by GitHub
parent 868511cf02
commit b943d73844
12 changed files with 93 additions and 108 deletions

View file

@ -44,7 +44,6 @@ if __name__ == '__main__':
# which convert the relevant layers in the model into INT4 format
model = AutoModelForCausalLM.from_pretrained(model_path,
load_in_4bit=True,
optimize_model=False,
trust_remote_code=True,
use_cache=True)
model = model.to('xpu')

View file

@ -42,7 +42,6 @@ if __name__ == '__main__':
# which convert the relevant layers in the model into INT4 format
model = AutoModelForCausalLM.from_pretrained(model_path,
load_in_4bit=True,
optimize_model=False,
trust_remote_code=True,
use_cache=True)
model = model.to('xpu')

View file

@ -26,7 +26,7 @@ import torch.utils.checkpoint
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from bigdl.llm.utils.common import invalidInputError
from bigdl.llm.transformers.models.utils import create_kv_cache, append_kv_cache
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
from bigdl.llm.transformers.models.utils import rotate_half, apply_rotary_pos_emb
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
@ -70,10 +70,8 @@ def baichuan_attention_forward_7b(
cache_k = past_key_value[0]
cache_v = past_key_value[1]
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
if device.type == 'xpu':
torch.xpu.empty_cache()
# allocate new
new_cache_k, new_cache_v = create_kv_cache(bsz,
new_cache_k, new_cache_v = extend_kv_cache(bsz,
self.num_heads,
self.head_dim,
cache_k.size(2),
@ -89,7 +87,7 @@ def baichuan_attention_forward_7b(
elif use_cache:
max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH
new_key_states, new_value_states = create_kv_cache(bsz,
new_key_states, new_value_states = init_kv_cache(bsz,
self.num_heads,
self.head_dim,
kv_seq_len,
@ -170,10 +168,8 @@ def baichuan_attention_forward_13b(
cache_k = past_key_value[0]
cache_v = past_key_value[1]
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
if device.type == 'xpu':
torch.xpu.empty_cache()
# allocate new
new_cache_k, new_cache_v = create_kv_cache(bsz,
new_cache_k, new_cache_v = extend_kv_cache(bsz,
self.num_heads,
self.head_dim,
cache_k.size(2),
@ -189,7 +185,7 @@ def baichuan_attention_forward_13b(
elif use_cache:
max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH
new_key_states, new_value_states = create_kv_cache(bsz,
new_key_states, new_value_states = init_kv_cache(bsz,
self.num_heads,
self.head_dim,
kv_seq_len,

View file

@ -26,7 +26,7 @@ from torch import nn
from torch.nn import functional as F
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from bigdl.llm.utils.common import invalidInputError
from bigdl.llm.transformers.models.utils import create_kv_cache, append_kv_cache
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
from bigdl.llm.transformers.models.utils import rotate_half, apply_rotary_pos_emb
from transformers.utils import logging, ContextManagers
logger = logging.get_logger(__name__)
@ -82,10 +82,8 @@ def baichuan_attention_forward_7b(
cache_k = past_key_value[0]
cache_v = past_key_value[1]
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
if device.type == 'xpu':
torch.xpu.empty_cache()
# allocate new
new_cache_k, new_cache_v = create_kv_cache(bsz,
new_cache_k, new_cache_v = extend_kv_cache(bsz,
self.num_heads,
self.head_dim,
cache_k.size(2),
@ -101,7 +99,7 @@ def baichuan_attention_forward_7b(
elif use_cache:
max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH
new_key_states, new_value_states = create_kv_cache(bsz,
new_key_states, new_value_states = init_kv_cache(bsz,
self.num_heads,
self.head_dim,
kv_seq_len,
@ -182,7 +180,7 @@ def baichuan_attention_forward_13b(
if device.type == 'xpu':
torch.xpu.empty_cache()
# allocate new
new_cache_k, new_cache_v = create_kv_cache(bsz,
new_cache_k, new_cache_v = extend_kv_cache(bsz,
self.num_heads,
self.head_dim,
cache_k.size(2),
@ -198,7 +196,7 @@ def baichuan_attention_forward_13b(
elif use_cache:
max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH
new_key_states, new_value_states = create_kv_cache(bsz,
new_key_states, new_value_states = init_kv_cache(bsz,
self.num_heads,
self.head_dim,
kv_seq_len,

View file

@ -37,7 +37,7 @@ from typing import Optional, Tuple
import torch
import torch.utils.checkpoint
from torch.nn import functional as F
from bigdl.llm.transformers.models.utils import create_kv_cache, append_kv_cache
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
@ -107,10 +107,8 @@ def bloom_attention_forward(
cache_k = layer_past[0].transpose(1, 2).view(batch_size, self.num_heads, -1, self.head_dim)
cache_v = layer_past[1].view(batch_size, self.num_heads, -1, self.head_dim)
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
if device.type == 'xpu':
torch.xpu.empty_cache()
# allocate new
new_cache_k, new_cache_v = create_kv_cache(
new_cache_k, new_cache_v = extend_kv_cache(
batch_size,
self.num_heads,
self.head_dim,
@ -128,7 +126,7 @@ def bloom_attention_forward(
elif use_cache:
max_cache_length = kv_length + KV_CACHE_ALLOC_BLOCK_LENGTH
new_key_states, new_value_states = create_kv_cache(
new_key_states, new_value_states = init_kv_cache(
batch_size,
self.num_heads,
self.head_dim,

View file

@ -22,7 +22,7 @@ import torch
import torch.utils.checkpoint
import torch.nn.functional as F
from typing import Optional, Tuple
from bigdl.llm.transformers.models.utils import create_kv_cache, append_kv_cache
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
def rotate_half(x):
@ -67,10 +67,8 @@ def attention_fn(
cache_v = cache_v.permute(1, 2, 0, 3)
past_length = cache_k.size(2)
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
if device.type == 'xpu':
torch.xpu.empty_cache()
max_cache_length = past_length + cur_length + KV_CACHE_ALLOC_BLOCK_LENGTH
new_cache_k, new_cache_v = create_kv_cache(batch_size,
new_cache_k, new_cache_v = extend_kv_cache(batch_size,
self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head,
past_length,
@ -84,7 +82,7 @@ def attention_fn(
elif use_cache:
max_cache_length = max(KV_CACHE_ALLOC_MIN_LENGTH, cur_length) \
+ KV_CACHE_ALLOC_BLOCK_LENGTH
key_cache, value_cache = create_kv_cache(batch_size, self.num_attention_heads_per_partition,
key_cache, value_cache = init_kv_cache(batch_size, self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head, cur_length,
max_cache_length,
dtype=query_layer.dtype, device=device)

View file

@ -20,7 +20,7 @@
import torch
from typing import Optional, Tuple, Union, List, Callable, Dict, Any
import torch.nn.functional as F
from bigdl.llm.transformers.models.utils import create_kv_cache, append_kv_cache
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
@ -151,10 +151,8 @@ def chatglm2_attention_forward_8eb45c(
past_length = cache_k.size(2)
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
if device.type == 'xpu':
torch.xpu.empty_cache()
max_cache_length = past_length + cur_length + KV_CACHE_ALLOC_BLOCK_LENGTH
new_cache_k, new_cache_v = create_kv_cache(batch_size,
new_cache_k, new_cache_v = extend_kv_cache(batch_size,
self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head,
past_length,
@ -172,7 +170,7 @@ def chatglm2_attention_forward_8eb45c(
max_cache_length = max(KV_CACHE_ALLOC_MIN_LENGTH, cur_length) \
+ KV_CACHE_ALLOC_BLOCK_LENGTH
key_cache, value_cache = create_kv_cache(batch_size, self.num_attention_heads_per_partition,
key_cache, value_cache = init_kv_cache(batch_size, self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head, cur_length,
max_cache_length,
dtype=query_layer.dtype, device=device)

View file

@ -38,7 +38,7 @@ from typing import Optional, Tuple
import torch
from torch.nn import functional as F
from bigdl.llm.utils.common import invalidInputError
from bigdl.llm.transformers.models.utils import create_kv_cache, append_kv_cache
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
@ -98,10 +98,8 @@ def rw_attention_forward_7b(
cache_k = layer_past[0].view(batch_size, self.num_kv, -1, self.head_dim)
cache_v = layer_past[1].view(batch_size, self.num_kv, -1, self.head_dim)
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
if device.type == 'xpu':
torch.xpu.empty_cache()
# allocate new
new_cache_k, new_cache_v = create_kv_cache(
new_cache_k, new_cache_v = extend_kv_cache(
batch_size,
self.num_kv,
self.head_dim,
@ -119,7 +117,7 @@ def rw_attention_forward_7b(
elif use_cache:
max_cache_length = kv_length + KV_CACHE_ALLOC_BLOCK_LENGTH
new_key_states, new_value_states = create_kv_cache(
new_key_states, new_value_states = init_kv_cache(
batch_size,
self.num_kv,
self.head_dim,
@ -280,7 +278,7 @@ def rw_attention_forward_40b(
cache_v = layer_past[1].view(batch_size, self.num_heads, -1, self.head_dim)
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
# allocate new
new_cache_k, new_cache_v = create_kv_cache(
new_cache_k, new_cache_v = extend_kv_cache(
batch_size,
self.num_heads,
self.head_dim,
@ -298,7 +296,7 @@ def rw_attention_forward_40b(
elif use_cache:
max_cache_length = kv_length + KV_CACHE_ALLOC_BLOCK_LENGTH
new_key_states, new_value_states = create_kv_cache(
new_key_states, new_value_states = init_kv_cache(
batch_size,
self.num_heads,
self.head_dim,
@ -454,7 +452,7 @@ def falcon_attention_forward(
cache_v = layer_past[1].view(batch_size, self.num_heads, -1, self.head_dim)
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
# allocate new
new_cache_k, new_cache_v = create_kv_cache(
new_cache_k, new_cache_v = extend_kv_cache(
batch_size,
self.num_heads,
self.head_dim,
@ -472,7 +470,7 @@ def falcon_attention_forward(
elif use_cache:
max_cache_length = kv_length + KV_CACHE_ALLOC_BLOCK_LENGTH
new_key_states, new_value_states = create_kv_cache(
new_key_states, new_value_states = init_kv_cache(
batch_size,
self.num_heads,
self.head_dim,

View file

@ -19,8 +19,8 @@
import torch
from typing import Optional, Tuple, Union
from bigdl.llm.transformers.models.utils import create_kv_cache, append_kv_cache, \
apply_rotary_pos_emb
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, \
apply_rotary_pos_emb, append_kv_cache
from transformers.utils.import_utils import is_torch_fx_proxy
@ -144,9 +144,7 @@ def gptj_attention_forward(
past_length = cache_k.size(2)
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
if device.type == 'xpu':
torch.xpu.empty_cache()
new_cache_k, new_cache_v = create_kv_cache(batch_size,
new_cache_k, new_cache_v = extend_kv_cache(batch_size,
self.num_attention_heads,
self.head_dim,
past_length,
@ -160,7 +158,7 @@ def gptj_attention_forward(
key, value = append_kv_cache(cache_k, cache_v, key, value)
elif use_cache:
key_cache, value_cache = create_kv_cache(batch_size,
key_cache, value_cache = init_kv_cache(batch_size,
self.num_attention_heads,
self.head_dim,
kv_seq_len,

View file

@ -34,7 +34,7 @@
import torch
from typing import Optional, Tuple
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb
from bigdl.llm.transformers.models.utils import create_kv_cache, append_kv_cache
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
@ -90,10 +90,8 @@ def gptneox_attention_forward(
past_key = layer_past[0]
past_value = layer_past[1]
if past_key.stride()[1] <= past_key.size(2) * past_key.size(3):
if device.type == 'xpu':
torch.xpu.empty_cache()
# allocate new
new_past_key, new_past_value = create_kv_cache(bsz,
new_past_key, new_past_value = extend_kv_cache(bsz,
self.num_attention_heads,
self.head_size,
past_key.size(2),
@ -108,7 +106,7 @@ def gptneox_attention_forward(
key, value = append_kv_cache(past_key, past_value, key, value)
elif use_cache:
max_cache_length = seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH
new_key, new_value = create_kv_cache(bsz,
new_key, new_value = init_kv_cache(bsz,
self.num_attention_heads,
self.head_size,
seq_len,

View file

@ -37,7 +37,7 @@ from typing import Optional, Tuple
import math
import torch.nn.functional as F
from bigdl.llm.utils.common import invalidInputError
from bigdl.llm.transformers.models.utils import create_kv_cache, append_kv_cache
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
from bigdl.llm.transformers.models.utils import rotate_half, apply_rotary_pos_emb
@ -112,10 +112,8 @@ def llama_attention_forward_4_31(
cache_k = past_key_value[0]
cache_v = past_key_value[1]
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
if device.type == 'xpu':
torch.xpu.empty_cache()
# allocate new
new_cache_k, new_cache_v = create_kv_cache(bsz,
new_cache_k, new_cache_v = extend_kv_cache(bsz,
self.num_key_value_heads, # Support GQA
self.head_dim,
cache_k.size(2),
@ -131,7 +129,7 @@ def llama_attention_forward_4_31(
elif use_cache:
max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH
new_key_states, new_value_states = create_kv_cache(bsz,
new_key_states, new_value_states = init_kv_cache(bsz,
self.num_key_value_heads,
self.head_dim,
kv_seq_len,

View file

@ -18,7 +18,7 @@ import torch
from bigdl.llm.utils.common import invalidInputError
def create_kv_cache(batch_size, num_heads, head_dim, current_length, max_length, dtype, device):
def init_kv_cache(batch_size, num_heads, head_dim, current_length, max_length, dtype, device):
key_cache_storage = torch.empty(batch_size, num_heads,
max_length, head_dim,
dtype=dtype, device=device)
@ -37,6 +37,13 @@ def create_kv_cache(batch_size, num_heads, head_dim, current_length, max_length,
return key_cache, value_cache
def extend_kv_cache(batch_size, num_heads, head_dim, current_length, max_length, dtype, device):
# empty cache to reduce gpu memory
if device.type == 'xpu':
torch.xpu.empty_cache()
return init_kv_cache(batch_size, num_heads, head_dim, current_length, max_length, dtype, device)
def append_kv_cache(cache_k, cache_v, key_states, value_states):
new_size = (cache_k.size(0),
cache_k.size(1),