[LLM] Support quantize kv cache for Baichuan2 7B (#10280)
* Add quatized kv cache framework for Baichuan2 7B * Support quantize kv cache for baichuan2 * Small fix * Fix python style
This commit is contained in:
parent
273de341d7
commit
f0ff0eebe1
1 changed files with 104 additions and 0 deletions
|
|
@ -93,6 +93,110 @@ def baichuan_attention_forward_7b(
|
|||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
if use_quantize_kv_cache(self.W_pack, hidden_states):
|
||||
forward_function = baichuan_attention_forward_7b_quantized
|
||||
else:
|
||||
forward_function = baichuan_attention_forward_7b_origin
|
||||
return forward_function(
|
||||
self=self,
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache
|
||||
)
|
||||
|
||||
|
||||
def baichuan_attention_forward_7b_quantized(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
device = hidden_states.device
|
||||
|
||||
proj = self.W_pack(hidden_states)
|
||||
proj = proj.unflatten(-1, (3, self.hidden_size)).unsqueeze(0).transpose(0, -2).squeeze(-2)
|
||||
# batch_size x source_len x hidden_size
|
||||
query_states = proj[0].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
# batch_size x target_len x head_size
|
||||
key_states = proj[1].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
# batch_size x source_len x hidden_size
|
||||
value_states = proj[2].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
if past_key_value is not None:
|
||||
kv_seq_len += past_key_value[0].shape[-2]
|
||||
if query_states.device.type == "xpu" and not (self.training and query_states.requires_grad):
|
||||
query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states,
|
||||
key_states,
|
||||
position_ids,
|
||||
"baichuan")
|
||||
else:
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
|
||||
cos, sin, position_ids, "baichuan")
|
||||
if past_key_value is None:
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
k_cache, v_cache = init_fp8_kv_cache(
|
||||
bsz, self.num_heads, kv_seq_len, self.head_dim,
|
||||
device=device
|
||||
)
|
||||
else:
|
||||
k_cache, v_cache = past_key_value
|
||||
key_states, value_states = append_fp8_kv_cache(k_cache, v_cache,
|
||||
key_states, value_states)
|
||||
|
||||
past_key_value = (key_states, value_states) if use_cache else None
|
||||
|
||||
if attention_mask is not None:
|
||||
if attention_mask.dtype == torch.bool:
|
||||
attention_mask.masked_fill_(attention_mask.logical_not(), float("-inf"))
|
||||
|
||||
scaling_factor = 1 / math.sqrt(query_states.size(-1))
|
||||
if query_states.size(2) != 1 or device.type != 'xpu':
|
||||
key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
|
||||
query_states.dtype)
|
||||
attn_output = torch.matmul(query_states * scaling_factor, key_states.transpose(-2, -1))
|
||||
else:
|
||||
import linear_q4_0
|
||||
attn_output = linear_q4_0.query_key_fp8_matmul(query_states * scaling_factor, key_states)
|
||||
|
||||
if attention_mask is not None:
|
||||
attn_output += attention_mask
|
||||
attn_output = torch.softmax(attn_output, -1)
|
||||
attn_output = attn_output.to(hidden_states.dtype)
|
||||
if query_states.size(2) != 1 or device.type != 'xpu':
|
||||
attn_output = torch.matmul(attn_output, value_states)
|
||||
else:
|
||||
import linear_q4_0
|
||||
attn_output = linear_q4_0.attn_value_fp8_matmul(attn_output,
|
||||
value_states.transpose(-1, -2))
|
||||
attn_output = attn_output.transpose(1, 2)
|
||||
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
if not output_attentions:
|
||||
attn_weights = None
|
||||
|
||||
return attn_output, attn_weights, past_key_value
|
||||
|
||||
|
||||
def baichuan_attention_forward_7b_origin(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
device = hidden_states.device
|
||||
|
|
|
|||
Loading…
Reference in a new issue