refactor merge_qkv and attention_softmax (#12213)
This commit is contained in:
parent
e279148aa0
commit
bb247e991b
6 changed files with 17 additions and 116 deletions
|
|
@ -43,8 +43,8 @@ import torch
|
||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from ipex_llm.utils.common.log4Error import invalidInputError
|
from ipex_llm.utils.common.log4Error import invalidInputError
|
||||||
|
from ipex_llm.transformers.models.common import merge_qkv_base, attention_softmax
|
||||||
from ipex_llm.transformers.models.utils import should_use_fuse_rope, apply_rotary_pos_emb
|
from ipex_llm.transformers.models.utils import should_use_fuse_rope, apply_rotary_pos_emb
|
||||||
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_cache_freq_xpu
|
|
||||||
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache
|
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache
|
||||||
from ipex_llm.transformers.models.utils import update_past_key_value
|
from ipex_llm.transformers.models.utils import update_past_key_value
|
||||||
from ipex_llm.transformers.models.utils import use_sdp, use_sdp_causal
|
from ipex_llm.transformers.models.utils import use_sdp, use_sdp_causal
|
||||||
|
|
@ -52,26 +52,7 @@ from einops import rearrange
|
||||||
|
|
||||||
|
|
||||||
def merge_qkv(module: torch.nn.Module):
|
def merge_qkv(module: torch.nn.Module):
|
||||||
if module.__class__.__name__ == "InternLMAttention":
|
merge_qkv_base(module, "InternLMAttention")
|
||||||
new_weight = torch.cat([
|
|
||||||
module.q_proj.weight.data,
|
|
||||||
module.k_proj.weight.data,
|
|
||||||
module.v_proj.weight.data,
|
|
||||||
], dim=0)
|
|
||||||
new_bias = torch.cat([
|
|
||||||
module.q_proj.bias.data,
|
|
||||||
module.k_proj.bias.data,
|
|
||||||
module.v_proj.bias.data,
|
|
||||||
], dim=-1)
|
|
||||||
|
|
||||||
qkv_proj = torch.nn.Linear(0, 0, bias=True)
|
|
||||||
qkv_proj.weight = torch.nn.Parameter(new_weight, requires_grad=False)
|
|
||||||
qkv_proj.bias = torch.nn.Parameter(new_bias, requires_grad=False)
|
|
||||||
qkv_proj.in_features = new_weight.size(1)
|
|
||||||
qkv_proj.out_features = new_weight.size(0)
|
|
||||||
module.qkv_proj = qkv_proj
|
|
||||||
|
|
||||||
del module.q_proj, module.k_proj, module.v_proj
|
|
||||||
|
|
||||||
|
|
||||||
def internlm_attention_forward(
|
def internlm_attention_forward(
|
||||||
|
|
@ -144,8 +125,7 @@ def internlm_attention_forward(
|
||||||
attn_weights = attn_weights + attention_mask
|
attn_weights = attn_weights + attention_mask
|
||||||
|
|
||||||
# upcast attention to fp32
|
# upcast attention to fp32
|
||||||
attn_weights = nn.functional.softmax(attn_weights,
|
attn_weights = attention_softmax(attn_weights, self.training)
|
||||||
dim=-1, dtype=torch.float32).to(query_states.dtype)
|
|
||||||
attn_output = torch.matmul(attn_weights, value_states)
|
attn_output = torch.matmul(attn_weights, value_states)
|
||||||
|
|
||||||
attn_output = attn_output.transpose(1, 2)
|
attn_output = attn_output.transpose(1, 2)
|
||||||
|
|
|
||||||
|
|
@ -34,6 +34,7 @@
|
||||||
import math
|
import math
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from ipex_llm.transformers.models.common import merge_qkv_base, attention_softmax
|
||||||
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_cache_freq_xpu
|
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_cache_freq_xpu
|
||||||
from ipex_llm.transformers.kv import DynamicNormalCache
|
from ipex_llm.transformers.kv import DynamicNormalCache
|
||||||
from ipex_llm.utils.common.log4Error import invalidInputError
|
from ipex_llm.utils.common.log4Error import invalidInputError
|
||||||
|
|
@ -55,26 +56,7 @@ def should_use_fuse_rope(self, hidden_states, position_ids):
|
||||||
|
|
||||||
|
|
||||||
def merge_qkv(module: torch.nn.Module):
|
def merge_qkv(module: torch.nn.Module):
|
||||||
if module.__class__.__name__ == "PhiAttention":
|
merge_qkv_base(module, "PhiAttention")
|
||||||
new_weight = torch.cat([
|
|
||||||
module.q_proj.weight.data,
|
|
||||||
module.k_proj.weight.data,
|
|
||||||
module.v_proj.weight.data,
|
|
||||||
], dim=0)
|
|
||||||
new_bias = torch.cat([
|
|
||||||
module.q_proj.bias.data,
|
|
||||||
module.k_proj.bias.data,
|
|
||||||
module.v_proj.bias.data,
|
|
||||||
], dim=-1)
|
|
||||||
|
|
||||||
qkv_proj = torch.nn.Linear(0, 0, bias=True)
|
|
||||||
qkv_proj.weight = torch.nn.Parameter(new_weight, requires_grad=False)
|
|
||||||
qkv_proj.bias = torch.nn.Parameter(new_bias, requires_grad=False)
|
|
||||||
qkv_proj.in_features = new_weight.size(1)
|
|
||||||
qkv_proj.out_features = new_weight.size(0)
|
|
||||||
module.qkv_proj = qkv_proj
|
|
||||||
|
|
||||||
del module.q_proj, module.k_proj, module.v_proj
|
|
||||||
|
|
||||||
|
|
||||||
def attention_forward(
|
def attention_forward(
|
||||||
|
|
@ -143,8 +125,7 @@ def attention_forward(
|
||||||
attn_weights = attn_weights + attention_mask
|
attn_weights = attn_weights + attention_mask
|
||||||
|
|
||||||
# upcast attention to fp32
|
# upcast attention to fp32
|
||||||
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1,
|
attn_weights = attention_softmax(attn_weights, self.training)
|
||||||
dtype=torch.float32).to(value_states.dtype)
|
|
||||||
attn_weights = torch.nn.functional.dropout(attn_weights, p=self.attention_dropout,
|
attn_weights = torch.nn.functional.dropout(attn_weights, p=self.attention_dropout,
|
||||||
training=self.training)
|
training=self.training)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -43,6 +43,7 @@ import torch.utils.checkpoint
|
||||||
from torch.nn import CrossEntropyLoss
|
from torch.nn import CrossEntropyLoss
|
||||||
from typing import Optional, Tuple, Union, List
|
from typing import Optional, Tuple, Union, List
|
||||||
from ipex_llm.utils.common import invalidInputError
|
from ipex_llm.utils.common import invalidInputError
|
||||||
|
from ipex_llm.transformers.models.common import merge_qkv_base
|
||||||
from ipex_llm.transformers.models.utils import use_quantize_kv_cache
|
from ipex_llm.transformers.models.utils import use_quantize_kv_cache
|
||||||
from ipex_llm.transformers.kv import DynamicFp8Cache, DynamicNormalCache
|
from ipex_llm.transformers.kv import DynamicFp8Cache, DynamicNormalCache
|
||||||
|
|
||||||
|
|
@ -367,26 +368,7 @@ def qwen2_moe_causal_lm_forward(
|
||||||
|
|
||||||
|
|
||||||
def merge_qkv(module: torch.nn.Module):
|
def merge_qkv(module: torch.nn.Module):
|
||||||
if isinstance(module, Qwen2MoeAttention):
|
merge_qkv_base(module, Qwen2MoeAttention)
|
||||||
new_weight = torch.cat([
|
|
||||||
module.q_proj.weight.data,
|
|
||||||
module.k_proj.weight.data,
|
|
||||||
module.v_proj.weight.data,
|
|
||||||
], dim=0)
|
|
||||||
new_bias = torch.cat([
|
|
||||||
module.q_proj.bias.data,
|
|
||||||
module.k_proj.bias.data,
|
|
||||||
module.v_proj.bias.data,
|
|
||||||
], dim=-1)
|
|
||||||
|
|
||||||
qkv_proj = torch.nn.Linear(0, 0, bias=True)
|
|
||||||
qkv_proj.weight = torch.nn.Parameter(new_weight, requires_grad=False)
|
|
||||||
qkv_proj.bias = torch.nn.Parameter(new_bias, requires_grad=False)
|
|
||||||
qkv_proj.in_features = new_weight.size(1)
|
|
||||||
qkv_proj.out_features = new_weight.size(0)
|
|
||||||
module.qkv_proj = qkv_proj
|
|
||||||
|
|
||||||
del module.q_proj, module.k_proj, module.v_proj
|
|
||||||
|
|
||||||
|
|
||||||
def qwen2moe_moeblock_forward(self, hidden_states: torch.Tensor):
|
def qwen2moe_moeblock_forward(self, hidden_states: torch.Tensor):
|
||||||
|
|
|
||||||
|
|
@ -45,8 +45,8 @@ from transformers.cache_utils import Cache
|
||||||
from transformers.models.stablelm.modeling_stablelm import repeat_kv
|
from transformers.models.stablelm.modeling_stablelm import repeat_kv
|
||||||
from transformers.models.stablelm.modeling_stablelm import StableLmAttention, StableLmModel
|
from transformers.models.stablelm.modeling_stablelm import StableLmAttention, StableLmModel
|
||||||
|
|
||||||
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb, \
|
from ipex_llm.transformers.models.common import merge_qkv_base, attention_softmax
|
||||||
apply_rotary_pos_emb_cache_freq_xpu
|
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb
|
||||||
from ipex_llm.transformers.models.utils import use_sdp, use_sdp_causal
|
from ipex_llm.transformers.models.utils import use_sdp, use_sdp_causal
|
||||||
from ipex_llm.transformers.models.utils import restore_fp8_kv_cache, use_quantize_kv_cache
|
from ipex_llm.transformers.models.utils import restore_fp8_kv_cache, use_quantize_kv_cache
|
||||||
from ipex_llm.transformers.models.utils import should_use_fuse_rope
|
from ipex_llm.transformers.models.utils import should_use_fuse_rope
|
||||||
|
|
@ -54,29 +54,7 @@ from ipex_llm.transformers.kv import DynamicFp8Cache, DynamicNormalCache
|
||||||
|
|
||||||
|
|
||||||
def merge_qkv(module: torch.nn.Module):
|
def merge_qkv(module: torch.nn.Module):
|
||||||
if isinstance(module, StableLmAttention):
|
merge_qkv_base(module, StableLmAttention)
|
||||||
new_weight = torch.cat([
|
|
||||||
module.q_proj.weight.data,
|
|
||||||
module.k_proj.weight.data,
|
|
||||||
module.v_proj.weight.data,
|
|
||||||
], dim=0)
|
|
||||||
|
|
||||||
if module.q_proj.bias is not None:
|
|
||||||
qkv_proj = torch.nn.Linear(0, 0, bias=True)
|
|
||||||
new_bias = torch.cat([
|
|
||||||
module.q_proj.bias.data,
|
|
||||||
module.k_proj.bias.data,
|
|
||||||
module.v_proj.bias.data,
|
|
||||||
], dim=0)
|
|
||||||
qkv_proj.bias = torch.nn.Parameter(new_bias, requires_grad=False)
|
|
||||||
else:
|
|
||||||
qkv_proj = torch.nn.Linear(0, 0, bias=False)
|
|
||||||
qkv_proj.weight = torch.nn.Parameter(new_weight, requires_grad=False)
|
|
||||||
qkv_proj.in_features = new_weight.size(1)
|
|
||||||
qkv_proj.out_features = new_weight.size(0)
|
|
||||||
module.qkv_proj = qkv_proj
|
|
||||||
|
|
||||||
del module.q_proj, module.k_proj, module.v_proj
|
|
||||||
|
|
||||||
|
|
||||||
def stablelm_model_forward(
|
def stablelm_model_forward(
|
||||||
|
|
@ -197,8 +175,7 @@ def stablelm_attention_forward(
|
||||||
attn_weights = attn_weights + attention_mask
|
attn_weights = attn_weights + attention_mask
|
||||||
|
|
||||||
# upcast attention to fp32
|
# upcast attention to fp32
|
||||||
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1,
|
attn_weights = attention_softmax(attn_weights, self.training)
|
||||||
dtype=torch.float32).to(value_states.dtype)
|
|
||||||
attn_weights = self.attention_dropout(attn_weights)
|
attn_weights = self.attention_dropout(attn_weights)
|
||||||
attn_output = torch.matmul(attn_weights, value_states)
|
attn_output = torch.matmul(attn_weights, value_states)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -40,6 +40,7 @@ import math
|
||||||
import torch
|
import torch
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
|
from ipex_llm.transformers.models.common import merge_qkv_base, attention_softmax
|
||||||
from ipex_llm.transformers.models.utils import (
|
from ipex_llm.transformers.models.utils import (
|
||||||
use_quantize_kv_cache, restore_fp8_kv_cache,
|
use_quantize_kv_cache, restore_fp8_kv_cache,
|
||||||
should_use_fuse_rope, use_sdp, use_sdp_causal
|
should_use_fuse_rope, use_sdp, use_sdp_causal
|
||||||
|
|
@ -54,26 +55,7 @@ from transformers.models.starcoder2.modeling_starcoder2 import Starcoder2Model,
|
||||||
|
|
||||||
|
|
||||||
def merge_qkv(module: torch.nn.Module):
|
def merge_qkv(module: torch.nn.Module):
|
||||||
if isinstance(module, Starcoder2Attention):
|
merge_qkv_base(module, Starcoder2Attention)
|
||||||
new_weight = torch.cat([
|
|
||||||
module.q_proj.weight.data,
|
|
||||||
module.k_proj.weight.data,
|
|
||||||
module.v_proj.weight.data,
|
|
||||||
], dim=0)
|
|
||||||
new_bias = torch.cat([
|
|
||||||
module.q_proj.bias.data,
|
|
||||||
module.k_proj.bias.data,
|
|
||||||
module.v_proj.bias.data,
|
|
||||||
], dim=-1)
|
|
||||||
|
|
||||||
qkv_proj = torch.nn.Linear(0, 0, bias=True)
|
|
||||||
qkv_proj.weight = torch.nn.Parameter(new_weight, requires_grad=False)
|
|
||||||
qkv_proj.bias = torch.nn.Parameter(new_bias, requires_grad=False)
|
|
||||||
qkv_proj.in_features = new_weight.size(1)
|
|
||||||
qkv_proj.out_features = new_weight.size(0)
|
|
||||||
module.qkv_proj = qkv_proj
|
|
||||||
|
|
||||||
del module.q_proj, module.k_proj, module.v_proj
|
|
||||||
|
|
||||||
|
|
||||||
def attention_forward(
|
def attention_forward(
|
||||||
|
|
@ -152,8 +134,7 @@ def attention_forward(
|
||||||
attn_weights = attn_weights + attention_mask
|
attn_weights = attn_weights + attention_mask
|
||||||
|
|
||||||
# upcast attention to fp32
|
# upcast attention to fp32
|
||||||
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1,
|
attn_weights = attention_softmax(attn_weights, self.training)
|
||||||
dtype=torch.float32).to(query_states.dtype)
|
|
||||||
attn_weights = torch.nn.functional.dropout(attn_weights, p=self.attention_dropout,
|
attn_weights = torch.nn.functional.dropout(attn_weights, p=self.attention_dropout,
|
||||||
training=self.training)
|
training=self.training)
|
||||||
attn_output = torch.matmul(attn_weights, value_states)
|
attn_output = torch.matmul(attn_weights, value_states)
|
||||||
|
|
|
||||||
|
|
@ -26,6 +26,7 @@ from typing import Optional, Tuple
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from ipex_llm.utils.common import invalidInputError
|
from ipex_llm.utils.common import invalidInputError
|
||||||
|
from ipex_llm.transformers.models.common import attention_softmax
|
||||||
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb, \
|
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb, \
|
||||||
mlp_fusion_check, fp16_fusion_check
|
mlp_fusion_check, fp16_fusion_check
|
||||||
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache
|
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache
|
||||||
|
|
@ -239,8 +240,7 @@ def yuan_attention_forward(
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
attn_weights = attn_weights + attention_mask
|
attn_weights = attn_weights + attention_mask
|
||||||
# upcast attention to fp32
|
# upcast attention to fp32
|
||||||
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1,
|
attn_weights = attention_softmax(attn_weights, self.training)
|
||||||
dtype=torch.float32).to(value_states.dtype)
|
|
||||||
attn_output = torch.matmul(attn_weights, value_states)
|
attn_output = torch.matmul(attn_weights, value_states)
|
||||||
|
|
||||||
attn_output = attn_output.transpose(1, 2)
|
attn_output = attn_output.transpose(1, 2)
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue