From 86248b0505ad5c1041dfdcad9058a416655c2fe9 Mon Sep 17 00:00:00 2001 From: "Huang, Xinshengzi" Date: Thu, 22 Aug 2024 10:59:08 +0800 Subject: [PATCH 01/14] add compress_kv for baichuan2 --- .../llm/src/ipex_llm/transformers/convert.py | 4 + .../ipex_llm/transformers/models/baichuan.py | 191 +++++++++++++++++- 2 files changed, 187 insertions(+), 8 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/convert.py b/python/llm/src/ipex_llm/transformers/convert.py index a7da4efc..9fefd634 100644 --- a/python/llm/src/ipex_llm/transformers/convert.py +++ b/python/llm/src/ipex_llm/transformers/convert.py @@ -1294,8 +1294,12 @@ def _optimize_post(model, lightweight_bmm=False): if model.config.hidden_size in [4096, 2048]: # baichuan-7B and baichuan2-7B from ipex_llm.transformers.models.baichuan import baichuan_attention_forward_7b + from ipex_llm.transformers.models.baichuan import baichuan_model_7b_forward + for i in range(len(model.model.layers)): + setattr(model.model.layers[i].self_attn, "layer_idx", i) convert_forward(model, module.Attention, baichuan_attention_forward_7b) convert_forward(model, module.RMSNorm, llama_rms_norm_forward) + convert_forward(model, module.BaichuanModel, baichuan_model_7b_forward) elif model.config.hidden_size == 5120: # baichuan-13B and baichuan2-13B from ipex_llm.transformers.models.baichuan import baichuan_attention_forward_13b diff --git a/python/llm/src/ipex_llm/transformers/models/baichuan.py b/python/llm/src/ipex_llm/transformers/models/baichuan.py index c74e9754..83dc215e 100644 --- a/python/llm/src/ipex_llm/transformers/models/baichuan.py +++ b/python/llm/src/ipex_llm/transformers/models/baichuan.py @@ -19,18 +19,25 @@ # https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat/blob/c6f8592a60b4ad73c210b28dd2ab3cca51abbf93/modeling_baichuan.py import math -from typing import Optional, Tuple +from typing import List, Optional, Tuple, Union import torch import torch.utils.checkpoint from torch.nn import functional as F -from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache +from transformers.modeling_outputs import BaseModelOutputWithPast +from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache, \ + should_use_compresskv, get_compresskv_attn_mask from ipex_llm.transformers.models.utils import update_past_key_value from ipex_llm.transformers.models.utils import should_use_fuse_rope from ipex_llm.transformers.models.utils import use_flash_attention, use_sdp, use_sdp_causal from ipex_llm.transformers.models.utils import apply_rotary_pos_emb, SILU from ipex_llm.transformers.models.utils import mlp_fusion_check +from ipex_llm.transformers.models.utils import is_enough_kv_cache_room_4_36 +from ipex_llm.transformers.kv import DynamicCompressFp8Cache, DynamicCompressCache +from ipex_llm.transformers.models.utils import extend_kv_cache, append_kv_cache import warnings +import os +KV_CACHE_ALLOC_BLOCK_LENGTH = int(os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256)) def pre_compute_inv_freq(module: torch.nn.Module): if module.__class__.__name__ == "RotaryEmbedding": @@ -70,6 +77,153 @@ def baichuan_mlp_forward( )) return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) +def baichuan_model_7b_forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, +) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if use_cache: + inputs = input_ids if input_ids is not None else inputs_embeds + use_compress_kv = should_use_compresskv(inputs, inputs.shape[1]) + use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.up_proj, inputs) + if use_compress_kv and not isinstance(past_key_values, + DynamicCompressCache): + if use_quantize_kv: + past_key_values = DynamicCompressFp8Cache.from_legacy_cache(past_key_values) + else: + past_key_values = DynamicCompressCache.from_legacy_cache(past_key_values) + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + seq_length_with_past = seq_length + past_key_values_length = 0 + + if past_key_values is not None: + if isinstance(past_key_values, DynamicCompressCache): + past_key_values_length = past_key_values.get_seq_length() + else: + past_key_values_length = past_key_values[0][0].shape[2] + seq_length_with_past = seq_length_with_past + past_key_values_length + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + ) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + # embed positions + if attention_mask is None: + attention_mask = torch.ones( + (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device + ) + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length + ) + + hidden_states = inputs_embeds + + if self.gradient_checkpointing and self.training: + if use_cache: + use_cache = False + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = () if use_cache else None + + use_compresskv = isinstance(past_key_values, DynamicCompressCache) + + # if not past_key_values and not use_compresskv: + # past_key_values = [None for _ in range(self.num_layers)] + + for idx, decoder_layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if not use_compresskv: + past_key_value = past_key_values[idx] if past_key_values is not None else None + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, output_attentions, None) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(decoder_layer), + hidden_states, + attention_mask, + position_ids, + None, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_values if use_compresskv else past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + if use_compresskv: + next_decoder_cache = past_key_values + else: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + def baichuan_attention_forward_7b( self, @@ -83,6 +237,8 @@ def baichuan_attention_forward_7b( bsz, q_len, _ = hidden_states.size() device = hidden_states.device + use_compresskv = isinstance(past_key_value, DynamicCompressCache) + qkv = self.W_pack(hidden_states) qkv = qkv.view(bsz, q_len, self.num_heads * 3, self.head_dim) qkv = qkv.transpose(1, 2) @@ -92,7 +248,11 @@ def baichuan_attention_forward_7b( kv_seq_len = key_states.shape[2] if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[2] + if use_compresskv: + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, + self.layer_idx) + else: + kv_seq_len += past_key_value[0].shape[2] # IPEX-LLM OPT: fuse rope if should_use_fuse_rope(hidden_states, position_ids, self.training): @@ -108,11 +268,22 @@ def baichuan_attention_forward_7b( # IPEX-LLM OPT: kv cache and quantize kv use_quantize_kv = use_quantize_kv_cache(self.W_pack, hidden_states) - key_states, value_states = update_past_key_value( - past_key_value, key_states, value_states, - kv_seq_len, use_quantize_kv, device - ) - past_key_value = (key_states, value_states) if use_cache else None + if use_quantize_kv or (not use_compresskv): + key_states, value_states = update_past_key_value( + past_key_value, key_states, value_states, + kv_seq_len, use_quantize_kv, device + ) + past_key_value = (key_states, value_states) if use_cache else None + + else: + enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, + self.layer_idx, + q_len) + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, + query_states, attention_mask, 1, + self.config, enough_kv_room, KV_CACHE_ALLOC_BLOCK_LENGTH) + if self.training: warnings.warn("xops is not supported on Intel GPU, so just use normal implementation") @@ -130,6 +301,10 @@ def baichuan_attention_forward_7b( if use_quantize_kv: attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states, attention_mask) + elif use_compresskv: + attention_mask = get_compresskv_attn_mask(key_states, attention_mask) + attn_output = xe_addons.sdp(query_states, key_states, value_states, + attention_mask) else: attn_output = xe_addons.sdp(query_states, key_states, value_states, attention_mask) From 6bb90357880c33b3b17c738887d6e2d9b0924758 Mon Sep 17 00:00:00 2001 From: "Huang, Xinshengzi" Date: Thu, 22 Aug 2024 11:08:48 +0800 Subject: [PATCH 02/14] fix typos --- python/llm/src/ipex_llm/transformers/models/baichuan.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/llm/src/ipex_llm/transformers/models/baichuan.py b/python/llm/src/ipex_llm/transformers/models/baichuan.py index 83dc215e..111dc1e5 100644 --- a/python/llm/src/ipex_llm/transformers/models/baichuan.py +++ b/python/llm/src/ipex_llm/transformers/models/baichuan.py @@ -284,7 +284,6 @@ def baichuan_attention_forward_7b( query_states, attention_mask, 1, self.config, enough_kv_room, KV_CACHE_ALLOC_BLOCK_LENGTH) - if self.training: warnings.warn("xops is not supported on Intel GPU, so just use normal implementation") From 6a5ca17afc9d3bbcc92522e9d599ffcbbe084378 Mon Sep 17 00:00:00 2001 From: "Huang, Xinshengzi" Date: Thu, 22 Aug 2024 11:09:58 +0800 Subject: [PATCH 03/14] fix typoes --- python/llm/src/ipex_llm/transformers/models/baichuan.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/llm/src/ipex_llm/transformers/models/baichuan.py b/python/llm/src/ipex_llm/transformers/models/baichuan.py index 111dc1e5..d91eb1e7 100644 --- a/python/llm/src/ipex_llm/transformers/models/baichuan.py +++ b/python/llm/src/ipex_llm/transformers/models/baichuan.py @@ -283,7 +283,7 @@ def baichuan_attention_forward_7b( key_states, value_states, self.layer_idx, query_states, attention_mask, 1, self.config, enough_kv_room, KV_CACHE_ALLOC_BLOCK_LENGTH) - + if self.training: warnings.warn("xops is not supported on Intel GPU, so just use normal implementation") From 4adadddbbcf228d887c27e5acbad82fe173acd7a Mon Sep 17 00:00:00 2001 From: "Huang, Xinshengzi" Date: Thu, 22 Aug 2024 11:12:23 +0800 Subject: [PATCH 04/14] fix typos --- python/llm/src/ipex_llm/transformers/models/baichuan.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/models/baichuan.py b/python/llm/src/ipex_llm/transformers/models/baichuan.py index d91eb1e7..9d4688df 100644 --- a/python/llm/src/ipex_llm/transformers/models/baichuan.py +++ b/python/llm/src/ipex_llm/transformers/models/baichuan.py @@ -277,8 +277,8 @@ def baichuan_attention_forward_7b( else: enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, - self.layer_idx, - q_len) + self.layer_idx, + q_len) key_states, value_states = past_key_value.update( key_states, value_states, self.layer_idx, query_states, attention_mask, 1, From 2a0aa9271babe3c411805d5f4012ed13cc395af6 Mon Sep 17 00:00:00 2001 From: "Huang, Xinshengzi" Date: Thu, 22 Aug 2024 11:23:22 +0800 Subject: [PATCH 05/14] fix typos --- .../ipex_llm/transformers/models/baichuan.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/models/baichuan.py b/python/llm/src/ipex_llm/transformers/models/baichuan.py index 9d4688df..21d86b06 100644 --- a/python/llm/src/ipex_llm/transformers/models/baichuan.py +++ b/python/llm/src/ipex_llm/transformers/models/baichuan.py @@ -89,9 +89,11 @@ def baichuan_model_7b_forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, BaseModelOutputWithPast]: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_attentions = output_attentions if output_attentions is not None \ + else self.config.output_attentions output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + output_hidden_states if output_hidden_states is not None else + self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache @@ -110,7 +112,8 @@ def baichuan_model_7b_forward( # retrieve input_ids and inputs_embeds if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at \ + the same time") elif input_ids is not None: batch_size, seq_length = input_ids.shape elif inputs_embeds is not None: @@ -130,9 +133,8 @@ def baichuan_model_7b_forward( if position_ids is None: device = input_ids.device if input_ids is not None else inputs_embeds.device - position_ids = torch.arange( - past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device - ) + position_ids = torch.arange(past_key_values_length, seq_length + past_key_values_length, + dtype=torch.long, device=device) position_ids = position_ids.unsqueeze(0).view(-1, seq_length) else: position_ids = position_ids.view(-1, seq_length).long() @@ -216,7 +218,8 @@ def baichuan_model_7b_forward( next_cache = next_decoder_cache if use_cache else None if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] + if v is not None) return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=next_cache, From c6ed1c412dfc096b747753465f1326ae4df7bc82 Mon Sep 17 00:00:00 2001 From: "Huang, Xinshengzi" Date: Thu, 22 Aug 2024 11:26:49 +0800 Subject: [PATCH 06/14] fix typos --- python/llm/src/ipex_llm/transformers/models/baichuan.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/llm/src/ipex_llm/transformers/models/baichuan.py b/python/llm/src/ipex_llm/transformers/models/baichuan.py index 21d86b06..a4be00ba 100644 --- a/python/llm/src/ipex_llm/transformers/models/baichuan.py +++ b/python/llm/src/ipex_llm/transformers/models/baichuan.py @@ -119,7 +119,8 @@ def baichuan_model_7b_forward( elif inputs_embeds is not None: batch_size, seq_length, _ = inputs_embeds.shape else: - raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + log4Error.invalidInputError("You have to specify either decoder_input_ids \ + or decoder_inputs_embeds") seq_length_with_past = seq_length past_key_values_length = 0 From 01ed397e7a5648a1f62dd04b3d58e45707194262 Mon Sep 17 00:00:00 2001 From: "Huang, Xinshengzi" Date: Thu, 22 Aug 2024 11:31:25 +0800 Subject: [PATCH 07/14] fix typos --- python/llm/src/ipex_llm/transformers/models/baichuan.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/models/baichuan.py b/python/llm/src/ipex_llm/transformers/models/baichuan.py index a4be00ba..1f39c0bd 100644 --- a/python/llm/src/ipex_llm/transformers/models/baichuan.py +++ b/python/llm/src/ipex_llm/transformers/models/baichuan.py @@ -97,7 +97,7 @@ def baichuan_model_7b_forward( ) use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict + return_dict = return_dict if return_dict is not None else self.config.use_return_dict if use_cache: inputs = input_ids if input_ids is not None else inputs_embeds @@ -164,9 +164,6 @@ def baichuan_model_7b_forward( use_compresskv = isinstance(past_key_values, DynamicCompressCache) - # if not past_key_values and not use_compresskv: - # past_key_values = [None for _ in range(self.num_layers)] - for idx, decoder_layer in enumerate(self.layers): if output_hidden_states: all_hidden_states += (hidden_states,) From 8a5df93de2b82319799aca3443015b04e0fc9a55 Mon Sep 17 00:00:00 2001 From: "Huang, Xinshengzi" Date: Thu, 22 Aug 2024 11:33:07 +0800 Subject: [PATCH 08/14] fix typos --- python/llm/src/ipex_llm/transformers/models/baichuan.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/llm/src/ipex_llm/transformers/models/baichuan.py b/python/llm/src/ipex_llm/transformers/models/baichuan.py index 1f39c0bd..d454a97c 100644 --- a/python/llm/src/ipex_llm/transformers/models/baichuan.py +++ b/python/llm/src/ipex_llm/transformers/models/baichuan.py @@ -77,6 +77,7 @@ def baichuan_mlp_forward( )) return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + def baichuan_model_7b_forward( self, input_ids: torch.LongTensor = None, From 48a827aa07894e9734a72f394cced9ae0e6ef79d Mon Sep 17 00:00:00 2001 From: "Huang, Xinshengzi" Date: Thu, 22 Aug 2024 11:35:47 +0800 Subject: [PATCH 09/14] fix typos --- python/llm/src/ipex_llm/transformers/models/baichuan.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/llm/src/ipex_llm/transformers/models/baichuan.py b/python/llm/src/ipex_llm/transformers/models/baichuan.py index d454a97c..84416832 100644 --- a/python/llm/src/ipex_llm/transformers/models/baichuan.py +++ b/python/llm/src/ipex_llm/transformers/models/baichuan.py @@ -39,6 +39,7 @@ import os KV_CACHE_ALLOC_BLOCK_LENGTH = int(os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256)) + def pre_compute_inv_freq(module: torch.nn.Module): if module.__class__.__name__ == "RotaryEmbedding": inv_freq = module.inv_freq From 42398a0045132815be77852dc63f5f2a5b5381cb Mon Sep 17 00:00:00 2001 From: "Huang, Xinshengzi" Date: Thu, 22 Aug 2024 13:17:13 +0800 Subject: [PATCH 10/14] add comment --- .../src/ipex_llm/transformers/models/baichuan.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/models/baichuan.py b/python/llm/src/ipex_llm/transformers/models/baichuan.py index 84416832..f8d7d186 100644 --- a/python/llm/src/ipex_llm/transformers/models/baichuan.py +++ b/python/llm/src/ipex_llm/transformers/models/baichuan.py @@ -271,14 +271,9 @@ def baichuan_attention_forward_7b( # IPEX-LLM OPT: kv cache and quantize kv use_quantize_kv = use_quantize_kv_cache(self.W_pack, hidden_states) - if use_quantize_kv or (not use_compresskv): - key_states, value_states = update_past_key_value( - past_key_value, key_states, value_states, - kv_seq_len, use_quantize_kv, device - ) - past_key_value = (key_states, value_states) if use_cache else None - else: + # [CompressKV] + if use_compresskv: enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx, q_len) @@ -286,6 +281,12 @@ def baichuan_attention_forward_7b( key_states, value_states, self.layer_idx, query_states, attention_mask, 1, self.config, enough_kv_room, KV_CACHE_ALLOC_BLOCK_LENGTH) + else: + key_states, value_states = update_past_key_value( + past_key_value, key_states, value_states, + kv_seq_len, use_quantize_kv, device + ) + past_key_value = (key_states, value_states) if use_cache else None if self.training: warnings.warn("xops is not supported on Intel GPU, so just use normal implementation") From ce7de77085f605a7ace3a7192638e304094e28df Mon Sep 17 00:00:00 2001 From: "Huang, Xinshengzi" Date: Thu, 22 Aug 2024 14:29:27 +0800 Subject: [PATCH 11/14] add comment of change in model forward --- python/llm/src/ipex_llm/transformers/models/baichuan.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/python/llm/src/ipex_llm/transformers/models/baichuan.py b/python/llm/src/ipex_llm/transformers/models/baichuan.py index f8d7d186..af34b3a5 100644 --- a/python/llm/src/ipex_llm/transformers/models/baichuan.py +++ b/python/llm/src/ipex_llm/transformers/models/baichuan.py @@ -101,6 +101,7 @@ def baichuan_model_7b_forward( return_dict = return_dict if return_dict is not None else self.config.use_return_dict + # IPEX-LLM OPT: compress kv and quantize kv if use_cache: inputs = input_ids if input_ids is not None else inputs_embeds use_compress_kv = should_use_compresskv(inputs, inputs.shape[1]) @@ -128,6 +129,7 @@ def baichuan_model_7b_forward( past_key_values_length = 0 if past_key_values is not None: + # IPEX-LLM OPT: compress kv if isinstance(past_key_values, DynamicCompressCache): past_key_values_length = past_key_values.get_seq_length() else: @@ -164,12 +166,14 @@ def baichuan_model_7b_forward( all_self_attns = () if output_attentions else None next_decoder_cache = () if use_cache else None + # IPEX-LLM OPT: compress kv use_compresskv = isinstance(past_key_values, DynamicCompressCache) for idx, decoder_layer in enumerate(self.layers): if output_hidden_states: all_hidden_states += (hidden_states,) + # IPEX-LLM OPT: compress kv if not use_compresskv: past_key_value = past_key_values[idx] if past_key_values is not None else None @@ -190,6 +194,7 @@ def baichuan_model_7b_forward( None, ) else: + # IPEX-LLM OPT: compress kv layer_outputs = decoder_layer( hidden_states, attention_mask=attention_mask, @@ -202,6 +207,7 @@ def baichuan_model_7b_forward( hidden_states = layer_outputs[0] if use_cache: + # IPEX-LLM OPT: compress kv if use_compresskv: next_decoder_cache = past_key_values else: From a2be3d75016a40b4a71104ed9811368a0ec70fa0 Mon Sep 17 00:00:00 2001 From: "Huang, Xinshengzi" Date: Thu, 22 Aug 2024 15:11:55 +0800 Subject: [PATCH 12/14] add comment of compress kv in attention forward --- python/llm/src/ipex_llm/transformers/models/baichuan.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/models/baichuan.py b/python/llm/src/ipex_llm/transformers/models/baichuan.py index af34b3a5..3944a948 100644 --- a/python/llm/src/ipex_llm/transformers/models/baichuan.py +++ b/python/llm/src/ipex_llm/transformers/models/baichuan.py @@ -307,13 +307,11 @@ def baichuan_attention_forward_7b( is_causal=True).to(hidden_states.dtype) elif use_sdp(q_len, kv_seq_len, self.head_dim, query_states): import xe_addons + if use_compresskv: + attention_mask = get_compresskv_attn_mask(key_states, attention_mask) if use_quantize_kv: attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states, attention_mask) - elif use_compresskv: - attention_mask = get_compresskv_attn_mask(key_states, attention_mask) - attn_output = xe_addons.sdp(query_states, key_states, value_states, - attention_mask) else: attn_output = xe_addons.sdp(query_states, key_states, value_states, attention_mask) From eb1e65f8a980d7ed9be4bd08ebfe1b59d83ae6e6 Mon Sep 17 00:00:00 2001 From: "Huang, Xinshengzi" Date: Thu, 22 Aug 2024 15:14:47 +0800 Subject: [PATCH 13/14] add comment --- python/llm/src/ipex_llm/transformers/models/baichuan.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/llm/src/ipex_llm/transformers/models/baichuan.py b/python/llm/src/ipex_llm/transformers/models/baichuan.py index 3944a948..9d412792 100644 --- a/python/llm/src/ipex_llm/transformers/models/baichuan.py +++ b/python/llm/src/ipex_llm/transformers/models/baichuan.py @@ -246,6 +246,7 @@ def baichuan_attention_forward_7b( bsz, q_len, _ = hidden_states.size() device = hidden_states.device + # [CompressKV] use_compresskv = isinstance(past_key_value, DynamicCompressCache) qkv = self.W_pack(hidden_states) @@ -257,6 +258,7 @@ def baichuan_attention_forward_7b( kv_seq_len = key_states.shape[2] if past_key_value is not None: + # [CompressKV] if use_compresskv: kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) From 4cf03d621230cea1bed6877756acd87ab220ad97 Mon Sep 17 00:00:00 2001 From: "Huang, Xinshengzi" Date: Thu, 22 Aug 2024 18:16:33 +0800 Subject: [PATCH 14/14] update baichuan-7b --- python/llm/src/ipex_llm/transformers/convert.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/python/llm/src/ipex_llm/transformers/convert.py b/python/llm/src/ipex_llm/transformers/convert.py index 9fefd634..02e7f575 100644 --- a/python/llm/src/ipex_llm/transformers/convert.py +++ b/python/llm/src/ipex_llm/transformers/convert.py @@ -1299,7 +1299,12 @@ def _optimize_post(model, lightweight_bmm=False): setattr(model.model.layers[i].self_attn, "layer_idx", i) convert_forward(model, module.Attention, baichuan_attention_forward_7b) convert_forward(model, module.RMSNorm, llama_rms_norm_forward) - convert_forward(model, module.BaichuanModel, baichuan_model_7b_forward) + if model.config.vocab_size == 125696: + # baichuan2-7B + convert_forward(model, module.BaichuanModel, baichuan_model_7b_forward) + elif model.config.vocab_size == 64000: + # baichuan-7B + convert_forward(model, module.Model, baichuan_model_7b_forward) elif model.config.hidden_size == 5120: # baichuan-13B and baichuan2-13B from ipex_llm.transformers.models.baichuan import baichuan_attention_forward_13b