simplify quantize kv cache api (#10011)
This commit is contained in:
parent
a3322e2a6c
commit
d720554d43
3 changed files with 26 additions and 49 deletions
|
|
@ -23,10 +23,8 @@ from typing import Optional, Tuple, List
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||||
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
|
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
|
||||||
from bigdl.llm.transformers.models.utils import init_fp8_kv_cache, extend_fp8_kv_cache, \
|
from bigdl.llm.transformers.models.utils import init_fp8_kv_cache, append_fp8_kv_cache, \
|
||||||
append_fp8_kv_cache, restore_fp8_kv_cache, quantize_kv_cache
|
restore_fp8_kv_cache, use_quantize_kv_cache
|
||||||
from bigdl.llm.transformers.models.utils import use_flash_attention
|
|
||||||
from bigdl.llm.transformers.models.llama import get_ipex_version
|
|
||||||
|
|
||||||
|
|
||||||
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
|
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
|
||||||
|
|
@ -189,7 +187,7 @@ def chatglm2_model_forward(
|
||||||
def chatglm2_attention_forward(
|
def chatglm2_attention_forward(
|
||||||
self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True
|
self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True
|
||||||
):
|
):
|
||||||
if quantize_kv_cache(self.query_key_value, hidden_states):
|
if use_quantize_kv_cache(self.query_key_value, hidden_states):
|
||||||
forward_function = chatglm2_quantized_attention_forward_8eb45c
|
forward_function = chatglm2_quantized_attention_forward_8eb45c
|
||||||
else:
|
else:
|
||||||
forward_function = chatglm2_attention_forward_8eb45c
|
forward_function = chatglm2_attention_forward_8eb45c
|
||||||
|
|
@ -263,9 +261,8 @@ def chatglm2_quantized_attention_forward_8eb45c(
|
||||||
if use_cache:
|
if use_cache:
|
||||||
k_cache, v_cache = init_fp8_kv_cache(batch_size,
|
k_cache, v_cache = init_fp8_kv_cache(batch_size,
|
||||||
n_kv_head,
|
n_kv_head,
|
||||||
|
seq_len,
|
||||||
head_dim,
|
head_dim,
|
||||||
0,
|
|
||||||
seq_len + KV_CACHE_ALLOC_MIN_LENGTH,
|
|
||||||
query_layer.device)
|
query_layer.device)
|
||||||
k_cache, v_cache = append_fp8_kv_cache(k_cache, v_cache, key_layer, value_layer)
|
k_cache, v_cache = append_fp8_kv_cache(k_cache, v_cache, key_layer, value_layer)
|
||||||
else:
|
else:
|
||||||
|
|
@ -274,15 +271,6 @@ def chatglm2_quantized_attention_forward_8eb45c(
|
||||||
v_cache = v_cache.permute(1, 2, 0, 3)
|
v_cache = v_cache.permute(1, 2, 0, 3)
|
||||||
# k_cache, v_cache's shape: [bs, n_kv_head, seq_len, head_dim]
|
# k_cache, v_cache's shape: [bs, n_kv_head, seq_len, head_dim]
|
||||||
|
|
||||||
kv_seq_len = seq_len + k_cache.size(2)
|
|
||||||
if k_cache.stride(1) < kv_seq_len * k_cache.size(3):
|
|
||||||
k_cache, v_cache = extend_fp8_kv_cache(
|
|
||||||
k_cache, v_cache,
|
|
||||||
kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH,
|
|
||||||
device=query_layer.device,
|
|
||||||
)
|
|
||||||
if query_layer.device.type == 'xpu':
|
|
||||||
torch.xpu.empty_cache()
|
|
||||||
k_cache, v_cache = append_fp8_kv_cache(k_cache, v_cache, key_layer, value_layer)
|
k_cache, v_cache = append_fp8_kv_cache(k_cache, v_cache, key_layer, value_layer)
|
||||||
|
|
||||||
if seq_len != 1:
|
if seq_len != 1:
|
||||||
|
|
|
||||||
|
|
@ -37,9 +37,9 @@ except ImportError:
|
||||||
rearrange = None
|
rearrange = None
|
||||||
|
|
||||||
from bigdl.llm.transformers.models.utils import extend_kv_cache, init_kv_cache, append_kv_cache
|
from bigdl.llm.transformers.models.utils import extend_kv_cache, init_kv_cache, append_kv_cache
|
||||||
from bigdl.llm.transformers.models.utils import init_fp8_kv_cache, extend_fp8_kv_cache, \
|
from bigdl.llm.transformers.models.utils import init_fp8_kv_cache, append_fp8_kv_cache, \
|
||||||
append_fp8_kv_cache, restore_fp8_kv_cache
|
restore_fp8_kv_cache, use_quantize_kv_cache
|
||||||
from bigdl.llm.transformers.models.utils import rotate_half, quantize_kv_cache
|
from bigdl.llm.transformers.models.utils import rotate_half
|
||||||
from bigdl.llm.transformers.models.utils import mlp_fusion_check
|
from bigdl.llm.transformers.models.utils import mlp_fusion_check
|
||||||
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_cache_freq_xpu
|
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_cache_freq_xpu
|
||||||
from bigdl.llm.utils.common import invalidInputError, invalidOperationError
|
from bigdl.llm.utils.common import invalidInputError, invalidOperationError
|
||||||
|
|
@ -146,7 +146,7 @@ def qwen_attention_forward(
|
||||||
else:
|
else:
|
||||||
causal_mask = None
|
causal_mask = None
|
||||||
|
|
||||||
if quantize_kv_cache(self.c_attn, hidden_states):
|
if use_quantize_kv_cache(self.c_attn, hidden_states):
|
||||||
query, key, value = query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2)
|
query, key, value = query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2)
|
||||||
# query, key, value's shape: [bs, num_heads, seq_len, head_dim]
|
# query, key, value's shape: [bs, num_heads, seq_len, head_dim]
|
||||||
|
|
||||||
|
|
@ -158,8 +158,7 @@ def qwen_attention_forward(
|
||||||
if use_cache:
|
if use_cache:
|
||||||
max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH
|
max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH
|
||||||
k_cache, v_cache = init_fp8_kv_cache(
|
k_cache, v_cache = init_fp8_kv_cache(
|
||||||
query.size(0), self.num_heads, self.head_dim,
|
query.size(0), self.num_heads, kv_seq_len, self.head_dim,
|
||||||
0, max_cache_length,
|
|
||||||
device=query.device,
|
device=query.device,
|
||||||
)
|
)
|
||||||
key, value = append_fp8_kv_cache(k_cache, v_cache, key, value)
|
key, value = append_fp8_kv_cache(k_cache, v_cache, key, value)
|
||||||
|
|
@ -169,17 +168,6 @@ def qwen_attention_forward(
|
||||||
v_cache = v_cache.transpose(1, 2)
|
v_cache = v_cache.transpose(1, 2)
|
||||||
# k_cache and v_cache's shape: [bs, num_heads, context_length, head_dim]
|
# k_cache and v_cache's shape: [bs, num_heads, context_length, head_dim]
|
||||||
|
|
||||||
if k_cache.stride(1) < kv_seq_len * k_cache.size(3):
|
|
||||||
# allocate new
|
|
||||||
k_cache, v_cache = extend_fp8_kv_cache(
|
|
||||||
k_cache, v_cache,
|
|
||||||
kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH,
|
|
||||||
device=query.device,
|
|
||||||
)
|
|
||||||
# empty cache to reduce gpu memory
|
|
||||||
if v_cache.device.type == 'xpu':
|
|
||||||
torch.xpu.empty_cache()
|
|
||||||
|
|
||||||
key, value = append_fp8_kv_cache(k_cache, v_cache, key, value)
|
key, value = append_fp8_kv_cache(k_cache, v_cache, key, value)
|
||||||
|
|
||||||
attn_output, attn_weight = core_attn(
|
attn_output, attn_weight = core_attn(
|
||||||
|
|
|
||||||
|
|
@ -20,7 +20,7 @@ from bigdl.llm.utils.common import invalidInputError
|
||||||
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
|
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
|
||||||
from bigdl.llm.transformers.utils import get_ipex_version, get_xpu_device_type
|
from bigdl.llm.transformers.utils import get_ipex_version, get_xpu_device_type
|
||||||
|
|
||||||
|
FP8_KV_ALLOC_LENGTH = 512
|
||||||
SYM_INT4 = ggml_tensor_qtype["sym_int4"]
|
SYM_INT4 = ggml_tensor_qtype["sym_int4"]
|
||||||
SYM_INT8 = ggml_tensor_qtype["sym_int8"]
|
SYM_INT8 = ggml_tensor_qtype["sym_int8"]
|
||||||
FP8E4 = ggml_tensor_qtype["fp8_e4m3"]
|
FP8E4 = ggml_tensor_qtype["fp8_e4m3"]
|
||||||
|
|
@ -65,7 +65,7 @@ def append_kv_cache(cache_k, cache_v, key_states, value_states):
|
||||||
return new_cache_k, new_cache_v
|
return new_cache_k, new_cache_v
|
||||||
|
|
||||||
|
|
||||||
def quantize_kv_cache(linear: torch.nn.Module, x: torch.Tensor) -> bool:
|
def use_quantize_kv_cache(linear: torch.nn.Module, x: torch.Tensor) -> bool:
|
||||||
if os.environ.get("BIGDL_QUANTIZE_KV_CACHE", None) is not None:
|
if os.environ.get("BIGDL_QUANTIZE_KV_CACHE", None) is not None:
|
||||||
return os.environ["BIGDL_QUANTIZE_KV_CACHE"] == "1"
|
return os.environ["BIGDL_QUANTIZE_KV_CACHE"] == "1"
|
||||||
else:
|
else:
|
||||||
|
|
@ -74,38 +74,39 @@ def quantize_kv_cache(linear: torch.nn.Module, x: torch.Tensor) -> bool:
|
||||||
linear.qtype != ggml_tensor_qtype["fp16"] and linear.qtype != ggml_tensor_qtype["bf16"]
|
linear.qtype != ggml_tensor_qtype["fp16"] and linear.qtype != ggml_tensor_qtype["bf16"]
|
||||||
|
|
||||||
|
|
||||||
def init_fp8_kv_cache(batch_size, num_heads, head_dim, current_length, max_length, device):
|
def init_fp8_kv_cache(batch_size, num_heads, current_length, head_dim, device):
|
||||||
|
max_length = current_length + FP8_KV_ALLOC_LENGTH
|
||||||
|
|
||||||
k_cache_storage = torch.empty(batch_size, num_heads, max_length, head_dim,
|
k_cache_storage = torch.empty(batch_size, num_heads, max_length, head_dim,
|
||||||
dtype=torch.uint8, device=device)
|
dtype=torch.uint8, device=device)
|
||||||
|
|
||||||
v_cache_storage = torch.empty(batch_size, num_heads, head_dim, max_length,
|
v_cache_storage = torch.empty(batch_size, num_heads, head_dim, max_length,
|
||||||
dtype=torch.uint8, device=device)
|
dtype=torch.uint8, device=device)
|
||||||
|
|
||||||
k_cache = k_cache_storage.as_strided((batch_size, num_heads, current_length, head_dim),
|
k_cache = k_cache_storage.as_strided((batch_size, num_heads, 0, head_dim),
|
||||||
k_cache_storage.stride(), storage_offset=0)
|
k_cache_storage.stride(), storage_offset=0)
|
||||||
|
|
||||||
v_cache = v_cache_storage.as_strided((batch_size, num_heads, head_dim, current_length),
|
v_cache = v_cache_storage.as_strided((batch_size, num_heads, head_dim, 0),
|
||||||
v_cache_storage.stride(), storage_offset=0)
|
v_cache_storage.stride(), storage_offset=0)
|
||||||
|
|
||||||
return k_cache, v_cache.transpose(-1, -2)
|
return k_cache, v_cache.transpose(-1, -2)
|
||||||
|
|
||||||
|
|
||||||
def extend_fp8_kv_cache(k_cache, v_cache, max_length, device):
|
|
||||||
batch_size, num_heads, cur_length, head_dim = k_cache.shape
|
|
||||||
new_k_cache, new_v_cache = init_fp8_kv_cache(batch_size, num_heads, head_dim,
|
|
||||||
cur_length, max_length, device)
|
|
||||||
new_k_cache[:] = k_cache
|
|
||||||
new_v_cache[:] = v_cache
|
|
||||||
return new_k_cache, new_v_cache
|
|
||||||
|
|
||||||
|
|
||||||
def append_fp8_kv_cache(k_cache, v_cache, key, value):
|
def append_fp8_kv_cache(k_cache, v_cache, key, value):
|
||||||
batch_size, num_heads, cur_length, head_dim = k_cache.shape
|
batch_size, num_heads, cur_length, head_dim = k_cache.shape
|
||||||
new_length = cur_length + key.size(2)
|
new_length = cur_length + key.size(2)
|
||||||
new_size = (batch_size, num_heads, new_length, head_dim)
|
new_size = (batch_size, num_heads, new_length, head_dim)
|
||||||
|
|
||||||
new_k_cache = k_cache.as_strided(new_size, k_cache.stride(), storage_offset=0)
|
if k_cache.stride(1) < new_length * k_cache.size(3):
|
||||||
new_v_cache = v_cache.as_strided(new_size, v_cache.stride(), storage_offset=0)
|
new_k_cache, new_v_cache = init_fp8_kv_cache(batch_size, num_heads, new_length,
|
||||||
|
head_dim, key.device)
|
||||||
|
new_k_cache = new_k_cache.as_strided(new_size, new_k_cache.stride(), storage_offset=0)
|
||||||
|
new_v_cache = new_v_cache.as_strided(new_size, new_v_cache.stride(), storage_offset=0)
|
||||||
|
new_k_cache[:, :, :cur_length, :] = k_cache
|
||||||
|
new_v_cache[:, :, :cur_length, :] = v_cache
|
||||||
|
else:
|
||||||
|
new_k_cache = k_cache.as_strided(new_size, k_cache.stride(), storage_offset=0)
|
||||||
|
new_v_cache = v_cache.as_strided(new_size, v_cache.stride(), storage_offset=0)
|
||||||
|
|
||||||
fp8_key = key.half().view(torch.uint8)[:, :, :, 1::2]
|
fp8_key = key.half().view(torch.uint8)[:, :, :, 1::2]
|
||||||
new_k_cache[:, :, cur_length:new_length, :] = fp8_key
|
new_k_cache[:, :, cur_length:new_length, :] = fp8_key
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue