add compress_kv for baichuan2

This commit is contained in:
Huang, Xinshengzi 2024-08-22 10:59:08 +08:00
parent 2946420e14
commit 86248b0505
2 changed files with 187 additions and 8 deletions

View file

@ -1294,8 +1294,12 @@ def _optimize_post(model, lightweight_bmm=False):
if model.config.hidden_size in [4096, 2048]: if model.config.hidden_size in [4096, 2048]:
# baichuan-7B and baichuan2-7B # baichuan-7B and baichuan2-7B
from ipex_llm.transformers.models.baichuan import baichuan_attention_forward_7b from ipex_llm.transformers.models.baichuan import baichuan_attention_forward_7b
from ipex_llm.transformers.models.baichuan import baichuan_model_7b_forward
for i in range(len(model.model.layers)):
setattr(model.model.layers[i].self_attn, "layer_idx", i)
convert_forward(model, module.Attention, baichuan_attention_forward_7b) convert_forward(model, module.Attention, baichuan_attention_forward_7b)
convert_forward(model, module.RMSNorm, llama_rms_norm_forward) convert_forward(model, module.RMSNorm, llama_rms_norm_forward)
convert_forward(model, module.BaichuanModel, baichuan_model_7b_forward)
elif model.config.hidden_size == 5120: elif model.config.hidden_size == 5120:
# baichuan-13B and baichuan2-13B # baichuan-13B and baichuan2-13B
from ipex_llm.transformers.models.baichuan import baichuan_attention_forward_13b from ipex_llm.transformers.models.baichuan import baichuan_attention_forward_13b

View file

