add comment of change in model forward

This commit is contained in:
Huang, Xinshengzi 2024-08-22 14:29:27 +08:00
parent 42398a0045
commit ce7de77085

View file

@ -101,6 +101,7 @@ def baichuan_model_7b_forward(
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# IPEX-LLM OPT: compress kv and quantize kv
if use_cache: if use_cache:
inputs = input_ids if input_ids is not None else inputs_embeds inputs = input_ids if input_ids is not None else inputs_embeds
use_compress_kv = should_use_compresskv(inputs, inputs.shape[1]) use_compress_kv = should_use_compresskv(inputs, inputs.shape[1])
@ -128,6 +129,7 @@ def baichuan_model_7b_forward(
past_key_values_length = 0 past_key_values_length = 0
if past_key_values is not None: if past_key_values is not None:
# IPEX-LLM OPT: compress kv
if isinstance(past_key_values, DynamicCompressCache): if isinstance(past_key_values, DynamicCompressCache):
past_key_values_length = past_key_values.get_seq_length() past_key_values_length = past_key_values.get_seq_length()
else: else:
@ -164,12 +166,14 @@ def baichuan_model_7b_forward(
all_self_attns = () if output_attentions else None all_self_attns = () if output_attentions else None
next_decoder_cache = () if use_cache else None next_decoder_cache = () if use_cache else None
# IPEX-LLM OPT: compress kv
use_compresskv = isinstance(past_key_values, DynamicCompressCache) use_compresskv = isinstance(past_key_values, DynamicCompressCache)
for idx, decoder_layer in enumerate(self.layers): for idx, decoder_layer in enumerate(self.layers):
if output_hidden_states: if output_hidden_states:
all_hidden_states += (hidden_states,) all_hidden_states += (hidden_states,)
# IPEX-LLM OPT: compress kv
if not use_compresskv: if not use_compresskv:
past_key_value = past_key_values[idx] if past_key_values is not None else None past_key_value = past_key_values[idx] if past_key_values is not None else None
@ -190,6 +194,7 @@ def baichuan_model_7b_forward(
None, None,
) )
else: else:
# IPEX-LLM OPT: compress kv
layer_outputs = decoder_layer( layer_outputs = decoder_layer(
hidden_states, hidden_states,
attention_mask=attention_mask, attention_mask=attention_mask,
@ -202,6 +207,7 @@ def baichuan_model_7b_forward(
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
if use_cache: if use_cache:
# IPEX-LLM OPT: compress kv
if use_compresskv: if use_compresskv:
next_decoder_cache = past_key_values next_decoder_cache = past_key_values
else: else: