optimize qwen1.5/2 memory usage when running long input with fp16 (#11403)

This commit is contained in:
Yishuo Wang 2024-06-24 13:43:04 +08:00 committed by GitHub
parent 7507000ef2
commit abe53eaa4f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 174 additions and 6 deletions

View file

@ -862,10 +862,10 @@ def convert_bigdl_other_module(model, dtype):
def convert_forward(m, target_m, new_forward):
if m.__class__ == target_m:
bound_method = new_forward.__get__(m, m.__class__)
setattr(m, "forward", bound_method)
for _, sub_m in m.named_children():
if sub_m.__class__ == target_m:
bound_method = new_forward.__get__(sub_m, sub_m.__class__)
setattr(sub_m, "forward", bound_method)
convert_forward(sub_m, target_m, new_forward)
@ -1298,9 +1298,13 @@ def _optimize_post(model, lightweight_bmm=False):
module = importlib.import_module(modeling_module_name)
from ipex_llm.transformers.models.qwen2 import qwen2_model_forward
from ipex_llm.transformers.models.qwen2 import qwen2_attention_forward
from ipex_llm.transformers.models.qwen2 import qwen2_causal_lm_forward
convert_forward(model,
module.Qwen2Model,
qwen2_model_forward)
convert_forward(model,
module.Qwen2ForCausalLM,
qwen2_causal_lm_forward)
convert_forward(model,
module.Qwen2RMSNorm,
llama_rms_norm_forward)
@ -1319,10 +1323,14 @@ def _optimize_post(model, lightweight_bmm=False):
module = importlib.import_module(modeling_module_name)
from ipex_llm.transformers.models.qwen2_moe import qwen2moe_moeblock_forward
from ipex_llm.transformers.models.qwen2_moe import qwen2moe_model_forward
from ipex_llm.transformers.models.qwen2_moe import qwen2_moe_causal_lm_forward
from ipex_llm.transformers.models.qwen2 import qwen2_attention_forward
convert_forward(model,
module.Qwen2MoeModel,
qwen2moe_model_forward)
convert_forward(model,
module.Qwen2MoeForCausalLM,
qwen2_moe_causal_lm_forward)
convert_forward(model,
module.Qwen2MoeRMSNorm,
llama_rms_norm_forward)

View file

@ -41,6 +41,7 @@ import math
from typing import Optional, Tuple, Union, List
import torch
from torch.nn import CrossEntropyLoss
from torch.nn.functional import scaled_dot_product_attention as sdpa
from ipex_llm.transformers.models.utils import should_use_fuse_rope
@ -53,7 +54,7 @@ from transformers.models.qwen2.modeling_qwen2 import Qwen2Attention, Qwen2MLP
from transformers.models.qwen2.modeling_qwen2 import apply_rotary_pos_emb, repeat_kv
from transformers.models.qwen2.modeling_qwen2 import _prepare_4d_causal_attention_mask_for_sdpa
from transformers.models.qwen2.modeling_qwen2 import _prepare_4d_causal_attention_mask
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from transformers.cache_utils import Cache, DynamicCache
from transformers import logging
@ -266,6 +267,74 @@ def qwen2_model_forward_internal(
)
def qwen2_causal_lm_forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
output_attentions = (
output_attentions if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
# ipex-llm changes start: remove `logits.float()` to reduce memory usage with long input
# logits = logits.float()
# ipex-llm changes end
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def merge_qkv(module: torch.nn.Module):
if isinstance(module, Qwen2Attention):
new_weight = torch.cat([

View file

@ -40,6 +40,7 @@
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch.nn import CrossEntropyLoss
from typing import Optional, Tuple, Union, List
from ipex_llm.utils.common import invalidInputError
from ipex_llm.transformers.models.utils import use_quantize_kv_cache
@ -47,9 +48,9 @@ from ipex_llm.transformers.kv import DynamicFp8Cache, DynamicNormalCache
from transformers.models.qwen2_moe.modeling_qwen2_moe import (
_prepare_4d_causal_attention_mask_for_sdpa, _prepare_4d_causal_attention_mask,
Qwen2MoeAttention,
load_balancing_loss_func, Qwen2MoeAttention,
)
from transformers.modeling_outputs import MoeModelOutputWithPast
from transformers.modeling_outputs import MoeModelOutputWithPast, MoeCausalLMOutputWithPast
from transformers.cache_utils import Cache, DynamicCache
from transformers import logging
@ -274,6 +275,96 @@ def qwen2_moe_model_forward_internal(
)
def qwen2_moe_causal_lm_forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
output_router_logits: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, MoeCausalLMOutputWithPast]:
output_attentions = (
output_attentions if output_attentions is not None
else self.config.output_attentions
)
output_router_logits = (
output_router_logits if output_router_logits is not None
else self.config.output_router_logits
)
output_hidden_states = (
output_hidden_states if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
output_router_logits=output_router_logits,
return_dict=return_dict,
)
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
# ipex-llm changes start: remove `logits.float()` to reduce memory usage with long input
# logits = logits.float()
# ipex-llm changes end
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
aux_loss = None
if output_router_logits:
aux_loss = load_balancing_loss_func(
outputs.router_logits if return_dict else outputs[-1],
self.num_experts,
self.num_experts_per_tok,
attention_mask,
)
if labels is not None:
# make sure to reside in the same device
loss += self.router_aux_loss_coef * aux_loss.to(loss.device)
if not return_dict:
output = (logits,) + outputs[1:]
if output_router_logits:
output = (aux_loss,) + output
return (loss,) + output if loss is not None else output
return MoeCausalLMOutputWithPast(
loss=loss,
aux_loss=aux_loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
router_logits=outputs.router_logits,
)
def merge_qkv(module: torch.nn.Module):
if isinstance(module, Qwen2MoeAttention):
new_weight = torch.cat([