optimize qwen1.5/2 memory usage when running long input with fp16 (#11403)
This commit is contained in:
parent
7507000ef2
commit
abe53eaa4f
3 changed files with 174 additions and 6 deletions
|
|
@ -862,10 +862,10 @@ def convert_bigdl_other_module(model, dtype):
|
||||||
|
|
||||||
|
|
||||||
def convert_forward(m, target_m, new_forward):
|
def convert_forward(m, target_m, new_forward):
|
||||||
|
if m.__class__ == target_m:
|
||||||
|
bound_method = new_forward.__get__(m, m.__class__)
|
||||||
|
setattr(m, "forward", bound_method)
|
||||||
for _, sub_m in m.named_children():
|
for _, sub_m in m.named_children():
|
||||||
if sub_m.__class__ == target_m:
|
|
||||||
bound_method = new_forward.__get__(sub_m, sub_m.__class__)
|
|
||||||
setattr(sub_m, "forward", bound_method)
|
|
||||||
convert_forward(sub_m, target_m, new_forward)
|
convert_forward(sub_m, target_m, new_forward)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -1298,9 +1298,13 @@ def _optimize_post(model, lightweight_bmm=False):
|
||||||
module = importlib.import_module(modeling_module_name)
|
module = importlib.import_module(modeling_module_name)
|
||||||
from ipex_llm.transformers.models.qwen2 import qwen2_model_forward
|
from ipex_llm.transformers.models.qwen2 import qwen2_model_forward
|
||||||
from ipex_llm.transformers.models.qwen2 import qwen2_attention_forward
|
from ipex_llm.transformers.models.qwen2 import qwen2_attention_forward
|
||||||
|
from ipex_llm.transformers.models.qwen2 import qwen2_causal_lm_forward
|
||||||
convert_forward(model,
|
convert_forward(model,
|
||||||
module.Qwen2Model,
|
module.Qwen2Model,
|
||||||
qwen2_model_forward)
|
qwen2_model_forward)
|
||||||
|
convert_forward(model,
|
||||||
|
module.Qwen2ForCausalLM,
|
||||||
|
qwen2_causal_lm_forward)
|
||||||
convert_forward(model,
|
convert_forward(model,
|
||||||
module.Qwen2RMSNorm,
|
module.Qwen2RMSNorm,
|
||||||
llama_rms_norm_forward)
|
llama_rms_norm_forward)
|
||||||
|
|
@ -1319,10 +1323,14 @@ def _optimize_post(model, lightweight_bmm=False):
|
||||||
module = importlib.import_module(modeling_module_name)
|
module = importlib.import_module(modeling_module_name)
|
||||||
from ipex_llm.transformers.models.qwen2_moe import qwen2moe_moeblock_forward
|
from ipex_llm.transformers.models.qwen2_moe import qwen2moe_moeblock_forward
|
||||||
from ipex_llm.transformers.models.qwen2_moe import qwen2moe_model_forward
|
from ipex_llm.transformers.models.qwen2_moe import qwen2moe_model_forward
|
||||||
|
from ipex_llm.transformers.models.qwen2_moe import qwen2_moe_causal_lm_forward
|
||||||
from ipex_llm.transformers.models.qwen2 import qwen2_attention_forward
|
from ipex_llm.transformers.models.qwen2 import qwen2_attention_forward
|
||||||
convert_forward(model,
|
convert_forward(model,
|
||||||
module.Qwen2MoeModel,
|
module.Qwen2MoeModel,
|
||||||
qwen2moe_model_forward)
|
qwen2moe_model_forward)
|
||||||
|
convert_forward(model,
|
||||||
|
module.Qwen2MoeForCausalLM,
|
||||||
|
qwen2_moe_causal_lm_forward)
|
||||||
convert_forward(model,
|
convert_forward(model,
|
||||||
module.Qwen2MoeRMSNorm,
|
module.Qwen2MoeRMSNorm,
|
||||||
llama_rms_norm_forward)
|
llama_rms_norm_forward)
|
||||||
|
|
|
||||||
|
|
@ -41,6 +41,7 @@ import math
|
||||||
from typing import Optional, Tuple, Union, List
|
from typing import Optional, Tuple, Union, List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from torch.nn import CrossEntropyLoss
|
||||||
from torch.nn.functional import scaled_dot_product_attention as sdpa
|
from torch.nn.functional import scaled_dot_product_attention as sdpa
|
||||||
|
|
||||||
from ipex_llm.transformers.models.utils import should_use_fuse_rope
|
from ipex_llm.transformers.models.utils import should_use_fuse_rope
|
||||||
|
|
@ -53,7 +54,7 @@ from transformers.models.qwen2.modeling_qwen2 import Qwen2Attention, Qwen2MLP
|
||||||
from transformers.models.qwen2.modeling_qwen2 import apply_rotary_pos_emb, repeat_kv
|
from transformers.models.qwen2.modeling_qwen2 import apply_rotary_pos_emb, repeat_kv
|
||||||
from transformers.models.qwen2.modeling_qwen2 import _prepare_4d_causal_attention_mask_for_sdpa
|
from transformers.models.qwen2.modeling_qwen2 import _prepare_4d_causal_attention_mask_for_sdpa
|
||||||
from transformers.models.qwen2.modeling_qwen2 import _prepare_4d_causal_attention_mask
|
from transformers.models.qwen2.modeling_qwen2 import _prepare_4d_causal_attention_mask
|
||||||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
||||||
from transformers.cache_utils import Cache, DynamicCache
|
from transformers.cache_utils import Cache, DynamicCache
|
||||||
from transformers import logging
|
from transformers import logging
|
||||||
|
|
||||||
|
|
@ -266,6 +267,74 @@ def qwen2_model_forward_internal(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def qwen2_causal_lm_forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
labels: Optional[torch.LongTensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||||
|
output_attentions = (
|
||||||
|
output_attentions if output_attentions is not None
|
||||||
|
else self.config.output_attentions
|
||||||
|
)
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states if output_hidden_states is not None
|
||||||
|
else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||||
|
outputs = self.model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = outputs[0]
|
||||||
|
logits = self.lm_head(hidden_states)
|
||||||
|
# ipex-llm changes start: remove `logits.float()` to reduce memory usage with long input
|
||||||
|
# logits = logits.float()
|
||||||
|
# ipex-llm changes end
|
||||||
|
|
||||||
|
loss = None
|
||||||
|
if labels is not None:
|
||||||
|
# Shift so that tokens < n predict n
|
||||||
|
shift_logits = logits[..., :-1, :].contiguous()
|
||||||
|
shift_labels = labels[..., 1:].contiguous()
|
||||||
|
# Flatten the tokens
|
||||||
|
loss_fct = CrossEntropyLoss()
|
||||||
|
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
||||||
|
shift_labels = shift_labels.view(-1)
|
||||||
|
# Enable model parallelism
|
||||||
|
shift_labels = shift_labels.to(shift_logits.device)
|
||||||
|
loss = loss_fct(shift_logits, shift_labels)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
output = (logits,) + outputs[1:]
|
||||||
|
return (loss,) + output if loss is not None else output
|
||||||
|
|
||||||
|
return CausalLMOutputWithPast(
|
||||||
|
loss=loss,
|
||||||
|
logits=logits,
|
||||||
|
past_key_values=outputs.past_key_values,
|
||||||
|
hidden_states=outputs.hidden_states,
|
||||||
|
attentions=outputs.attentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def merge_qkv(module: torch.nn.Module):
|
def merge_qkv(module: torch.nn.Module):
|
||||||
if isinstance(module, Qwen2Attention):
|
if isinstance(module, Qwen2Attention):
|
||||||
new_weight = torch.cat([
|
new_weight = torch.cat([
|
||||||
|
|
|
||||||
|
|
@ -40,6 +40,7 @@
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
|
from torch.nn import CrossEntropyLoss
|
||||||
from typing import Optional, Tuple, Union, List
|
from typing import Optional, Tuple, Union, List
|
||||||
from ipex_llm.utils.common import invalidInputError
|
from ipex_llm.utils.common import invalidInputError
|
||||||
from ipex_llm.transformers.models.utils import use_quantize_kv_cache
|
from ipex_llm.transformers.models.utils import use_quantize_kv_cache
|
||||||
|
|
@ -47,9 +48,9 @@ from ipex_llm.transformers.kv import DynamicFp8Cache, DynamicNormalCache
|
||||||
|
|
||||||
from transformers.models.qwen2_moe.modeling_qwen2_moe import (
|
from transformers.models.qwen2_moe.modeling_qwen2_moe import (
|
||||||
_prepare_4d_causal_attention_mask_for_sdpa, _prepare_4d_causal_attention_mask,
|
_prepare_4d_causal_attention_mask_for_sdpa, _prepare_4d_causal_attention_mask,
|
||||||
Qwen2MoeAttention,
|
load_balancing_loss_func, Qwen2MoeAttention,
|
||||||
)
|
)
|
||||||
from transformers.modeling_outputs import MoeModelOutputWithPast
|
from transformers.modeling_outputs import MoeModelOutputWithPast, MoeCausalLMOutputWithPast
|
||||||
from transformers.cache_utils import Cache, DynamicCache
|
from transformers.cache_utils import Cache, DynamicCache
|
||||||
from transformers import logging
|
from transformers import logging
|
||||||
|
|
||||||
|
|
@ -274,6 +275,96 @@ def qwen2_moe_model_forward_internal(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def qwen2_moe_causal_lm_forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
labels: Optional[torch.LongTensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
output_router_logits: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
) -> Union[Tuple, MoeCausalLMOutputWithPast]:
|
||||||
|
output_attentions = (
|
||||||
|
output_attentions if output_attentions is not None
|
||||||
|
else self.config.output_attentions
|
||||||
|
)
|
||||||
|
output_router_logits = (
|
||||||
|
output_router_logits if output_router_logits is not None
|
||||||
|
else self.config.output_router_logits
|
||||||
|
)
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states if output_hidden_states is not None
|
||||||
|
else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||||
|
outputs = self.model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
output_router_logits=output_router_logits,
|
||||||
|
return_dict=return_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = outputs[0]
|
||||||
|
logits = self.lm_head(hidden_states)
|
||||||
|
# ipex-llm changes start: remove `logits.float()` to reduce memory usage with long input
|
||||||
|
# logits = logits.float()
|
||||||
|
# ipex-llm changes end
|
||||||
|
|
||||||
|
loss = None
|
||||||
|
if labels is not None:
|
||||||
|
# Shift so that tokens < n predict n
|
||||||
|
shift_logits = logits[..., :-1, :].contiguous()
|
||||||
|
shift_labels = labels[..., 1:].contiguous()
|
||||||
|
# Flatten the tokens
|
||||||
|
loss_fct = CrossEntropyLoss()
|
||||||
|
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
||||||
|
shift_labels = shift_labels.view(-1)
|
||||||
|
# Enable model parallelism
|
||||||
|
shift_labels = shift_labels.to(shift_logits.device)
|
||||||
|
loss = loss_fct(shift_logits, shift_labels)
|
||||||
|
|
||||||
|
aux_loss = None
|
||||||
|
if output_router_logits:
|
||||||
|
aux_loss = load_balancing_loss_func(
|
||||||
|
outputs.router_logits if return_dict else outputs[-1],
|
||||||
|
self.num_experts,
|
||||||
|
self.num_experts_per_tok,
|
||||||
|
attention_mask,
|
||||||
|
)
|
||||||
|
if labels is not None:
|
||||||
|
# make sure to reside in the same device
|
||||||
|
loss += self.router_aux_loss_coef * aux_loss.to(loss.device)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
output = (logits,) + outputs[1:]
|
||||||
|
if output_router_logits:
|
||||||
|
output = (aux_loss,) + output
|
||||||
|
return (loss,) + output if loss is not None else output
|
||||||
|
|
||||||
|
return MoeCausalLMOutputWithPast(
|
||||||
|
loss=loss,
|
||||||
|
aux_loss=aux_loss,
|
||||||
|
logits=logits,
|
||||||
|
past_key_values=outputs.past_key_values,
|
||||||
|
hidden_states=outputs.hidden_states,
|
||||||
|
attentions=outputs.attentions,
|
||||||
|
router_logits=outputs.router_logits,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def merge_qkv(module: torch.nn.Module):
|
def merge_qkv(module: torch.nn.Module):
|
||||||
if isinstance(module, Qwen2MoeAttention):
|
if isinstance(module, Qwen2MoeAttention):
|
||||||
new_weight = torch.cat([
|
new_weight = torch.cat([
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue