diff --git a/python/llm/src/ipex_llm/transformers/convert.py b/python/llm/src/ipex_llm/transformers/convert.py index 21759ae0..51049cee 100644 --- a/python/llm/src/ipex_llm/transformers/convert.py +++ b/python/llm/src/ipex_llm/transformers/convert.py @@ -862,10 +862,10 @@ def convert_bigdl_other_module(model, dtype): def convert_forward(m, target_m, new_forward): + if m.__class__ == target_m: + bound_method = new_forward.__get__(m, m.__class__) + setattr(m, "forward", bound_method) for _, sub_m in m.named_children(): - if sub_m.__class__ == target_m: - bound_method = new_forward.__get__(sub_m, sub_m.__class__) - setattr(sub_m, "forward", bound_method) convert_forward(sub_m, target_m, new_forward) @@ -1298,9 +1298,13 @@ def _optimize_post(model, lightweight_bmm=False): module = importlib.import_module(modeling_module_name) from ipex_llm.transformers.models.qwen2 import qwen2_model_forward from ipex_llm.transformers.models.qwen2 import qwen2_attention_forward + from ipex_llm.transformers.models.qwen2 import qwen2_causal_lm_forward convert_forward(model, module.Qwen2Model, qwen2_model_forward) + convert_forward(model, + module.Qwen2ForCausalLM, + qwen2_causal_lm_forward) convert_forward(model, module.Qwen2RMSNorm, llama_rms_norm_forward) @@ -1319,10 +1323,14 @@ def _optimize_post(model, lightweight_bmm=False): module = importlib.import_module(modeling_module_name) from ipex_llm.transformers.models.qwen2_moe import qwen2moe_moeblock_forward from ipex_llm.transformers.models.qwen2_moe import qwen2moe_model_forward + from ipex_llm.transformers.models.qwen2_moe import qwen2_moe_causal_lm_forward from ipex_llm.transformers.models.qwen2 import qwen2_attention_forward convert_forward(model, module.Qwen2MoeModel, qwen2moe_model_forward) + convert_forward(model, + module.Qwen2MoeForCausalLM, + qwen2_moe_causal_lm_forward) convert_forward(model, module.Qwen2MoeRMSNorm, llama_rms_norm_forward) diff --git a/python/llm/src/ipex_llm/transformers/models/qwen2.py b/python/llm/src/ipex_llm/transformers/models/qwen2.py index 80cd4299..bf2293b6 100644 --- a/python/llm/src/ipex_llm/transformers/models/qwen2.py +++ b/python/llm/src/ipex_llm/transformers/models/qwen2.py @@ -41,6 +41,7 @@ import math from typing import Optional, Tuple, Union, List import torch +from torch.nn import CrossEntropyLoss from torch.nn.functional import scaled_dot_product_attention as sdpa from ipex_llm.transformers.models.utils import should_use_fuse_rope @@ -53,7 +54,7 @@ from transformers.models.qwen2.modeling_qwen2 import Qwen2Attention, Qwen2MLP from transformers.models.qwen2.modeling_qwen2 import apply_rotary_pos_emb, repeat_kv from transformers.models.qwen2.modeling_qwen2 import _prepare_4d_causal_attention_mask_for_sdpa from transformers.models.qwen2.modeling_qwen2 import _prepare_4d_causal_attention_mask -from transformers.modeling_outputs import BaseModelOutputWithPast +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from transformers.cache_utils import Cache, DynamicCache from transformers import logging @@ -266,6 +267,74 @@ def qwen2_model_forward_internal( ) +def qwen2_causal_lm_forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, +) -> Union[Tuple, CausalLMOutputWithPast]: + output_attentions = ( + output_attentions if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + # ipex-llm changes start: remove `logits.float()` to reduce memory usage with long input + # logits = logits.float() + # ipex-llm changes end + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def merge_qkv(module: torch.nn.Module): if isinstance(module, Qwen2Attention): new_weight = torch.cat([ diff --git a/python/llm/src/ipex_llm/transformers/models/qwen2_moe.py b/python/llm/src/ipex_llm/transformers/models/qwen2_moe.py index be159316..9258624a 100644 --- a/python/llm/src/ipex_llm/transformers/models/qwen2_moe.py +++ b/python/llm/src/ipex_llm/transformers/models/qwen2_moe.py @@ -40,6 +40,7 @@ import torch import torch.nn.functional as F import torch.utils.checkpoint +from torch.nn import CrossEntropyLoss from typing import Optional, Tuple, Union, List from ipex_llm.utils.common import invalidInputError from ipex_llm.transformers.models.utils import use_quantize_kv_cache @@ -47,9 +48,9 @@ from ipex_llm.transformers.kv import DynamicFp8Cache, DynamicNormalCache from transformers.models.qwen2_moe.modeling_qwen2_moe import ( _prepare_4d_causal_attention_mask_for_sdpa, _prepare_4d_causal_attention_mask, - Qwen2MoeAttention, + load_balancing_loss_func, Qwen2MoeAttention, ) -from transformers.modeling_outputs import MoeModelOutputWithPast +from transformers.modeling_outputs import MoeModelOutputWithPast, MoeCausalLMOutputWithPast from transformers.cache_utils import Cache, DynamicCache from transformers import logging @@ -274,6 +275,96 @@ def qwen2_moe_model_forward_internal( ) +def qwen2_moe_causal_lm_forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + return_dict: Optional[bool] = None, +) -> Union[Tuple, MoeCausalLMOutputWithPast]: + output_attentions = ( + output_attentions if output_attentions is not None + else self.config.output_attentions + ) + output_router_logits = ( + output_router_logits if output_router_logits is not None + else self.config.output_router_logits + ) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + output_router_logits=output_router_logits, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + # ipex-llm changes start: remove `logits.float()` to reduce memory usage with long input + # logits = logits.float() + # ipex-llm changes end + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + aux_loss = None + if output_router_logits: + aux_loss = load_balancing_loss_func( + outputs.router_logits if return_dict else outputs[-1], + self.num_experts, + self.num_experts_per_tok, + attention_mask, + ) + if labels is not None: + # make sure to reside in the same device + loss += self.router_aux_loss_coef * aux_loss.to(loss.device) + + if not return_dict: + output = (logits,) + outputs[1:] + if output_router_logits: + output = (aux_loss,) + output + return (loss,) + output if loss is not None else output + + return MoeCausalLMOutputWithPast( + loss=loss, + aux_loss=aux_loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + router_logits=outputs.router_logits, + ) + + def merge_qkv(module: torch.nn.Module): if isinstance(module, Qwen2MoeAttention): new_weight = torch.cat([