simplify quantize kv cache api (#10011)
This commit is contained in:
parent
a3322e2a6c
commit
d720554d43
3 changed files with 26 additions and 49 deletions
|
|
@ -23,10 +23,8 @@ from typing import Optional, Tuple, List
|
|||
import torch.nn.functional as F
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
|
||||
from bigdl.llm.transformers.models.utils import init_fp8_kv_cache, extend_fp8_kv_cache, \
|
||||
append_fp8_kv_cache, restore_fp8_kv_cache, quantize_kv_cache
|
||||
from bigdl.llm.transformers.models.utils import use_flash_attention
|
||||
from bigdl.llm.transformers.models.llama import get_ipex_version
|
||||
from bigdl.llm.transformers.models.utils import init_fp8_kv_cache, append_fp8_kv_cache, \
|
||||
restore_fp8_kv_cache, use_quantize_kv_cache
|
||||
|
||||
|
||||
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
|
||||
|
|
@ -189,7 +187,7 @@ def chatglm2_model_forward(
|
|||
def chatglm2_attention_forward(
|
||||
self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True
|
||||
):
|
||||
if quantize_kv_cache(self.query_key_value, hidden_states):
|
||||
if use_quantize_kv_cache(self.query_key_value, hidden_states):
|
||||
forward_function = chatglm2_quantized_attention_forward_8eb45c
|
||||
else:
|
||||
forward_function = chatglm2_attention_forward_8eb45c
|
||||
|
|
@ -263,9 +261,8 @@ def chatglm2_quantized_attention_forward_8eb45c(
|
|||
if use_cache:
|
||||
k_cache, v_cache = init_fp8_kv_cache(batch_size,
|
||||
n_kv_head,
|
||||
seq_len,
|
||||
head_dim,
|
||||
0,
|
||||
seq_len + KV_CACHE_ALLOC_MIN_LENGTH,
|
||||
query_layer.device)
|
||||
k_cache, v_cache = append_fp8_kv_cache(k_cache, v_cache, key_layer, value_layer)
|
||||
else:
|
||||
|
|
@ -274,15 +271,6 @@ def chatglm2_quantized_attention_forward_8eb45c(
|
|||
v_cache = v_cache.permute(1, 2, 0, 3)
|
||||
# k_cache, v_cache's shape: [bs, n_kv_head, seq_len, head_dim]
|
||||
|
||||
kv_seq_len = seq_len + k_cache.size(2)
|
||||
if k_cache.stride(1) < kv_seq_len * k_cache.size(3):
|
||||
k_cache, v_cache = extend_fp8_kv_cache(
|
||||
k_cache, v_cache,
|
||||
kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH,
|
||||
device=query_layer.device,
|
||||
)
|
||||
if query_layer.device.type == 'xpu':
|
||||
torch.xpu.empty_cache()
|
||||
k_cache, v_cache = append_fp8_kv_cache(k_cache, v_cache, key_layer, value_layer)
|
||||
|
||||
if seq_len != 1:
|
||||
|
|
|
|||
|
|
@ -37,9 +37,9 @@ except ImportError:
|
|||
rearrange = None
|
||||
|
||||
from bigdl.llm.transformers.models.utils import extend_kv_cache, init_kv_cache, append_kv_cache
|
||||
from bigdl.llm.transformers.models.utils import init_fp8_kv_cache, extend_fp8_kv_cache, \
|
||||
append_fp8_kv_cache, restore_fp8_kv_cache
|
||||
from bigdl.llm.transformers.models.utils import rotate_half, quantize_kv_cache
|
||||
from bigdl.llm.transformers.models.utils import init_fp8_kv_cache, append_fp8_kv_cache, \
|
||||
restore_fp8_kv_cache, use_quantize_kv_cache
|
||||
from bigdl.llm.transformers.models.utils import rotate_half
|
||||
from bigdl.llm.transformers.models.utils import mlp_fusion_check
|
||||
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_cache_freq_xpu
|
||||
from bigdl.llm.utils.common import invalidInputError, invalidOperationError
|
||||
|
|
@ -146,7 +146,7 @@ def qwen_attention_forward(
|
|||
else:
|
||||
causal_mask = None
|
||||
|
||||
if quantize_kv_cache(self.c_attn, hidden_states):
|
||||
if use_quantize_kv_cache(self.c_attn, hidden_states):
|
||||
query, key, value = query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2)
|
||||
# query, key, value's shape: [bs, num_heads, seq_len, head_dim]
|
||||
|
||||
|
|
@ -158,8 +158,7 @@ def qwen_attention_forward(
|
|||
if use_cache:
|
||||
max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH
|
||||
k_cache, v_cache = init_fp8_kv_cache(
|
||||
query.size(0), self.num_heads, self.head_dim,
|
||||
0, max_cache_length,
|
||||
query.size(0), self.num_heads, kv_seq_len, self.head_dim,
|
||||
device=query.device,
|
||||
)
|
||||
key, value = append_fp8_kv_cache(k_cache, v_cache, key, value)
|
||||
|
|
@ -169,17 +168,6 @@ def qwen_attention_forward(
|
|||
v_cache = v_cache.transpose(1, 2)
|
||||
# k_cache and v_cache's shape: [bs, num_heads, context_length, head_dim]
|
||||
|
||||
if k_cache.stride(1) < kv_seq_len * k_cache.size(3):
|
||||
# allocate new
|
||||
k_cache, v_cache = extend_fp8_kv_cache(
|
||||
k_cache, v_cache,
|
||||
kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH,
|
||||
device=query.device,
|
||||
)
|
||||
# empty cache to reduce gpu memory
|
||||
if v_cache.device.type == 'xpu':
|
||||
torch.xpu.empty_cache()
|
||||
|
||||
key, value = append_fp8_kv_cache(k_cache, v_cache, key, value)
|
||||
|
||||
attn_output, attn_weight = core_attn(
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ from bigdl.llm.utils.common import invalidInputError
|
|||
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
|
||||
from bigdl.llm.transformers.utils import get_ipex_version, get_xpu_device_type
|
||||
|
||||
|
||||
FP8_KV_ALLOC_LENGTH = 512
|
||||
SYM_INT4 = ggml_tensor_qtype["sym_int4"]
|
||||
SYM_INT8 = ggml_tensor_qtype["sym_int8"]
|
||||
FP8E4 = ggml_tensor_qtype["fp8_e4m3"]
|
||||
|
|
@ -65,7 +65,7 @@ def append_kv_cache(cache_k, cache_v, key_states, value_states):
|
|||
return new_cache_k, new_cache_v
|
||||
|
||||
|
||||
def quantize_kv_cache(linear: torch.nn.Module, x: torch.Tensor) -> bool:
|
||||
def use_quantize_kv_cache(linear: torch.nn.Module, x: torch.Tensor) -> bool:
|
||||
if os.environ.get("BIGDL_QUANTIZE_KV_CACHE", None) is not None:
|
||||
return os.environ["BIGDL_QUANTIZE_KV_CACHE"] == "1"
|
||||
else:
|
||||
|
|
@ -74,38 +74,39 @@ def quantize_kv_cache(linear: torch.nn.Module, x: torch.Tensor) -> bool:
|
|||
linear.qtype != ggml_tensor_qtype["fp16"] and linear.qtype != ggml_tensor_qtype["bf16"]
|
||||
|
||||
|
||||
def init_fp8_kv_cache(batch_size, num_heads, head_dim, current_length, max_length, device):
|
||||
def init_fp8_kv_cache(batch_size, num_heads, current_length, head_dim, device):
|
||||
max_length = current_length + FP8_KV_ALLOC_LENGTH
|
||||
|
||||
k_cache_storage = torch.empty(batch_size, num_heads, max_length, head_dim,
|
||||
dtype=torch.uint8, device=device)
|
||||
|
||||
v_cache_storage = torch.empty(batch_size, num_heads, head_dim, max_length,
|
||||
dtype=torch.uint8, device=device)
|
||||
|
||||
k_cache = k_cache_storage.as_strided((batch_size, num_heads, current_length, head_dim),
|
||||
k_cache = k_cache_storage.as_strided((batch_size, num_heads, 0, head_dim),
|
||||
k_cache_storage.stride(), storage_offset=0)
|
||||
|
||||
v_cache = v_cache_storage.as_strided((batch_size, num_heads, head_dim, current_length),
|
||||
v_cache = v_cache_storage.as_strided((batch_size, num_heads, head_dim, 0),
|
||||
v_cache_storage.stride(), storage_offset=0)
|
||||
|
||||
return k_cache, v_cache.transpose(-1, -2)
|
||||
|
||||
|
||||
def extend_fp8_kv_cache(k_cache, v_cache, max_length, device):
|
||||
batch_size, num_heads, cur_length, head_dim = k_cache.shape
|
||||
new_k_cache, new_v_cache = init_fp8_kv_cache(batch_size, num_heads, head_dim,
|
||||
cur_length, max_length, device)
|
||||
new_k_cache[:] = k_cache
|
||||
new_v_cache[:] = v_cache
|
||||
return new_k_cache, new_v_cache
|
||||
|
||||
|
||||
def append_fp8_kv_cache(k_cache, v_cache, key, value):
|
||||
batch_size, num_heads, cur_length, head_dim = k_cache.shape
|
||||
new_length = cur_length + key.size(2)
|
||||
new_size = (batch_size, num_heads, new_length, head_dim)
|
||||
|
||||
new_k_cache = k_cache.as_strided(new_size, k_cache.stride(), storage_offset=0)
|
||||
new_v_cache = v_cache.as_strided(new_size, v_cache.stride(), storage_offset=0)
|
||||
if k_cache.stride(1) < new_length * k_cache.size(3):
|
||||
new_k_cache, new_v_cache = init_fp8_kv_cache(batch_size, num_heads, new_length,
|
||||
head_dim, key.device)
|
||||
new_k_cache = new_k_cache.as_strided(new_size, new_k_cache.stride(), storage_offset=0)
|
||||
new_v_cache = new_v_cache.as_strided(new_size, new_v_cache.stride(), storage_offset=0)
|
||||
new_k_cache[:, :, :cur_length, :] = k_cache
|
||||
new_v_cache[:, :, :cur_length, :] = v_cache
|
||||
else:
|
||||
new_k_cache = k_cache.as_strided(new_size, k_cache.stride(), storage_offset=0)
|
||||
new_v_cache = v_cache.as_strided(new_size, v_cache.stride(), storage_offset=0)
|
||||
|
||||
fp8_key = key.half().view(torch.uint8)[:, :, :, 1::2]
|
||||
new_k_cache[:, :, cur_length:new_length, :] = fp8_key
|
||||
|
|
|
|||
Loading…
Reference in a new issue