[LLM] Add quantize kv_cache for Baichuan2-13B (#10203)
* add quantize kv_cache for baichuan2-13b * style fix
This commit is contained in:
parent
34ee1aa91f
commit
ca1166a0e5
1 changed files with 128 additions and 0 deletions
|
|
@ -24,6 +24,8 @@ import torch
|
||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
|
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
|
||||||
|
from bigdl.llm.transformers.models.utils import init_fp8_kv_cache, append_fp8_kv_cache, \
|
||||||
|
restore_fp8_kv_cache, use_quantize_kv_cache
|
||||||
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, \
|
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, \
|
||||||
append_kv_cache, is_enough_kv_cache_room_4_31
|
append_kv_cache, is_enough_kv_cache_room_4_31
|
||||||
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb
|
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb
|
||||||
|
|
@ -197,6 +199,132 @@ def baichuan_attention_forward_13b(
|
||||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
use_cache: bool = False,
|
use_cache: bool = False,
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
|
if use_quantize_kv_cache(self.W_pack, hidden_states):
|
||||||
|
forward_function = baichuan_attention_forward_13b_quantized
|
||||||
|
else:
|
||||||
|
forward_function = baichuan_attention_forward_13b_origin
|
||||||
|
return forward_function(
|
||||||
|
self=self,
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
past_key_value=past_key_value,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
use_cache=use_cache,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def baichuan_attention_forward_13b_quantized(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
use_cache: bool = False,
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
device = hidden_states.device
|
||||||
|
|
||||||
|
proj = self.W_pack(hidden_states)
|
||||||
|
proj = (
|
||||||
|
proj.unflatten(-1, (3, self.hidden_size))
|
||||||
|
.unsqueeze(0)
|
||||||
|
.transpose(0, -2)
|
||||||
|
.squeeze(-2)
|
||||||
|
)
|
||||||
|
query_states = (
|
||||||
|
proj[0].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
)
|
||||||
|
key_states = (
|
||||||
|
proj[1].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
)
|
||||||
|
value_states = (
|
||||||
|
proj[2].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
)
|
||||||
|
|
||||||
|
kv_seq_len = key_states.shape[-2]
|
||||||
|
|
||||||
|
if past_key_value is None:
|
||||||
|
# should use origin attn here
|
||||||
|
attn_weights = torch.matmul(query_states,
|
||||||
|
key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||||
|
|
||||||
|
if attention_mask is not None:
|
||||||
|
if q_len == 1: # inference with cache
|
||||||
|
if len(attention_mask.size()) == 4:
|
||||||
|
attention_mask = attention_mask[:, :, -1:, :]
|
||||||
|
else:
|
||||||
|
attention_mask = attention_mask[:, -1:, :]
|
||||||
|
attn_weights = attn_weights + attention_mask
|
||||||
|
attn_weights = torch.max(
|
||||||
|
attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)
|
||||||
|
attn_output = torch.matmul(attn_weights, value_states)
|
||||||
|
|
||||||
|
if use_cache:
|
||||||
|
k_cache, v_cache = init_fp8_kv_cache(
|
||||||
|
bsz, self.num_heads, kv_seq_len, self.head_dim,
|
||||||
|
device=device
|
||||||
|
)
|
||||||
|
key_states, value_states = append_fp8_kv_cache(k_cache, v_cache,
|
||||||
|
key_states, value_states)
|
||||||
|
past_key_value = (key_states, value_states)
|
||||||
|
|
||||||
|
else:
|
||||||
|
k_cache, v_cache = past_key_value
|
||||||
|
key_states, value_states = append_fp8_kv_cache(k_cache, v_cache,
|
||||||
|
key_states, value_states)
|
||||||
|
past_key_value = (key_states, value_states)
|
||||||
|
|
||||||
|
if query_states.size(2) != 1 or device.type != 'xpu':
|
||||||
|
key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
|
||||||
|
query_states.dtype)
|
||||||
|
|
||||||
|
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3))
|
||||||
|
else:
|
||||||
|
import linear_q4_0
|
||||||
|
attn_weights = linear_q4_0.query_key_fp8_matmul(query_states, key_states)
|
||||||
|
|
||||||
|
attn_weights = attn_weights / math.sqrt(self.head_dim)
|
||||||
|
|
||||||
|
if attention_mask is not None:
|
||||||
|
if q_len == 1: # inference with cache
|
||||||
|
if len(attention_mask.size()) == 4:
|
||||||
|
attention_mask = attention_mask[:, :, -1:, :]
|
||||||
|
else:
|
||||||
|
attention_mask = attention_mask[:, -1:, :]
|
||||||
|
attn_weights = attn_weights + attention_mask
|
||||||
|
attn_weights = torch.max(attn_weights,
|
||||||
|
torch.tensor(torch.finfo(attn_weights.dtype).min))
|
||||||
|
|
||||||
|
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)
|
||||||
|
|
||||||
|
if query_states.size(2) != 1 or device.type != 'xpu':
|
||||||
|
attn_output = torch.matmul(attn_weights, value_states)
|
||||||
|
else:
|
||||||
|
import linear_q4_0
|
||||||
|
attn_output = linear_q4_0.attn_value_fp8_matmul(attn_weights,
|
||||||
|
value_states.transpose(-1, -2))
|
||||||
|
|
||||||
|
attn_output = attn_output.transpose(1, 2)
|
||||||
|
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||||
|
attn_output = self.o_proj(attn_output)
|
||||||
|
|
||||||
|
if not output_attentions:
|
||||||
|
attn_weights = None
|
||||||
|
|
||||||
|
return attn_output, attn_weights, past_key_value
|
||||||
|
|
||||||
|
|
||||||
|
def baichuan_attention_forward_13b_origin(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
use_cache: bool = False,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
bsz, q_len, _ = hidden_states.size()
|
bsz, q_len, _ = hidden_states.size()
|
||||||
device = hidden_states.device
|
device = hidden_states.device
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue