Merge pull request #11891 from hxsz1997/baichuan2-compresskv
Add compress_kv for Baichuan2
This commit is contained in:
		
						commit
						650e6e6ce4
					
				
					 2 changed files with 201 additions and 8 deletions
				
			
		| 
						 | 
				
			
			@ -1296,8 +1296,17 @@ def _optimize_post(model, lightweight_bmm=False):
 | 
			
		|||
        if model.config.hidden_size in [4096, 2048]:
 | 
			
		||||
            # baichuan-7B and baichuan2-7B
 | 
			
		||||
            from ipex_llm.transformers.models.baichuan import baichuan_attention_forward_7b
 | 
			
		||||
            from ipex_llm.transformers.models.baichuan import baichuan_model_7b_forward
 | 
			
		||||
            for i in range(len(model.model.layers)):
 | 
			
		||||
                setattr(model.model.layers[i].self_attn, "layer_idx", i)
 | 
			
		||||
            convert_forward(model, module.Attention, baichuan_attention_forward_7b)
 | 
			
		||||
            convert_forward(model, module.RMSNorm, llama_rms_norm_forward)
 | 
			
		||||
            if model.config.vocab_size == 125696:
 | 
			
		||||
                # baichuan2-7B
 | 
			
		||||
                convert_forward(model, module.BaichuanModel, baichuan_model_7b_forward)
 | 
			
		||||
            elif model.config.vocab_size == 64000:
 | 
			
		||||
                # baichuan-7B
 | 
			
		||||
                convert_forward(model, module.Model, baichuan_model_7b_forward)
 | 
			
		||||
        elif model.config.hidden_size == 5120:
 | 
			
		||||
            # baichuan-13B and baichuan2-13B
 | 
			
		||||
            from ipex_llm.transformers.models.baichuan import baichuan_attention_forward_13b
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -19,17 +19,25 @@
 | 
			
		|||
# https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat/blob/c6f8592a60b4ad73c210b28dd2ab3cca51abbf93/modeling_baichuan.py
 | 
			
		||||
 | 
			
		||||
import math
 | 
			
		||||
from typing import Optional, Tuple
 | 
			
		||||
from typing import List, Optional, Tuple, Union
 | 
			
		||||
import torch
 | 
			
		||||
import torch.utils.checkpoint
 | 
			
		||||
from torch.nn import functional as F
 | 
			
		||||
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache
 | 
			
		||||
from transformers.modeling_outputs import BaseModelOutputWithPast
 | 
			
		||||
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache, \
 | 
			
		||||
    should_use_compresskv, get_compresskv_attn_mask
 | 
			
		||||
from ipex_llm.transformers.models.utils import update_past_key_value
 | 
			
		||||
from ipex_llm.transformers.models.utils import should_use_fuse_rope
 | 
			
		||||
from ipex_llm.transformers.models.utils import use_flash_attention, use_sdp, use_sdp_causal
 | 
			
		||||
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb, SILU
 | 
			
		||||
from ipex_llm.transformers.models.utils import mlp_fusion_check
 | 
			
		||||
from ipex_llm.transformers.models.utils import is_enough_kv_cache_room_4_36
 | 
			
		||||
from ipex_llm.transformers.kv import DynamicCompressFp8Cache, DynamicCompressCache
 | 
			
		||||
from ipex_llm.transformers.models.utils import extend_kv_cache, append_kv_cache
 | 
			
		||||
import warnings
 | 
			
		||||
import os
 | 
			
		||||
 | 
			
		||||
