Merge pull request #11891 from hxsz1997/baichuan2-compresskv

Add compress_kv for Baichuan2
This commit is contained in:
hxsz1997 2024-08-23 06:09:58 +03:00 committed by GitHub
commit 650e6e6ce4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 201 additions and 8 deletions

View file

@ -1296,8 +1296,17 @@ def _optimize_post(model, lightweight_bmm=False):
if model.config.hidden_size in [4096, 2048]:
# baichuan-7B and baichuan2-7B
from ipex_llm.transformers.models.baichuan import baichuan_attention_forward_7b
from ipex_llm.transformers.models.baichuan import baichuan_model_7b_forward
for i in range(len(model.model.layers)):
setattr(model.model.layers[i].self_attn, "layer_idx", i)
convert_forward(model, module.Attention, baichuan_attention_forward_7b)
convert_forward(model, module.RMSNorm, llama_rms_norm_forward)
if model.config.vocab_size == 125696:
# baichuan2-7B
convert_forward(model, module.BaichuanModel, baichuan_model_7b_forward)
elif model.config.vocab_size == 64000:
# baichuan-7B
convert_forward(model, module.Model, baichuan_model_7b_forward)
elif model.config.hidden_size == 5120:
# baichuan-13B and baichuan2-13B
from ipex_llm.transformers.models.baichuan import baichuan_attention_forward_13b

View file

@ -19,17 +19,25 @@
# https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat/blob/c6f8592a60b4ad73c210b28dd2ab3cca51abbf93/modeling_baichuan.py
import math
from typing import Optional, Tuple
from typing import List, Optional, Tuple, Union
import torch
import torch.utils.checkpoint
from torch.nn import functional as F
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache
from transformers.modeling_outputs import BaseModelOutputWithPast
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache, \
should_use_compresskv, get_compresskv_attn_mask
from ipex_llm.transformers.models.utils import update_past_key_value
from ipex_llm.transformers.models.utils import should_use_fuse_rope
from ipex_llm.transformers.models.utils import use_flash_attention, use_sdp, use_sdp_causal
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb, SILU
from ipex_llm.transformers.models.utils import mlp_fusion_check
from ipex_llm.transformers.models.utils import is_enough_kv_cache_room_4_36
from ipex_llm.transformers.kv import DynamicCompressFp8Cache, DynamicCompressCache
from ipex_llm.transformers.models.utils import extend_kv_cache, append_kv_cache
import warnings
import os
KV_CACHE_ALLOC_BLOCK_LENGTH = int(os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256))
def pre_compute_inv_freq(module: torch.nn.Module):
@ -71,6 +79,161 @@ def baichuan_mlp_forward(
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
def baichuan_model_7b_forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None \
else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else
self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# IPEX-LLM OPT: compress kv and quantize kv
if use_cache:
inputs = input_ids if input_ids is not None else inputs_embeds
use_compress_kv = should_use_compresskv(inputs, inputs.shape[1])
use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.up_proj, inputs)
if use_compress_kv and not isinstance(past_key_values,
DynamicCompressCache):
if use_quantize_kv:
past_key_values = DynamicCompressFp8Cache.from_legacy_cache(past_key_values)
else:
past_key_values = DynamicCompressCache.from_legacy_cache(past_key_values)
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at \
the same time")
elif input_ids is not None:
batch_size, seq_length = input_ids.shape
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape
else:
log4Error.invalidInputError("You have to specify either decoder_input_ids \
or decoder_inputs_embeds")
seq_length_with_past = seq_length
past_key_values_length = 0
if past_key_values is not None:
# IPEX-LLM OPT: compress kv
if isinstance(past_key_values, DynamicCompressCache):
past_key_values_length = past_key_values.get_seq_length()
else:
past_key_values_length = past_key_values[0][0].shape[2]
seq_length_with_past = seq_length_with_past + past_key_values_length
if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(past_key_values_length, seq_length + past_key_values_length,
dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else:
position_ids = position_ids.view(-1, seq_length).long()
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
# embed positions
if attention_mask is None:
attention_mask = torch.ones(
(batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
)
attention_mask = self._prepare_decoder_attention_mask(
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
)
hidden_states = inputs_embeds
if self.gradient_checkpointing and self.training:
if use_cache:
use_cache = False
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = () if use_cache else None
# IPEX-LLM OPT: compress kv
use_compresskv = isinstance(past_key_values, DynamicCompressCache)
for idx, decoder_layer in enumerate(self.layers):
if output_hidden_states:
all_hidden_states += (hidden_states,)
# IPEX-LLM OPT: compress kv
if not use_compresskv:
past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, output_attentions, None)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
hidden_states,
attention_mask,
position_ids,
None,
)
else:
# IPEX-LLM OPT: compress kv
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_values if use_compresskv else past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = layer_outputs[0]
if use_cache:
# IPEX-LLM OPT: compress kv
if use_compresskv:
next_decoder_cache = past_key_values
else:
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
if output_attentions:
all_self_attns += (layer_outputs[1],)
hidden_states = self.norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
if v is not None)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
def baichuan_attention_forward_7b(
self,
hidden_states: torch.Tensor,
@ -83,6 +246,9 @@ def baichuan_attention_forward_7b(
bsz, q_len, _ = hidden_states.size()
device = hidden_states.device
# [CompressKV]
use_compresskv = isinstance(past_key_value, DynamicCompressCache)
qkv = self.W_pack(hidden_states)
qkv = qkv.view(bsz, q_len, self.num_heads * 3, self.head_dim)
qkv = qkv.transpose(1, 2)
@ -92,7 +258,12 @@ def baichuan_attention_forward_7b(
kv_seq_len = key_states.shape[2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[2]
# [CompressKV]
if use_compresskv:
kv_seq_len += past_key_value.get_usable_length(kv_seq_len,
self.layer_idx)
else:
kv_seq_len += past_key_value[0].shape[2]
# IPEX-LLM OPT: fuse rope
if should_use_fuse_rope(hidden_states, position_ids, self.training):
@ -108,11 +279,22 @@ def baichuan_attention_forward_7b(
# IPEX-LLM OPT: kv cache and quantize kv
use_quantize_kv = use_quantize_kv_cache(self.W_pack, hidden_states)
key_states, value_states = update_past_key_value(
past_key_value, key_states, value_states,
kv_seq_len, use_quantize_kv, device
)
past_key_value = (key_states, value_states) if use_cache else None
# [CompressKV]
if use_compresskv:
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value,
self.layer_idx,
q_len)
key_states, value_states = past_key_value.update(
key_states, value_states, self.layer_idx,
query_states, attention_mask, 1,
self.config, enough_kv_room, KV_CACHE_ALLOC_BLOCK_LENGTH)
else:
key_states, value_states = update_past_key_value(
past_key_value, key_states, value_states,
kv_seq_len, use_quantize_kv, device
)
past_key_value = (key_states, value_states) if use_cache else None
if self.training:
warnings.warn("xops is not supported on Intel GPU, so just use normal implementation")
@ -127,6 +309,8 @@ def baichuan_attention_forward_7b(
is_causal=True).to(hidden_states.dtype)
elif use_sdp(q_len, kv_seq_len, self.head_dim, query_states):
import xe_addons
if use_compresskv:
attention_mask = get_compresskv_attn_mask(key_states, attention_mask)
if use_quantize_kv:
attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states,
attention_mask)