[LLM] support quantize kv cache to fp8 (#9812)

This commit is contained in:
Yishuo Wang 2024-01-08 09:28:20 +08:00 committed by GitHub
parent 248ae7fad2
commit afaa871144
2 changed files with 212 additions and 67 deletions

View file

@ -37,7 +37,9 @@ except ImportError:
rearrange = None
from bigdl.llm.transformers.models.utils import extend_kv_cache, init_kv_cache, append_kv_cache
from bigdl.llm.transformers.models.utils import rotate_half
from bigdl.llm.transformers.models.utils import init_fp8_kv_cache, extend_fp8_kv_cache, \
append_fp8_kv_cache, restore_fp8_kv_cache
from bigdl.llm.transformers.models.utils import rotate_half, quantize_kv_cache
from bigdl.llm.utils.common import invalidInputError, invalidOperationError
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
@ -83,29 +85,18 @@ def qwen_attention_forward(
query = self._split_heads(query, self.num_heads, self.head_dim)
key = self._split_heads(key, self.num_heads, self.head_dim)
value = self._split_heads(value, self.num_heads, self.head_dim)
kv_seq_len = hidden_states.size()[1]
# query, key, value's shape: [bs, seq_len, num_heads, head_dim]
if rotary_pos_emb_list is not None:
cur_len = query.shape[1]
if len(rotary_pos_emb_list) == 1:
if query.device.type == 'xpu':
cos, sin = rotary_pos_emb_list[0]
cos = cos[:, -cur_len:, :, :]
sin = sin[:, -cur_len:, :, :]
rot_dim = cos.shape[-1]
query_cur = query[..., :rot_dim]
key_cur = key[..., :rot_dim]
torch.ops.torch_ipex.apply_rotary_embedding(query_cur, sin, cos, query_cur)
torch.ops.torch_ipex.apply_rotary_embedding(key_cur, sin, cos, key_cur)
else:
rotary_pos_emb = rotary_pos_emb_list[0]
rotary_pos_emb = [i[:, -cur_len:, :, :] for i in rotary_pos_emb]
rotary_pos_emb = (rotary_pos_emb,) * 2
q_pos_emb, k_pos_emb = rotary_pos_emb
# Slice the pos emb for current inference
query = apply_rotary_pos_emb(query, q_pos_emb)
key = apply_rotary_pos_emb(key, k_pos_emb)
rotary_pos_emb = rotary_pos_emb_list[0]
rotary_pos_emb = [i[:, -cur_len:, :, :] for i in rotary_pos_emb]
rotary_pos_emb = (rotary_pos_emb,) * 2
q_pos_emb, k_pos_emb = rotary_pos_emb
# Slice the pos emb for current inference
query = apply_rotary_pos_emb(query, q_pos_emb)
key = apply_rotary_pos_emb(key, k_pos_emb)
else:
query_list = []
key_list = []
@ -119,62 +110,106 @@ def qwen_attention_forward(
query = torch.cat(query_list, dim=0)
key = torch.cat(key_list, dim=0)
bsz, _, n_heads, head_dim = key.size()
query_size, key_size = query.size(1), key.size(1)
kv_seq_len = key_size if layer_past is None else key_size + layer_past[0].size(1)
if layer_past is not None:
cache_k, cache_v = layer_past[0], layer_past[1]
cache_k = cache_k.transpose(1, 2)
cache_v = cache_v.transpose(1, 2)
kv_seq_len += cache_k.shape[2]
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
# allocate new
new_cache_k, new_cache_v = extend_kv_cache(bsz,
self.num_heads,
self.head_dim,
cache_k.size(2),
kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH,
dtype=cache_k.dtype,
device=hidden_states.device)
new_cache_k[:] = cache_k
new_cache_v[:] = cache_v
cache_k = new_cache_k
cache_v = new_cache_v
key_states, value_states = append_kv_cache(cache_k, cache_v,
key.transpose(1, 2), value.transpose(1, 2))
key = key_states
value = value_states
elif use_cache:
max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH
new_key_states, new_value_states = init_kv_cache(bsz,
self.num_heads,
self.head_dim,
kv_seq_len,
max_cache_length,
dtype=key.dtype,
device=hidden_states.device)
new_key_states[:] = key.transpose(1, 2)
new_value_states[:] = value.transpose(1, 2)
key = new_key_states
value = new_value_states
query_size, key_size = query.size(1), key.size(2)
if key_size > self.seq_length and self.use_logn_attn and not self.training:
seq_start = key_size - query_size
seq_end = key_size
if kv_seq_len > self.seq_length and self.use_logn_attn and not self.training:
seq_start = kv_seq_len - query_size
seq_end = kv_seq_len
logn_tensor = self.logn_tensor[:, seq_start:seq_end, :, :].type_as(query)
query = query * logn_tensor.expand_as(query)
if query_size == key_size:
if key_size == kv_seq_len:
causal_mask = torch.tril(
torch.ones((key_size, key_size), dtype=torch.bool, device=query.device)
).view(1, 1, key_size, key_size)
else:
causal_mask = None
query = query.transpose(1, 2)
attn_output, attn_weight = self._attn(
query, key, value, causal_mask, attention_mask, head_mask
)
if quantize_kv_cache(self.c_attn, hidden_states):
query, key, value = query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2)
# query, key, value's shape: [bs, num_heads, seq_len, head_dim]
if layer_past is None:
# For first token, use original attn
attn_output, attn_weight = self._attn(
query, key, value, causal_mask, attention_mask, head_mask
)
if use_cache:
max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH
k_cache, v_cache = init_fp8_kv_cache(
query.size(0), self.num_heads, self.head_dim,
0, max_cache_length,
device=query.device,
)
key, value = append_fp8_kv_cache(k_cache, v_cache, key, value)
else:
k_cache, v_cache = layer_past[0], layer_past[1]
k_cache = k_cache.transpose(1, 2)
v_cache = v_cache.transpose(1, 2)
# k_cache and v_cache's shape: [bs, num_heads, context_length, head_dim]
if k_cache.stride(1) < kv_seq_len * k_cache.size(3):
# allocate new
k_cache, v_cache = extend_fp8_kv_cache(
k_cache, v_cache,
kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH,
device=query.device,
)
# empty cache to reduce gpu memory
if v_cache.device.type == 'xpu':
torch.xpu.empty_cache()
key, value = append_fp8_kv_cache(k_cache, v_cache, key, value)
attn_output, attn_weight = core_attn(
self, query, key, value, causal_mask, attention_mask, head_mask
)
else:
bsz = key.size(0)
if layer_past is not None:
cache_k, cache_v = layer_past[0], layer_past[1]
cache_k = cache_k.transpose(1, 2)
cache_v = cache_v.transpose(1, 2)
kv_seq_len += cache_k.shape[2]
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
# allocate new
new_cache_k, new_cache_v = extend_kv_cache(bsz,
self.num_heads,
self.head_dim,
cache_k.size(2),
kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH,
dtype=cache_k.dtype,
device=hidden_states.device)
new_cache_k[:] = cache_k
new_cache_v[:] = cache_v
cache_k = new_cache_k
cache_v = new_cache_v
key_states, value_states = append_kv_cache(cache_k, cache_v,
key.transpose(1, 2), value.transpose(1, 2))
key = key_states
value = value_states
elif use_cache:
max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH
new_key_states, new_value_states = init_kv_cache(bsz,
self.num_heads,
self.head_dim,
kv_seq_len,
max_cache_length,
dtype=key.dtype,
device=hidden_states.device)
new_key_states[:] = key.transpose(1, 2)
new_value_states[:] = value.transpose(1, 2)
key = new_key_states
value = new_value_states
query = query.transpose(1, 2)
attn_output, attn_weight = self._attn(
query, key, value, causal_mask, attention_mask, head_mask
)
context_layer = self._merge_heads(
attn_output, self.num_heads, self.head_dim
)
@ -191,6 +226,54 @@ def qwen_attention_forward(
return outputs
def core_attn(self, query, key, value, causal_mask=None, attention_mask=None, head_mask=None):
if query.size(2) != 1 or query.device.type != 'xpu':
# We have no CPU fp8 matmul implementation for now, so just upscale to fp32
key, value = restore_fp8_kv_cache(key, value, query.dtype)
attn_weights = torch.matmul(query, key.transpose(-1, -2))
else:
import linear_q4_0
attn_weights = linear_q4_0.query_key_fp8_matmul(query, key)
if self.scale_attn_weights:
if self.use_cache_quantization:
size_temp = value[0].size(-1)
else:
size_temp = value.size(-1)
attn_weights = attn_weights / (size_temp ** 0.5)
mask_value = torch.finfo(attn_weights.dtype).min
if causal_mask is not None:
attn_weights = torch.where(
causal_mask, attn_weights.to(attn_weights.dtype), mask_value
)
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
if self.softmax_in_fp32:
attn_weights = torch.nn.functional.softmax(attn_weights.float(), dim=-1)
else:
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)
attn_weights = attn_weights.type(query.dtype)
attn_weights = self.attn_dropout(attn_weights)
if head_mask is not None:
attn_weights = attn_weights * head_mask
if query.size(2) != 1 or query.device.type != 'xpu':
# We have no CPU fp8 matmul implementation for now, so just upscale to fp32
attn_output = torch.matmul(attn_weights, value)
else:
import linear_q4_0
attn_output = linear_q4_0.attn_value_fp8_matmul(attn_weights, value.transpose(-1, -2))
attn_output = attn_output.transpose(1, 2)
return attn_output, attn_weights
def qwen_mlp_forward(self, x: torch.Tensor) -> torch.Tensor:
x_2d = x.view(-1, x.shape[-1])
if x_2d.shape[0] == 1 and x.device.type == 'xpu' \

View file

@ -14,8 +14,10 @@
# limitations under the License.
#
import os
import torch
from bigdl.llm.utils.common import invalidInputError
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
from bigdl.llm.transformers.utils import get_ipex_version
@ -57,6 +59,66 @@ def append_kv_cache(cache_k, cache_v, key_states, value_states):
return new_cache_k, new_cache_v
def quantize_kv_cache(linear: torch.nn.Module, x: torch.Tensor) -> bool:
if os.environ.get("BIGDL_QUANTIZE_KV_CACHE", None) is not None:
return os.environ["BIGDL_QUANTIZE_KV_CACHE"] == "1"
else:
return x.device.type == 'xpu' and hasattr(linear, "qtype") and \
linear.qtype != ggml_tensor_qtype["fp16"] and linear.qtype != ggml_tensor_qtype["bf16"]
def init_fp8_kv_cache(batch_size, num_heads, head_dim, current_length, max_length, device):
k_cache_storage = torch.empty(batch_size, num_heads, max_length, head_dim,
dtype=torch.uint8, device=device)
v_cache_storage = torch.empty(batch_size, num_heads, head_dim, max_length,
dtype=torch.uint8, device=device)
k_cache = k_cache_storage.as_strided((batch_size, num_heads, current_length, head_dim),
k_cache_storage.stride(), storage_offset=0)
v_cache = v_cache_storage.as_strided((batch_size, num_heads, head_dim, current_length),
v_cache_storage.stride(), storage_offset=0)
return k_cache, v_cache.transpose(-1, -2)
def extend_fp8_kv_cache(k_cache, v_cache, max_length, device):
batch_size, num_heads, cur_length, head_dim = k_cache.shape
new_k_cache, new_v_cache = init_fp8_kv_cache(batch_size, num_heads, head_dim,
cur_length, max_length, device)
new_k_cache[:] = k_cache
new_v_cache[:] = v_cache
return new_k_cache, new_v_cache
def append_fp8_kv_cache(k_cache, v_cache, key, value):
batch_size, num_heads, cur_length, head_dim = k_cache.shape
new_length = cur_length + key.size(2)
new_size = (batch_size, num_heads, new_length, head_dim)
new_k_cache = k_cache.as_strided(new_size, k_cache.stride(), storage_offset=0)
new_v_cache = v_cache.as_strided(new_size, v_cache.stride(), storage_offset=0)
fp8_key = key.half().view(torch.uint8)[:, :, :, 1::2]
new_k_cache[:, :, cur_length:new_length, :] = fp8_key
fp8_value = value.half().view(torch.uint8)[:, :, :, 1::2]
new_v_cache[:, :, cur_length:new_length, :] = fp8_value
return new_k_cache, new_v_cache
def restore_fp8_kv_cache(k_cache, v_cache, dtype):
new_k_cache = torch.full(k_cache.shape, 128, dtype=torch.int16, device=k_cache.device)
new_k_cache.view(torch.uint8)[:, :, :, 1::2] = k_cache
new_k_cache = new_k_cache.view(torch.half)
new_v_cache = torch.full(v_cache.shape, 128, dtype=torch.int16, device=v_cache.device)
new_v_cache.view(torch.uint8)[:, :, :, 1::2] = v_cache
new_v_cache = new_v_cache.view(torch.half)
return new_k_cache.to(dtype=dtype), new_v_cache.to(dtype=dtype)
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., :x.shape[-1] // 2]