LLM: refactor kv cache (#9030)
* refactor utils * meet code review; update all models * small fix
This commit is contained in:
parent
868511cf02
commit
b943d73844
12 changed files with 93 additions and 108 deletions
|
|
@ -44,7 +44,6 @@ if __name__ == '__main__':
|
||||||
# which convert the relevant layers in the model into INT4 format
|
# which convert the relevant layers in the model into INT4 format
|
||||||
model = AutoModelForCausalLM.from_pretrained(model_path,
|
model = AutoModelForCausalLM.from_pretrained(model_path,
|
||||||
load_in_4bit=True,
|
load_in_4bit=True,
|
||||||
optimize_model=False,
|
|
||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
use_cache=True)
|
use_cache=True)
|
||||||
model = model.to('xpu')
|
model = model.to('xpu')
|
||||||
|
|
|
||||||
|
|
@ -42,7 +42,6 @@ if __name__ == '__main__':
|
||||||
# which convert the relevant layers in the model into INT4 format
|
# which convert the relevant layers in the model into INT4 format
|
||||||
model = AutoModelForCausalLM.from_pretrained(model_path,
|
model = AutoModelForCausalLM.from_pretrained(model_path,
|
||||||
load_in_4bit=True,
|
load_in_4bit=True,
|
||||||
optimize_model=False,
|
|
||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
use_cache=True)
|
use_cache=True)
|
||||||
model = model.to('xpu')
|
model = model.to('xpu')
|
||||||
|
|
|
||||||
|
|
@ -26,7 +26,7 @@ import torch.utils.checkpoint
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||||
from bigdl.llm.utils.common import invalidInputError
|
from bigdl.llm.utils.common import invalidInputError
|
||||||
from bigdl.llm.transformers.models.utils import create_kv_cache, append_kv_cache
|
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
|
||||||
from bigdl.llm.transformers.models.utils import rotate_half, apply_rotary_pos_emb
|
from bigdl.llm.transformers.models.utils import rotate_half, apply_rotary_pos_emb
|
||||||
|
|
||||||
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
|
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
|
||||||
|
|
@ -70,10 +70,8 @@ def baichuan_attention_forward_7b(
|
||||||
cache_k = past_key_value[0]
|
cache_k = past_key_value[0]
|
||||||
cache_v = past_key_value[1]
|
cache_v = past_key_value[1]
|
||||||
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
|
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
|
||||||
if device.type == 'xpu':
|
|
||||||
torch.xpu.empty_cache()
|
|
||||||
# allocate new
|
# allocate new
|
||||||
new_cache_k, new_cache_v = create_kv_cache(bsz,
|
new_cache_k, new_cache_v = extend_kv_cache(bsz,
|
||||||
self.num_heads,
|
self.num_heads,
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
cache_k.size(2),
|
cache_k.size(2),
|
||||||
|
|
@ -89,13 +87,13 @@ def baichuan_attention_forward_7b(
|
||||||
|
|
||||||
elif use_cache:
|
elif use_cache:
|
||||||
max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH
|
max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH
|
||||||
new_key_states, new_value_states = create_kv_cache(bsz,
|
new_key_states, new_value_states = init_kv_cache(bsz,
|
||||||
self.num_heads,
|
self.num_heads,
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
kv_seq_len,
|
kv_seq_len,
|
||||||
max_cache_length,
|
max_cache_length,
|
||||||
dtype=key_states.dtype,
|
dtype=key_states.dtype,
|
||||||
device=device)
|
device=device)
|
||||||
new_key_states[:] = key_states
|
new_key_states[:] = key_states
|
||||||
new_value_states[:] = value_states
|
new_value_states[:] = value_states
|
||||||
key_states = new_key_states
|
key_states = new_key_states
|
||||||
|
|
@ -170,10 +168,8 @@ def baichuan_attention_forward_13b(
|
||||||
cache_k = past_key_value[0]
|
cache_k = past_key_value[0]
|
||||||
cache_v = past_key_value[1]
|
cache_v = past_key_value[1]
|
||||||
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
|
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
|
||||||
if device.type == 'xpu':
|
|
||||||
torch.xpu.empty_cache()
|
|
||||||
# allocate new
|
# allocate new
|
||||||
new_cache_k, new_cache_v = create_kv_cache(bsz,
|
new_cache_k, new_cache_v = extend_kv_cache(bsz,
|
||||||
self.num_heads,
|
self.num_heads,
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
cache_k.size(2),
|
cache_k.size(2),
|
||||||
|
|
@ -189,13 +185,13 @@ def baichuan_attention_forward_13b(
|
||||||
|
|
||||||
elif use_cache:
|
elif use_cache:
|
||||||
max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH
|
max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH
|
||||||
new_key_states, new_value_states = create_kv_cache(bsz,
|
new_key_states, new_value_states = init_kv_cache(bsz,
|
||||||
self.num_heads,
|
self.num_heads,
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
kv_seq_len,
|
kv_seq_len,
|
||||||
max_cache_length,
|
max_cache_length,
|
||||||
dtype=key_states.dtype,
|
dtype=key_states.dtype,
|
||||||
device=device)
|
device=device)
|
||||||
new_key_states[:] = key_states
|
new_key_states[:] = key_states
|
||||||
new_value_states[:] = value_states
|
new_value_states[:] = value_states
|
||||||
key_states = new_key_states
|
key_states = new_key_states
|
||||||
|
|
|
||||||
|
|
@ -26,7 +26,7 @@ from torch import nn
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||||
from bigdl.llm.utils.common import invalidInputError
|
from bigdl.llm.utils.common import invalidInputError
|
||||||
from bigdl.llm.transformers.models.utils import create_kv_cache, append_kv_cache
|
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
|
||||||
from bigdl.llm.transformers.models.utils import rotate_half, apply_rotary_pos_emb
|
from bigdl.llm.transformers.models.utils import rotate_half, apply_rotary_pos_emb
|
||||||
from transformers.utils import logging, ContextManagers
|
from transformers.utils import logging, ContextManagers
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
@ -82,10 +82,8 @@ def baichuan_attention_forward_7b(
|
||||||
cache_k = past_key_value[0]
|
cache_k = past_key_value[0]
|
||||||
cache_v = past_key_value[1]
|
cache_v = past_key_value[1]
|
||||||
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
|
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
|
||||||
if device.type == 'xpu':
|
|
||||||
torch.xpu.empty_cache()
|
|
||||||
# allocate new
|
# allocate new
|
||||||
new_cache_k, new_cache_v = create_kv_cache(bsz,
|
new_cache_k, new_cache_v = extend_kv_cache(bsz,
|
||||||
self.num_heads,
|
self.num_heads,
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
cache_k.size(2),
|
cache_k.size(2),
|
||||||
|
|
@ -101,13 +99,13 @@ def baichuan_attention_forward_7b(
|
||||||
|
|
||||||
elif use_cache:
|
elif use_cache:
|
||||||
max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH
|
max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH
|
||||||
new_key_states, new_value_states = create_kv_cache(bsz,
|
new_key_states, new_value_states = init_kv_cache(bsz,
|
||||||
self.num_heads,
|
self.num_heads,
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
kv_seq_len,
|
kv_seq_len,
|
||||||
max_cache_length,
|
max_cache_length,
|
||||||
dtype=key_states.dtype,
|
dtype=key_states.dtype,
|
||||||
device=device)
|
device=device)
|
||||||
new_key_states[:] = key_states
|
new_key_states[:] = key_states
|
||||||
new_value_states[:] = value_states
|
new_value_states[:] = value_states
|
||||||
key_states = new_key_states
|
key_states = new_key_states
|
||||||
|
|
@ -182,7 +180,7 @@ def baichuan_attention_forward_13b(
|
||||||
if device.type == 'xpu':
|
if device.type == 'xpu':
|
||||||
torch.xpu.empty_cache()
|
torch.xpu.empty_cache()
|
||||||
# allocate new
|
# allocate new
|
||||||
new_cache_k, new_cache_v = create_kv_cache(bsz,
|
new_cache_k, new_cache_v = extend_kv_cache(bsz,
|
||||||
self.num_heads,
|
self.num_heads,
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
cache_k.size(2),
|
cache_k.size(2),
|
||||||
|
|
@ -198,13 +196,13 @@ def baichuan_attention_forward_13b(
|
||||||
|
|
||||||
elif use_cache:
|
elif use_cache:
|
||||||
max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH
|
max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH
|
||||||
new_key_states, new_value_states = create_kv_cache(bsz,
|
new_key_states, new_value_states = init_kv_cache(bsz,
|
||||||
self.num_heads,
|
self.num_heads,
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
kv_seq_len,
|
kv_seq_len,
|
||||||
max_cache_length,
|
max_cache_length,
|
||||||
dtype=key_states.dtype,
|
dtype=key_states.dtype,
|
||||||
device=device)
|
device=device)
|
||||||
new_key_states[:] = key_states
|
new_key_states[:] = key_states
|
||||||
new_value_states[:] = value_states
|
new_value_states[:] = value_states
|
||||||
key_states = new_key_states
|
key_states = new_key_states
|
||||||
|
|
|
||||||
|
|
@ -37,7 +37,7 @@ from typing import Optional, Tuple
|
||||||
import torch
|
import torch
|
||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
from bigdl.llm.transformers.models.utils import create_kv_cache, append_kv_cache
|
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
|
||||||
|
|
||||||
|
|
||||||
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
|
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
|
||||||
|
|
@ -107,10 +107,8 @@ def bloom_attention_forward(
|
||||||
cache_k = layer_past[0].transpose(1, 2).view(batch_size, self.num_heads, -1, self.head_dim)
|
cache_k = layer_past[0].transpose(1, 2).view(batch_size, self.num_heads, -1, self.head_dim)
|
||||||
cache_v = layer_past[1].view(batch_size, self.num_heads, -1, self.head_dim)
|
cache_v = layer_past[1].view(batch_size, self.num_heads, -1, self.head_dim)
|
||||||
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
|
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
|
||||||
if device.type == 'xpu':
|
|
||||||
torch.xpu.empty_cache()
|
|
||||||
# allocate new
|
# allocate new
|
||||||
new_cache_k, new_cache_v = create_kv_cache(
|
new_cache_k, new_cache_v = extend_kv_cache(
|
||||||
batch_size,
|
batch_size,
|
||||||
self.num_heads,
|
self.num_heads,
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
|
|
@ -128,7 +126,7 @@ def bloom_attention_forward(
|
||||||
|
|
||||||
elif use_cache:
|
elif use_cache:
|
||||||
max_cache_length = kv_length + KV_CACHE_ALLOC_BLOCK_LENGTH
|
max_cache_length = kv_length + KV_CACHE_ALLOC_BLOCK_LENGTH
|
||||||
new_key_states, new_value_states = create_kv_cache(
|
new_key_states, new_value_states = init_kv_cache(
|
||||||
batch_size,
|
batch_size,
|
||||||
self.num_heads,
|
self.num_heads,
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
|
|
|
||||||
|
|
@ -22,7 +22,7 @@ import torch
|
||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
from bigdl.llm.transformers.models.utils import create_kv_cache, append_kv_cache
|
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
|
||||||
|
|
||||||
|
|
||||||
def rotate_half(x):
|
def rotate_half(x):
|
||||||
|
|
@ -67,10 +67,8 @@ def attention_fn(
|
||||||
cache_v = cache_v.permute(1, 2, 0, 3)
|
cache_v = cache_v.permute(1, 2, 0, 3)
|
||||||
past_length = cache_k.size(2)
|
past_length = cache_k.size(2)
|
||||||
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
|
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
|
||||||
if device.type == 'xpu':
|
|
||||||
torch.xpu.empty_cache()
|
|
||||||
max_cache_length = past_length + cur_length + KV_CACHE_ALLOC_BLOCK_LENGTH
|
max_cache_length = past_length + cur_length + KV_CACHE_ALLOC_BLOCK_LENGTH
|
||||||
new_cache_k, new_cache_v = create_kv_cache(batch_size,
|
new_cache_k, new_cache_v = extend_kv_cache(batch_size,
|
||||||
self.num_attention_heads_per_partition,
|
self.num_attention_heads_per_partition,
|
||||||
self.hidden_size_per_attention_head,
|
self.hidden_size_per_attention_head,
|
||||||
past_length,
|
past_length,
|
||||||
|
|
@ -84,10 +82,10 @@ def attention_fn(
|
||||||
elif use_cache:
|
elif use_cache:
|
||||||
max_cache_length = max(KV_CACHE_ALLOC_MIN_LENGTH, cur_length) \
|
max_cache_length = max(KV_CACHE_ALLOC_MIN_LENGTH, cur_length) \
|
||||||
+ KV_CACHE_ALLOC_BLOCK_LENGTH
|
+ KV_CACHE_ALLOC_BLOCK_LENGTH
|
||||||
key_cache, value_cache = create_kv_cache(batch_size, self.num_attention_heads_per_partition,
|
key_cache, value_cache = init_kv_cache(batch_size, self.num_attention_heads_per_partition,
|
||||||
self.hidden_size_per_attention_head, cur_length,
|
self.hidden_size_per_attention_head, cur_length,
|
||||||
max_cache_length,
|
max_cache_length,
|
||||||
dtype=query_layer.dtype, device=device)
|
dtype=query_layer.dtype, device=device)
|
||||||
key_cache[:] = key_layer
|
key_cache[:] = key_layer
|
||||||
value_cache[:] = value_layer
|
value_cache[:] = value_layer
|
||||||
key_layer = key_cache
|
key_layer = key_cache
|
||||||
|
|
|
||||||
|
|
@ -20,7 +20,7 @@
|
||||||
import torch
|
import torch
|
||||||
from typing import Optional, Tuple, Union, List, Callable, Dict, Any
|
from typing import Optional, Tuple, Union, List, Callable, Dict, Any
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from bigdl.llm.transformers.models.utils import create_kv_cache, append_kv_cache
|
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
|
||||||
|
|
||||||
|
|
||||||
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
|
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
|
||||||
|
|
@ -151,10 +151,8 @@ def chatglm2_attention_forward_8eb45c(
|
||||||
past_length = cache_k.size(2)
|
past_length = cache_k.size(2)
|
||||||
|
|
||||||
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
|
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
|
||||||
if device.type == 'xpu':
|
|
||||||
torch.xpu.empty_cache()
|
|
||||||
max_cache_length = past_length + cur_length + KV_CACHE_ALLOC_BLOCK_LENGTH
|
max_cache_length = past_length + cur_length + KV_CACHE_ALLOC_BLOCK_LENGTH
|
||||||
new_cache_k, new_cache_v = create_kv_cache(batch_size,
|
new_cache_k, new_cache_v = extend_kv_cache(batch_size,
|
||||||
self.num_attention_heads_per_partition,
|
self.num_attention_heads_per_partition,
|
||||||
self.hidden_size_per_attention_head,
|
self.hidden_size_per_attention_head,
|
||||||
past_length,
|
past_length,
|
||||||
|
|
@ -172,10 +170,10 @@ def chatglm2_attention_forward_8eb45c(
|
||||||
|
|
||||||
max_cache_length = max(KV_CACHE_ALLOC_MIN_LENGTH, cur_length) \
|
max_cache_length = max(KV_CACHE_ALLOC_MIN_LENGTH, cur_length) \
|
||||||
+ KV_CACHE_ALLOC_BLOCK_LENGTH
|
+ KV_CACHE_ALLOC_BLOCK_LENGTH
|
||||||
key_cache, value_cache = create_kv_cache(batch_size, self.num_attention_heads_per_partition,
|
key_cache, value_cache = init_kv_cache(batch_size, self.num_attention_heads_per_partition,
|
||||||
self.hidden_size_per_attention_head, cur_length,
|
self.hidden_size_per_attention_head, cur_length,
|
||||||
max_cache_length,
|
max_cache_length,
|
||||||
dtype=query_layer.dtype, device=device)
|
dtype=query_layer.dtype, device=device)
|
||||||
key_cache[:] = key_layer
|
key_cache[:] = key_layer
|
||||||
value_cache[:] = value_layer
|
value_cache[:] = value_layer
|
||||||
key_layer = key_cache
|
key_layer = key_cache
|
||||||
|
|
|
||||||
|
|
@ -38,7 +38,7 @@ from typing import Optional, Tuple
|
||||||
import torch
|
import torch
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
from bigdl.llm.utils.common import invalidInputError
|
from bigdl.llm.utils.common import invalidInputError
|
||||||
from bigdl.llm.transformers.models.utils import create_kv_cache, append_kv_cache
|
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
|
||||||
|
|
||||||
|
|
||||||
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
|
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
|
||||||
|
|
@ -98,10 +98,8 @@ def rw_attention_forward_7b(
|
||||||
cache_k = layer_past[0].view(batch_size, self.num_kv, -1, self.head_dim)
|
cache_k = layer_past[0].view(batch_size, self.num_kv, -1, self.head_dim)
|
||||||
cache_v = layer_past[1].view(batch_size, self.num_kv, -1, self.head_dim)
|
cache_v = layer_past[1].view(batch_size, self.num_kv, -1, self.head_dim)
|
||||||
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
|
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
|
||||||
if device.type == 'xpu':
|
|
||||||
torch.xpu.empty_cache()
|
|
||||||
# allocate new
|
# allocate new
|
||||||
new_cache_k, new_cache_v = create_kv_cache(
|
new_cache_k, new_cache_v = extend_kv_cache(
|
||||||
batch_size,
|
batch_size,
|
||||||
self.num_kv,
|
self.num_kv,
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
|
|
@ -119,7 +117,7 @@ def rw_attention_forward_7b(
|
||||||
|
|
||||||
elif use_cache:
|
elif use_cache:
|
||||||
max_cache_length = kv_length + KV_CACHE_ALLOC_BLOCK_LENGTH
|
max_cache_length = kv_length + KV_CACHE_ALLOC_BLOCK_LENGTH
|
||||||
new_key_states, new_value_states = create_kv_cache(
|
new_key_states, new_value_states = init_kv_cache(
|
||||||
batch_size,
|
batch_size,
|
||||||
self.num_kv,
|
self.num_kv,
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
|
|
@ -280,7 +278,7 @@ def rw_attention_forward_40b(
|
||||||
cache_v = layer_past[1].view(batch_size, self.num_heads, -1, self.head_dim)
|
cache_v = layer_past[1].view(batch_size, self.num_heads, -1, self.head_dim)
|
||||||
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
|
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
|
||||||
# allocate new
|
# allocate new
|
||||||
new_cache_k, new_cache_v = create_kv_cache(
|
new_cache_k, new_cache_v = extend_kv_cache(
|
||||||
batch_size,
|
batch_size,
|
||||||
self.num_heads,
|
self.num_heads,
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
|
|
@ -298,7 +296,7 @@ def rw_attention_forward_40b(
|
||||||
|
|
||||||
elif use_cache:
|
elif use_cache:
|
||||||
max_cache_length = kv_length + KV_CACHE_ALLOC_BLOCK_LENGTH
|
max_cache_length = kv_length + KV_CACHE_ALLOC_BLOCK_LENGTH
|
||||||
new_key_states, new_value_states = create_kv_cache(
|
new_key_states, new_value_states = init_kv_cache(
|
||||||
batch_size,
|
batch_size,
|
||||||
self.num_heads,
|
self.num_heads,
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
|
|
@ -454,7 +452,7 @@ def falcon_attention_forward(
|
||||||
cache_v = layer_past[1].view(batch_size, self.num_heads, -1, self.head_dim)
|
cache_v = layer_past[1].view(batch_size, self.num_heads, -1, self.head_dim)
|
||||||
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
|
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
|
||||||
# allocate new
|
# allocate new
|
||||||
new_cache_k, new_cache_v = create_kv_cache(
|
new_cache_k, new_cache_v = extend_kv_cache(
|
||||||
batch_size,
|
batch_size,
|
||||||
self.num_heads,
|
self.num_heads,
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
|
|
@ -472,7 +470,7 @@ def falcon_attention_forward(
|
||||||
|
|
||||||
elif use_cache:
|
elif use_cache:
|
||||||
max_cache_length = kv_length + KV_CACHE_ALLOC_BLOCK_LENGTH
|
max_cache_length = kv_length + KV_CACHE_ALLOC_BLOCK_LENGTH
|
||||||
new_key_states, new_value_states = create_kv_cache(
|
new_key_states, new_value_states = init_kv_cache(
|
||||||
batch_size,
|
batch_size,
|
||||||
self.num_heads,
|
self.num_heads,
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
|
|
|
||||||
|
|
@ -19,8 +19,8 @@
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from typing import Optional, Tuple, Union
|
from typing import Optional, Tuple, Union
|
||||||
from bigdl.llm.transformers.models.utils import create_kv_cache, append_kv_cache, \
|
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, \
|
||||||
apply_rotary_pos_emb
|
apply_rotary_pos_emb, append_kv_cache
|
||||||
from transformers.utils.import_utils import is_torch_fx_proxy
|
from transformers.utils.import_utils import is_torch_fx_proxy
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -144,9 +144,7 @@ def gptj_attention_forward(
|
||||||
past_length = cache_k.size(2)
|
past_length = cache_k.size(2)
|
||||||
|
|
||||||
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
|
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
|
||||||
if device.type == 'xpu':
|
new_cache_k, new_cache_v = extend_kv_cache(batch_size,
|
||||||
torch.xpu.empty_cache()
|
|
||||||
new_cache_k, new_cache_v = create_kv_cache(batch_size,
|
|
||||||
self.num_attention_heads,
|
self.num_attention_heads,
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
past_length,
|
past_length,
|
||||||
|
|
@ -160,13 +158,13 @@ def gptj_attention_forward(
|
||||||
key, value = append_kv_cache(cache_k, cache_v, key, value)
|
key, value = append_kv_cache(cache_k, cache_v, key, value)
|
||||||
|
|
||||||
elif use_cache:
|
elif use_cache:
|
||||||
key_cache, value_cache = create_kv_cache(batch_size,
|
key_cache, value_cache = init_kv_cache(batch_size,
|
||||||
self.num_attention_heads,
|
self.num_attention_heads,
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
kv_seq_len,
|
kv_seq_len,
|
||||||
kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH,
|
kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH,
|
||||||
dtype=key.dtype,
|
dtype=key.dtype,
|
||||||
device=device)
|
device=device)
|
||||||
key_cache[:] = key
|
key_cache[:] = key
|
||||||
value_cache[:] = value
|
value_cache[:] = value
|
||||||
key = key_cache
|
key = key_cache
|
||||||
|
|
|
||||||
|
|
@ -34,7 +34,7 @@
|
||||||
import torch
|
import torch
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb
|
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb
|
||||||
from bigdl.llm.transformers.models.utils import create_kv_cache, append_kv_cache
|
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache
|
||||||
|
|
||||||
|
|
||||||
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
|
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
|
||||||
|
|
@ -90,10 +90,8 @@ def gptneox_attention_forward(
|
||||||
past_key = layer_past[0]
|
past_key = layer_past[0]
|
||||||
past_value = layer_past[1]
|
past_value = layer_past[1]
|
||||||
if past_key.stride()[1] <= past_key.size(2) * past_key.size(3):
|
if past_key.stride()[1] <= past_key.size(2) * past_key.size(3):
|
||||||
if device.type == 'xpu':
|
|
||||||
torch.xpu.empty_cache()
|
|
||||||
# allocate new
|
# allocate new
|
||||||
new_past_key, new_past_value = create_kv_cache(bsz,
|
new_past_key, new_past_value = extend_kv_cache(bsz,
|
||||||
self.num_attention_heads,
|
self.num_attention_heads,
|
||||||
self.head_size,
|
self.head_size,
|
||||||
past_key.size(2),
|
past_key.size(2),
|
||||||
|
|
@ -108,13 +106,13 @@ def gptneox_attention_forward(
|
||||||
key, value = append_kv_cache(past_key, past_value, key, value)
|
key, value = append_kv_cache(past_key, past_value, key, value)
|
||||||
elif use_cache:
|
elif use_cache:
|
||||||
max_cache_length = seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH
|
max_cache_length = seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH
|
||||||
new_key, new_value = create_kv_cache(bsz,
|
new_key, new_value = init_kv_cache(bsz,
|
||||||
self.num_attention_heads,
|
self.num_attention_heads,
|
||||||
self.head_size,
|
self.head_size,
|
||||||
seq_len,
|
seq_len,
|
||||||
max_cache_length,
|
max_cache_length,
|
||||||
dtype=key.dtype,
|
dtype=key.dtype,
|
||||||
device=device)
|
device=device)
|
||||||
new_key[:] = key
|
new_key[:] = key
|
||||||
new_value[:] = value
|
new_value[:] = value
|
||||||
key = new_key
|
key = new_key
|
||||||
|
|
|
||||||
|
|
@ -37,7 +37,7 @@ from typing import Optional, Tuple
|
||||||
import math
|
import math
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from bigdl.llm.utils.common import invalidInputError
|
from bigdl.llm.utils.common import invalidInputError
|
||||||
from bigdl.llm.transformers.models.utils import create_kv_cache, append_kv_cache
|
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
|
||||||
from bigdl.llm.transformers.models.utils import rotate_half, apply_rotary_pos_emb
|
from bigdl.llm.transformers.models.utils import rotate_half, apply_rotary_pos_emb
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -112,10 +112,8 @@ def llama_attention_forward_4_31(
|
||||||
cache_k = past_key_value[0]
|
cache_k = past_key_value[0]
|
||||||
cache_v = past_key_value[1]
|
cache_v = past_key_value[1]
|
||||||
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
|
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
|
||||||
if device.type == 'xpu':
|
|
||||||
torch.xpu.empty_cache()
|
|
||||||
# allocate new
|
# allocate new
|
||||||
new_cache_k, new_cache_v = create_kv_cache(bsz,
|
new_cache_k, new_cache_v = extend_kv_cache(bsz,
|
||||||
self.num_key_value_heads, # Support GQA
|
self.num_key_value_heads, # Support GQA
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
cache_k.size(2),
|
cache_k.size(2),
|
||||||
|
|
@ -131,13 +129,13 @@ def llama_attention_forward_4_31(
|
||||||
|
|
||||||
elif use_cache:
|
elif use_cache:
|
||||||
max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH
|
max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH
|
||||||
new_key_states, new_value_states = create_kv_cache(bsz,
|
new_key_states, new_value_states = init_kv_cache(bsz,
|
||||||
self.num_key_value_heads,
|
self.num_key_value_heads,
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
kv_seq_len,
|
kv_seq_len,
|
||||||
max_cache_length,
|
max_cache_length,
|
||||||
dtype=key_states.dtype,
|
dtype=key_states.dtype,
|
||||||
device=device)
|
device=device)
|
||||||
new_key_states[:] = key_states
|
new_key_states[:] = key_states
|
||||||
new_value_states[:] = value_states
|
new_value_states[:] = value_states
|
||||||
key_states = new_key_states
|
key_states = new_key_states
|
||||||
|
|
|
||||||
|
|
@ -18,7 +18,7 @@ import torch
|
||||||
from bigdl.llm.utils.common import invalidInputError
|
from bigdl.llm.utils.common import invalidInputError
|
||||||
|
|
||||||
|
|
||||||
def create_kv_cache(batch_size, num_heads, head_dim, current_length, max_length, dtype, device):
|
def init_kv_cache(batch_size, num_heads, head_dim, current_length, max_length, dtype, device):
|
||||||
key_cache_storage = torch.empty(batch_size, num_heads,
|
key_cache_storage = torch.empty(batch_size, num_heads,
|
||||||
max_length, head_dim,
|
max_length, head_dim,
|
||||||
dtype=dtype, device=device)
|
dtype=dtype, device=device)
|
||||||
|
|
@ -27,7 +27,7 @@ def create_kv_cache(batch_size, num_heads, head_dim, current_length, max_length,
|
||||||
dtype=dtype, device=device)
|
dtype=dtype, device=device)
|
||||||
|
|
||||||
key_cache = key_cache_storage.as_strided((batch_size, num_heads,
|
key_cache = key_cache_storage.as_strided((batch_size, num_heads,
|
||||||
current_length, head_dim),
|
current_length, head_dim),
|
||||||
key_cache_storage.stride(),
|
key_cache_storage.stride(),
|
||||||
storage_offset=0)
|
storage_offset=0)
|
||||||
value_cache = value_cache_storage.as_strided((batch_size, num_heads,
|
value_cache = value_cache_storage.as_strided((batch_size, num_heads,
|
||||||
|
|
@ -37,6 +37,13 @@ def create_kv_cache(batch_size, num_heads, head_dim, current_length, max_length,
|
||||||
return key_cache, value_cache
|
return key_cache, value_cache
|
||||||
|
|
||||||
|
|
||||||
|
def extend_kv_cache(batch_size, num_heads, head_dim, current_length, max_length, dtype, device):
|
||||||
|
# empty cache to reduce gpu memory
|
||||||
|
if device.type == 'xpu':
|
||||||
|
torch.xpu.empty_cache()
|
||||||
|
return init_kv_cache(batch_size, num_heads, head_dim, current_length, max_length, dtype, device)
|
||||||
|
|
||||||
|
|
||||||
def append_kv_cache(cache_k, cache_v, key_states, value_states):
|
def append_kv_cache(cache_k, cache_v, key_states, value_states):
|
||||||
new_size = (cache_k.size(0),
|
new_size = (cache_k.size(0),
|
||||||
cache_k.size(1),
|
cache_k.size(1),
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue