add comment of change in model forward
This commit is contained in:
parent
42398a0045
commit
ce7de77085
1 changed files with 6 additions and 0 deletions
|
|
@ -101,6 +101,7 @@ def baichuan_model_7b_forward(
|
||||||
|
|
||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
# IPEX-LLM OPT: compress kv and quantize kv
|
||||||
if use_cache:
|
if use_cache:
|
||||||
inputs = input_ids if input_ids is not None else inputs_embeds
|
inputs = input_ids if input_ids is not None else inputs_embeds
|
||||||
use_compress_kv = should_use_compresskv(inputs, inputs.shape[1])
|
use_compress_kv = should_use_compresskv(inputs, inputs.shape[1])
|
||||||
|
|
@ -128,6 +129,7 @@ def baichuan_model_7b_forward(
|
||||||
past_key_values_length = 0
|
past_key_values_length = 0
|
||||||
|
|
||||||
if past_key_values is not None:
|
if past_key_values is not None:
|
||||||
|
# IPEX-LLM OPT: compress kv
|
||||||
if isinstance(past_key_values, DynamicCompressCache):
|
if isinstance(past_key_values, DynamicCompressCache):
|
||||||
past_key_values_length = past_key_values.get_seq_length()
|
past_key_values_length = past_key_values.get_seq_length()
|
||||||
else:
|
else:
|
||||||
|
|
@ -164,12 +166,14 @@ def baichuan_model_7b_forward(
|
||||||
all_self_attns = () if output_attentions else None
|
all_self_attns = () if output_attentions else None
|
||||||
next_decoder_cache = () if use_cache else None
|
next_decoder_cache = () if use_cache else None
|
||||||
|
|
||||||
|
# IPEX-LLM OPT: compress kv
|
||||||
use_compresskv = isinstance(past_key_values, DynamicCompressCache)
|
use_compresskv = isinstance(past_key_values, DynamicCompressCache)
|
||||||
|
|
||||||
for idx, decoder_layer in enumerate(self.layers):
|
for idx, decoder_layer in enumerate(self.layers):
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states += (hidden_states,)
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
|
# IPEX-LLM OPT: compress kv
|
||||||
if not use_compresskv:
|
if not use_compresskv:
|
||||||
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
||||||
|
|
||||||
|
|
@ -190,6 +194,7 @@ def baichuan_model_7b_forward(
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
# IPEX-LLM OPT: compress kv
|
||||||
layer_outputs = decoder_layer(
|
layer_outputs = decoder_layer(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
|
|
@ -202,6 +207,7 @@ def baichuan_model_7b_forward(
|
||||||
hidden_states = layer_outputs[0]
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
if use_cache:
|
if use_cache:
|
||||||
|
# IPEX-LLM OPT: compress kv
|
||||||
if use_compresskv:
|
if use_compresskv:
|
||||||
next_decoder_cache = past_key_values
|
next_decoder_cache = past_key_values
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue