[LLM] support quantize kv cache to fp8 (#9812)
This commit is contained in:
parent
248ae7fad2
commit
afaa871144
2 changed files with 212 additions and 67 deletions
|
|
@ -37,7 +37,9 @@ except ImportError:
|
||||||
rearrange = None
|
rearrange = None
|
||||||
|
|
||||||
from bigdl.llm.transformers.models.utils import extend_kv_cache, init_kv_cache, append_kv_cache
|
from bigdl.llm.transformers.models.utils import extend_kv_cache, init_kv_cache, append_kv_cache
|
||||||
from bigdl.llm.transformers.models.utils import rotate_half
|
from bigdl.llm.transformers.models.utils import init_fp8_kv_cache, extend_fp8_kv_cache, \
|
||||||
|
append_fp8_kv_cache, restore_fp8_kv_cache
|
||||||
|
from bigdl.llm.transformers.models.utils import rotate_half, quantize_kv_cache
|
||||||
from bigdl.llm.utils.common import invalidInputError, invalidOperationError
|
from bigdl.llm.utils.common import invalidInputError, invalidOperationError
|
||||||
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
|
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
|
||||||
|
|
||||||
|
|
@ -83,22 +85,11 @@ def qwen_attention_forward(
|
||||||
query = self._split_heads(query, self.num_heads, self.head_dim)
|
query = self._split_heads(query, self.num_heads, self.head_dim)
|
||||||
key = self._split_heads(key, self.num_heads, self.head_dim)
|
key = self._split_heads(key, self.num_heads, self.head_dim)
|
||||||
value = self._split_heads(value, self.num_heads, self.head_dim)
|
value = self._split_heads(value, self.num_heads, self.head_dim)
|
||||||
|
# query, key, value's shape: [bs, seq_len, num_heads, head_dim]
|
||||||
kv_seq_len = hidden_states.size()[1]
|
|
||||||
|
|
||||||
if rotary_pos_emb_list is not None:
|
if rotary_pos_emb_list is not None:
|
||||||
cur_len = query.shape[1]
|
cur_len = query.shape[1]
|
||||||
if len(rotary_pos_emb_list) == 1:
|
if len(rotary_pos_emb_list) == 1:
|
||||||
if query.device.type == 'xpu':
|
|
||||||
cos, sin = rotary_pos_emb_list[0]
|
|
||||||
cos = cos[:, -cur_len:, :, :]
|
|
||||||
sin = sin[:, -cur_len:, :, :]
|
|
||||||
rot_dim = cos.shape[-1]
|
|
||||||
query_cur = query[..., :rot_dim]
|
|
||||||
key_cur = key[..., :rot_dim]
|
|
||||||
torch.ops.torch_ipex.apply_rotary_embedding(query_cur, sin, cos, query_cur)
|
|
||||||
torch.ops.torch_ipex.apply_rotary_embedding(key_cur, sin, cos, key_cur)
|
|
||||||
else:
|
|
||||||
rotary_pos_emb = rotary_pos_emb_list[0]
|
rotary_pos_emb = rotary_pos_emb_list[0]
|
||||||
rotary_pos_emb = [i[:, -cur_len:, :, :] for i in rotary_pos_emb]
|
rotary_pos_emb = [i[:, -cur_len:, :, :] for i in rotary_pos_emb]
|
||||||
rotary_pos_emb = (rotary_pos_emb,) * 2
|
rotary_pos_emb = (rotary_pos_emb,) * 2
|
||||||
|
|
@ -119,8 +110,63 @@ def qwen_attention_forward(
|
||||||
query = torch.cat(query_list, dim=0)
|
query = torch.cat(query_list, dim=0)
|
||||||
key = torch.cat(key_list, dim=0)
|
key = torch.cat(key_list, dim=0)
|
||||||
|
|
||||||
bsz, _, n_heads, head_dim = key.size()
|
query_size, key_size = query.size(1), key.size(1)
|
||||||
|
kv_seq_len = key_size if layer_past is None else key_size + layer_past[0].size(1)
|
||||||
|
|
||||||
|
if kv_seq_len > self.seq_length and self.use_logn_attn and not self.training:
|
||||||
|
seq_start = kv_seq_len - query_size
|
||||||
|
seq_end = kv_seq_len
|
||||||
|
logn_tensor = self.logn_tensor[:, seq_start:seq_end, :, :].type_as(query)
|
||||||
|
query = query * logn_tensor.expand_as(query)
|
||||||
|
if key_size == kv_seq_len:
|
||||||
|
causal_mask = torch.tril(
|
||||||
|
torch.ones((key_size, key_size), dtype=torch.bool, device=query.device)
|
||||||
|
).view(1, 1, key_size, key_size)
|
||||||
|
else:
|
||||||
|
causal_mask = None
|
||||||
|
|
||||||
|
if quantize_kv_cache(self.c_attn, hidden_states):
|
||||||
|
query, key, value = query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2)
|
||||||
|
# query, key, value's shape: [bs, num_heads, seq_len, head_dim]
|
||||||
|
|
||||||
|
if layer_past is None:
|
||||||
|
# For first token, use original attn
|
||||||
|
attn_output, attn_weight = self._attn(
|
||||||
|
query, key, value, causal_mask, attention_mask, head_mask
|
||||||
|
)
|
||||||
|
if use_cache:
|
||||||
|
max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH
|
||||||
|
k_cache, v_cache = init_fp8_kv_cache(
|
||||||
|
query.size(0), self.num_heads, self.head_dim,
|
||||||
|
0, max_cache_length,
|
||||||
|
device=query.device,
|
||||||
|
)
|
||||||
|
key, value = append_fp8_kv_cache(k_cache, v_cache, key, value)
|
||||||
|
else:
|
||||||
|
k_cache, v_cache = layer_past[0], layer_past[1]
|
||||||
|
k_cache = k_cache.transpose(1, 2)
|
||||||
|
v_cache = v_cache.transpose(1, 2)
|
||||||
|
# k_cache and v_cache's shape: [bs, num_heads, context_length, head_dim]
|
||||||
|
|
||||||
|
if k_cache.stride(1) < kv_seq_len * k_cache.size(3):
|
||||||
|
# allocate new
|
||||||
|
k_cache, v_cache = extend_fp8_kv_cache(
|
||||||
|
k_cache, v_cache,
|
||||||
|
kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH,
|
||||||
|
device=query.device,
|
||||||
|
)
|
||||||
|
# empty cache to reduce gpu memory
|
||||||
|
if v_cache.device.type == 'xpu':
|
||||||
|
torch.xpu.empty_cache()
|
||||||
|
|
||||||
|
key, value = append_fp8_kv_cache(k_cache, v_cache, key, value)
|
||||||
|
|
||||||
|
attn_output, attn_weight = core_attn(
|
||||||
|
self, query, key, value, causal_mask, attention_mask, head_mask
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
bsz = key.size(0)
|
||||||
if layer_past is not None:
|
if layer_past is not None:
|
||||||
cache_k, cache_v = layer_past[0], layer_past[1]
|
cache_k, cache_v = layer_past[0], layer_past[1]
|
||||||
cache_k = cache_k.transpose(1, 2)
|
cache_k = cache_k.transpose(1, 2)
|
||||||
|
|
@ -158,23 +204,12 @@ def qwen_attention_forward(
|
||||||
key = new_key_states
|
key = new_key_states
|
||||||
value = new_value_states
|
value = new_value_states
|
||||||
|
|
||||||
query_size, key_size = query.size(1), key.size(2)
|
|
||||||
if key_size > self.seq_length and self.use_logn_attn and not self.training:
|
|
||||||
seq_start = key_size - query_size
|
|
||||||
seq_end = key_size
|
|
||||||
logn_tensor = self.logn_tensor[:, seq_start:seq_end, :, :].type_as(query)
|
|
||||||
query = query * logn_tensor.expand_as(query)
|
|
||||||
if query_size == key_size:
|
|
||||||
causal_mask = torch.tril(
|
|
||||||
torch.ones((key_size, key_size), dtype=torch.bool, device=query.device)
|
|
||||||
).view(1, 1, key_size, key_size)
|
|
||||||
else:
|
|
||||||
causal_mask = None
|
|
||||||
query = query.transpose(1, 2)
|
query = query.transpose(1, 2)
|
||||||
|
|
||||||
attn_output, attn_weight = self._attn(
|
attn_output, attn_weight = self._attn(
|
||||||
query, key, value, causal_mask, attention_mask, head_mask
|
query, key, value, causal_mask, attention_mask, head_mask
|
||||||
)
|
)
|
||||||
|
|
||||||
context_layer = self._merge_heads(
|
context_layer = self._merge_heads(
|
||||||
attn_output, self.num_heads, self.head_dim
|
attn_output, self.num_heads, self.head_dim
|
||||||
)
|
)
|
||||||
|
|
@ -191,6 +226,54 @@ def qwen_attention_forward(
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
def core_attn(self, query, key, value, causal_mask=None, attention_mask=None, head_mask=None):
|
||||||
|
if query.size(2) != 1 or query.device.type != 'xpu':
|
||||||
|
# We have no CPU fp8 matmul implementation for now, so just upscale to fp32
|
||||||
|
key, value = restore_fp8_kv_cache(key, value, query.dtype)
|
||||||
|
attn_weights = torch.matmul(query, key.transpose(-1, -2))
|
||||||
|
else:
|
||||||
|
import linear_q4_0
|
||||||
|
attn_weights = linear_q4_0.query_key_fp8_matmul(query, key)
|
||||||
|
|
||||||
|
if self.scale_attn_weights:
|
||||||
|
if self.use_cache_quantization:
|
||||||
|
size_temp = value[0].size(-1)
|
||||||
|
else:
|
||||||
|
size_temp = value.size(-1)
|
||||||
|
attn_weights = attn_weights / (size_temp ** 0.5)
|
||||||
|
|
||||||
|
mask_value = torch.finfo(attn_weights.dtype).min
|
||||||
|
if causal_mask is not None:
|
||||||
|
attn_weights = torch.where(
|
||||||
|
causal_mask, attn_weights.to(attn_weights.dtype), mask_value
|
||||||
|
)
|
||||||
|
|
||||||
|
if attention_mask is not None:
|
||||||
|
attn_weights = attn_weights + attention_mask
|
||||||
|
|
||||||
|
if self.softmax_in_fp32:
|
||||||
|
attn_weights = torch.nn.functional.softmax(attn_weights.float(), dim=-1)
|
||||||
|
else:
|
||||||
|
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)
|
||||||
|
|
||||||
|
attn_weights = attn_weights.type(query.dtype)
|
||||||
|
attn_weights = self.attn_dropout(attn_weights)
|
||||||
|
|
||||||
|
if head_mask is not None:
|
||||||
|
attn_weights = attn_weights * head_mask
|
||||||
|
|
||||||
|
if query.size(2) != 1 or query.device.type != 'xpu':
|
||||||
|
# We have no CPU fp8 matmul implementation for now, so just upscale to fp32
|
||||||
|
attn_output = torch.matmul(attn_weights, value)
|
||||||
|
else:
|
||||||
|
import linear_q4_0
|
||||||
|
attn_output = linear_q4_0.attn_value_fp8_matmul(attn_weights, value.transpose(-1, -2))
|
||||||
|
|
||||||
|
attn_output = attn_output.transpose(1, 2)
|
||||||
|
|
||||||
|
return attn_output, attn_weights
|
||||||
|
|
||||||
|
|
||||||
def qwen_mlp_forward(self, x: torch.Tensor) -> torch.Tensor:
|
def qwen_mlp_forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
x_2d = x.view(-1, x.shape[-1])
|
x_2d = x.view(-1, x.shape[-1])
|
||||||
if x_2d.shape[0] == 1 and x.device.type == 'xpu' \
|
if x_2d.shape[0] == 1 and x.device.type == 'xpu' \
|
||||||
|
|
|
||||||
|
|
@ -14,8 +14,10 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
|
||||||
|
import os
|
||||||
import torch
|
import torch
|
||||||
from bigdl.llm.utils.common import invalidInputError
|
from bigdl.llm.utils.common import invalidInputError
|
||||||
|
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
|
||||||
from bigdl.llm.transformers.utils import get_ipex_version
|
from bigdl.llm.transformers.utils import get_ipex_version
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -57,6 +59,66 @@ def append_kv_cache(cache_k, cache_v, key_states, value_states):
|
||||||
return new_cache_k, new_cache_v
|
return new_cache_k, new_cache_v
|
||||||
|
|
||||||
|
|
||||||
|
def quantize_kv_cache(linear: torch.nn.Module, x: torch.Tensor) -> bool:
|
||||||
|
if os.environ.get("BIGDL_QUANTIZE_KV_CACHE", None) is not None:
|
||||||
|
return os.environ["BIGDL_QUANTIZE_KV_CACHE"] == "1"
|
||||||
|
else:
|
||||||
|
return x.device.type == 'xpu' and hasattr(linear, "qtype") and \
|
||||||
|
linear.qtype != ggml_tensor_qtype["fp16"] and linear.qtype != ggml_tensor_qtype["bf16"]
|
||||||
|
|
||||||
|
|
||||||
|
def init_fp8_kv_cache(batch_size, num_heads, head_dim, current_length, max_length, device):
|
||||||
|
k_cache_storage = torch.empty(batch_size, num_heads, max_length, head_dim,
|
||||||
|
dtype=torch.uint8, device=device)
|
||||||
|
|
||||||
|
v_cache_storage = torch.empty(batch_size, num_heads, head_dim, max_length,
|
||||||
|
dtype=torch.uint8, device=device)
|
||||||
|
|
||||||
|
k_cache = k_cache_storage.as_strided((batch_size, num_heads, current_length, head_dim),
|
||||||
|
k_cache_storage.stride(), storage_offset=0)
|
||||||
|
|
||||||
|
v_cache = v_cache_storage.as_strided((batch_size, num_heads, head_dim, current_length),
|
||||||
|
v_cache_storage.stride(), storage_offset=0)
|
||||||
|
|
||||||
|
return k_cache, v_cache.transpose(-1, -2)
|
||||||
|
|
||||||
|
|
||||||
|
def extend_fp8_kv_cache(k_cache, v_cache, max_length, device):
|
||||||
|
batch_size, num_heads, cur_length, head_dim = k_cache.shape
|
||||||
|
new_k_cache, new_v_cache = init_fp8_kv_cache(batch_size, num_heads, head_dim,
|
||||||
|
cur_length, max_length, device)
|
||||||
|
new_k_cache[:] = k_cache
|
||||||
|
new_v_cache[:] = v_cache
|
||||||
|
return new_k_cache, new_v_cache
|
||||||
|
|
||||||
|
|
||||||
|
def append_fp8_kv_cache(k_cache, v_cache, key, value):
|
||||||
|
batch_size, num_heads, cur_length, head_dim = k_cache.shape
|
||||||
|
new_length = cur_length + key.size(2)
|
||||||
|
new_size = (batch_size, num_heads, new_length, head_dim)
|
||||||
|
|
||||||
|
new_k_cache = k_cache.as_strided(new_size, k_cache.stride(), storage_offset=0)
|
||||||
|
new_v_cache = v_cache.as_strided(new_size, v_cache.stride(), storage_offset=0)
|
||||||
|
|
||||||
|
fp8_key = key.half().view(torch.uint8)[:, :, :, 1::2]
|
||||||
|
new_k_cache[:, :, cur_length:new_length, :] = fp8_key
|
||||||
|
fp8_value = value.half().view(torch.uint8)[:, :, :, 1::2]
|
||||||
|
new_v_cache[:, :, cur_length:new_length, :] = fp8_value
|
||||||
|
|
||||||
|
return new_k_cache, new_v_cache
|
||||||
|
|
||||||
|
|
||||||
|
def restore_fp8_kv_cache(k_cache, v_cache, dtype):
|
||||||
|
new_k_cache = torch.full(k_cache.shape, 128, dtype=torch.int16, device=k_cache.device)
|
||||||
|
new_k_cache.view(torch.uint8)[:, :, :, 1::2] = k_cache
|
||||||
|
new_k_cache = new_k_cache.view(torch.half)
|
||||||
|
new_v_cache = torch.full(v_cache.shape, 128, dtype=torch.int16, device=v_cache.device)
|
||||||
|
new_v_cache.view(torch.uint8)[:, :, :, 1::2] = v_cache
|
||||||
|
new_v_cache = new_v_cache.view(torch.half)
|
||||||
|
|
||||||
|
return new_k_cache.to(dtype=dtype), new_v_cache.to(dtype=dtype)
|
||||||
|
|
||||||
|
|
||||||
def rotate_half(x):
|
def rotate_half(x):
|
||||||
"""Rotates half the hidden dims of the input."""
|
"""Rotates half the hidden dims of the input."""
|
||||||
x1 = x[..., :x.shape[-1] // 2]
|
x1 = x[..., :x.shape[-1] // 2]
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue