refactor merge_qkv and attention_softmax (#12213)
This commit is contained in:
parent
e279148aa0
commit
bb247e991b
6 changed files with 17 additions and 116 deletions
|
|
@ -43,8 +43,8 @@ import torch
|
|||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
from ipex_llm.utils.common.log4Error import invalidInputError
|
||||
from ipex_llm.transformers.models.common import merge_qkv_base, attention_softmax
|
||||
from ipex_llm.transformers.models.utils import should_use_fuse_rope, apply_rotary_pos_emb
|
||||
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_cache_freq_xpu
|
||||
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache
|
||||
from ipex_llm.transformers.models.utils import update_past_key_value
|
||||
from ipex_llm.transformers.models.utils import use_sdp, use_sdp_causal
|
||||
|
|
@ -52,26 +52,7 @@ from einops import rearrange
|
|||
|
||||
|
||||
def merge_qkv(module: torch.nn.Module):
|
||||
if module.__class__.__name__ == "InternLMAttention":
|
||||
new_weight = torch.cat([
|
||||
module.q_proj.weight.data,
|
||||
module.k_proj.weight.data,
|
||||
module.v_proj.weight.data,
|
||||
], dim=0)
|
||||
new_bias = torch.cat([
|
||||
module.q_proj.bias.data,
|
||||
module.k_proj.bias.data,
|
||||
module.v_proj.bias.data,
|
||||
], dim=-1)
|
||||
|
||||
qkv_proj = torch.nn.Linear(0, 0, bias=True)
|
||||
qkv_proj.weight = torch.nn.Parameter(new_weight, requires_grad=False)
|
||||
qkv_proj.bias = torch.nn.Parameter(new_bias, requires_grad=False)
|
||||
qkv_proj.in_features = new_weight.size(1)
|
||||
qkv_proj.out_features = new_weight.size(0)
|
||||
module.qkv_proj = qkv_proj
|
||||
|
||||
del module.q_proj, module.k_proj, module.v_proj
|
||||
merge_qkv_base(module, "InternLMAttention")
|
||||
|
||||
|
||||
def internlm_attention_forward(
|
||||
|
|
@ -144,8 +125,7 @@ def internlm_attention_forward(
|
|||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
# upcast attention to fp32
|
||||
attn_weights = nn.functional.softmax(attn_weights,
|
||||
dim=-1, dtype=torch.float32).to(query_states.dtype)
|
||||
attn_weights = attention_softmax(attn_weights, self.training)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2)
|
||||
|
|
|
|||
|
|
@ -34,6 +34,7 @@
|
|||
import math
|
||||
import torch
|
||||
|
||||
from ipex_llm.transformers.models.common import merge_qkv_base, attention_softmax
|
||||
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_cache_freq_xpu
|
||||
from ipex_llm.transformers.kv import DynamicNormalCache
|
||||
from ipex_llm.utils.common.log4Error import invalidInputError
|
||||
|
|
@ -55,26 +56,7 @@ def should_use_fuse_rope(self, hidden_states, position_ids):
|
|||
|
||||
|
||||
def merge_qkv(module: torch.nn.Module):
|
||||
if module.__class__.__name__ == "PhiAttention":
|
||||
new_weight = torch.cat([
|
||||
module.q_proj.weight.data,
|
||||
module.k_proj.weight.data,
|
||||
module.v_proj.weight.data,
|
||||
], dim=0)
|
||||
new_bias = torch.cat([
|
||||
module.q_proj.bias.data,
|
||||
module.k_proj.bias.data,
|
||||
module.v_proj.bias.data,
|
||||
], dim=-1)
|
||||
|
||||
qkv_proj = torch.nn.Linear(0, 0, bias=True)
|
||||
qkv_proj.weight = torch.nn.Parameter(new_weight, requires_grad=False)
|
||||
qkv_proj.bias = torch.nn.Parameter(new_bias, requires_grad=False)
|
||||
qkv_proj.in_features = new_weight.size(1)
|
||||
qkv_proj.out_features = new_weight.size(0)
|
||||
module.qkv_proj = qkv_proj
|
||||
|
||||
del module.q_proj, module.k_proj, module.v_proj
|
||||
merge_qkv_base(module, "PhiAttention")
|
||||
|
||||
|
||||
def attention_forward(
|
||||
|
|
@ -143,8 +125,7 @@ def attention_forward(
|
|||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
# upcast attention to fp32
|
||||
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1,
|
||||
dtype=torch.float32).to(value_states.dtype)
|
||||
attn_weights = attention_softmax(attn_weights, self.training)
|
||||
attn_weights = torch.nn.functional.dropout(attn_weights, p=self.attention_dropout,
|
||||
training=self.training)
|
||||
|
||||
|
|
|
|||
|
|
@ -43,6 +43,7 @@ import torch.utils.checkpoint
|
|||
from torch.nn import CrossEntropyLoss
|
||||
from typing import Optional, Tuple, Union, List
|
||||
from ipex_llm.utils.common import invalidInputError
|
||||
from ipex_llm.transformers.models.common import merge_qkv_base
|
||||
from ipex_llm.transformers.models.utils import use_quantize_kv_cache
|
||||
from ipex_llm.transformers.kv import DynamicFp8Cache, DynamicNormalCache
|
||||
|
||||
|
|
@ -367,26 +368,7 @@ def qwen2_moe_causal_lm_forward(
|
|||
|
||||
|
||||
def merge_qkv(module: torch.nn.Module):
|
||||
if isinstance(module, Qwen2MoeAttention):
|
||||
new_weight = torch.cat([
|
||||
module.q_proj.weight.data,
|
||||
module.k_proj.weight.data,
|
||||
module.v_proj.weight.data,
|
||||
], dim=0)
|
||||
new_bias = torch.cat([
|
||||
module.q_proj.bias.data,
|
||||
module.k_proj.bias.data,
|
||||
module.v_proj.bias.data,
|
||||
], dim=-1)
|
||||
|
||||
qkv_proj = torch.nn.Linear(0, 0, bias=True)
|
||||
qkv_proj.weight = torch.nn.Parameter(new_weight, requires_grad=False)
|
||||
qkv_proj.bias = torch.nn.Parameter(new_bias, requires_grad=False)
|
||||
qkv_proj.in_features = new_weight.size(1)
|
||||
qkv_proj.out_features = new_weight.size(0)
|
||||
module.qkv_proj = qkv_proj
|
||||
|
||||
del module.q_proj, module.k_proj, module.v_proj
|
||||
merge_qkv_base(module, Qwen2MoeAttention)
|
||||
|
||||
|
||||
def qwen2moe_moeblock_forward(self, hidden_states: torch.Tensor):
|
||||
|
|
|
|||
|
|
@ -45,8 +45,8 @@ from transformers.cache_utils import Cache
|
|||
from transformers.models.stablelm.modeling_stablelm import repeat_kv
|
||||
from transformers.models.stablelm.modeling_stablelm import StableLmAttention, StableLmModel
|
||||
|
||||
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb, \
|
||||
apply_rotary_pos_emb_cache_freq_xpu
|
||||
from ipex_llm.transformers.models.common import merge_qkv_base, attention_softmax
|
||||
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb
|
||||
from ipex_llm.transformers.models.utils import use_sdp, use_sdp_causal
|
||||
from ipex_llm.transformers.models.utils import restore_fp8_kv_cache, use_quantize_kv_cache
|
||||
from ipex_llm.transformers.models.utils import should_use_fuse_rope
|
||||
|
|
@ -54,29 +54,7 @@ from ipex_llm.transformers.kv import DynamicFp8Cache, DynamicNormalCache
|
|||
|
||||
|
||||
def merge_qkv(module: torch.nn.Module):
|
||||
if isinstance(module, StableLmAttention):
|
||||
new_weight = torch.cat([
|
||||
module.q_proj.weight.data,
|
||||
module.k_proj.weight.data,
|
||||
module.v_proj.weight.data,
|
||||
], dim=0)
|
||||
|
||||
if module.q_proj.bias is not None:
|
||||
qkv_proj = torch.nn.Linear(0, 0, bias=True)
|
||||
new_bias = torch.cat([
|
||||
module.q_proj.bias.data,
|
||||
module.k_proj.bias.data,
|
||||
module.v_proj.bias.data,
|
||||
], dim=0)
|
||||
qkv_proj.bias = torch.nn.Parameter(new_bias, requires_grad=False)
|
||||
else:
|
||||
qkv_proj = torch.nn.Linear(0, 0, bias=False)
|
||||
qkv_proj.weight = torch.nn.Parameter(new_weight, requires_grad=False)
|
||||
qkv_proj.in_features = new_weight.size(1)
|
||||
qkv_proj.out_features = new_weight.size(0)
|
||||
module.qkv_proj = qkv_proj
|
||||
|
||||
del module.q_proj, module.k_proj, module.v_proj
|
||||
merge_qkv_base(module, StableLmAttention)
|
||||
|
||||
|
||||
def stablelm_model_forward(
|
||||
|
|
@ -197,8 +175,7 @@ def stablelm_attention_forward(
|
|||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
# upcast attention to fp32
|
||||
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1,
|
||||
dtype=torch.float32).to(value_states.dtype)
|
||||
attn_weights = attention_softmax(attn_weights, self.training)
|
||||
attn_weights = self.attention_dropout(attn_weights)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
|
||||
|
|
|
|||
|
|
@ -40,6 +40,7 @@ import math
|
|||
import torch
|
||||
import warnings
|
||||
|
||||
from ipex_llm.transformers.models.common import merge_qkv_base, attention_softmax
|
||||
from ipex_llm.transformers.models.utils import (
|
||||
use_quantize_kv_cache, restore_fp8_kv_cache,
|
||||
should_use_fuse_rope, use_sdp, use_sdp_causal
|
||||
|
|
@ -54,26 +55,7 @@ from transformers.models.starcoder2.modeling_starcoder2 import Starcoder2Model,
|
|||
|
||||
|
||||
def merge_qkv(module: torch.nn.Module):
|
||||
if isinstance(module, Starcoder2Attention):
|
||||
new_weight = torch.cat([
|
||||
module.q_proj.weight.data,
|
||||
module.k_proj.weight.data,
|
||||
module.v_proj.weight.data,
|
||||
], dim=0)
|
||||
new_bias = torch.cat([
|
||||
module.q_proj.bias.data,
|
||||
module.k_proj.bias.data,
|
||||
module.v_proj.bias.data,
|
||||
], dim=-1)
|
||||
|
||||
qkv_proj = torch.nn.Linear(0, 0, bias=True)
|
||||
qkv_proj.weight = torch.nn.Parameter(new_weight, requires_grad=False)
|
||||
qkv_proj.bias = torch.nn.Parameter(new_bias, requires_grad=False)
|
||||
qkv_proj.in_features = new_weight.size(1)
|
||||
qkv_proj.out_features = new_weight.size(0)
|
||||
module.qkv_proj = qkv_proj
|
||||
|
||||
del module.q_proj, module.k_proj, module.v_proj
|
||||
merge_qkv_base(module, Starcoder2Attention)
|
||||
|
||||
|
||||
def attention_forward(
|
||||
|
|
@ -152,8 +134,7 @@ def attention_forward(
|
|||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
# upcast attention to fp32
|
||||
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1,
|
||||
dtype=torch.float32).to(query_states.dtype)
|
||||
attn_weights = attention_softmax(attn_weights, self.training)
|
||||
attn_weights = torch.nn.functional.dropout(attn_weights, p=self.attention_dropout,
|
||||
training=self.training)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
|
|
|
|||
|
|
@ -26,6 +26,7 @@ from typing import Optional, Tuple
|
|||
import torch
|
||||
|
||||
from ipex_llm.utils.common import invalidInputError
|
||||
from ipex_llm.transformers.models.common import attention_softmax
|
||||
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb, \
|
||||
mlp_fusion_check, fp16_fusion_check
|
||||
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache
|
||||
|
|
@ -239,8 +240,7 @@ def yuan_attention_forward(
|
|||
if attention_mask is not None:
|
||||
attn_weights = attn_weights + attention_mask
|
||||
# upcast attention to fp32
|
||||
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1,
|
||||
dtype=torch.float32).to(value_states.dtype)
|
||||
attn_weights = attention_softmax(attn_weights, self.training)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2)
|
||||
|
|
|
|||
Loading…
Reference in a new issue