KV_CACHE_ALLOC_BLOCK_LENGTH = int(os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def pre_compute_inv_freq(module: torch.nn.Module):
 | 
			
		||||
| 
						 | 
				
			
			@ -71,6 +79,161 @@ def baichuan_mlp_forward(
 | 
			
		|||
    return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def baichuan_model_7b_forward(
 | 
			
		||||
        self,
 | 
			
		||||
        input_ids: torch.LongTensor = None,
 | 
			
		||||
        attention_mask: Optional[torch.Tensor] = None,
 | 
			
		||||
        position_ids: Optional[torch.LongTensor] = None,
 | 
			
		||||
        past_key_values: Optional[List[torch.FloatTensor]] = None,
 | 
			
		||||
        inputs_embeds: Optional[torch.FloatTensor] = None,
 | 
			
		||||
        use_cache: Optional[bool] = None,
 | 
			
		||||
        output_attentions: Optional[bool] = None,
 | 
			
		||||
        output_hidden_states: Optional[bool] = None,
 | 
			
		||||
        return_dict: Optional[bool] = None,
 | 
			
		||||
) -> Union[Tuple, BaseModelOutputWithPast]:
 | 
			
		||||
    output_attentions = output_attentions if output_attentions is not None \
 | 
			
		||||
        else self.config.output_attentions
 | 
			
		||||
    output_hidden_states = (
 | 
			
		||||
        output_hidden_states if output_hidden_states is not None else
 | 
			
		||||
        self.config.output_hidden_states
 | 
			
		||||
    )
 | 
			
		||||
    use_cache = use_cache if use_cache is not None else self.config.use_cache
 | 
			
		||||
 | 
			
		||||
    return_dict = return_dict if return_dict is not None else self.config.use_return_dict
 | 
			
		||||
 | 
			
		||||
    # IPEX-LLM OPT: compress kv and quantize kv
 | 
			
		||||
    if use_cache:
 | 
			
		||||
        inputs = input_ids if input_ids is not None else inputs_embeds
 | 
			
		||||
        use_compress_kv = should_use_compresskv(inputs, inputs.shape[1])
 | 
			
		||||
        use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.up_proj, inputs)
 | 
			
		||||
        if use_compress_kv and not isinstance(past_key_values,
 | 
			
		||||
                                              DynamicCompressCache):
 | 
			
		||||
            if use_quantize_kv:
 | 
			
		||||
                past_key_values = DynamicCompressFp8Cache.from_legacy_cache(past_key_values)
 | 
			
		||||
            else:
 | 
			
		||||
                past_key_values = DynamicCompressCache.from_legacy_cache(past_key_values)
 | 
			
		||||
 | 
			
		||||
    # retrieve input_ids and inputs_embeds
 | 
			
		||||
    if input_ids is not None and inputs_embeds is not None:
 | 
			
		||||
        raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at \
 | 
			
		||||
                          the same time")
 | 
			
		||||
    elif input_ids is not None:
 | 
			
		||||
        batch_size, seq_length = input_ids.shape
 | 
			
		||||
    elif inputs_embeds is not None:
 | 
			
		||||
        batch_size, seq_length, _ = inputs_embeds.shape
 | 
			
		||||
    else:
 | 
			
		||||
        log4Error.invalidInputError("You have to specify either decoder_input_ids \
 | 
			
		||||
                                     or decoder_inputs_embeds")
 | 
			
		||||
 | 
			
		||||
    seq_length_with_past = seq_length
 | 
			
		||||
    past_key_values_length = 0
 | 
			
		||||
 | 
			
		||||
    if past_key_values is not None:
 | 
			
		||||
        # IPEX-LLM OPT: compress kv
 | 
			
		||||
        if isinstance(past_key_values, DynamicCompressCache):
 | 
			
		||||
            past_key_values_length = past_key_values.get_seq_length()
 | 
			
		||||
        else:
 | 
			
		||||
            past_key_values_length = past_key_values[0][0].shape[2]
 | 
			
		||||
        seq_length_with_past = seq_length_with_past + past_key_values_length
 | 
			
		||||
 | 
			
		||||
    if position_ids is None:
 | 
			
		||||
        device = input_ids.device if input_ids is not None else inputs_embeds.device
 | 
			
		||||
        position_ids = torch.arange(past_key_values_length, seq_length + past_key_values_length,
 | 
			
		||||
                                    dtype=torch.long, device=device)
 | 
			
		||||
        position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
 | 
			
		||||
    else:
 | 
			
		||||
        position_ids = position_ids.view(-1, seq_length).long()
 | 
			
		||||
 | 
			
		||||
    if inputs_embeds is None:
 | 
			
		||||
        inputs_embeds = self.embed_tokens(input_ids)
 | 
			
		||||
    # embed positions
 | 
			
		||||
    if attention_mask is None:
 | 
			
		||||
        attention_mask = torch.ones(
 | 
			
		||||
            (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
 | 
			
		||||
        )
 | 
			
		||||
    attention_mask = self._prepare_decoder_attention_mask(
 | 
			
		||||
        attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    hidden_states = inputs_embeds
 | 
			
		||||
 | 
			
		||||
    if self.gradient_checkpointing and self.training:
 | 
			
		||||
        if use_cache:
 | 
			
		||||
            use_cache = False
 | 
			
		||||
 | 
			
		||||
    # decoder layers
 | 
			
		||||
    all_hidden_states = () if output_hidden_states else None
 | 
			
		||||
    all_self_attns = () if output_attentions else None
 | 
			
		||||
    next_decoder_cache = () if use_cache else None
 | 
			
		||||
 | 
			
		||||
    # IPEX-LLM OPT: compress kv
 | 
			
		||||
    use_compresskv = isinstance(past_key_values, DynamicCompressCache)
 | 
			
		||||
 | 
			
		||||
    for idx, decoder_layer in enumerate(self.layers):
 | 
			
		||||
        if output_hidden_states:
 | 
			
		||||
            all_hidden_states += (hidden_states,)
 | 
			
		||||
 | 
			
		||||
        # IPEX-LLM OPT: compress kv
 | 
			
		||||
        if not use_compresskv:
 | 
			
		||||
            past_key_value = past_key_values[idx] if past_key_values is not None else None
 | 
			
		||||
 | 
			
		||||
        if self.gradient_checkpointing and self.training:
 | 
			
		||||
 | 
			
		||||
            def create_custom_forward(module):
 | 
			
		||||
                def custom_forward(*inputs):
 | 
			
		||||
                    # None for past_key_value
 | 
			
		||||
                    return module(*inputs, output_attentions, None)
 | 
			
		||||
 | 
			
		||||
                return custom_forward
 | 
			
		||||
 | 
			
		||||
            layer_outputs = torch.utils.checkpoint.checkpoint(
 | 
			
		||||
                create_custom_forward(decoder_layer),
 | 
			
		||||
                hidden_states,
 | 
			
		||||
                attention_mask,
 | 
			
		||||
                position_ids,
 | 
			
		||||
                None,
 | 
			
		||||
            )
 | 
			
		||||
        else:
 | 
			
		||||
            # IPEX-LLM OPT: compress kv
 | 
			
		||||
            layer_outputs = decoder_layer(
 | 
			
		||||
                hidden_states,
 | 
			
		||||
                attention_mask=attention_mask,
 | 
			
		||||
                position_ids=position_ids,
 | 
			
		||||
                past_key_value=past_key_values if use_compresskv else past_key_value,
 | 
			
		||||
                output_attentions=output_attentions,
 | 
			
		||||
                use_cache=use_cache,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        hidden_states = layer_outputs[0]
 | 
			
		||||
 | 
			
		||||
        if use_cache:
 | 
			
		||||
            # IPEX-LLM OPT: compress kv
 | 
			
		||||
            if use_compresskv:
 | 
			
		||||
                next_decoder_cache = past_key_values
 | 
			
		||||
            else:
 | 
			
		||||
                next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
 | 
			
		||||
 | 
			
		||||
        if output_attentions:
 | 
			
		||||
            all_self_attns += (layer_outputs[1],)
 | 
			
		||||
 | 
			
		||||
    hidden_states = self.norm(hidden_states)
 | 
			
		||||
 | 
			
		||||
    # add hidden states from the last decoder layer
 | 
			
		||||
    if output_hidden_states:
 | 
			
		||||
        all_hidden_states += (hidden_states,)
 | 
			
		||||
 | 
			
		||||
    next_cache = next_decoder_cache if use_cache else None
 | 
			
		||||
    if not return_dict:
 | 
			
		||||
        return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
 | 
			
		||||
                     if v is not None)
 | 
			
		||||
    return BaseModelOutputWithPast(
 | 
			
		||||
        last_hidden_state=hidden_states,
 | 
			
		||||
        past_key_values=next_cache,
 | 
			
		||||
        hidden_states=all_hidden_states,
 | 
			
		||||
        attentions=all_self_attns,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def baichuan_attention_forward_7b(
 | 
			
		||||
    self,
 | 
			
		||||
    hidden_states: torch.Tensor,
 | 
			
		||||
| 
						 | 
				
			
			@ -83,6 +246,9 @@ def baichuan_attention_forward_7b(
 | 
			
		|||
    bsz, q_len, _ = hidden_states.size()
 | 
			
		||||
    device = hidden_states.device
 | 
			
		||||
 | 
			
		||||
    # [CompressKV]
 | 
			
		||||
    use_compresskv = isinstance(past_key_value, DynamicCompressCache)
 | 
			
		||||
 | 
			
		||||
    qkv = self.W_pack(hidden_states)
 | 
			
		||||
    qkv = qkv.view(bsz, q_len, self.num_heads * 3, self.head_dim)
 | 
			
		||||
    qkv = qkv.transpose(1, 2)
 | 
			
		||||
| 
						 | 
				
			
			@ -92,7 +258,12 @@ def baichuan_attention_forward_7b(
 | 
			
		|||
 | 
			
		||||
    kv_seq_len = key_states.shape[2]
 | 
			
		||||
    if past_key_value is not None:
 | 
			
		||||
        kv_seq_len += past_key_value[0].shape[2]
 | 
			
		||||
        # [CompressKV]
 | 
			
		||||
        if use_compresskv:
 | 
			
		||||
            kv_seq_len += past_key_value.get_usable_length(kv_seq_len,
 | 
			
		||||
                                                           self.layer_idx)
 | 
			
		||||
        else:
 | 
			
		||||
            kv_seq_len += past_key_value[0].shape[2]
 | 
			
		||||
 | 
			
		||||
    # IPEX-LLM OPT: fuse rope
 | 
			
		||||
    if should_use_fuse_rope(hidden_states, position_ids, self.training):
 | 
			
		||||
| 
						 | 
				
			
			@ -108,11 +279,22 @@ def baichuan_attention_forward_7b(
 | 
			
		|||
 | 
			
		||||
    # IPEX-LLM OPT: kv cache and quantize kv
 | 
			
		||||
    use_quantize_kv = use_quantize_kv_cache(self.W_pack, hidden_states)
 | 
			
		||||
    key_states, value_states = update_past_key_value(
 | 
			
		||||
        past_key_value, key_states, value_states,
 | 
			
		||||
        kv_seq_len, use_quantize_kv, device
 | 
			
		||||
    )
 | 
			
		||||
    past_key_value = (key_states, value_states) if use_cache else None
 | 
			
		||||
 | 
			
		||||
    # [CompressKV]
 | 
			
		||||
    if use_compresskv:
 | 
			
		||||
        enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value,
 | 
			
		||||
                                                      self.layer_idx,
 | 
			
		||||
                                                      q_len)
 | 
			
		||||
        key_states, value_states = past_key_value.update(
 | 
			
		||||
            key_states, value_states, self.layer_idx,
 | 
			
		||||
            query_states, attention_mask, 1,
 | 
			
		||||
            self.config, enough_kv_room, KV_CACHE_ALLOC_BLOCK_LENGTH)
 | 
			
		||||
    else:
 | 
			
		||||
        key_states, value_states = update_past_key_value(
 | 
			
		||||
            past_key_value, key_states, value_states,
 | 
			
		||||
            kv_seq_len, use_quantize_kv, device
 | 
			
		||||
        )
 | 
			
		||||
        past_key_value = (key_states, value_states) if use_cache else None
 | 
			
		||||
 | 
			
		||||
    if self.training:
 | 
			
		||||
        warnings.warn("xops is not supported on Intel GPU, so just use normal implementation")
 | 
			
		||||
| 
						 | 
				
			
			@ -127,6 +309,8 @@ def baichuan_attention_forward_7b(
 | 
			
		|||
                                                     is_causal=True).to(hidden_states.dtype)
 | 
			
		||||
    elif use_sdp(q_len, kv_seq_len, self.head_dim, query_states):
 | 
			
		||||
        import xe_addons
 | 
			
		||||
        if use_compresskv:
 | 
			
		||||
            attention_mask = get_compresskv_attn_mask(key_states, attention_mask)
 | 
			
		||||
        if use_quantize_kv:
 | 
			
		||||
            attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states,
 | 
			
		||||
                                            attention_mask)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue