optimize qwen2 memory usage again (#11520)
This commit is contained in:
parent
8f376e5192
commit
7cb09a8eac
1 changed files with 25 additions and 43 deletions
|
|
@ -73,43 +73,6 @@ def qwen2_model_forward(
|
||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
):
|
|
||||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
|
||||||
input = input_ids if input_ids is not None else inputs_embeds
|
|
||||||
use_quantize_kv = (
|
|
||||||
self.config.hidden_size != 3584 # disable quantize kv in specific model
|
|
||||||
and use_quantize_kv_cache(self.layers[0].mlp.up_proj, input)
|
|
||||||
)
|
|
||||||
if use_cache:
|
|
||||||
if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache):
|
|
||||||
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
|
|
||||||
if not use_quantize_kv and not isinstance(past_key_values, DynamicNormalCache):
|
|
||||||
past_key_values = DynamicNormalCache.from_legacy_cache(past_key_values)
|
|
||||||
return qwen2_model_forward_internal(
|
|
||||||
self=self,
|
|
||||||
input_ids=input_ids,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
position_ids=position_ids,
|
|
||||||
past_key_values=past_key_values,
|
|
||||||
inputs_embeds=inputs_embeds,
|
|
||||||
use_cache=use_cache,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
output_hidden_states=output_hidden_states,
|
|
||||||
return_dict=return_dict,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def qwen2_model_forward_internal(
|
|
||||||
self,
|
|
||||||
input_ids: torch.LongTensor = None,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
|
||||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
||||||
use_cache: Optional[bool] = None,
|
|
||||||
output_attentions: Optional[bool] = None,
|
|
||||||
output_hidden_states: Optional[bool] = None,
|
|
||||||
return_dict: Optional[bool] = None,
|
|
||||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||||
output_attentions = output_attentions if output_attentions is not None else \
|
output_attentions = output_attentions if output_attentions is not None else \
|
||||||
self.config.output_attentions
|
self.config.output_attentions
|
||||||
|
|
@ -144,11 +107,21 @@ def qwen2_model_forward_internal(
|
||||||
|
|
||||||
past_key_values_length = 0
|
past_key_values_length = 0
|
||||||
|
|
||||||
|
# ipex-llm changes start
|
||||||
|
# IPEX-LLM OPT: kv cache and quantize kv cache
|
||||||
|
inputs = input_ids if input_ids is not None else inputs_embeds
|
||||||
|
use_quantize_kv = (
|
||||||
|
self.config.hidden_size != 3584 # disable quantize kv in specific model
|
||||||
|
and use_quantize_kv_cache(self.layers[0].mlp.up_proj, inputs)
|
||||||
|
)
|
||||||
|
|
||||||
if use_cache:
|
if use_cache:
|
||||||
use_legacy_cache = not isinstance(past_key_values, Cache)
|
if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache):
|
||||||
if use_legacy_cache:
|
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
|
||||||
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
if not use_quantize_kv and not isinstance(past_key_values, DynamicNormalCache):
|
||||||
|
past_key_values = DynamicNormalCache.from_legacy_cache(past_key_values)
|
||||||
past_key_values_length = past_key_values.get_usable_length(seq_length)
|
past_key_values_length = past_key_values.get_usable_length(seq_length)
|
||||||
|
# ipex-llm changes end
|
||||||
|
|
||||||
if position_ids is None:
|
if position_ids is None:
|
||||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||||
|
|
@ -176,7 +149,15 @@ def qwen2_model_forward_internal(
|
||||||
"the input. "
|
"the input. "
|
||||||
)
|
)
|
||||||
|
|
||||||
if self._attn_implementation == "flash_attention_2":
|
# ipex-llm changes start: don't generate `attention_mask` in specific cases
|
||||||
|
if seq_length == 1 or batch_size == 1 and use_sdp_causal(
|
||||||
|
seq_length, seq_length + past_key_values_length,
|
||||||
|
self.config.hidden_size // self.config.num_attention_heads,
|
||||||
|
inputs_embeds, self.training
|
||||||
|
):
|
||||||
|
attention_mask = None
|
||||||
|
# ipex-llm changes end
|
||||||
|
elif self._attn_implementation == "flash_attention_2":
|
||||||
# 2d mask is passed through the layers
|
# 2d mask is passed through the layers
|
||||||
attention_mask = attention_mask if (attention_mask is not None and
|
attention_mask = attention_mask if (attention_mask is not None and
|
||||||
0 in attention_mask) else None
|
0 in attention_mask) else None
|
||||||
|
|
@ -251,10 +232,11 @@ def qwen2_model_forward_internal(
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states += (hidden_states,)
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
|
# ipex-llm changes start: remove `to_legacy_cache`
|
||||||
next_cache = None
|
next_cache = None
|
||||||
if use_cache:
|
if use_cache:
|
||||||
next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else \
|
next_cache = next_decoder_cache
|
||||||
next_decoder_cache
|
# ipex-llm changes end
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
return tuple(v for v in [hidden_states, next_cache,
|
return tuple(v for v in [hidden_states, next_cache,
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue