diff --git a/python/llm/src/ipex_llm/transformers/models/internlm.py b/python/llm/src/ipex_llm/transformers/models/internlm.py index a3dd0cb8..1851d383 100644 --- a/python/llm/src/ipex_llm/transformers/models/internlm.py +++ b/python/llm/src/ipex_llm/transformers/models/internlm.py @@ -43,8 +43,8 @@ import torch import torch.utils.checkpoint from torch import nn from ipex_llm.utils.common.log4Error import invalidInputError +from ipex_llm.transformers.models.common import merge_qkv_base, attention_softmax from ipex_llm.transformers.models.utils import should_use_fuse_rope, apply_rotary_pos_emb -from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_cache_freq_xpu from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache from ipex_llm.transformers.models.utils import update_past_key_value from ipex_llm.transformers.models.utils import use_sdp, use_sdp_causal @@ -52,26 +52,7 @@ from einops import rearrange def merge_qkv(module: torch.nn.Module): - if module.__class__.__name__ == "InternLMAttention": - new_weight = torch.cat([ - module.q_proj.weight.data, - module.k_proj.weight.data, - module.v_proj.weight.data, - ], dim=0) - new_bias = torch.cat([ - module.q_proj.bias.data, - module.k_proj.bias.data, - module.v_proj.bias.data, - ], dim=-1) - - qkv_proj = torch.nn.Linear(0, 0, bias=True) - qkv_proj.weight = torch.nn.Parameter(new_weight, requires_grad=False) - qkv_proj.bias = torch.nn.Parameter(new_bias, requires_grad=False) - qkv_proj.in_features = new_weight.size(1) - qkv_proj.out_features = new_weight.size(0) - module.qkv_proj = qkv_proj - - del module.q_proj, module.k_proj, module.v_proj + merge_qkv_base(module, "InternLMAttention") def internlm_attention_forward( @@ -144,8 +125,7 @@ def internlm_attention_forward( attn_weights = attn_weights + attention_mask # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, - dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = attention_softmax(attn_weights, self.training) attn_output = torch.matmul(attn_weights, value_states) attn_output = attn_output.transpose(1, 2) diff --git a/python/llm/src/ipex_llm/transformers/models/phi.py b/python/llm/src/ipex_llm/transformers/models/phi.py index 9a5a01a5..43365623 100644 --- a/python/llm/src/ipex_llm/transformers/models/phi.py +++ b/python/llm/src/ipex_llm/transformers/models/phi.py @@ -34,6 +34,7 @@ import math import torch +from ipex_llm.transformers.models.common import merge_qkv_base, attention_softmax from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_cache_freq_xpu from ipex_llm.transformers.kv import DynamicNormalCache from ipex_llm.utils.common.log4Error import invalidInputError @@ -55,26 +56,7 @@ def should_use_fuse_rope(self, hidden_states, position_ids): def merge_qkv(module: torch.nn.Module): - if module.__class__.__name__ == "PhiAttention": - new_weight = torch.cat([ - module.q_proj.weight.data, - module.k_proj.weight.data, - module.v_proj.weight.data, - ], dim=0) - new_bias = torch.cat([ - module.q_proj.bias.data, - module.k_proj.bias.data, - module.v_proj.bias.data, - ], dim=-1) - - qkv_proj = torch.nn.Linear(0, 0, bias=True) - qkv_proj.weight = torch.nn.Parameter(new_weight, requires_grad=False) - qkv_proj.bias = torch.nn.Parameter(new_bias, requires_grad=False) - qkv_proj.in_features = new_weight.size(1) - qkv_proj.out_features = new_weight.size(0) - module.qkv_proj = qkv_proj - - del module.q_proj, module.k_proj, module.v_proj + merge_qkv_base(module, "PhiAttention") def attention_forward( @@ -143,8 +125,7 @@ def attention_forward( attn_weights = attn_weights + attention_mask # upcast attention to fp32 - attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, - dtype=torch.float32).to(value_states.dtype) + attn_weights = attention_softmax(attn_weights, self.training) attn_weights = torch.nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) diff --git a/python/llm/src/ipex_llm/transformers/models/qwen2_moe.py b/python/llm/src/ipex_llm/transformers/models/qwen2_moe.py index afb90cfd..d2e2f026 100644 --- a/python/llm/src/ipex_llm/transformers/models/qwen2_moe.py +++ b/python/llm/src/ipex_llm/transformers/models/qwen2_moe.py @@ -43,6 +43,7 @@ import torch.utils.checkpoint from torch.nn import CrossEntropyLoss from typing import Optional, Tuple, Union, List from ipex_llm.utils.common import invalidInputError +from ipex_llm.transformers.models.common import merge_qkv_base from ipex_llm.transformers.models.utils import use_quantize_kv_cache from ipex_llm.transformers.kv import DynamicFp8Cache, DynamicNormalCache @@ -367,26 +368,7 @@ def qwen2_moe_causal_lm_forward( def merge_qkv(module: torch.nn.Module): - if isinstance(module, Qwen2MoeAttention): - new_weight = torch.cat([ - module.q_proj.weight.data, - module.k_proj.weight.data, - module.v_proj.weight.data, - ], dim=0) - new_bias = torch.cat([ - module.q_proj.bias.data, - module.k_proj.bias.data, - module.v_proj.bias.data, - ], dim=-1) - - qkv_proj = torch.nn.Linear(0, 0, bias=True) - qkv_proj.weight = torch.nn.Parameter(new_weight, requires_grad=False) - qkv_proj.bias = torch.nn.Parameter(new_bias, requires_grad=False) - qkv_proj.in_features = new_weight.size(1) - qkv_proj.out_features = new_weight.size(0) - module.qkv_proj = qkv_proj - - del module.q_proj, module.k_proj, module.v_proj + merge_qkv_base(module, Qwen2MoeAttention) def qwen2moe_moeblock_forward(self, hidden_states: torch.Tensor): diff --git a/python/llm/src/ipex_llm/transformers/models/stablelm.py b/python/llm/src/ipex_llm/transformers/models/stablelm.py index bfcb50ec..37639ff9 100644 --- a/python/llm/src/ipex_llm/transformers/models/stablelm.py +++ b/python/llm/src/ipex_llm/transformers/models/stablelm.py @@ -45,8 +45,8 @@ from transformers.cache_utils import Cache from transformers.models.stablelm.modeling_stablelm import repeat_kv from transformers.models.stablelm.modeling_stablelm import StableLmAttention, StableLmModel -from ipex_llm.transformers.models.utils import apply_rotary_pos_emb, \ - apply_rotary_pos_emb_cache_freq_xpu +from ipex_llm.transformers.models.common import merge_qkv_base, attention_softmax +from ipex_llm.transformers.models.utils import apply_rotary_pos_emb from ipex_llm.transformers.models.utils import use_sdp, use_sdp_causal from ipex_llm.transformers.models.utils import restore_fp8_kv_cache, use_quantize_kv_cache from ipex_llm.transformers.models.utils import should_use_fuse_rope @@ -54,29 +54,7 @@ from ipex_llm.transformers.kv import DynamicFp8Cache, DynamicNormalCache def merge_qkv(module: torch.nn.Module): - if isinstance(module, StableLmAttention): - new_weight = torch.cat([ - module.q_proj.weight.data, - module.k_proj.weight.data, - module.v_proj.weight.data, - ], dim=0) - - if module.q_proj.bias is not None: - qkv_proj = torch.nn.Linear(0, 0, bias=True) - new_bias = torch.cat([ - module.q_proj.bias.data, - module.k_proj.bias.data, - module.v_proj.bias.data, - ], dim=0) - qkv_proj.bias = torch.nn.Parameter(new_bias, requires_grad=False) - else: - qkv_proj = torch.nn.Linear(0, 0, bias=False) - qkv_proj.weight = torch.nn.Parameter(new_weight, requires_grad=False) - qkv_proj.in_features = new_weight.size(1) - qkv_proj.out_features = new_weight.size(0) - module.qkv_proj = qkv_proj - - del module.q_proj, module.k_proj, module.v_proj + merge_qkv_base(module, StableLmAttention) def stablelm_model_forward( @@ -197,8 +175,7 @@ def stablelm_attention_forward( attn_weights = attn_weights + attention_mask # upcast attention to fp32 - attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, - dtype=torch.float32).to(value_states.dtype) + attn_weights = attention_softmax(attn_weights, self.training) attn_weights = self.attention_dropout(attn_weights) attn_output = torch.matmul(attn_weights, value_states) diff --git a/python/llm/src/ipex_llm/transformers/models/starcoder2.py b/python/llm/src/ipex_llm/transformers/models/starcoder2.py index 654d5c0a..9ebb0c5f 100644 --- a/python/llm/src/ipex_llm/transformers/models/starcoder2.py +++ b/python/llm/src/ipex_llm/transformers/models/starcoder2.py @@ -40,6 +40,7 @@ import math import torch import warnings +from ipex_llm.transformers.models.common import merge_qkv_base, attention_softmax from ipex_llm.transformers.models.utils import ( use_quantize_kv_cache, restore_fp8_kv_cache, should_use_fuse_rope, use_sdp, use_sdp_causal @@ -54,26 +55,7 @@ from transformers.models.starcoder2.modeling_starcoder2 import Starcoder2Model, def merge_qkv(module: torch.nn.Module): - if isinstance(module, Starcoder2Attention): - new_weight = torch.cat([ - module.q_proj.weight.data, - module.k_proj.weight.data, - module.v_proj.weight.data, - ], dim=0) - new_bias = torch.cat([ - module.q_proj.bias.data, - module.k_proj.bias.data, - module.v_proj.bias.data, - ], dim=-1) - - qkv_proj = torch.nn.Linear(0, 0, bias=True) - qkv_proj.weight = torch.nn.Parameter(new_weight, requires_grad=False) - qkv_proj.bias = torch.nn.Parameter(new_bias, requires_grad=False) - qkv_proj.in_features = new_weight.size(1) - qkv_proj.out_features = new_weight.size(0) - module.qkv_proj = qkv_proj - - del module.q_proj, module.k_proj, module.v_proj + merge_qkv_base(module, Starcoder2Attention) def attention_forward( @@ -152,8 +134,7 @@ def attention_forward( attn_weights = attn_weights + attention_mask # upcast attention to fp32 - attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, - dtype=torch.float32).to(query_states.dtype) + attn_weights = attention_softmax(attn_weights, self.training) attn_weights = torch.nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) attn_output = torch.matmul(attn_weights, value_states) diff --git a/python/llm/src/ipex_llm/transformers/models/yuan.py b/python/llm/src/ipex_llm/transformers/models/yuan.py index 9f480ad3..339e958b 100644 --- a/python/llm/src/ipex_llm/transformers/models/yuan.py +++ b/python/llm/src/ipex_llm/transformers/models/yuan.py @@ -26,6 +26,7 @@ from typing import Optional, Tuple import torch from ipex_llm.utils.common import invalidInputError +from ipex_llm.transformers.models.common import attention_softmax from ipex_llm.transformers.models.utils import apply_rotary_pos_emb, \ mlp_fusion_check, fp16_fusion_check from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache @@ -239,8 +240,7 @@ def yuan_attention_forward( if attention_mask is not None: attn_weights = attn_weights + attention_mask # upcast attention to fp32 - attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, - dtype=torch.float32).to(value_states.dtype) + attn_weights = attention_softmax(attn_weights, self.training) attn_output = torch.matmul(attn_weights, value_states) attn_output = attn_output.transpose(1, 2)