Merge pull request #11891 from hxsz1997/baichuan2-compresskv
Add compress_kv for Baichuan2
This commit is contained in:
commit
650e6e6ce4
2 changed files with 201 additions and 8 deletions
|
|
@ -1296,8 +1296,17 @@ def _optimize_post(model, lightweight_bmm=False):
|
||||||
if model.config.hidden_size in [4096, 2048]:
|
if model.config.hidden_size in [4096, 2048]:
|
||||||
# baichuan-7B and baichuan2-7B
|
# baichuan-7B and baichuan2-7B
|
||||||
from ipex_llm.transformers.models.baichuan import baichuan_attention_forward_7b
|
from ipex_llm.transformers.models.baichuan import baichuan_attention_forward_7b
|
||||||
|
from ipex_llm.transformers.models.baichuan import baichuan_model_7b_forward
|
||||||
|
for i in range(len(model.model.layers)):
|
||||||
|
setattr(model.model.layers[i].self_attn, "layer_idx", i)
|
||||||
convert_forward(model, module.Attention, baichuan_attention_forward_7b)
|
convert_forward(model, module.Attention, baichuan_attention_forward_7b)
|
||||||
convert_forward(model, module.RMSNorm, llama_rms_norm_forward)
|
convert_forward(model, module.RMSNorm, llama_rms_norm_forward)
|
||||||
|
if model.config.vocab_size == 125696:
|
||||||
|
# baichuan2-7B
|
||||||
|
convert_forward(model, module.BaichuanModel, baichuan_model_7b_forward)
|
||||||
|
elif model.config.vocab_size == 64000:
|
||||||
|
# baichuan-7B
|
||||||
|
convert_forward(model, module.Model, baichuan_model_7b_forward)
|
||||||
elif model.config.hidden_size == 5120:
|
elif model.config.hidden_size == 5120:
|
||||||
# baichuan-13B and baichuan2-13B
|
# baichuan-13B and baichuan2-13B
|
||||||
from ipex_llm.transformers.models.baichuan import baichuan_attention_forward_13b
|
from ipex_llm.transformers.models.baichuan import baichuan_attention_forward_13b
|
||||||
|
|
|
||||||
|
|
@ -19,17 +19,25 @@
|
||||||
# https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat/blob/c6f8592a60b4ad73c210b28dd2ab3cca51abbf93/modeling_baichuan.py
|
# https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat/blob/c6f8592a60b4ad73c210b28dd2ab3cca51abbf93/modeling_baichuan.py
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from typing import Optional, Tuple
|
from typing import List, Optional, Tuple, Union
|
||||||
import torch
|
import torch
|
||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache
|
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||||
|
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache, \
|
||||||
|
should_use_compresskv, get_compresskv_attn_mask
|
||||||
from ipex_llm.transformers.models.utils import update_past_key_value
|
from ipex_llm.transformers.models.utils import update_past_key_value
|
||||||
from ipex_llm.transformers.models.utils import should_use_fuse_rope
|
from ipex_llm.transformers.models.utils import should_use_fuse_rope
|
||||||
from ipex_llm.transformers.models.utils import use_flash_attention, use_sdp, use_sdp_causal
|
from ipex_llm.transformers.models.utils import use_flash_attention, use_sdp, use_sdp_causal
|
||||||
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb, SILU
|
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb, SILU
|
||||||
from ipex_llm.transformers.models.utils import mlp_fusion_check
|
from ipex_llm.transformers.models.utils import mlp_fusion_check
|
||||||
|
from ipex_llm.transformers.models.utils import is_enough_kv_cache_room_4_36
|
||||||
|
from ipex_llm.transformers.kv import DynamicCompressFp8Cache, DynamicCompressCache
|
||||||
|
from ipex_llm.transformers.models.utils import extend_kv_cache, append_kv_cache
|
||||||
import warnings
|
import warnings
|
||||||
|
import os
|
||||||
|
|
||||||
|
KV_CACHE_ALLOC_BLOCK_LENGTH = int(os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256))
|
||||||
|
|
||||||
|
|
||||||
def pre_compute_inv_freq(module: torch.nn.Module):
|
def pre_compute_inv_freq(module: torch.nn.Module):
|
||||||
|
|
@ -71,6 +79,161 @@ def baichuan_mlp_forward(
|
||||||
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
||||||
|
|
||||||
|
|
||||||
|
def baichuan_model_7b_forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||||
|
output_attentions = output_attentions if output_attentions is not None \
|
||||||
|
else self.config.output_attentions
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states if output_hidden_states is not None else
|
||||||
|
self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||||
|
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
# IPEX-LLM OPT: compress kv and quantize kv
|
||||||
|
if use_cache:
|
||||||
|
inputs = input_ids if input_ids is not None else inputs_embeds
|
||||||
|
use_compress_kv = should_use_compresskv(inputs, inputs.shape[1])
|
||||||
|
use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.up_proj, inputs)
|
||||||
|
if use_compress_kv and not isinstance(past_key_values,
|
||||||
|
DynamicCompressCache):
|
||||||
|
if use_quantize_kv:
|
||||||
|
past_key_values = DynamicCompressFp8Cache.from_legacy_cache(past_key_values)
|
||||||
|
else:
|
||||||
|
past_key_values = DynamicCompressCache.from_legacy_cache(past_key_values)
|
||||||
|
|
||||||
|
# retrieve input_ids and inputs_embeds
|
||||||
|
if input_ids is not None and inputs_embeds is not None:
|
||||||
|
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at \
|
||||||
|
the same time")
|
||||||
|
elif input_ids is not None:
|
||||||
|
batch_size, seq_length = input_ids.shape
|
||||||
|
elif inputs_embeds is not None:
|
||||||
|
batch_size, seq_length, _ = inputs_embeds.shape
|
||||||
|
else:
|
||||||
|
log4Error.invalidInputError("You have to specify either decoder_input_ids \
|
||||||
|
or decoder_inputs_embeds")
|
||||||
|
|
||||||
|
seq_length_with_past = seq_length
|
||||||
|
past_key_values_length = 0
|
||||||
|
|
||||||
|
if past_key_values is not None:
|
||||||
|
# IPEX-LLM OPT: compress kv
|
||||||
|
if isinstance(past_key_values, DynamicCompressCache):
|
||||||
|
past_key_values_length = past_key_values.get_seq_length()
|
||||||
|
else:
|
||||||
|
past_key_values_length = past_key_values[0][0].shape[2]
|
||||||
|
seq_length_with_past = seq_length_with_past + past_key_values_length
|
||||||
|
|
||||||
|
if position_ids is None:
|
||||||
|
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||||
|
position_ids = torch.arange(past_key_values_length, seq_length + past_key_values_length,
|
||||||
|
dtype=torch.long, device=device)
|
||||||
|
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
||||||
|
else:
|
||||||
|
position_ids = position_ids.view(-1, seq_length).long()
|
||||||
|
|
||||||
|
if inputs_embeds is None:
|
||||||
|
inputs_embeds = self.embed_tokens(input_ids)
|
||||||
|
# embed positions
|
||||||
|
if attention_mask is None:
|
||||||
|
attention_mask = torch.ones(
|
||||||
|
(batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
|
||||||
|
)
|
||||||
|
attention_mask = self._prepare_decoder_attention_mask(
|
||||||
|
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = inputs_embeds
|
||||||
|
|
||||||
|
if self.gradient_checkpointing and self.training:
|
||||||
|
if use_cache:
|
||||||
|
use_cache = False
|
||||||
|
|
||||||
|
# decoder layers
|
||||||
|
all_hidden_states = () if output_hidden_states else None
|
||||||
|
all_self_attns = () if output_attentions else None
|
||||||
|
next_decoder_cache = () if use_cache else None
|
||||||
|
|
||||||
|
# IPEX-LLM OPT: compress kv
|
||||||
|
use_compresskv = isinstance(past_key_values, DynamicCompressCache)
|
||||||
|
|
||||||
|
for idx, decoder_layer in enumerate(self.layers):
|
||||||
|
if output_hidden_states:
|
||||||
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
|
# IPEX-LLM OPT: compress kv
|
||||||
|
if not use_compresskv:
|
||||||
|
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
||||||
|
|
||||||
|
if self.gradient_checkpointing and self.training:
|
||||||
|
|
||||||
|
def create_custom_forward(module):
|
||||||
|
def custom_forward(*inputs):
|
||||||
|
# None for past_key_value
|
||||||
|
return module(*inputs, output_attentions, None)
|
||||||
|
|
||||||
|
return custom_forward
|
||||||
|
|
||||||
|
layer_outputs = torch.utils.checkpoint.checkpoint(
|
||||||
|
create_custom_forward(decoder_layer),
|
||||||
|
hidden_states,
|
||||||
|
attention_mask,
|
||||||
|
position_ids,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# IPEX-LLM OPT: compress kv
|
||||||
|
layer_outputs = decoder_layer(
|
||||||
|
hidden_states,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_value=past_key_values if use_compresskv else past_key_value,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
use_cache=use_cache,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
|
if use_cache:
|
||||||
|
# IPEX-LLM OPT: compress kv
|
||||||
|
if use_compresskv:
|
||||||
|
next_decoder_cache = past_key_values
|
||||||
|
else:
|
||||||
|
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
|
||||||
|
|
||||||
|
if output_attentions:
|
||||||
|
all_self_attns += (layer_outputs[1],)
|
||||||
|
|
||||||
|
hidden_states = self.norm(hidden_states)
|
||||||
|
|
||||||
|
# add hidden states from the last decoder layer
|
||||||
|
if output_hidden_states:
|
||||||
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
|
next_cache = next_decoder_cache if use_cache else None
|
||||||
|
if not return_dict:
|
||||||
|
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
|
||||||
|
if v is not None)
|
||||||
|
return BaseModelOutputWithPast(
|
||||||
|
last_hidden_state=hidden_states,
|
||||||
|
past_key_values=next_cache,
|
||||||
|
hidden_states=all_hidden_states,
|
||||||
|
attentions=all_self_attns,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def baichuan_attention_forward_7b(
|
def baichuan_attention_forward_7b(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
|
|
@ -83,6 +246,9 @@ def baichuan_attention_forward_7b(
|
||||||
bsz, q_len, _ = hidden_states.size()
|
bsz, q_len, _ = hidden_states.size()
|
||||||
device = hidden_states.device
|
device = hidden_states.device
|
||||||
|
|
||||||
|
# [CompressKV]
|
||||||
|
use_compresskv = isinstance(past_key_value, DynamicCompressCache)
|
||||||
|
|
||||||
qkv = self.W_pack(hidden_states)
|
qkv = self.W_pack(hidden_states)
|
||||||
qkv = qkv.view(bsz, q_len, self.num_heads * 3, self.head_dim)
|
qkv = qkv.view(bsz, q_len, self.num_heads * 3, self.head_dim)
|
||||||
qkv = qkv.transpose(1, 2)
|
qkv = qkv.transpose(1, 2)
|
||||||
|
|
@ -92,6 +258,11 @@ def baichuan_attention_forward_7b(
|
||||||
|
|
||||||
kv_seq_len = key_states.shape[2]
|
kv_seq_len = key_states.shape[2]
|
||||||
if past_key_value is not None:
|
if past_key_value is not None:
|
||||||
|
# [CompressKV]
|
||||||
|
if use_compresskv:
|
||||||
|
kv_seq_len += past_key_value.get_usable_length(kv_seq_len,
|
||||||
|
self.layer_idx)
|
||||||
|
else:
|
||||||
kv_seq_len += past_key_value[0].shape[2]
|
kv_seq_len += past_key_value[0].shape[2]
|
||||||
|
|
||||||
# IPEX-LLM OPT: fuse rope
|
# IPEX-LLM OPT: fuse rope
|
||||||
|
|
@ -108,6 +279,17 @@ def baichuan_attention_forward_7b(
|
||||||
|
|
||||||
# IPEX-LLM OPT: kv cache and quantize kv
|
# IPEX-LLM OPT: kv cache and quantize kv
|
||||||
use_quantize_kv = use_quantize_kv_cache(self.W_pack, hidden_states)
|
use_quantize_kv = use_quantize_kv_cache(self.W_pack, hidden_states)
|
||||||
|
|
||||||
|
# [CompressKV]
|
||||||
|
if use_compresskv:
|
||||||
|
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value,
|
||||||
|
self.layer_idx,
|
||||||
|
q_len)
|
||||||
|
key_states, value_states = past_key_value.update(
|
||||||
|
key_states, value_states, self.layer_idx,
|
||||||
|
query_states, attention_mask, 1,
|
||||||
|
self.config, enough_kv_room, KV_CACHE_ALLOC_BLOCK_LENGTH)
|
||||||
|
else:
|
||||||
key_states, value_states = update_past_key_value(
|
key_states, value_states = update_past_key_value(
|
||||||
past_key_value, key_states, value_states,
|
past_key_value, key_states, value_states,
|
||||||
kv_seq_len, use_quantize_kv, device
|
kv_seq_len, use_quantize_kv, device
|
||||||
|
|
@ -127,6 +309,8 @@ def baichuan_attention_forward_7b(
|
||||||
is_causal=True).to(hidden_states.dtype)
|
is_causal=True).to(hidden_states.dtype)
|
||||||
elif use_sdp(q_len, kv_seq_len, self.head_dim, query_states):
|
elif use_sdp(q_len, kv_seq_len, self.head_dim, query_states):
|
||||||
import xe_addons
|
import xe_addons
|
||||||
|
if use_compresskv:
|
||||||
|
attention_mask = get_compresskv_attn_mask(key_states, attention_mask)
|
||||||
if use_quantize_kv:
|
if use_quantize_kv:
|
||||||
attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states,
|
attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states,
|
||||||
attention_mask)
|
attention_mask)
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue