optimize qwen1.5/2 memory usage when running long input with fp16 (#11403)
This commit is contained in:
parent
7507000ef2
commit
abe53eaa4f
3 changed files with 174 additions and 6 deletions
|
|
@ -862,10 +862,10 @@ def convert_bigdl_other_module(model, dtype):
|
|||
|
||||
|
||||
def convert_forward(m, target_m, new_forward):
|
||||
if m.__class__ == target_m:
|
||||
bound_method = new_forward.__get__(m, m.__class__)
|
||||
setattr(m, "forward", bound_method)
|
||||
for _, sub_m in m.named_children():
|
||||
if sub_m.__class__ == target_m:
|
||||
bound_method = new_forward.__get__(sub_m, sub_m.__class__)
|
||||
setattr(sub_m, "forward", bound_method)
|
||||
convert_forward(sub_m, target_m, new_forward)
|
||||
|
||||
|
||||
|
|
@ -1298,9 +1298,13 @@ def _optimize_post(model, lightweight_bmm=False):
|
|||
module = importlib.import_module(modeling_module_name)
|
||||
from ipex_llm.transformers.models.qwen2 import qwen2_model_forward
|
||||
from ipex_llm.transformers.models.qwen2 import qwen2_attention_forward
|
||||
from ipex_llm.transformers.models.qwen2 import qwen2_causal_lm_forward
|
||||
convert_forward(model,
|
||||
module.Qwen2Model,
|
||||
qwen2_model_forward)
|
||||
convert_forward(model,
|
||||
module.Qwen2ForCausalLM,
|
||||
qwen2_causal_lm_forward)
|
||||
convert_forward(model,
|
||||
module.Qwen2RMSNorm,
|
||||
llama_rms_norm_forward)
|
||||
|
|
@ -1319,10 +1323,14 @@ def _optimize_post(model, lightweight_bmm=False):
|
|||
module = importlib.import_module(modeling_module_name)
|
||||
from ipex_llm.transformers.models.qwen2_moe import qwen2moe_moeblock_forward
|
||||
from ipex_llm.transformers.models.qwen2_moe import qwen2moe_model_forward
|
||||
from ipex_llm.transformers.models.qwen2_moe import qwen2_moe_causal_lm_forward
|
||||
from ipex_llm.transformers.models.qwen2 import qwen2_attention_forward
|
||||
convert_forward(model,
|
||||
module.Qwen2MoeModel,
|
||||
qwen2moe_model_forward)
|
||||
convert_forward(model,
|
||||
module.Qwen2MoeForCausalLM,
|
||||
qwen2_moe_causal_lm_forward)
|
||||
convert_forward(model,
|
||||
module.Qwen2MoeRMSNorm,
|
||||
llama_rms_norm_forward)
|
||||
|
|
|
|||
|
|
@ -41,6 +41,7 @@ import math
|
|||
from typing import Optional, Tuple, Union, List
|
||||
|
||||
import torch
|
||||
from torch.nn import CrossEntropyLoss
|
||||
from torch.nn.functional import scaled_dot_product_attention as sdpa
|
||||
|
||||
from ipex_llm.transformers.models.utils import should_use_fuse_rope
|
||||
|
|
@ -53,7 +54,7 @@ from transformers.models.qwen2.modeling_qwen2 import Qwen2Attention, Qwen2MLP
|
|||
from transformers.models.qwen2.modeling_qwen2 import apply_rotary_pos_emb, repeat_kv
|
||||
from transformers.models.qwen2.modeling_qwen2 import _prepare_4d_causal_attention_mask_for_sdpa
|
||||
from transformers.models.qwen2.modeling_qwen2 import _prepare_4d_causal_attention_mask
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
||||
from transformers.cache_utils import Cache, DynamicCache
|
||||
from transformers import logging
|
||||
|
||||
|
|
@ -266,6 +267,74 @@ def qwen2_model_forward_internal(
|
|||
)
|
||||
|
||||
|
||||
def qwen2_causal_lm_forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
output_attentions = (
|
||||
output_attentions if output_attentions is not None
|
||||
else self.config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
logits = self.lm_head(hidden_states)
|
||||
# ipex-llm changes start: remove `logits.float()` to reduce memory usage with long input
|
||||
# logits = logits.float()
|
||||
# ipex-llm changes end
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
loss_fct = CrossEntropyLoss()
|
||||
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
||||
shift_labels = shift_labels.view(-1)
|
||||
# Enable model parallelism
|
||||
shift_labels = shift_labels.to(shift_logits.device)
|
||||
loss = loss_fct(shift_logits, shift_labels)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
def merge_qkv(module: torch.nn.Module):
|
||||
if isinstance(module, Qwen2Attention):
|
||||
new_weight = torch.cat([
|
||||
|
|
|
|||
|
|
@ -40,6 +40,7 @@
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint
|
||||
from torch.nn import CrossEntropyLoss
|
||||
from typing import Optional, Tuple, Union, List
|
||||
from ipex_llm.utils.common import invalidInputError
|
||||
from ipex_llm.transformers.models.utils import use_quantize_kv_cache
|
||||
|
|
@ -47,9 +48,9 @@ from ipex_llm.transformers.kv import DynamicFp8Cache, DynamicNormalCache
|
|||
|
||||
from transformers.models.qwen2_moe.modeling_qwen2_moe import (
|
||||
_prepare_4d_causal_attention_mask_for_sdpa, _prepare_4d_causal_attention_mask,
|
||||
Qwen2MoeAttention,
|
||||
load_balancing_loss_func, Qwen2MoeAttention,
|
||||
)
|
||||
from transformers.modeling_outputs import MoeModelOutputWithPast
|
||||
from transformers.modeling_outputs import MoeModelOutputWithPast, MoeCausalLMOutputWithPast
|
||||
from transformers.cache_utils import Cache, DynamicCache
|
||||
from transformers import logging
|
||||
|
||||
|
|
@ -274,6 +275,96 @@ def qwen2_moe_model_forward_internal(
|
|||
)
|
||||
|
||||
|
||||
def qwen2_moe_causal_lm_forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
output_router_logits: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, MoeCausalLMOutputWithPast]:
|
||||
output_attentions = (
|
||||
output_attentions if output_attentions is not None
|
||||
else self.config.output_attentions
|
||||
)
|
||||
output_router_logits = (
|
||||
output_router_logits if output_router_logits is not None
|
||||
else self.config.output_router_logits
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
output_router_logits=output_router_logits,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
logits = self.lm_head(hidden_states)
|
||||
# ipex-llm changes start: remove `logits.float()` to reduce memory usage with long input
|
||||
# logits = logits.float()
|
||||
# ipex-llm changes end
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
loss_fct = CrossEntropyLoss()
|
||||
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
||||
shift_labels = shift_labels.view(-1)
|
||||
# Enable model parallelism
|
||||
shift_labels = shift_labels.to(shift_logits.device)
|
||||
loss = loss_fct(shift_logits, shift_labels)
|
||||
|
||||
aux_loss = None
|
||||
if output_router_logits:
|
||||
aux_loss = load_balancing_loss_func(
|
||||
outputs.router_logits if return_dict else outputs[-1],
|
||||
self.num_experts,
|
||||
self.num_experts_per_tok,
|
||||
attention_mask,
|
||||
)
|
||||
if labels is not None:
|
||||
# make sure to reside in the same device
|
||||
loss += self.router_aux_loss_coef * aux_loss.to(loss.device)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
if output_router_logits:
|
||||
output = (aux_loss,) + output
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
||||
return MoeCausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
aux_loss=aux_loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
router_logits=outputs.router_logits,
|
||||
)
|
||||
|
||||
|
||||
def merge_qkv(module: torch.nn.Module):
|
||||
if isinstance(module, Qwen2MoeAttention):
|
||||
new_weight = torch.cat([
|
||||
|
|
|
|||
Loading…
Reference in a new issue