refactor merge_qkv and attention_softmax (#12213)

This commit is contained in:
Yishuo Wang 2024-10-16 15:58:14 +08:00 committed by GitHub
parent e279148aa0
commit bb247e991b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 17 additions and 116 deletions

View file

@ -43,8 +43,8 @@ import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from torch import nn from torch import nn
from ipex_llm.utils.common.log4Error import invalidInputError from ipex_llm.utils.common.log4Error import invalidInputError
from ipex_llm.transformers.models.common import merge_qkv_base, attention_softmax
from ipex_llm.transformers.models.utils import should_use_fuse_rope, apply_rotary_pos_emb from ipex_llm.transformers.models.utils import should_use_fuse_rope, apply_rotary_pos_emb
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_cache_freq_xpu
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache
from ipex_llm.transformers.models.utils import update_past_key_value from ipex_llm.transformers.models.utils import update_past_key_value
from ipex_llm.transformers.models.utils import use_sdp, use_sdp_causal from ipex_llm.transformers.models.utils import use_sdp, use_sdp_causal
@ -52,26 +52,7 @@ from einops import rearrange
def merge_qkv(module: torch.nn.Module): def merge_qkv(module: torch.nn.Module):
if module.__class__.__name__ == "InternLMAttention": merge_qkv_base(module, "InternLMAttention")
new_weight = torch.cat([
module.q_proj.weight.data,
module.k_proj.weight.data,
module.v_proj.weight.data,
], dim=0)
new_bias = torch.cat([
module.q_proj.bias.data,
module.k_proj.bias.data,
module.v_proj.bias.data,
], dim=-1)
qkv_proj = torch.nn.Linear(0, 0, bias=True)
qkv_proj.weight = torch.nn.Parameter(new_weight, requires_grad=False)
qkv_proj.bias = torch.nn.Parameter(new_bias, requires_grad=False)
qkv_proj.in_features = new_weight.size(1)
qkv_proj.out_features = new_weight.size(0)
module.qkv_proj = qkv_proj
del module.q_proj, module.k_proj, module.v_proj
def internlm_attention_forward( def internlm_attention_forward(
@ -144,8 +125,7 @@ def internlm_attention_forward(
attn_weights = attn_weights + attention_mask attn_weights = attn_weights + attention_mask
# upcast attention to fp32 # upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, attn_weights = attention_softmax(attn_weights, self.training)
dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_output = torch.matmul(attn_weights, value_states) attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2) attn_output = attn_output.transpose(1, 2)

View file

@ -34,6 +34,7 @@
import math import math
import torch import torch
from ipex_llm.transformers.models.common import merge_qkv_base, attention_softmax
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_cache_freq_xpu from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_cache_freq_xpu
from ipex_llm.transformers.kv import DynamicNormalCache from ipex_llm.transformers.kv import DynamicNormalCache
from ipex_llm.utils.common.log4Error import invalidInputError from ipex_llm.utils.common.log4Error import invalidInputError
@ -55,26 +56,7 @@ def should_use_fuse_rope(self, hidden_states, position_ids):
def merge_qkv(module: torch.nn.Module): def merge_qkv(module: torch.nn.Module):
if module.__class__.__name__ == "PhiAttention": merge_qkv_base(module, "PhiAttention")
new_weight = torch.cat([
module.q_proj.weight.data,
module.k_proj.weight.data,
module.v_proj.weight.data,
], dim=0)
new_bias = torch.cat([
module.q_proj.bias.data,
module.k_proj.bias.data,
module.v_proj.bias.data,
], dim=-1)
qkv_proj = torch.nn.Linear(0, 0, bias=True)
qkv_proj.weight = torch.nn.Parameter(new_weight, requires_grad=False)
qkv_proj.bias = torch.nn.Parameter(new_bias, requires_grad=False)
qkv_proj.in_features = new_weight.size(1)
qkv_proj.out_features = new_weight.size(0)
module.qkv_proj = qkv_proj
del module.q_proj, module.k_proj, module.v_proj
def attention_forward( def attention_forward(
@ -143,8 +125,7 @@ def attention_forward(
attn_weights = attn_weights + attention_mask attn_weights = attn_weights + attention_mask
# upcast attention to fp32 # upcast attention to fp32
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, attn_weights = attention_softmax(attn_weights, self.training)
dtype=torch.float32).to(value_states.dtype)
attn_weights = torch.nn.functional.dropout(attn_weights, p=self.attention_dropout, attn_weights = torch.nn.functional.dropout(attn_weights, p=self.attention_dropout,
training=self.training) training=self.training)

View file

@ -43,6 +43,7 @@ import torch.utils.checkpoint
from torch.nn import CrossEntropyLoss from torch.nn import CrossEntropyLoss
from typing import Optional, Tuple, Union, List from typing import Optional, Tuple, Union, List
from ipex_llm.utils.common import invalidInputError from ipex_llm.utils.common import invalidInputError
from ipex_llm.transformers.models.common import merge_qkv_base
from ipex_llm.transformers.models.utils import use_quantize_kv_cache from ipex_llm.transformers.models.utils import use_quantize_kv_cache
from ipex_llm.transformers.kv import DynamicFp8Cache, DynamicNormalCache from ipex_llm.transformers.kv import DynamicFp8Cache, DynamicNormalCache
@ -367,26 +368,7 @@ def qwen2_moe_causal_lm_forward(
def merge_qkv(module: torch.nn.Module): def merge_qkv(module: torch.nn.Module):
if isinstance(module, Qwen2MoeAttention): merge_qkv_base(module, Qwen2MoeAttention)
new_weight = torch.cat([
module.q_proj.weight.data,
module.k_proj.weight.data,
module.v_proj.weight.data,
], dim=0)
new_bias = torch.cat([
module.q_proj.bias.data,
module.k_proj.bias.data,
module.v_proj.bias.data,
], dim=-1)
qkv_proj = torch.nn.Linear(0, 0, bias=True)
qkv_proj.weight = torch.nn.Parameter(new_weight, requires_grad=False)
qkv_proj.bias = torch.nn.Parameter(new_bias, requires_grad=False)
qkv_proj.in_features = new_weight.size(1)
qkv_proj.out_features = new_weight.size(0)
module.qkv_proj = qkv_proj
del module.q_proj, module.k_proj, module.v_proj
def qwen2moe_moeblock_forward(self, hidden_states: torch.Tensor): def qwen2moe_moeblock_forward(self, hidden_states: torch.Tensor):

View file

@ -45,8 +45,8 @@ from transformers.cache_utils import Cache
from transformers.models.stablelm.modeling_stablelm import repeat_kv from transformers.models.stablelm.modeling_stablelm import repeat_kv
from transformers.models.stablelm.modeling_stablelm import StableLmAttention, StableLmModel from transformers.models.stablelm.modeling_stablelm import StableLmAttention, StableLmModel
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb, \ from ipex_llm.transformers.models.common import merge_qkv_base, attention_softmax
apply_rotary_pos_emb_cache_freq_xpu from ipex_llm.transformers.models.utils import apply_rotary_pos_emb
from ipex_llm.transformers.models.utils import use_sdp, use_sdp_causal from ipex_llm.transformers.models.utils import use_sdp, use_sdp_causal
from ipex_llm.transformers.models.utils import restore_fp8_kv_cache, use_quantize_kv_cache from ipex_llm.transformers.models.utils import restore_fp8_kv_cache, use_quantize_kv_cache
from ipex_llm.transformers.models.utils import should_use_fuse_rope from ipex_llm.transformers.models.utils import should_use_fuse_rope
@ -54,29 +54,7 @@ from ipex_llm.transformers.kv import DynamicFp8Cache, DynamicNormalCache
def merge_qkv(module: torch.nn.Module): def merge_qkv(module: torch.nn.Module):
if isinstance(module, StableLmAttention): merge_qkv_base(module, StableLmAttention)
new_weight = torch.cat([
module.q_proj.weight.data,
module.k_proj.weight.data,
module.v_proj.weight.data,
], dim=0)
if module.q_proj.bias is not None:
qkv_proj = torch.nn.Linear(0, 0, bias=True)
new_bias = torch.cat([
module.q_proj.bias.data,
module.k_proj.bias.data,
module.v_proj.bias.data,
], dim=0)
qkv_proj.bias = torch.nn.Parameter(new_bias, requires_grad=False)
else:
qkv_proj = torch.nn.Linear(0, 0, bias=False)
qkv_proj.weight = torch.nn.Parameter(new_weight, requires_grad=False)
qkv_proj.in_features = new_weight.size(1)
qkv_proj.out_features = new_weight.size(0)
module.qkv_proj = qkv_proj
del module.q_proj, module.k_proj, module.v_proj
def stablelm_model_forward( def stablelm_model_forward(
@ -197,8 +175,7 @@ def stablelm_attention_forward(
attn_weights = attn_weights + attention_mask attn_weights = attn_weights + attention_mask
# upcast attention to fp32 # upcast attention to fp32
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, attn_weights = attention_softmax(attn_weights, self.training)
dtype=torch.float32).to(value_states.dtype)
attn_weights = self.attention_dropout(attn_weights) attn_weights = self.attention_dropout(attn_weights)
attn_output = torch.matmul(attn_weights, value_states) attn_output = torch.matmul(attn_weights, value_states)

View file

@ -40,6 +40,7 @@ import math
import torch import torch
import warnings import warnings
from ipex_llm.transformers.models.common import merge_qkv_base, attention_softmax
from ipex_llm.transformers.models.utils import ( from ipex_llm.transformers.models.utils import (
use_quantize_kv_cache, restore_fp8_kv_cache, use_quantize_kv_cache, restore_fp8_kv_cache,
should_use_fuse_rope, use_sdp, use_sdp_causal should_use_fuse_rope, use_sdp, use_sdp_causal
@ -54,26 +55,7 @@ from transformers.models.starcoder2.modeling_starcoder2 import Starcoder2Model,
def merge_qkv(module: torch.nn.Module): def merge_qkv(module: torch.nn.Module):
if isinstance(module, Starcoder2Attention): merge_qkv_base(module, Starcoder2Attention)
new_weight = torch.cat([
module.q_proj.weight.data,
module.k_proj.weight.data,
module.v_proj.weight.data,
], dim=0)
new_bias = torch.cat([
module.q_proj.bias.data,
module.k_proj.bias.data,
module.v_proj.bias.data,
], dim=-1)
qkv_proj = torch.nn.Linear(0, 0, bias=True)
qkv_proj.weight = torch.nn.Parameter(new_weight, requires_grad=False)
qkv_proj.bias = torch.nn.Parameter(new_bias, requires_grad=False)
qkv_proj.in_features = new_weight.size(1)
qkv_proj.out_features = new_weight.size(0)
module.qkv_proj = qkv_proj
del module.q_proj, module.k_proj, module.v_proj
def attention_forward( def attention_forward(
@ -152,8 +134,7 @@ def attention_forward(
attn_weights = attn_weights + attention_mask attn_weights = attn_weights + attention_mask
# upcast attention to fp32 # upcast attention to fp32
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, attn_weights = attention_softmax(attn_weights, self.training)
dtype=torch.float32).to(query_states.dtype)
attn_weights = torch.nn.functional.dropout(attn_weights, p=self.attention_dropout, attn_weights = torch.nn.functional.dropout(attn_weights, p=self.attention_dropout,
training=self.training) training=self.training)
attn_output = torch.matmul(attn_weights, value_states) attn_output = torch.matmul(attn_weights, value_states)

View file

@ -26,6 +26,7 @@ from typing import Optional, Tuple
import torch import torch
from ipex_llm.utils.common import invalidInputError from ipex_llm.utils.common import invalidInputError
from ipex_llm.transformers.models.common import attention_softmax
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb, \ from ipex_llm.transformers.models.utils import apply_rotary_pos_emb, \
mlp_fusion_check, fp16_fusion_check mlp_fusion_check, fp16_fusion_check
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache
@ -239,8 +240,7 @@ def yuan_attention_forward(
if attention_mask is not None: if attention_mask is not None:
attn_weights = attn_weights + attention_mask attn_weights = attn_weights + attention_mask
# upcast attention to fp32 # upcast attention to fp32
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, attn_weights = attention_softmax(attn_weights, self.training)
dtype=torch.float32).to(value_states.dtype)
attn_output = torch.matmul(attn_weights, value_states) attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2) attn_output = attn_output.transpose(1, 2)