simplify quantize kv cache api (#10011)

This commit is contained in:
Yishuo Wang 2024-01-29 09:23:57 +08:00 committed by GitHub
parent a3322e2a6c
commit d720554d43
3 changed files with 26 additions and 49 deletions

View file

@ -23,10 +23,8 @@ from typing import Optional, Tuple, List
import torch.nn.functional as F
from transformers.modeling_outputs import BaseModelOutputWithPast
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
from bigdl.llm.transformers.models.utils import init_fp8_kv_cache, extend_fp8_kv_cache, \
append_fp8_kv_cache, restore_fp8_kv_cache, quantize_kv_cache
from bigdl.llm.transformers.models.utils import use_flash_attention
from bigdl.llm.transformers.models.llama import get_ipex_version
from bigdl.llm.transformers.models.utils import init_fp8_kv_cache, append_fp8_kv_cache, \
restore_fp8_kv_cache, use_quantize_kv_cache
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
@ -189,7 +187,7 @@ def chatglm2_model_forward(
def chatglm2_attention_forward(
self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True
):
if quantize_kv_cache(self.query_key_value, hidden_states):
if use_quantize_kv_cache(self.query_key_value, hidden_states):
forward_function = chatglm2_quantized_attention_forward_8eb45c
else:
forward_function = chatglm2_attention_forward_8eb45c
@ -263,9 +261,8 @@ def chatglm2_quantized_attention_forward_8eb45c(
if use_cache:
k_cache, v_cache = init_fp8_kv_cache(batch_size,
n_kv_head,
seq_len,
head_dim,
0,
seq_len + KV_CACHE_ALLOC_MIN_LENGTH,
query_layer.device)
k_cache, v_cache = append_fp8_kv_cache(k_cache, v_cache, key_layer, value_layer)
else:
@ -274,15 +271,6 @@ def chatglm2_quantized_attention_forward_8eb45c(
v_cache = v_cache.permute(1, 2, 0, 3)
# k_cache, v_cache's shape: [bs, n_kv_head, seq_len, head_dim]
kv_seq_len = seq_len + k_cache.size(2)
if k_cache.stride(1) < kv_seq_len * k_cache.size(3):
k_cache, v_cache = extend_fp8_kv_cache(
k_cache, v_cache,
kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH,
device=query_layer.device,
)
if query_layer.device.type == 'xpu':
torch.xpu.empty_cache()
k_cache, v_cache = append_fp8_kv_cache(k_cache, v_cache, key_layer, value_layer)
if seq_len != 1:

View file

@ -37,9 +37,9 @@ except ImportError:
rearrange = None
from bigdl.llm.transformers.models.utils import extend_kv_cache, init_kv_cache, append_kv_cache
from bigdl.llm.transformers.models.utils import init_fp8_kv_cache, extend_fp8_kv_cache, \
append_fp8_kv_cache, restore_fp8_kv_cache
from bigdl.llm.transformers.models.utils import rotate_half, quantize_kv_cache
from bigdl.llm.transformers.models.utils import init_fp8_kv_cache, append_fp8_kv_cache, \
restore_fp8_kv_cache, use_quantize_kv_cache
from bigdl.llm.transformers.models.utils import rotate_half
from bigdl.llm.transformers.models.utils import mlp_fusion_check
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_cache_freq_xpu
from bigdl.llm.utils.common import invalidInputError, invalidOperationError
@ -146,7 +146,7 @@ def qwen_attention_forward(
else:
causal_mask = None
if quantize_kv_cache(self.c_attn, hidden_states):
if use_quantize_kv_cache(self.c_attn, hidden_states):
query, key, value = query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2)
# query, key, value's shape: [bs, num_heads, seq_len, head_dim]
@ -158,8 +158,7 @@ def qwen_attention_forward(
if use_cache:
max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH
k_cache, v_cache = init_fp8_kv_cache(
query.size(0), self.num_heads, self.head_dim,
0, max_cache_length,
query.size(0), self.num_heads, kv_seq_len, self.head_dim,
device=query.device,
)
key, value = append_fp8_kv_cache(k_cache, v_cache, key, value)
@ -169,17 +168,6 @@ def qwen_attention_forward(
v_cache = v_cache.transpose(1, 2)
# k_cache and v_cache's shape: [bs, num_heads, context_length, head_dim]
if k_cache.stride(1) < kv_seq_len * k_cache.size(3):
# allocate new
k_cache, v_cache = extend_fp8_kv_cache(
k_cache, v_cache,
kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH,
device=query.device,
)
# empty cache to reduce gpu memory
if v_cache.device.type == 'xpu':
torch.xpu.empty_cache()
key, value = append_fp8_kv_cache(k_cache, v_cache, key, value)
attn_output, attn_weight = core_attn(

View file

@ -20,7 +20,7 @@ from bigdl.llm.utils.common import invalidInputError
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
from bigdl.llm.transformers.utils import get_ipex_version, get_xpu_device_type
FP8_KV_ALLOC_LENGTH = 512
SYM_INT4 = ggml_tensor_qtype["sym_int4"]
SYM_INT8 = ggml_tensor_qtype["sym_int8"]
FP8E4 = ggml_tensor_qtype["fp8_e4m3"]
@ -65,7 +65,7 @@ def append_kv_cache(cache_k, cache_v, key_states, value_states):
return new_cache_k, new_cache_v
def quantize_kv_cache(linear: torch.nn.Module, x: torch.Tensor) -> bool:
def use_quantize_kv_cache(linear: torch.nn.Module, x: torch.Tensor) -> bool:
if os.environ.get("BIGDL_QUANTIZE_KV_CACHE", None) is not None:
return os.environ["BIGDL_QUANTIZE_KV_CACHE"] == "1"
else:
@ -74,38 +74,39 @@ def quantize_kv_cache(linear: torch.nn.Module, x: torch.Tensor) -> bool:
linear.qtype != ggml_tensor_qtype["fp16"] and linear.qtype != ggml_tensor_qtype["bf16"]
def init_fp8_kv_cache(batch_size, num_heads, head_dim, current_length, max_length, device):
def init_fp8_kv_cache(batch_size, num_heads, current_length, head_dim, device):
max_length = current_length + FP8_KV_ALLOC_LENGTH
k_cache_storage = torch.empty(batch_size, num_heads, max_length, head_dim,
dtype=torch.uint8, device=device)
v_cache_storage = torch.empty(batch_size, num_heads, head_dim, max_length,
dtype=torch.uint8, device=device)
k_cache = k_cache_storage.as_strided((batch_size, num_heads, current_length, head_dim),
k_cache = k_cache_storage.as_strided((batch_size, num_heads, 0, head_dim),
k_cache_storage.stride(), storage_offset=0)
v_cache = v_cache_storage.as_strided((batch_size, num_heads, head_dim, current_length),
v_cache = v_cache_storage.as_strided((batch_size, num_heads, head_dim, 0),
v_cache_storage.stride(), storage_offset=0)
return k_cache, v_cache.transpose(-1, -2)
def extend_fp8_kv_cache(k_cache, v_cache, max_length, device):
batch_size, num_heads, cur_length, head_dim = k_cache.shape
new_k_cache, new_v_cache = init_fp8_kv_cache(batch_size, num_heads, head_dim,
cur_length, max_length, device)
new_k_cache[:] = k_cache
new_v_cache[:] = v_cache
return new_k_cache, new_v_cache
def append_fp8_kv_cache(k_cache, v_cache, key, value):
batch_size, num_heads, cur_length, head_dim = k_cache.shape
new_length = cur_length + key.size(2)
new_size = (batch_size, num_heads, new_length, head_dim)
new_k_cache = k_cache.as_strided(new_size, k_cache.stride(), storage_offset=0)
new_v_cache = v_cache.as_strided(new_size, v_cache.stride(), storage_offset=0)
if k_cache.stride(1) < new_length * k_cache.size(3):
new_k_cache, new_v_cache = init_fp8_kv_cache(batch_size, num_heads, new_length,
head_dim, key.device)
new_k_cache = new_k_cache.as_strided(new_size, new_k_cache.stride(), storage_offset=0)
new_v_cache = new_v_cache.as_strided(new_size, new_v_cache.stride(), storage_offset=0)
new_k_cache[:, :, :cur_length, :] = k_cache
new_v_cache[:, :, :cur_length, :] = v_cache
else:
new_k_cache = k_cache.as_strided(new_size, k_cache.stride(), storage_offset=0)
new_v_cache = v_cache.as_strided(new_size, v_cache.stride(), storage_offset=0)
fp8_key = key.half().view(torch.uint8)[:, :, :, 1::2]
new_k_cache[:, :, cur_length:new_length, :] = fp8_key