LLM: refactor kv cache (#9030)
* refactor utils * meet code review; update all models * small fix
This commit is contained in:
parent
868511cf02
commit
b943d73844
12 changed files with 93 additions and 108 deletions
|
|
@ -44,7 +44,6 @@ if __name__ == '__main__':
|
|||
# which convert the relevant layers in the model into INT4 format
|
||||
model = AutoModelForCausalLM.from_pretrained(model_path,
|
||||
load_in_4bit=True,
|
||||
optimize_model=False,
|
||||
trust_remote_code=True,
|
||||
use_cache=True)
|
||||
model = model.to('xpu')
|
||||
|
|
|
|||
|
|
@ -42,7 +42,6 @@ if __name__ == '__main__':
|
|||
# which convert the relevant layers in the model into INT4 format
|
||||
model = AutoModelForCausalLM.from_pretrained(model_path,
|
||||
load_in_4bit=True,
|
||||
optimize_model=False,
|
||||
trust_remote_code=True,
|
||||
use_cache=True)
|
||||
model = model.to('xpu')
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ import torch.utils.checkpoint
|
|||
from torch import nn
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
from bigdl.llm.utils.common import invalidInputError
|
||||
from bigdl.llm.transformers.models.utils import create_kv_cache, append_kv_cache
|
||||
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
|
||||
from bigdl.llm.transformers.models.utils import rotate_half, apply_rotary_pos_emb
|
||||
|
||||
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
|
||||
|
|
@ -70,10 +70,8 @@ def baichuan_attention_forward_7b(
|
|||
cache_k = past_key_value[0]
|
||||
cache_v = past_key_value[1]
|
||||
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
|
||||
if device.type == 'xpu':
|
||||
torch.xpu.empty_cache()
|
||||
# allocate new
|
||||
new_cache_k, new_cache_v = create_kv_cache(bsz,
|
||||
new_cache_k, new_cache_v = extend_kv_cache(bsz,
|
||||
self.num_heads,
|
||||
self.head_dim,
|
||||
cache_k.size(2),
|
||||
|
|
@ -89,7 +87,7 @@ def baichuan_attention_forward_7b(
|
|||
|
||||
elif use_cache:
|
||||
max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH
|
||||
new_key_states, new_value_states = create_kv_cache(bsz,
|
||||
new_key_states, new_value_states = init_kv_cache(bsz,
|
||||
self.num_heads,
|
||||
self.head_dim,
|
||||
kv_seq_len,
|
||||
|
|
@ -170,10 +168,8 @@ def baichuan_attention_forward_13b(
|
|||
cache_k = past_key_value[0]
|
||||
cache_v = past_key_value[1]
|
||||
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
|
||||
if device.type == 'xpu':
|
||||
torch.xpu.empty_cache()
|
||||
# allocate new
|
||||
new_cache_k, new_cache_v = create_kv_cache(bsz,
|
||||
new_cache_k, new_cache_v = extend_kv_cache(bsz,
|
||||
self.num_heads,
|
||||
self.head_dim,
|
||||
cache_k.size(2),
|
||||
|
|
@ -189,7 +185,7 @@ def baichuan_attention_forward_13b(
|
|||
|
||||
elif use_cache:
|
||||
max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH
|
||||
new_key_states, new_value_states = create_kv_cache(bsz,
|
||||
new_key_states, new_value_states = init_kv_cache(bsz,
|
||||
self.num_heads,
|
||||
self.head_dim,
|
||||
kv_seq_len,
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ from torch import nn
|
|||
from torch.nn import functional as F
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
from bigdl.llm.utils.common import invalidInputError
|
||||
from bigdl.llm.transformers.models.utils import create_kv_cache, append_kv_cache
|
||||
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
|
||||
from bigdl.llm.transformers.models.utils import rotate_half, apply_rotary_pos_emb
|
||||
from transformers.utils import logging, ContextManagers
|
||||
logger = logging.get_logger(__name__)
|
||||
|
|
@ -82,10 +82,8 @@ def baichuan_attention_forward_7b(
|
|||
cache_k = past_key_value[0]
|
||||
cache_v = past_key_value[1]
|
||||
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
|
||||
if device.type == 'xpu':
|
||||
torch.xpu.empty_cache()
|
||||
# allocate new
|
||||
new_cache_k, new_cache_v = create_kv_cache(bsz,
|
||||
new_cache_k, new_cache_v = extend_kv_cache(bsz,
|
||||
self.num_heads,
|
||||
self.head_dim,
|
||||
cache_k.size(2),
|
||||
|
|
@ -101,7 +99,7 @@ def baichuan_attention_forward_7b(
|
|||
|
||||
elif use_cache:
|
||||
max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH
|
||||
new_key_states, new_value_states = create_kv_cache(bsz,
|
||||
new_key_states, new_value_states = init_kv_cache(bsz,
|
||||
self.num_heads,
|
||||
self.head_dim,
|
||||
kv_seq_len,
|
||||
|
|
@ -182,7 +180,7 @@ def baichuan_attention_forward_13b(
|
|||
if device.type == 'xpu':
|
||||
torch.xpu.empty_cache()
|
||||
# allocate new
|
||||
new_cache_k, new_cache_v = create_kv_cache(bsz,
|
||||
new_cache_k, new_cache_v = extend_kv_cache(bsz,
|
||||
self.num_heads,
|
||||
self.head_dim,
|
||||
cache_k.size(2),
|
||||
|
|
@ -198,7 +196,7 @@ def baichuan_attention_forward_13b(
|
|||
|
||||
elif use_cache:
|
||||
max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH
|
||||
new_key_states, new_value_states = create_kv_cache(bsz,
|
||||
new_key_states, new_value_states = init_kv_cache(bsz,
|
||||
self.num_heads,
|
||||
self.head_dim,
|
||||
kv_seq_len,
|
||||
|
|
|
|||
|
|
@ -37,7 +37,7 @@ from typing import Optional, Tuple
|
|||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch.nn import functional as F
|
||||
from bigdl.llm.transformers.models.utils import create_kv_cache, append_kv_cache
|
||||
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
|
||||
|
||||
|
||||
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
|
||||
|
|
@ -107,10 +107,8 @@ def bloom_attention_forward(
|
|||
cache_k = layer_past[0].transpose(1, 2).view(batch_size, self.num_heads, -1, self.head_dim)
|
||||
cache_v = layer_past[1].view(batch_size, self.num_heads, -1, self.head_dim)
|
||||
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
|
||||
if device.type == 'xpu':
|
||||
torch.xpu.empty_cache()
|
||||
# allocate new
|
||||
new_cache_k, new_cache_v = create_kv_cache(
|
||||
new_cache_k, new_cache_v = extend_kv_cache(
|
||||
batch_size,
|
||||
self.num_heads,
|
||||
self.head_dim,
|
||||
|
|
@ -128,7 +126,7 @@ def bloom_attention_forward(
|
|||
|
||||
elif use_cache:
|
||||
max_cache_length = kv_length + KV_CACHE_ALLOC_BLOCK_LENGTH
|
||||
new_key_states, new_value_states = create_kv_cache(
|
||||
new_key_states, new_value_states = init_kv_cache(
|
||||
batch_size,
|
||||
self.num_heads,
|
||||
self.head_dim,
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ import torch
|
|||
import torch.utils.checkpoint
|
||||
import torch.nn.functional as F
|
||||
from typing import Optional, Tuple
|
||||
from bigdl.llm.transformers.models.utils import create_kv_cache, append_kv_cache
|
||||
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
|
||||
|
||||
|
||||
def rotate_half(x):
|
||||
|
|
@ -67,10 +67,8 @@ def attention_fn(
|
|||
cache_v = cache_v.permute(1, 2, 0, 3)
|
||||
past_length = cache_k.size(2)
|
||||
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
|
||||
if device.type == 'xpu':
|
||||
torch.xpu.empty_cache()
|
||||
max_cache_length = past_length + cur_length + KV_CACHE_ALLOC_BLOCK_LENGTH
|
||||
new_cache_k, new_cache_v = create_kv_cache(batch_size,
|
||||
new_cache_k, new_cache_v = extend_kv_cache(batch_size,
|
||||
self.num_attention_heads_per_partition,
|
||||
self.hidden_size_per_attention_head,
|
||||
past_length,
|
||||
|
|
@ -84,7 +82,7 @@ def attention_fn(
|
|||
elif use_cache:
|
||||
max_cache_length = max(KV_CACHE_ALLOC_MIN_LENGTH, cur_length) \
|
||||
+ KV_CACHE_ALLOC_BLOCK_LENGTH
|
||||
key_cache, value_cache = create_kv_cache(batch_size, self.num_attention_heads_per_partition,
|
||||
key_cache, value_cache = init_kv_cache(batch_size, self.num_attention_heads_per_partition,
|
||||
self.hidden_size_per_attention_head, cur_length,
|
||||
max_cache_length,
|
||||
dtype=query_layer.dtype, device=device)
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@
|
|||
import torch
|
||||
from typing import Optional, Tuple, Union, List, Callable, Dict, Any
|
||||
import torch.nn.functional as F
|
||||
from bigdl.llm.transformers.models.utils import create_kv_cache, append_kv_cache
|
||||
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
|
||||
|
||||
|
||||
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
|
||||
|
|
@ -151,10 +151,8 @@ def chatglm2_attention_forward_8eb45c(
|
|||
past_length = cache_k.size(2)
|
||||
|
||||
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
|
||||
if device.type == 'xpu':
|
||||
torch.xpu.empty_cache()
|
||||
max_cache_length = past_length + cur_length + KV_CACHE_ALLOC_BLOCK_LENGTH
|
||||
new_cache_k, new_cache_v = create_kv_cache(batch_size,
|
||||
new_cache_k, new_cache_v = extend_kv_cache(batch_size,
|
||||
self.num_attention_heads_per_partition,
|
||||
self.hidden_size_per_attention_head,
|
||||
past_length,
|
||||
|
|
@ -172,7 +170,7 @@ def chatglm2_attention_forward_8eb45c(
|
|||
|
||||
max_cache_length = max(KV_CACHE_ALLOC_MIN_LENGTH, cur_length) \
|
||||
+ KV_CACHE_ALLOC_BLOCK_LENGTH
|
||||
key_cache, value_cache = create_kv_cache(batch_size, self.num_attention_heads_per_partition,
|
||||
key_cache, value_cache = init_kv_cache(batch_size, self.num_attention_heads_per_partition,
|
||||
self.hidden_size_per_attention_head, cur_length,
|
||||
max_cache_length,
|
||||
dtype=query_layer.dtype, device=device)
|
||||
|
|
|
|||
|
|
@ -38,7 +38,7 @@ from typing import Optional, Tuple
|
|||
import torch
|
||||
from torch.nn import functional as F
|
||||
from bigdl.llm.utils.common import invalidInputError
|
||||
from bigdl.llm.transformers.models.utils import create_kv_cache, append_kv_cache
|
||||
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
|
||||
|
||||
|
||||
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
|
||||
|
|
@ -98,10 +98,8 @@ def rw_attention_forward_7b(
|
|||
cache_k = layer_past[0].view(batch_size, self.num_kv, -1, self.head_dim)
|
||||
cache_v = layer_past[1].view(batch_size, self.num_kv, -1, self.head_dim)
|
||||
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
|
||||
if device.type == 'xpu':
|
||||
torch.xpu.empty_cache()
|
||||
# allocate new
|
||||
new_cache_k, new_cache_v = create_kv_cache(
|
||||
new_cache_k, new_cache_v = extend_kv_cache(
|
||||
batch_size,
|
||||
self.num_kv,
|
||||
self.head_dim,
|
||||
|
|
@ -119,7 +117,7 @@ def rw_attention_forward_7b(
|
|||
|
||||
elif use_cache:
|
||||
max_cache_length = kv_length + KV_CACHE_ALLOC_BLOCK_LENGTH
|
||||
new_key_states, new_value_states = create_kv_cache(
|
||||
new_key_states, new_value_states = init_kv_cache(
|
||||
batch_size,
|
||||
self.num_kv,
|
||||
self.head_dim,
|
||||
|
|
@ -280,7 +278,7 @@ def rw_attention_forward_40b(
|
|||
cache_v = layer_past[1].view(batch_size, self.num_heads, -1, self.head_dim)
|
||||
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
|
||||
# allocate new
|
||||
new_cache_k, new_cache_v = create_kv_cache(
|
||||
new_cache_k, new_cache_v = extend_kv_cache(
|
||||
batch_size,
|
||||
self.num_heads,
|
||||
self.head_dim,
|
||||
|
|
@ -298,7 +296,7 @@ def rw_attention_forward_40b(
|
|||
|
||||
elif use_cache:
|
||||
max_cache_length = kv_length + KV_CACHE_ALLOC_BLOCK_LENGTH
|
||||
new_key_states, new_value_states = create_kv_cache(
|
||||
new_key_states, new_value_states = init_kv_cache(
|
||||
batch_size,
|
||||
self.num_heads,
|
||||
self.head_dim,
|
||||
|
|
@ -454,7 +452,7 @@ def falcon_attention_forward(
|
|||
cache_v = layer_past[1].view(batch_size, self.num_heads, -1, self.head_dim)
|
||||
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
|
||||
# allocate new
|
||||
new_cache_k, new_cache_v = create_kv_cache(
|
||||
new_cache_k, new_cache_v = extend_kv_cache(
|
||||
batch_size,
|
||||
self.num_heads,
|
||||
self.head_dim,
|
||||
|
|
@ -472,7 +470,7 @@ def falcon_attention_forward(
|
|||
|
||||
elif use_cache:
|
||||
max_cache_length = kv_length + KV_CACHE_ALLOC_BLOCK_LENGTH
|
||||
new_key_states, new_value_states = create_kv_cache(
|
||||
new_key_states, new_value_states = init_kv_cache(
|
||||
batch_size,
|
||||
self.num_heads,
|
||||
self.head_dim,
|
||||
|
|
|
|||
|
|
@ -19,8 +19,8 @@
|
|||
|
||||
import torch
|
||||
from typing import Optional, Tuple, Union
|
||||
from bigdl.llm.transformers.models.utils import create_kv_cache, append_kv_cache, \
|
||||
apply_rotary_pos_emb
|
||||
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, \
|
||||
apply_rotary_pos_emb, append_kv_cache
|
||||
from transformers.utils.import_utils import is_torch_fx_proxy
|
||||
|
||||
|
||||
|
|
@ -144,9 +144,7 @@ def gptj_attention_forward(
|
|||
past_length = cache_k.size(2)
|
||||
|
||||
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
|
||||
if device.type == 'xpu':
|
||||
torch.xpu.empty_cache()
|
||||
new_cache_k, new_cache_v = create_kv_cache(batch_size,
|
||||
new_cache_k, new_cache_v = extend_kv_cache(batch_size,
|
||||
self.num_attention_heads,
|
||||
self.head_dim,
|
||||
past_length,
|
||||
|
|
@ -160,7 +158,7 @@ def gptj_attention_forward(
|
|||
key, value = append_kv_cache(cache_k, cache_v, key, value)
|
||||
|
||||
elif use_cache:
|
||||
key_cache, value_cache = create_kv_cache(batch_size,
|
||||
key_cache, value_cache = init_kv_cache(batch_size,
|
||||
self.num_attention_heads,
|
||||
self.head_dim,
|
||||
kv_seq_len,
|
||||
|
|
|
|||
|
|
@ -34,7 +34,7 @@
|
|||
import torch
|
||||
from typing import Optional, Tuple
|
||||
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb
|
||||
from bigdl.llm.transformers.models.utils import create_kv_cache, append_kv_cache
|
||||
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache
|
||||
|
||||
|
||||
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
|
||||
|
|
@ -90,10 +90,8 @@ def gptneox_attention_forward(
|
|||
past_key = layer_past[0]
|
||||
past_value = layer_past[1]
|
||||
if past_key.stride()[1] <= past_key.size(2) * past_key.size(3):
|
||||
if device.type == 'xpu':
|
||||
torch.xpu.empty_cache()
|
||||
# allocate new
|
||||
new_past_key, new_past_value = create_kv_cache(bsz,
|
||||
new_past_key, new_past_value = extend_kv_cache(bsz,
|
||||
self.num_attention_heads,
|
||||
self.head_size,
|
||||
past_key.size(2),
|
||||
|
|
@ -108,7 +106,7 @@ def gptneox_attention_forward(
|
|||
key, value = append_kv_cache(past_key, past_value, key, value)
|
||||
elif use_cache:
|
||||
max_cache_length = seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH
|
||||
new_key, new_value = create_kv_cache(bsz,
|
||||
new_key, new_value = init_kv_cache(bsz,
|
||||
self.num_attention_heads,
|
||||
self.head_size,
|
||||
seq_len,
|
||||
|
|
|
|||
|
|
@ -37,7 +37,7 @@ from typing import Optional, Tuple
|
|||
import math
|
||||
import torch.nn.functional as F
|
||||
from bigdl.llm.utils.common import invalidInputError
|
||||
from bigdl.llm.transformers.models.utils import create_kv_cache, append_kv_cache
|
||||
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
|
||||
from bigdl.llm.transformers.models.utils import rotate_half, apply_rotary_pos_emb
|
||||
|
||||
|
||||
|
|
@ -112,10 +112,8 @@ def llama_attention_forward_4_31(
|
|||
cache_k = past_key_value[0]
|
||||
cache_v = past_key_value[1]
|
||||
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
|
||||
if device.type == 'xpu':
|
||||
torch.xpu.empty_cache()
|
||||
# allocate new
|
||||
new_cache_k, new_cache_v = create_kv_cache(bsz,
|
||||
new_cache_k, new_cache_v = extend_kv_cache(bsz,
|
||||
self.num_key_value_heads, # Support GQA
|
||||
self.head_dim,
|
||||
cache_k.size(2),
|
||||
|
|
@ -131,7 +129,7 @@ def llama_attention_forward_4_31(
|
|||
|
||||
elif use_cache:
|
||||
max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH
|
||||
new_key_states, new_value_states = create_kv_cache(bsz,
|
||||
new_key_states, new_value_states = init_kv_cache(bsz,
|
||||
self.num_key_value_heads,
|
||||
self.head_dim,
|
||||
kv_seq_len,
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ import torch
|
|||
from bigdl.llm.utils.common import invalidInputError
|
||||
|
||||
|
||||
def create_kv_cache(batch_size, num_heads, head_dim, current_length, max_length, dtype, device):
|
||||
def init_kv_cache(batch_size, num_heads, head_dim, current_length, max_length, dtype, device):
|
||||
key_cache_storage = torch.empty(batch_size, num_heads,
|
||||
max_length, head_dim,
|
||||
dtype=dtype, device=device)
|
||||
|
|
@ -37,6 +37,13 @@ def create_kv_cache(batch_size, num_heads, head_dim, current_length, max_length,
|
|||
return key_cache, value_cache
|
||||
|
||||
|
||||
def extend_kv_cache(batch_size, num_heads, head_dim, current_length, max_length, dtype, device):
|
||||
# empty cache to reduce gpu memory
|
||||
if device.type == 'xpu':
|
||||
torch.xpu.empty_cache()
|
||||
return init_kv_cache(batch_size, num_heads, head_dim, current_length, max_length, dtype, device)
|
||||
|
||||
|
||||
def append_kv_cache(cache_k, cache_v, key_states, value_states):
|
||||
new_size = (cache_k.size(0),
|
||||
cache_k.size(1),
|
||||
|
|
|
|||
Loading…
Reference in a new issue