@ -19,18 +19,25 @@
# https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat/blob/c6f8592a60b4ad73c210b28dd2ab3cca51abbf93/modeling_baichuan.py # https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat/blob/c6f8592a60b4ad73c210b28dd2ab3cca51abbf93/modeling_baichuan.py
import math import math
from typing import Optional, Tuple from typing import List, Optional, Tuple, Union
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from torch.nn import functional as F from torch.nn import functional as F
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache from transformers.modeling_outputs import BaseModelOutputWithPast
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache, \
should_use_compresskv, get_compresskv_attn_mask
from ipex_llm.transformers.models.utils import update_past_key_value from ipex_llm.transformers.models.utils import update_past_key_value
from ipex_llm.transformers.models.utils import should_use_fuse_rope from ipex_llm.transformers.models.utils import should_use_fuse_rope
from ipex_llm.transformers.models.utils import use_flash_attention, use_sdp, use_sdp_causal from ipex_llm.transformers.models.utils import use_flash_attention, use_sdp, use_sdp_causal
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb, SILU from ipex_llm.transformers.models.utils import apply_rotary_pos_emb, SILU
from ipex_llm.transformers.models.utils import mlp_fusion_check from ipex_llm.transformers.models.utils import mlp_fusion_check
from ipex_llm.transformers.models.utils import is_enough_kv_cache_room_4_36
from ipex_llm.transformers.kv import DynamicCompressFp8Cache, DynamicCompressCache
from ipex_llm.transformers.models.utils import extend_kv_cache, append_kv_cache
import warnings import warnings
import os
KV_CACHE_ALLOC_BLOCK_LENGTH = int(os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256))
def pre_compute_inv_freq(module: torch.nn.Module): def pre_compute_inv_freq(module: torch.nn.Module):
if module.__class__.__name__ == "RotaryEmbedding": if module.__class__.__name__ == "RotaryEmbedding":
@ -70,6 +77,153 @@ def baichuan_mlp_forward(
)) ))
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
def baichuan_model_7b_forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if use_cache:
inputs = input_ids if input_ids is not None else inputs_embeds
use_compress_kv = should_use_compresskv(inputs, inputs.shape[1])
use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.up_proj, inputs)
if use_compress_kv and not isinstance(past_key_values,
DynamicCompressCache):
if use_quantize_kv:
past_key_values = DynamicCompressFp8Cache.from_legacy_cache(past_key_values)
else:
past_key_values = DynamicCompressCache.from_legacy_cache(past_key_values)
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
elif input_ids is not None:
batch_size, seq_length = input_ids.shape
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape
else:
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
seq_length_with_past = seq_length
past_key_values_length = 0
if past_key_values is not None:
if isinstance(past_key_values, DynamicCompressCache):
past_key_values_length = past_key_values.get_seq_length()
else:
past_key_values_length = past_key_values[0][0].shape[2]
seq_length_with_past = seq_length_with_past + past_key_values_length
if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else:
position_ids = position_ids.view(-1, seq_length).long()
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
# embed positions
if attention_mask is None:
attention_mask = torch.ones(
(batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
)
attention_mask = self._prepare_decoder_attention_mask(
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
)
hidden_states = inputs_embeds
if self.gradient_checkpointing and self.training:
if use_cache:
use_cache = False
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = () if use_cache else None
use_compresskv = isinstance(past_key_values, DynamicCompressCache)
# if not past_key_values and not use_compresskv:
# past_key_values = [None for _ in range(self.num_layers)]
for idx, decoder_layer in enumerate(self.layers):
if output_hidden_states:
all_hidden_states += (hidden_states,)
if not use_compresskv:
past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, output_attentions, None)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
hidden_states,
attention_mask,
position_ids,
None,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_values if use_compresskv else past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = layer_outputs[0]
if use_cache:
if use_compresskv:
next_decoder_cache = past_key_values
else:
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
if output_attentions:
all_self_attns += (layer_outputs[1],)
hidden_states = self.norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
def baichuan_attention_forward_7b( def baichuan_attention_forward_7b(
self, self,
@ -83,6 +237,8 @@ def baichuan_attention_forward_7b(
bsz, q_len, _ = hidden_states.size() bsz, q_len, _ = hidden_states.size()
device = hidden_states.device device = hidden_states.device
use_compresskv = isinstance(past_key_value, DynamicCompressCache)
qkv = self.W_pack(hidden_states) qkv = self.W_pack(hidden_states)
qkv = qkv.view(bsz, q_len, self.num_heads * 3, self.head_dim) qkv = qkv.view(bsz, q_len, self.num_heads * 3, self.head_dim)
qkv = qkv.transpose(1, 2) qkv = qkv.transpose(1, 2)
@ -92,6 +248,10 @@ def baichuan_attention_forward_7b(
kv_seq_len = key_states.shape[2] kv_seq_len = key_states.shape[2]
if past_key_value is not None: if past_key_value is not None:
if use_compresskv:
kv_seq_len += past_key_value.get_usable_length(kv_seq_len,
self.layer_idx)
else:
kv_seq_len += past_key_value[0].shape[2] kv_seq_len += past_key_value[0].shape[2]
# IPEX-LLM OPT: fuse rope # IPEX-LLM OPT: fuse rope
@ -108,12 +268,23 @@ def baichuan_attention_forward_7b(
# IPEX-LLM OPT: kv cache and quantize kv # IPEX-LLM OPT: kv cache and quantize kv
use_quantize_kv = use_quantize_kv_cache(self.W_pack, hidden_states) use_quantize_kv = use_quantize_kv_cache(self.W_pack, hidden_states)
if use_quantize_kv or (not use_compresskv):
key_states, value_states = update_past_key_value( key_states, value_states = update_past_key_value(
past_key_value, key_states, value_states, past_key_value, key_states, value_states,
kv_seq_len, use_quantize_kv, device kv_seq_len, use_quantize_kv, device
) )
past_key_value = (key_states, value_states) if use_cache else None past_key_value = (key_states, value_states) if use_cache else None
else:
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value,
self.layer_idx,
q_len)
key_states, value_states = past_key_value.update(
key_states, value_states, self.layer_idx,
query_states, attention_mask, 1,
self.config, enough_kv_room, KV_CACHE_ALLOC_BLOCK_LENGTH)
if self.training: if self.training:
warnings.warn("xops is not supported on Intel GPU, so just use normal implementation") warnings.warn("xops is not supported on Intel GPU, so just use normal implementation")
@ -130,6 +301,10 @@ def baichuan_attention_forward_7b(
if use_quantize_kv: if use_quantize_kv:
attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states, attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states,
attention_mask) attention_mask)
elif use_compresskv:
attention_mask = get_compresskv_attn_mask(key_states, attention_mask)
attn_output = xe_addons.sdp(query_states, key_states, value_states,
attention_mask)
else: else:
attn_output = xe_addons.sdp(query_states, key_states, value_states, attn_output = xe_addons.sdp(query_states, key_states, value_states,
attention_mask) attention_mask)