From afaa87114440e56d9c07305bb979510f0b44d440 Mon Sep 17 00:00:00 2001 From: Yishuo Wang Date: Mon, 8 Jan 2024 09:28:20 +0800 Subject: [PATCH] [LLM] support quantize kv cache to fp8 (#9812) --- .../src/bigdl/llm/transformers/models/qwen.py | 217 ++++++++++++------ .../bigdl/llm/transformers/models/utils.py | 62 +++++ 2 files changed, 212 insertions(+), 67 deletions(-) diff --git a/python/llm/src/bigdl/llm/transformers/models/qwen.py b/python/llm/src/bigdl/llm/transformers/models/qwen.py index d107ac61..c2c6ef0e 100644 --- a/python/llm/src/bigdl/llm/transformers/models/qwen.py +++ b/python/llm/src/bigdl/llm/transformers/models/qwen.py @@ -37,7 +37,9 @@ except ImportError: rearrange = None from bigdl.llm.transformers.models.utils import extend_kv_cache, init_kv_cache, append_kv_cache -from bigdl.llm.transformers.models.utils import rotate_half +from bigdl.llm.transformers.models.utils import init_fp8_kv_cache, extend_fp8_kv_cache, \ + append_fp8_kv_cache, restore_fp8_kv_cache +from bigdl.llm.transformers.models.utils import rotate_half, quantize_kv_cache from bigdl.llm.utils.common import invalidInputError, invalidOperationError from bigdl.llm.ggml.quantize import ggml_tensor_qtype @@ -83,29 +85,18 @@ def qwen_attention_forward( query = self._split_heads(query, self.num_heads, self.head_dim) key = self._split_heads(key, self.num_heads, self.head_dim) value = self._split_heads(value, self.num_heads, self.head_dim) - - kv_seq_len = hidden_states.size()[1] + # query, key, value's shape: [bs, seq_len, num_heads, head_dim] if rotary_pos_emb_list is not None: cur_len = query.shape[1] if len(rotary_pos_emb_list) == 1: - if query.device.type == 'xpu': - cos, sin = rotary_pos_emb_list[0] - cos = cos[:, -cur_len:, :, :] - sin = sin[:, -cur_len:, :, :] - rot_dim = cos.shape[-1] - query_cur = query[..., :rot_dim] - key_cur = key[..., :rot_dim] - torch.ops.torch_ipex.apply_rotary_embedding(query_cur, sin, cos, query_cur) - torch.ops.torch_ipex.apply_rotary_embedding(key_cur, sin, cos, key_cur) - else: - rotary_pos_emb = rotary_pos_emb_list[0] - rotary_pos_emb = [i[:, -cur_len:, :, :] for i in rotary_pos_emb] - rotary_pos_emb = (rotary_pos_emb,) * 2 - q_pos_emb, k_pos_emb = rotary_pos_emb - # Slice the pos emb for current inference - query = apply_rotary_pos_emb(query, q_pos_emb) - key = apply_rotary_pos_emb(key, k_pos_emb) + rotary_pos_emb = rotary_pos_emb_list[0] + rotary_pos_emb = [i[:, -cur_len:, :, :] for i in rotary_pos_emb] + rotary_pos_emb = (rotary_pos_emb,) * 2 + q_pos_emb, k_pos_emb = rotary_pos_emb + # Slice the pos emb for current inference + query = apply_rotary_pos_emb(query, q_pos_emb) + key = apply_rotary_pos_emb(key, k_pos_emb) else: query_list = [] key_list = [] @@ -119,62 +110,106 @@ def qwen_attention_forward( query = torch.cat(query_list, dim=0) key = torch.cat(key_list, dim=0) - bsz, _, n_heads, head_dim = key.size() + query_size, key_size = query.size(1), key.size(1) + kv_seq_len = key_size if layer_past is None else key_size + layer_past[0].size(1) - if layer_past is not None: - cache_k, cache_v = layer_past[0], layer_past[1] - cache_k = cache_k.transpose(1, 2) - cache_v = cache_v.transpose(1, 2) - kv_seq_len += cache_k.shape[2] - if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3): - # allocate new - new_cache_k, new_cache_v = extend_kv_cache(bsz, - self.num_heads, - self.head_dim, - cache_k.size(2), - kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH, - dtype=cache_k.dtype, - device=hidden_states.device) - new_cache_k[:] = cache_k - new_cache_v[:] = cache_v - cache_k = new_cache_k - cache_v = new_cache_v - - key_states, value_states = append_kv_cache(cache_k, cache_v, - key.transpose(1, 2), value.transpose(1, 2)) - key = key_states - value = value_states - elif use_cache: - max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH - new_key_states, new_value_states = init_kv_cache(bsz, - self.num_heads, - self.head_dim, - kv_seq_len, - max_cache_length, - dtype=key.dtype, - device=hidden_states.device) - new_key_states[:] = key.transpose(1, 2) - new_value_states[:] = value.transpose(1, 2) - key = new_key_states - value = new_value_states - - query_size, key_size = query.size(1), key.size(2) - if key_size > self.seq_length and self.use_logn_attn and not self.training: - seq_start = key_size - query_size - seq_end = key_size + if kv_seq_len > self.seq_length and self.use_logn_attn and not self.training: + seq_start = kv_seq_len - query_size + seq_end = kv_seq_len logn_tensor = self.logn_tensor[:, seq_start:seq_end, :, :].type_as(query) query = query * logn_tensor.expand_as(query) - if query_size == key_size: + if key_size == kv_seq_len: causal_mask = torch.tril( torch.ones((key_size, key_size), dtype=torch.bool, device=query.device) ).view(1, 1, key_size, key_size) else: causal_mask = None - query = query.transpose(1, 2) - attn_output, attn_weight = self._attn( - query, key, value, causal_mask, attention_mask, head_mask - ) + if quantize_kv_cache(self.c_attn, hidden_states): + query, key, value = query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2) + # query, key, value's shape: [bs, num_heads, seq_len, head_dim] + + if layer_past is None: + # For first token, use original attn + attn_output, attn_weight = self._attn( + query, key, value, causal_mask, attention_mask, head_mask + ) + if use_cache: + max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH + k_cache, v_cache = init_fp8_kv_cache( + query.size(0), self.num_heads, self.head_dim, + 0, max_cache_length, + device=query.device, + ) + key, value = append_fp8_kv_cache(k_cache, v_cache, key, value) + else: + k_cache, v_cache = layer_past[0], layer_past[1] + k_cache = k_cache.transpose(1, 2) + v_cache = v_cache.transpose(1, 2) + # k_cache and v_cache's shape: [bs, num_heads, context_length, head_dim] + + if k_cache.stride(1) < kv_seq_len * k_cache.size(3): + # allocate new + k_cache, v_cache = extend_fp8_kv_cache( + k_cache, v_cache, + kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH, + device=query.device, + ) + # empty cache to reduce gpu memory + if v_cache.device.type == 'xpu': + torch.xpu.empty_cache() + + key, value = append_fp8_kv_cache(k_cache, v_cache, key, value) + + attn_output, attn_weight = core_attn( + self, query, key, value, causal_mask, attention_mask, head_mask + ) + + else: + bsz = key.size(0) + if layer_past is not None: + cache_k, cache_v = layer_past[0], layer_past[1] + cache_k = cache_k.transpose(1, 2) + cache_v = cache_v.transpose(1, 2) + kv_seq_len += cache_k.shape[2] + if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3): + # allocate new + new_cache_k, new_cache_v = extend_kv_cache(bsz, + self.num_heads, + self.head_dim, + cache_k.size(2), + kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH, + dtype=cache_k.dtype, + device=hidden_states.device) + new_cache_k[:] = cache_k + new_cache_v[:] = cache_v + cache_k = new_cache_k + cache_v = new_cache_v + + key_states, value_states = append_kv_cache(cache_k, cache_v, + key.transpose(1, 2), value.transpose(1, 2)) + key = key_states + value = value_states + elif use_cache: + max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH + new_key_states, new_value_states = init_kv_cache(bsz, + self.num_heads, + self.head_dim, + kv_seq_len, + max_cache_length, + dtype=key.dtype, + device=hidden_states.device) + new_key_states[:] = key.transpose(1, 2) + new_value_states[:] = value.transpose(1, 2) + key = new_key_states + value = new_value_states + + query = query.transpose(1, 2) + + attn_output, attn_weight = self._attn( + query, key, value, causal_mask, attention_mask, head_mask + ) + context_layer = self._merge_heads( attn_output, self.num_heads, self.head_dim ) @@ -191,6 +226,54 @@ def qwen_attention_forward( return outputs +def core_attn(self, query, key, value, causal_mask=None, attention_mask=None, head_mask=None): + if query.size(2) != 1 or query.device.type != 'xpu': + # We have no CPU fp8 matmul implementation for now, so just upscale to fp32 + key, value = restore_fp8_kv_cache(key, value, query.dtype) + attn_weights = torch.matmul(query, key.transpose(-1, -2)) + else: + import linear_q4_0 + attn_weights = linear_q4_0.query_key_fp8_matmul(query, key) + + if self.scale_attn_weights: + if self.use_cache_quantization: + size_temp = value[0].size(-1) + else: + size_temp = value.size(-1) + attn_weights = attn_weights / (size_temp ** 0.5) + + mask_value = torch.finfo(attn_weights.dtype).min + if causal_mask is not None: + attn_weights = torch.where( + causal_mask, attn_weights.to(attn_weights.dtype), mask_value + ) + + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + if self.softmax_in_fp32: + attn_weights = torch.nn.functional.softmax(attn_weights.float(), dim=-1) + else: + attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1) + + attn_weights = attn_weights.type(query.dtype) + attn_weights = self.attn_dropout(attn_weights) + + if head_mask is not None: + attn_weights = attn_weights * head_mask + + if query.size(2) != 1 or query.device.type != 'xpu': + # We have no CPU fp8 matmul implementation for now, so just upscale to fp32 + attn_output = torch.matmul(attn_weights, value) + else: + import linear_q4_0 + attn_output = linear_q4_0.attn_value_fp8_matmul(attn_weights, value.transpose(-1, -2)) + + attn_output = attn_output.transpose(1, 2) + + return attn_output, attn_weights + + def qwen_mlp_forward(self, x: torch.Tensor) -> torch.Tensor: x_2d = x.view(-1, x.shape[-1]) if x_2d.shape[0] == 1 and x.device.type == 'xpu' \ diff --git a/python/llm/src/bigdl/llm/transformers/models/utils.py b/python/llm/src/bigdl/llm/transformers/models/utils.py index 5736bcd1..8502ec63 100644 --- a/python/llm/src/bigdl/llm/transformers/models/utils.py +++ b/python/llm/src/bigdl/llm/transformers/models/utils.py @@ -14,8 +14,10 @@ # limitations under the License. # +import os import torch from bigdl.llm.utils.common import invalidInputError +from bigdl.llm.ggml.quantize import ggml_tensor_qtype from bigdl.llm.transformers.utils import get_ipex_version @@ -57,6 +59,66 @@ def append_kv_cache(cache_k, cache_v, key_states, value_states): return new_cache_k, new_cache_v +def quantize_kv_cache(linear: torch.nn.Module, x: torch.Tensor) -> bool: + if os.environ.get("BIGDL_QUANTIZE_KV_CACHE", None) is not None: + return os.environ["BIGDL_QUANTIZE_KV_CACHE"] == "1" + else: + return x.device.type == 'xpu' and hasattr(linear, "qtype") and \ + linear.qtype != ggml_tensor_qtype["fp16"] and linear.qtype != ggml_tensor_qtype["bf16"] + + +def init_fp8_kv_cache(batch_size, num_heads, head_dim, current_length, max_length, device): + k_cache_storage = torch.empty(batch_size, num_heads, max_length, head_dim, + dtype=torch.uint8, device=device) + + v_cache_storage = torch.empty(batch_size, num_heads, head_dim, max_length, + dtype=torch.uint8, device=device) + + k_cache = k_cache_storage.as_strided((batch_size, num_heads, current_length, head_dim), + k_cache_storage.stride(), storage_offset=0) + + v_cache = v_cache_storage.as_strided((batch_size, num_heads, head_dim, current_length), + v_cache_storage.stride(), storage_offset=0) + + return k_cache, v_cache.transpose(-1, -2) + + +def extend_fp8_kv_cache(k_cache, v_cache, max_length, device): + batch_size, num_heads, cur_length, head_dim = k_cache.shape + new_k_cache, new_v_cache = init_fp8_kv_cache(batch_size, num_heads, head_dim, + cur_length, max_length, device) + new_k_cache[:] = k_cache + new_v_cache[:] = v_cache + return new_k_cache, new_v_cache + + +def append_fp8_kv_cache(k_cache, v_cache, key, value): + batch_size, num_heads, cur_length, head_dim = k_cache.shape + new_length = cur_length + key.size(2) + new_size = (batch_size, num_heads, new_length, head_dim) + + new_k_cache = k_cache.as_strided(new_size, k_cache.stride(), storage_offset=0) + new_v_cache = v_cache.as_strided(new_size, v_cache.stride(), storage_offset=0) + + fp8_key = key.half().view(torch.uint8)[:, :, :, 1::2] + new_k_cache[:, :, cur_length:new_length, :] = fp8_key + fp8_value = value.half().view(torch.uint8)[:, :, :, 1::2] + new_v_cache[:, :, cur_length:new_length, :] = fp8_value + + return new_k_cache, new_v_cache + + +def restore_fp8_kv_cache(k_cache, v_cache, dtype): + new_k_cache = torch.full(k_cache.shape, 128, dtype=torch.int16, device=k_cache.device) + new_k_cache.view(torch.uint8)[:, :, :, 1::2] = k_cache + new_k_cache = new_k_cache.view(torch.half) + new_v_cache = torch.full(v_cache.shape, 128, dtype=torch.int16, device=v_cache.device) + new_v_cache.view(torch.uint8)[:, :, :, 1::2] = v_cache + new_v_cache = new_v_cache.view(torch.half) + + return new_k_cache.to(dtype=dtype), new_v_cache.to(dtype=dtype) + + def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., :x.shape[-1] // 2]