[LLM] support quantize kv cache to fp8 (#9812)
This commit is contained in:
parent
248ae7fad2
commit
afaa871144
2 changed files with 212 additions and 67 deletions
|
|
@ -37,7 +37,9 @@ except ImportError:
|
|||
rearrange = None
|
||||
|
||||
from bigdl.llm.transformers.models.utils import extend_kv_cache, init_kv_cache, append_kv_cache
|
||||
from bigdl.llm.transformers.models.utils import rotate_half
|
||||
from bigdl.llm.transformers.models.utils import init_fp8_kv_cache, extend_fp8_kv_cache, \
|
||||
append_fp8_kv_cache, restore_fp8_kv_cache
|
||||
from bigdl.llm.transformers.models.utils import rotate_half, quantize_kv_cache
|
||||
from bigdl.llm.utils.common import invalidInputError, invalidOperationError
|
||||
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
|
||||
|
||||
|
|
@ -83,29 +85,18 @@ def qwen_attention_forward(
|
|||
query = self._split_heads(query, self.num_heads, self.head_dim)
|
||||
key = self._split_heads(key, self.num_heads, self.head_dim)
|
||||
value = self._split_heads(value, self.num_heads, self.head_dim)
|
||||
|
||||
kv_seq_len = hidden_states.size()[1]
|
||||
# query, key, value's shape: [bs, seq_len, num_heads, head_dim]
|
||||
|
||||
if rotary_pos_emb_list is not None:
|
||||
cur_len = query.shape[1]
|
||||
if len(rotary_pos_emb_list) == 1:
|
||||
if query.device.type == 'xpu':
|
||||
cos, sin = rotary_pos_emb_list[0]
|
||||
cos = cos[:, -cur_len:, :, :]
|
||||
sin = sin[:, -cur_len:, :, :]
|
||||
rot_dim = cos.shape[-1]
|
||||
query_cur = query[..., :rot_dim]
|
||||
key_cur = key[..., :rot_dim]
|
||||
torch.ops.torch_ipex.apply_rotary_embedding(query_cur, sin, cos, query_cur)
|
||||
torch.ops.torch_ipex.apply_rotary_embedding(key_cur, sin, cos, key_cur)
|
||||
else:
|
||||
rotary_pos_emb = rotary_pos_emb_list[0]
|
||||
rotary_pos_emb = [i[:, -cur_len:, :, :] for i in rotary_pos_emb]
|
||||
rotary_pos_emb = (rotary_pos_emb,) * 2
|
||||
q_pos_emb, k_pos_emb = rotary_pos_emb
|
||||
# Slice the pos emb for current inference
|
||||
query = apply_rotary_pos_emb(query, q_pos_emb)
|
||||
key = apply_rotary_pos_emb(key, k_pos_emb)
|
||||
rotary_pos_emb = rotary_pos_emb_list[0]
|
||||
rotary_pos_emb = [i[:, -cur_len:, :, :] for i in rotary_pos_emb]
|
||||
rotary_pos_emb = (rotary_pos_emb,) * 2
|
||||
q_pos_emb, k_pos_emb = rotary_pos_emb
|
||||
# Slice the pos emb for current inference
|
||||
query = apply_rotary_pos_emb(query, q_pos_emb)
|
||||
key = apply_rotary_pos_emb(key, k_pos_emb)
|
||||
else:
|
||||
query_list = []
|
||||
key_list = []
|
||||
|
|
@ -119,62 +110,106 @@ def qwen_attention_forward(
|
|||
query = torch.cat(query_list, dim=0)
|
||||
key = torch.cat(key_list, dim=0)
|
||||
|
||||
bsz, _, n_heads, head_dim = key.size()
|
||||
query_size, key_size = query.size(1), key.size(1)
|
||||
kv_seq_len = key_size if layer_past is None else key_size + layer_past[0].size(1)
|
||||
|
||||
if layer_past is not None:
|
||||
cache_k, cache_v = layer_past[0], layer_past[1]
|
||||
cache_k = cache_k.transpose(1, 2)
|
||||
cache_v = cache_v.transpose(1, 2)
|
||||
kv_seq_len += cache_k.shape[2]
|
||||
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
|
||||
# allocate new
|
||||
new_cache_k, new_cache_v = extend_kv_cache(bsz,
|
||||
self.num_heads,
|
||||
self.head_dim,
|
||||
cache_k.size(2),
|
||||
kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH,
|
||||
dtype=cache_k.dtype,
|
||||
device=hidden_states.device)
|
||||
new_cache_k[:] = cache_k
|
||||
new_cache_v[:] = cache_v
|
||||
cache_k = new_cache_k
|
||||
cache_v = new_cache_v
|
||||
|
||||
key_states, value_states = append_kv_cache(cache_k, cache_v,
|
||||
key.transpose(1, 2), value.transpose(1, 2))
|
||||
key = key_states
|
||||
value = value_states
|
||||
elif use_cache:
|
||||
max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH
|
||||
new_key_states, new_value_states = init_kv_cache(bsz,
|
||||
self.num_heads,
|
||||
self.head_dim,
|
||||
kv_seq_len,
|
||||
max_cache_length,
|
||||
dtype=key.dtype,
|
||||
device=hidden_states.device)
|
||||
new_key_states[:] = key.transpose(1, 2)
|
||||
new_value_states[:] = value.transpose(1, 2)
|
||||
key = new_key_states
|
||||
value = new_value_states
|
||||
|
||||
query_size, key_size = query.size(1), key.size(2)
|
||||
if key_size > self.seq_length and self.use_logn_attn and not self.training:
|
||||
seq_start = key_size - query_size
|
||||
seq_end = key_size
|
||||
if kv_seq_len > self.seq_length and self.use_logn_attn and not self.training:
|
||||
seq_start = kv_seq_len - query_size
|
||||
seq_end = kv_seq_len
|
||||
logn_tensor = self.logn_tensor[:, seq_start:seq_end, :, :].type_as(query)
|
||||
query = query * logn_tensor.expand_as(query)
|
||||
if query_size == key_size:
|
||||
if key_size == kv_seq_len:
|
||||
causal_mask = torch.tril(
|
||||
torch.ones((key_size, key_size), dtype=torch.bool, device=query.device)
|
||||
).view(1, 1, key_size, key_size)
|
||||
else:
|
||||
causal_mask = None
|
||||
query = query.transpose(1, 2)
|
||||
|
||||
attn_output, attn_weight = self._attn(
|
||||
query, key, value, causal_mask, attention_mask, head_mask
|
||||
)
|
||||
if quantize_kv_cache(self.c_attn, hidden_states):
|
||||
query, key, value = query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2)
|
||||
# query, key, value's shape: [bs, num_heads, seq_len, head_dim]
|
||||
|
||||
if layer_past is None:
|
||||
# For first token, use original attn
|
||||
attn_output, attn_weight = self._attn(
|
||||
query, key, value, causal_mask, attention_mask, head_mask
|
||||
)
|
||||
if use_cache:
|
||||
max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH
|
||||
k_cache, v_cache = init_fp8_kv_cache(
|
||||
query.size(0), self.num_heads, self.head_dim,
|
||||
0, max_cache_length,
|
||||
device=query.device,
|
||||
)
|
||||
key, value = append_fp8_kv_cache(k_cache, v_cache, key, value)
|
||||
else:
|
||||
k_cache, v_cache = layer_past[0], layer_past[1]
|
||||
k_cache = k_cache.transpose(1, 2)
|
||||
v_cache = v_cache.transpose(1, 2)
|
||||
# k_cache and v_cache's shape: [bs, num_heads, context_length, head_dim]
|
||||
|
||||
if k_cache.stride(1) < kv_seq_len * k_cache.size(3):
|
||||
# allocate new
|
||||
k_cache, v_cache = extend_fp8_kv_cache(
|
||||
k_cache, v_cache,
|
||||
kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH,
|
||||
device=query.device,
|
||||
)
|
||||
# empty cache to reduce gpu memory
|
||||
if v_cache.device.type == 'xpu':
|
||||
torch.xpu.empty_cache()
|
||||
|
||||
key, value = append_fp8_kv_cache(k_cache, v_cache, key, value)
|
||||
|
||||
attn_output, attn_weight = core_attn(
|
||||
self, query, key, value, causal_mask, attention_mask, head_mask
|
||||
)
|
||||
|
||||
else:
|
||||
bsz = key.size(0)
|
||||
if layer_past is not None:
|
||||
cache_k, cache_v = layer_past[0], layer_past[1]
|
||||
cache_k = cache_k.transpose(1, 2)
|
||||
cache_v = cache_v.transpose(1, 2)
|
||||
kv_seq_len += cache_k.shape[2]
|
||||
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
|
||||
# allocate new
|
||||
new_cache_k, new_cache_v = extend_kv_cache(bsz,
|
||||
self.num_heads,
|
||||
self.head_dim,
|
||||
cache_k.size(2),
|
||||
kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH,
|
||||
dtype=cache_k.dtype,
|
||||
device=hidden_states.device)
|
||||
new_cache_k[:] = cache_k
|
||||
new_cache_v[:] = cache_v
|
||||
cache_k = new_cache_k
|
||||
cache_v = new_cache_v
|
||||
|
||||
key_states, value_states = append_kv_cache(cache_k, cache_v,
|
||||
key.transpose(1, 2), value.transpose(1, 2))
|
||||
key = key_states
|
||||
value = value_states
|
||||
elif use_cache:
|
||||
max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH
|
||||
new_key_states, new_value_states = init_kv_cache(bsz,
|
||||
self.num_heads,
|
||||
self.head_dim,
|
||||
kv_seq_len,
|
||||
max_cache_length,
|
||||
dtype=key.dtype,
|
||||
device=hidden_states.device)
|
||||
new_key_states[:] = key.transpose(1, 2)
|
||||
new_value_states[:] = value.transpose(1, 2)
|
||||
key = new_key_states
|
||||
value = new_value_states
|
||||
|
||||
query = query.transpose(1, 2)
|
||||
|
||||
attn_output, attn_weight = self._attn(
|
||||
query, key, value, causal_mask, attention_mask, head_mask
|
||||
)
|
||||
|
||||
context_layer = self._merge_heads(
|
||||
attn_output, self.num_heads, self.head_dim
|
||||
)
|
||||
|
|
@ -191,6 +226,54 @@ def qwen_attention_forward(
|
|||
return outputs
|
||||
|
||||
|
||||
def core_attn(self, query, key, value, causal_mask=None, attention_mask=None, head_mask=None):
|
||||
if query.size(2) != 1 or query.device.type != 'xpu':
|
||||
# We have no CPU fp8 matmul implementation for now, so just upscale to fp32
|
||||
key, value = restore_fp8_kv_cache(key, value, query.dtype)
|
||||
attn_weights = torch.matmul(query, key.transpose(-1, -2))
|
||||
else:
|
||||
import linear_q4_0
|
||||
attn_weights = linear_q4_0.query_key_fp8_matmul(query, key)
|
||||
|
||||
if self.scale_attn_weights:
|
||||
if self.use_cache_quantization:
|
||||
size_temp = value[0].size(-1)
|
||||
else:
|
||||
size_temp = value.size(-1)
|
||||
attn_weights = attn_weights / (size_temp ** 0.5)
|
||||
|
||||
mask_value = torch.finfo(attn_weights.dtype).min
|
||||
if causal_mask is not None:
|
||||
attn_weights = torch.where(
|
||||
causal_mask, attn_weights.to(attn_weights.dtype), mask_value
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
if self.softmax_in_fp32:
|
||||
attn_weights = torch.nn.functional.softmax(attn_weights.float(), dim=-1)
|
||||
else:
|
||||
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)
|
||||
|
||||
attn_weights = attn_weights.type(query.dtype)
|
||||
attn_weights = self.attn_dropout(attn_weights)
|
||||
|
||||
if head_mask is not None:
|
||||
attn_weights = attn_weights * head_mask
|
||||
|
||||
if query.size(2) != 1 or query.device.type != 'xpu':
|
||||
# We have no CPU fp8 matmul implementation for now, so just upscale to fp32
|
||||
attn_output = torch.matmul(attn_weights, value)
|
||||
else:
|
||||
import linear_q4_0
|
||||
attn_output = linear_q4_0.attn_value_fp8_matmul(attn_weights, value.transpose(-1, -2))
|
||||
|
||||
attn_output = attn_output.transpose(1, 2)
|
||||
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
def qwen_mlp_forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x_2d = x.view(-1, x.shape[-1])
|
||||
if x_2d.shape[0] == 1 and x.device.type == 'xpu' \
|
||||
|
|
|
|||
|
|
@ -14,8 +14,10 @@
|
|||
# limitations under the License.
|
||||
#
|
||||
|
||||
import os
|
||||
import torch
|
||||
from bigdl.llm.utils.common import invalidInputError
|
||||
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
|
||||
from bigdl.llm.transformers.utils import get_ipex_version
|
||||
|
||||
|
||||
|
|
@ -57,6 +59,66 @@ def append_kv_cache(cache_k, cache_v, key_states, value_states):
|
|||
return new_cache_k, new_cache_v
|
||||
|
||||
|
||||
def quantize_kv_cache(linear: torch.nn.Module, x: torch.Tensor) -> bool:
|
||||
if os.environ.get("BIGDL_QUANTIZE_KV_CACHE", None) is not None:
|
||||
return os.environ["BIGDL_QUANTIZE_KV_CACHE"] == "1"
|
||||
else:
|
||||
return x.device.type == 'xpu' and hasattr(linear, "qtype") and \
|
||||
linear.qtype != ggml_tensor_qtype["fp16"] and linear.qtype != ggml_tensor_qtype["bf16"]
|
||||
|
||||
|
||||
def init_fp8_kv_cache(batch_size, num_heads, head_dim, current_length, max_length, device):
|
||||
k_cache_storage = torch.empty(batch_size, num_heads, max_length, head_dim,
|
||||
dtype=torch.uint8, device=device)
|
||||
|
||||
v_cache_storage = torch.empty(batch_size, num_heads, head_dim, max_length,
|
||||
dtype=torch.uint8, device=device)
|
||||
|
||||
k_cache = k_cache_storage.as_strided((batch_size, num_heads, current_length, head_dim),
|
||||
k_cache_storage.stride(), storage_offset=0)
|
||||
|
||||
v_cache = v_cache_storage.as_strided((batch_size, num_heads, head_dim, current_length),
|
||||
v_cache_storage.stride(), storage_offset=0)
|
||||
|
||||
return k_cache, v_cache.transpose(-1, -2)
|
||||
|
||||
|
||||
def extend_fp8_kv_cache(k_cache, v_cache, max_length, device):
|
||||
batch_size, num_heads, cur_length, head_dim = k_cache.shape
|
||||
new_k_cache, new_v_cache = init_fp8_kv_cache(batch_size, num_heads, head_dim,
|
||||
cur_length, max_length, device)
|
||||
new_k_cache[:] = k_cache
|
||||
new_v_cache[:] = v_cache
|
||||
return new_k_cache, new_v_cache
|
||||
|
||||
|
||||
def append_fp8_kv_cache(k_cache, v_cache, key, value):
|
||||
batch_size, num_heads, cur_length, head_dim = k_cache.shape
|
||||
new_length = cur_length + key.size(2)
|
||||
new_size = (batch_size, num_heads, new_length, head_dim)
|
||||
|
||||
new_k_cache = k_cache.as_strided(new_size, k_cache.stride(), storage_offset=0)
|
||||
new_v_cache = v_cache.as_strided(new_size, v_cache.stride(), storage_offset=0)
|
||||
|
||||
fp8_key = key.half().view(torch.uint8)[:, :, :, 1::2]
|
||||
new_k_cache[:, :, cur_length:new_length, :] = fp8_key
|
||||
fp8_value = value.half().view(torch.uint8)[:, :, :, 1::2]
|
||||
new_v_cache[:, :, cur_length:new_length, :] = fp8_value
|
||||
|
||||
return new_k_cache, new_v_cache
|
||||
|
||||
|
||||
def restore_fp8_kv_cache(k_cache, v_cache, dtype):
|
||||
new_k_cache = torch.full(k_cache.shape, 128, dtype=torch.int16, device=k_cache.device)
|
||||
new_k_cache.view(torch.uint8)[:, :, :, 1::2] = k_cache
|
||||
new_k_cache = new_k_cache.view(torch.half)
|
||||
new_v_cache = torch.full(v_cache.shape, 128, dtype=torch.int16, device=v_cache.device)
|
||||
new_v_cache.view(torch.uint8)[:, :, :, 1::2] = v_cache
|
||||
new_v_cache = new_v_cache.view(torch.half)
|
||||
|
||||
return new_k_cache.to(dtype=dtype), new_v_cache.to(dtype=dtype)
|
||||
|
||||
|
||||
def rotate_half(x):
|
||||
"""Rotates half the hidden dims of the input."""
|
||||
x1 = x[..., :x.shape[-1] // 2]
|
||||
|
|
|
|||
Loading…
Reference in a new issue