Merge pull request #11891 from hxsz1997/baichuan2-compresskv
Add compress_kv for Baichuan2
This commit is contained in:
commit
650e6e6ce4
2 changed files with 201 additions and 8 deletions
|
|
@ -1296,8 +1296,17 @@ def _optimize_post(model, lightweight_bmm=False):
|
|||
if model.config.hidden_size in [4096, 2048]:
|
||||
# baichuan-7B and baichuan2-7B
|
||||
from ipex_llm.transformers.models.baichuan import baichuan_attention_forward_7b
|
||||
from ipex_llm.transformers.models.baichuan import baichuan_model_7b_forward
|
||||
for i in range(len(model.model.layers)):
|
||||
setattr(model.model.layers[i].self_attn, "layer_idx", i)
|
||||
convert_forward(model, module.Attention, baichuan_attention_forward_7b)
|
||||
convert_forward(model, module.RMSNorm, llama_rms_norm_forward)
|
||||
if model.config.vocab_size == 125696:
|
||||
# baichuan2-7B
|
||||
convert_forward(model, module.BaichuanModel, baichuan_model_7b_forward)
|
||||
elif model.config.vocab_size == 64000:
|
||||
# baichuan-7B
|
||||
convert_forward(model, module.Model, baichuan_model_7b_forward)
|
||||
elif model.config.hidden_size == 5120:
|
||||
# baichuan-13B and baichuan2-13B
|
||||
from ipex_llm.transformers.models.baichuan import baichuan_attention_forward_13b
|
||||
|
|
|
|||
|
|
@ -19,17 +19,25 @@
|
|||
# https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat/blob/c6f8592a60b4ad73c210b28dd2ab3cca51abbf93/modeling_baichuan.py
|
||||
|
||||
import math
|
||||
from typing import Optional, Tuple
|
||||
from typing import List, Optional, Tuple, Union
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch.nn import functional as F
|
||||
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache, \
|
||||
should_use_compresskv, get_compresskv_attn_mask
|
||||
from ipex_llm.transformers.models.utils import update_past_key_value
|
||||
from ipex_llm.transformers.models.utils import should_use_fuse_rope
|
||||
from ipex_llm.transformers.models.utils import use_flash_attention, use_sdp, use_sdp_causal
|
||||
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb, SILU
|
||||
from ipex_llm.transformers.models.utils import mlp_fusion_check
|
||||
from ipex_llm.transformers.models.utils import is_enough_kv_cache_room_4_36
|
||||
from ipex_llm.transformers.kv import DynamicCompressFp8Cache, DynamicCompressCache
|
||||
from ipex_llm.transformers.models.utils import extend_kv_cache, append_kv_cache
|
||||
import warnings
|
||||
import os
|
||||
|
||||
KV_CACHE_ALLOC_BLOCK_LENGTH = int(os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256))
|
||||
|
||||
|
||||
def pre_compute_inv_freq(module: torch.nn.Module):
|
||||
|
|
@ -71,6 +79,161 @@ def baichuan_mlp_forward(
|
|||
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
||||
|
||||
|
||||
def baichuan_model_7b_forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||
output_attentions = output_attentions if output_attentions is not None \
|
||||
else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else
|
||||
self.config.output_hidden_states
|
||||
)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# IPEX-LLM OPT: compress kv and quantize kv
|
||||
if use_cache:
|
||||
inputs = input_ids if input_ids is not None else inputs_embeds
|
||||
use_compress_kv = should_use_compresskv(inputs, inputs.shape[1])
|
||||
use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.up_proj, inputs)
|
||||
if use_compress_kv and not isinstance(past_key_values,
|
||||
DynamicCompressCache):
|
||||
if use_quantize_kv:
|
||||
past_key_values = DynamicCompressFp8Cache.from_legacy_cache(past_key_values)
|
||||
else:
|
||||
past_key_values = DynamicCompressCache.from_legacy_cache(past_key_values)
|
||||
|
||||
# retrieve input_ids and inputs_embeds
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at \
|
||||
the same time")
|
||||
elif input_ids is not None:
|
||||
batch_size, seq_length = input_ids.shape
|
||||
elif inputs_embeds is not None:
|
||||
batch_size, seq_length, _ = inputs_embeds.shape
|
||||
else:
|
||||
log4Error.invalidInputError("You have to specify either decoder_input_ids \
|
||||
or decoder_inputs_embeds")
|
||||
|
||||
seq_length_with_past = seq_length
|
||||
past_key_values_length = 0
|
||||
|
||||
if past_key_values is not None:
|
||||
# IPEX-LLM OPT: compress kv
|
||||
if isinstance(past_key_values, DynamicCompressCache):
|
||||
past_key_values_length = past_key_values.get_seq_length()
|
||||
else:
|
||||
past_key_values_length = past_key_values[0][0].shape[2]
|
||||
seq_length_with_past = seq_length_with_past + past_key_values_length
|
||||
|
||||
if position_ids is None:
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
position_ids = torch.arange(past_key_values_length, seq_length + past_key_values_length,
|
||||
dtype=torch.long, device=device)
|
||||
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
||||
else:
|
||||
position_ids = position_ids.view(-1, seq_length).long()
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
# embed positions
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones(
|
||||
(batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
|
||||
)
|
||||
attention_mask = self._prepare_decoder_attention_mask(
|
||||
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
|
||||
)
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
use_cache = False
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
next_decoder_cache = () if use_cache else None
|
||||
|
||||
# IPEX-LLM OPT: compress kv
|
||||
use_compresskv = isinstance(past_key_values, DynamicCompressCache)
|
||||
|
||||
for idx, decoder_layer in enumerate(self.layers):
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
# IPEX-LLM OPT: compress kv
|
||||
if not use_compresskv:
|
||||
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
# None for past_key_value
|
||||
return module(*inputs, output_attentions, None)
|
||||
|
||||
return custom_forward
|
||||
|
||||
layer_outputs = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(decoder_layer),
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
None,
|
||||
)
|
||||
else:
|
||||
# IPEX-LLM OPT: compress kv
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_values if use_compresskv else past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if use_cache:
|
||||
# IPEX-LLM OPT: compress kv
|
||||
if use_compresskv:
|
||||
next_decoder_cache = past_key_values
|
||||
else:
|
||||
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
|
||||
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
next_cache = next_decoder_cache if use_cache else None
|
||||
if not return_dict:
|
||||
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
|
||||
if v is not None)
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=next_cache,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attns,
|
||||
)
|
||||
|
||||
|
||||
def baichuan_attention_forward_7b(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
|
|
@ -83,6 +246,9 @@ def baichuan_attention_forward_7b(
|
|||
bsz, q_len, _ = hidden_states.size()
|
||||
device = hidden_states.device
|
||||
|
||||
# [CompressKV]
|
||||
use_compresskv = isinstance(past_key_value, DynamicCompressCache)
|
||||
|
||||
qkv = self.W_pack(hidden_states)
|
||||
qkv = qkv.view(bsz, q_len, self.num_heads * 3, self.head_dim)
|
||||
qkv = qkv.transpose(1, 2)
|
||||
|
|
@ -92,7 +258,12 @@ def baichuan_attention_forward_7b(
|
|||
|
||||
kv_seq_len = key_states.shape[2]
|
||||
if past_key_value is not None:
|
||||
kv_seq_len += past_key_value[0].shape[2]
|
||||
# [CompressKV]
|
||||
if use_compresskv:
|
||||
kv_seq_len += past_key_value.get_usable_length(kv_seq_len,
|
||||
self.layer_idx)
|
||||
else:
|
||||
kv_seq_len += past_key_value[0].shape[2]
|
||||
|
||||
# IPEX-LLM OPT: fuse rope
|
||||
if should_use_fuse_rope(hidden_states, position_ids, self.training):
|
||||
|
|
@ -108,11 +279,22 @@ def baichuan_attention_forward_7b(
|
|||
|
||||
# IPEX-LLM OPT: kv cache and quantize kv
|
||||
use_quantize_kv = use_quantize_kv_cache(self.W_pack, hidden_states)
|
||||
key_states, value_states = update_past_key_value(
|
||||
past_key_value, key_states, value_states,
|
||||
kv_seq_len, use_quantize_kv, device
|
||||
)
|
||||
past_key_value = (key_states, value_states) if use_cache else None
|
||||
|
||||
# [CompressKV]
|
||||
if use_compresskv:
|
||||
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value,
|
||||
self.layer_idx,
|
||||
q_len)
|
||||
key_states, value_states = past_key_value.update(
|
||||
key_states, value_states, self.layer_idx,
|
||||
query_states, attention_mask, 1,
|
||||
self.config, enough_kv_room, KV_CACHE_ALLOC_BLOCK_LENGTH)
|
||||
else:
|
||||
key_states, value_states = update_past_key_value(
|
||||
past_key_value, key_states, value_states,
|
||||
kv_seq_len, use_quantize_kv, device
|
||||
)
|
||||
past_key_value = (key_states, value_states) if use_cache else None
|
||||
|
||||
if self.training:
|
||||
warnings.warn("xops is not supported on Intel GPU, so just use normal implementation")
|
||||
|
|
@ -127,6 +309,8 @@ def baichuan_attention_forward_7b(
|
|||
is_causal=True).to(hidden_states.dtype)
|
||||
elif use_sdp(q_len, kv_seq_len, self.head_dim, query_states):
|
||||
import xe_addons
|
||||
if use_compresskv:
|
||||
attention_mask = get_compresskv_attn_mask(key_states, attention_mask)
|
||||
if use_quantize_kv:
|
||||
attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states,
|
||||
attention_mask)
|
||||
|
|
|
|||
Loading…
Reference in a new issue