LLM: Support quantize kv cache in mistral. (#10261)
* init * update quantize kv.
This commit is contained in:
parent
db0d129226
commit
a4de3095f3
1 changed files with 205 additions and 0 deletions
|
|
@ -43,6 +43,8 @@ from torch import nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from bigdl.llm.utils.common import invalidInputError
|
from bigdl.llm.utils.common import invalidInputError
|
||||||
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
|
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
|
||||||
|
from bigdl.llm.transformers.models.utils import init_fp8_kv_cache, append_fp8_kv_cache, \
|
||||||
|
restore_fp8_kv_cache, use_quantize_kv_cache
|
||||||
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb, \
|
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb, \
|
||||||
apply_rotary_pos_emb_no_cache_xpu
|
apply_rotary_pos_emb_no_cache_xpu
|
||||||
from bigdl.llm.transformers.models.utils import is_enough_kv_cache_room_4_31, \
|
from bigdl.llm.transformers.models.utils import is_enough_kv_cache_room_4_31, \
|
||||||
|
|
@ -128,6 +130,209 @@ def mistral_attention_forward(
|
||||||
output_attentions: bool=False,
|
output_attentions: bool=False,
|
||||||
use_cache: bool=False,
|
use_cache: bool=False,
|
||||||
padding_mask: Optional[torch.Tensor]=None,
|
padding_mask: Optional[torch.Tensor]=None,
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
|
if use_quantize_kv_cache(self.q_proj, hidden_states):
|
||||||
|
forward_function = mistral_attention_forward_quantized
|
||||||
|
else:
|
||||||
|
forward_function = mistral_attention_forward_original
|
||||||
|
return forward_function(
|
||||||
|
self=self,
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_value=past_key_value,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
use_cache=use_cache,
|
||||||
|
padding_mask=padding_mask
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def mistral_attention_forward_quantized(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor]=None,
|
||||||
|
position_ids: Optional[torch.LongTensor]=None,
|
||||||
|
past_key_value: Optional[Tuple[torch.Tensor]]=None,
|
||||||
|
output_attentions: bool=False,
|
||||||
|
use_cache: bool=False,
|
||||||
|
padding_mask: Optional[torch.Tensor]=None,
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
|
bsz, q_len, hidden_size = hidden_states.size()
|
||||||
|
device = hidden_states.device
|
||||||
|
# for flash attention
|
||||||
|
original_dtype = hidden_states.dtype
|
||||||
|
|
||||||
|
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
|
||||||
|
enough_kv_room = is_enough_kv_cache_room_4_31(past_key_value)
|
||||||
|
decoding_fast_path = use_decoding_fast_path(self.q_proj.qtype,
|
||||||
|
use_fuse_rope,
|
||||||
|
enough_kv_room,
|
||||||
|
bsz * q_len)
|
||||||
|
|
||||||
|
if decoding_fast_path:
|
||||||
|
hidden_states = hidden_states.view(1, -1)
|
||||||
|
tmp_cache_k, tmp_cache_v = init_kv_cache(
|
||||||
|
bsz,
|
||||||
|
self.num_key_value_heads,
|
||||||
|
self.head_dim,
|
||||||
|
0,
|
||||||
|
1,
|
||||||
|
dtype=hidden_states.dtype,
|
||||||
|
device=device
|
||||||
|
)
|
||||||
|
import linear_q4_0
|
||||||
|
query_states, key_states, value_states = linear_q4_0.forward_qkv(hidden_states,
|
||||||
|
self.q_proj.weight,
|
||||||
|
self.k_proj.weight,
|
||||||
|
self.v_proj.weight,
|
||||||
|
position_ids,
|
||||||
|
tmp_cache_k, tmp_cache_v,
|
||||||
|
self.q_proj.weight.qtype,
|
||||||
|
0,
|
||||||
|
self.head_dim)
|
||||||
|
else:
|
||||||
|
query_states = self.q_proj(hidden_states)
|
||||||
|
key_states = self.k_proj(hidden_states)
|
||||||
|
value_states = self.v_proj(hidden_states)
|
||||||
|
|
||||||
|
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
key_states = key_states.view(bsz, q_len,
|
||||||
|
self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||||
|
value_states = value_states.view(bsz, q_len,
|
||||||
|
self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
|
kv_seq_len = key_states.shape[-2]
|
||||||
|
if past_key_value is not None:
|
||||||
|
kv_seq_len += past_key_value[0].shape[-2]
|
||||||
|
|
||||||
|
if use_fuse_rope:
|
||||||
|
query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states,
|
||||||
|
key_states,
|
||||||
|
position_ids,
|
||||||
|
"mistral")
|
||||||
|
else:
|
||||||
|
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||||
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
|
||||||
|
cos, sin, position_ids, "mistral")
|
||||||
|
|
||||||
|
if not self.training and not hidden_states.requires_grad:
|
||||||
|
fsdp_flag = use_flash_attention(query_states, key_states)
|
||||||
|
else:
|
||||||
|
fsdp_flag = False
|
||||||
|
if fsdp_flag:
|
||||||
|
attention_dtype = torch.float16 # use fp16 for flash attention
|
||||||
|
else:
|
||||||
|
attention_dtype = original_dtype
|
||||||
|
|
||||||
|
# repeat k/v heads if n_kv_heads < n_heads
|
||||||
|
key_states = repeat_kv(key_states, self.num_key_value_groups).to(device,
|
||||||
|
dtype=attention_dtype)
|
||||||
|
value_states = repeat_kv(value_states, self.num_key_value_groups).to(device,
|
||||||
|
dtype=attention_dtype)
|
||||||
|
kv_seq_len = key_states.shape[-2]
|
||||||
|
if past_key_value is None:
|
||||||
|
attn_weights = torch.matmul(query_states.to(key_states.dtype),
|
||||||
|
key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||||
|
|
||||||
|
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
||||||
|
invalidInputError(
|
||||||
|
False,
|
||||||
|
f"Attention weights should be of size "
|
||||||
|
f"{(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
|
||||||
|
f" {attn_weights.size()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if attention_mask is not None:
|
||||||
|
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
||||||
|
invalidInputError(
|
||||||
|
False,
|
||||||
|
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)},"
|
||||||
|
f" but is {attention_mask.size()}"
|
||||||
|
)
|
||||||
|
attn_weights = attn_weights + attention_mask
|
||||||
|
|
||||||
|
# upcast attention to fp32
|
||||||
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1,
|
||||||
|
dtype=torch.float32).to(query_states.dtype)
|
||||||
|
attn_output = torch.matmul(attn_weights, value_states)
|
||||||
|
if use_cache:
|
||||||
|
k_cache, v_cache = init_fp8_kv_cache(
|
||||||
|
bsz, self.num_heads, kv_seq_len, self.head_dim,
|
||||||
|
device=query_states.device
|
||||||
|
)
|
||||||
|
key_states, value_states = append_fp8_kv_cache(k_cache, v_cache,
|
||||||
|
key_states, value_states)
|
||||||
|
past_key_value = (key_states, value_states)
|
||||||
|
else:
|
||||||
|
k_cache, v_cache = past_key_value
|
||||||
|
key_states, value_states = append_fp8_kv_cache(k_cache, v_cache,
|
||||||
|
key_states, value_states)
|
||||||
|
kv_seq_len = key_states.shape[-2]
|
||||||
|
past_key_value = (key_states, value_states)
|
||||||
|
|
||||||
|
if query_states.size(2) != 1 or query_states.device.type != 'xpu':
|
||||||
|
key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
|
||||||
|
query_states.dtype)
|
||||||
|
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3))
|
||||||
|
else:
|
||||||
|
import linear_q4_0
|
||||||
|
attn_weights = linear_q4_0.query_key_fp8_matmul(query_states, key_states)
|
||||||
|
|
||||||
|
attn_weights = attn_weights / math.sqrt(self.head_dim)
|
||||||
|
|
||||||
|
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
||||||
|
invalidInputError(
|
||||||
|
False,
|
||||||
|
f"Attention weights should be of size "
|
||||||
|
f"{(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
|
||||||
|
f" {attn_weights.size()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if attention_mask is not None:
|
||||||
|
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
||||||
|
invalidInputError(
|
||||||
|
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)},"
|
||||||
|
f" but is {attention_mask.size()}"
|
||||||
|
)
|
||||||
|
attn_weights = attn_weights + attention_mask
|
||||||
|
|
||||||
|
# upcast attention to fp32
|
||||||
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1,
|
||||||
|
dtype=torch.float32).to(query_states.dtype)
|
||||||
|
|
||||||
|
if query_states.size(2) != 1 or query_states.device.type != 'xpu':
|
||||||
|
attn_output = torch.matmul(attn_weights, value_states)
|
||||||
|
else:
|
||||||
|
import linear_q4_0
|
||||||
|
attn_output = linear_q4_0.attn_value_fp8_matmul(attn_weights,
|
||||||
|
value_states.transpose(-1, -2))
|
||||||
|
|
||||||
|
attn_output_size = (bsz, self.num_heads, q_len, self.head_dim)
|
||||||
|
if attn_output.size() != attn_output_size:
|
||||||
|
invalidInputError(False,
|
||||||
|
f"`attn_output` should be of size {attn_output_size},"
|
||||||
|
f" but is {attn_output.size()}")
|
||||||
|
|
||||||
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||||
|
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||||
|
|
||||||
|
attn_output = self.o_proj(attn_output)
|
||||||
|
|
||||||
|
if not output_attentions:
|
||||||
|
attn_weights = None
|
||||||
|
|
||||||
|
return attn_output.to(original_dtype), attn_weights, past_key_value
|
||||||
|
|
||||||
|
|
||||||
|
def mistral_attention_forward_original(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor]=None,
|
||||||
|
position_ids: Optional[torch.LongTensor]=None,
|
||||||
|
past_key_value: Optional[Tuple[torch.Tensor]]=None,
|
||||||
|
output_attentions: bool=False,
|
||||||
|
use_cache: bool=False,
|
||||||
|
padding_mask: Optional[torch.Tensor]=None,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
bsz, q_len, hidden_size = hidden_states.size()
|
bsz, q_len, hidden_size = hidden_states.size()
|
||||||
device = hidden_states.device
|
device = hidden_states.device
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue