From d720554d4347758cf9e5082df3543c79d7a1e351 Mon Sep 17 00:00:00 2001 From: Yishuo Wang Date: Mon, 29 Jan 2024 09:23:57 +0800 Subject: [PATCH] simplify quantize kv cache api (#10011) --- .../bigdl/llm/transformers/models/chatglm2.py | 20 +++-------- .../src/bigdl/llm/transformers/models/qwen.py | 22 +++---------- .../bigdl/llm/transformers/models/utils.py | 33 ++++++++++--------- 3 files changed, 26 insertions(+), 49 deletions(-) diff --git a/python/llm/src/bigdl/llm/transformers/models/chatglm2.py b/python/llm/src/bigdl/llm/transformers/models/chatglm2.py index 67719ad2..1986b6ce 100644 --- a/python/llm/src/bigdl/llm/transformers/models/chatglm2.py +++ b/python/llm/src/bigdl/llm/transformers/models/chatglm2.py @@ -23,10 +23,8 @@ from typing import Optional, Tuple, List import torch.nn.functional as F from transformers.modeling_outputs import BaseModelOutputWithPast from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache -from bigdl.llm.transformers.models.utils import init_fp8_kv_cache, extend_fp8_kv_cache, \ - append_fp8_kv_cache, restore_fp8_kv_cache, quantize_kv_cache -from bigdl.llm.transformers.models.utils import use_flash_attention -from bigdl.llm.transformers.models.llama import get_ipex_version +from bigdl.llm.transformers.models.utils import init_fp8_kv_cache, append_fp8_kv_cache, \ + restore_fp8_kv_cache, use_quantize_kv_cache KV_CACHE_ALLOC_BLOCK_LENGTH = 256 @@ -189,7 +187,7 @@ def chatglm2_model_forward( def chatglm2_attention_forward( self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True ): - if quantize_kv_cache(self.query_key_value, hidden_states): + if use_quantize_kv_cache(self.query_key_value, hidden_states): forward_function = chatglm2_quantized_attention_forward_8eb45c else: forward_function = chatglm2_attention_forward_8eb45c @@ -263,9 +261,8 @@ def chatglm2_quantized_attention_forward_8eb45c( if use_cache: k_cache, v_cache = init_fp8_kv_cache(batch_size, n_kv_head, + seq_len, head_dim, - 0, - seq_len + KV_CACHE_ALLOC_MIN_LENGTH, query_layer.device) k_cache, v_cache = append_fp8_kv_cache(k_cache, v_cache, key_layer, value_layer) else: @@ -274,15 +271,6 @@ def chatglm2_quantized_attention_forward_8eb45c( v_cache = v_cache.permute(1, 2, 0, 3) # k_cache, v_cache's shape: [bs, n_kv_head, seq_len, head_dim] - kv_seq_len = seq_len + k_cache.size(2) - if k_cache.stride(1) < kv_seq_len * k_cache.size(3): - k_cache, v_cache = extend_fp8_kv_cache( - k_cache, v_cache, - kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH, - device=query_layer.device, - ) - if query_layer.device.type == 'xpu': - torch.xpu.empty_cache() k_cache, v_cache = append_fp8_kv_cache(k_cache, v_cache, key_layer, value_layer) if seq_len != 1: diff --git a/python/llm/src/bigdl/llm/transformers/models/qwen.py b/python/llm/src/bigdl/llm/transformers/models/qwen.py index 64f06626..7ca475af 100644 --- a/python/llm/src/bigdl/llm/transformers/models/qwen.py +++ b/python/llm/src/bigdl/llm/transformers/models/qwen.py @@ -37,9 +37,9 @@ except ImportError: rearrange = None from bigdl.llm.transformers.models.utils import extend_kv_cache, init_kv_cache, append_kv_cache -from bigdl.llm.transformers.models.utils import init_fp8_kv_cache, extend_fp8_kv_cache, \ - append_fp8_kv_cache, restore_fp8_kv_cache -from bigdl.llm.transformers.models.utils import rotate_half, quantize_kv_cache +from bigdl.llm.transformers.models.utils import init_fp8_kv_cache, append_fp8_kv_cache, \ + restore_fp8_kv_cache, use_quantize_kv_cache +from bigdl.llm.transformers.models.utils import rotate_half from bigdl.llm.transformers.models.utils import mlp_fusion_check from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_cache_freq_xpu from bigdl.llm.utils.common import invalidInputError, invalidOperationError @@ -146,7 +146,7 @@ def qwen_attention_forward( else: causal_mask = None - if quantize_kv_cache(self.c_attn, hidden_states): + if use_quantize_kv_cache(self.c_attn, hidden_states): query, key, value = query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2) # query, key, value's shape: [bs, num_heads, seq_len, head_dim] @@ -158,8 +158,7 @@ def qwen_attention_forward( if use_cache: max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH k_cache, v_cache = init_fp8_kv_cache( - query.size(0), self.num_heads, self.head_dim, - 0, max_cache_length, + query.size(0), self.num_heads, kv_seq_len, self.head_dim, device=query.device, ) key, value = append_fp8_kv_cache(k_cache, v_cache, key, value) @@ -169,17 +168,6 @@ def qwen_attention_forward( v_cache = v_cache.transpose(1, 2) # k_cache and v_cache's shape: [bs, num_heads, context_length, head_dim] - if k_cache.stride(1) < kv_seq_len * k_cache.size(3): - # allocate new - k_cache, v_cache = extend_fp8_kv_cache( - k_cache, v_cache, - kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH, - device=query.device, - ) - # empty cache to reduce gpu memory - if v_cache.device.type == 'xpu': - torch.xpu.empty_cache() - key, value = append_fp8_kv_cache(k_cache, v_cache, key, value) attn_output, attn_weight = core_attn( diff --git a/python/llm/src/bigdl/llm/transformers/models/utils.py b/python/llm/src/bigdl/llm/transformers/models/utils.py index 5efdf0d9..1dc1a36a 100644 --- a/python/llm/src/bigdl/llm/transformers/models/utils.py +++ b/python/llm/src/bigdl/llm/transformers/models/utils.py @@ -20,7 +20,7 @@ from bigdl.llm.utils.common import invalidInputError from bigdl.llm.ggml.quantize import ggml_tensor_qtype from bigdl.llm.transformers.utils import get_ipex_version, get_xpu_device_type - +FP8_KV_ALLOC_LENGTH = 512 SYM_INT4 = ggml_tensor_qtype["sym_int4"] SYM_INT8 = ggml_tensor_qtype["sym_int8"] FP8E4 = ggml_tensor_qtype["fp8_e4m3"] @@ -65,7 +65,7 @@ def append_kv_cache(cache_k, cache_v, key_states, value_states): return new_cache_k, new_cache_v -def quantize_kv_cache(linear: torch.nn.Module, x: torch.Tensor) -> bool: +def use_quantize_kv_cache(linear: torch.nn.Module, x: torch.Tensor) -> bool: if os.environ.get("BIGDL_QUANTIZE_KV_CACHE", None) is not None: return os.environ["BIGDL_QUANTIZE_KV_CACHE"] == "1" else: @@ -74,38 +74,39 @@ def quantize_kv_cache(linear: torch.nn.Module, x: torch.Tensor) -> bool: linear.qtype != ggml_tensor_qtype["fp16"] and linear.qtype != ggml_tensor_qtype["bf16"] -def init_fp8_kv_cache(batch_size, num_heads, head_dim, current_length, max_length, device): +def init_fp8_kv_cache(batch_size, num_heads, current_length, head_dim, device): + max_length = current_length + FP8_KV_ALLOC_LENGTH + k_cache_storage = torch.empty(batch_size, num_heads, max_length, head_dim, dtype=torch.uint8, device=device) v_cache_storage = torch.empty(batch_size, num_heads, head_dim, max_length, dtype=torch.uint8, device=device) - k_cache = k_cache_storage.as_strided((batch_size, num_heads, current_length, head_dim), + k_cache = k_cache_storage.as_strided((batch_size, num_heads, 0, head_dim), k_cache_storage.stride(), storage_offset=0) - v_cache = v_cache_storage.as_strided((batch_size, num_heads, head_dim, current_length), + v_cache = v_cache_storage.as_strided((batch_size, num_heads, head_dim, 0), v_cache_storage.stride(), storage_offset=0) return k_cache, v_cache.transpose(-1, -2) -def extend_fp8_kv_cache(k_cache, v_cache, max_length, device): - batch_size, num_heads, cur_length, head_dim = k_cache.shape - new_k_cache, new_v_cache = init_fp8_kv_cache(batch_size, num_heads, head_dim, - cur_length, max_length, device) - new_k_cache[:] = k_cache - new_v_cache[:] = v_cache - return new_k_cache, new_v_cache - - def append_fp8_kv_cache(k_cache, v_cache, key, value): batch_size, num_heads, cur_length, head_dim = k_cache.shape new_length = cur_length + key.size(2) new_size = (batch_size, num_heads, new_length, head_dim) - new_k_cache = k_cache.as_strided(new_size, k_cache.stride(), storage_offset=0) - new_v_cache = v_cache.as_strided(new_size, v_cache.stride(), storage_offset=0) + if k_cache.stride(1) < new_length * k_cache.size(3): + new_k_cache, new_v_cache = init_fp8_kv_cache(batch_size, num_heads, new_length, + head_dim, key.device) + new_k_cache = new_k_cache.as_strided(new_size, new_k_cache.stride(), storage_offset=0) + new_v_cache = new_v_cache.as_strided(new_size, new_v_cache.stride(), storage_offset=0) + new_k_cache[:, :, :cur_length, :] = k_cache + new_v_cache[:, :, :cur_length, :] = v_cache + else: + new_k_cache = k_cache.as_strided(new_size, k_cache.stride(), storage_offset=0) + new_v_cache = v_cache.as_strided(new_size, v_cache.stride(), storage_offset=0) fp8_key = key.half().view(torch.uint8)[:, :, :, 1::2] new_k_cache[:, :, cur_length:new_length, :] = fp8_key