optimize qwen2 memory usage again (#11520)

This commit is contained in:
Yishuo Wang 2024-07-05 17:32:34 +08:00 committed by GitHub
parent 8f376e5192
commit 7cb09a8eac
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -73,43 +73,6 @@ def qwen2_model_forward(
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
):
use_cache = use_cache if use_cache is not None else self.config.use_cache
input = input_ids if input_ids is not None else inputs_embeds
use_quantize_kv = (
self.config.hidden_size != 3584 # disable quantize kv in specific model
and use_quantize_kv_cache(self.layers[0].mlp.up_proj, input)
)
if use_cache:
if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache):
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
if not use_quantize_kv and not isinstance(past_key_values, DynamicNormalCache):
past_key_values = DynamicNormalCache.from_legacy_cache(past_key_values)
return qwen2_model_forward_internal(
self=self,
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
def qwen2_model_forward_internal(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPast]: ) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else \ output_attentions = output_attentions if output_attentions is not None else \
self.config.output_attentions self.config.output_attentions
@ -144,11 +107,21 @@ def qwen2_model_forward_internal(
past_key_values_length = 0 past_key_values_length = 0
# ipex-llm changes start
# IPEX-LLM OPT: kv cache and quantize kv cache
inputs = input_ids if input_ids is not None else inputs_embeds
use_quantize_kv = (
self.config.hidden_size != 3584 # disable quantize kv in specific model
and use_quantize_kv_cache(self.layers[0].mlp.up_proj, inputs)
)
if use_cache: if use_cache:
use_legacy_cache = not isinstance(past_key_values, Cache) if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache):
if use_legacy_cache: past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
past_key_values = DynamicCache.from_legacy_cache(past_key_values) if not use_quantize_kv and not isinstance(past_key_values, DynamicNormalCache):
past_key_values = DynamicNormalCache.from_legacy_cache(past_key_values)
past_key_values_length = past_key_values.get_usable_length(seq_length) past_key_values_length = past_key_values.get_usable_length(seq_length)
# ipex-llm changes end
if position_ids is None: if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device device = input_ids.device if input_ids is not None else inputs_embeds.device
@ -176,7 +149,15 @@ def qwen2_model_forward_internal(
"the input. " "the input. "
) )
if self._attn_implementation == "flash_attention_2": # ipex-llm changes start: don't generate `attention_mask` in specific cases
if seq_length == 1 or batch_size == 1 and use_sdp_causal(
seq_length, seq_length + past_key_values_length,
self.config.hidden_size // self.config.num_attention_heads,
inputs_embeds, self.training
):
attention_mask = None
# ipex-llm changes end
elif self._attn_implementation == "flash_attention_2":
# 2d mask is passed through the layers # 2d mask is passed through the layers
attention_mask = attention_mask if (attention_mask is not None and attention_mask = attention_mask if (attention_mask is not None and
0 in attention_mask) else None 0 in attention_mask) else None
@ -251,10 +232,11 @@ def qwen2_model_forward_internal(
if output_hidden_states: if output_hidden_states:
all_hidden_states += (hidden_states,) all_hidden_states += (hidden_states,)
# ipex-llm changes start: remove `to_legacy_cache`
next_cache = None next_cache = None
if use_cache: if use_cache:
next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else \ next_cache = next_decoder_cache
next_decoder_cache # ipex-llm changes end
if not return_dict: if not return_dict:
return tuple(v for v in [hidden_states, next_cache, return tuple(v for v in [hidden_states, next_cache,