optimize npu llama perf again (#11431)
This commit is contained in:
parent
9f6e5b4fba
commit
ca0e69c3a7
4 changed files with 64 additions and 25 deletions
|
|
@ -116,12 +116,9 @@ class _BaseAutoModelClass:
|
||||||
try:
|
try:
|
||||||
# for intel_npu_acceleration_library >= 1.1.0
|
# for intel_npu_acceleration_library >= 1.1.0
|
||||||
from intel_npu_acceleration_library.quantization import quantize_model
|
from intel_npu_acceleration_library.quantization import quantize_model
|
||||||
from intel_npu_acceleration_library.compiler import (
|
from intel_npu_acceleration_library.compiler import create_npu_kernels
|
||||||
apply_horizontal_fusion, create_npu_kernels
|
|
||||||
)
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
optimize_llm(model)
|
optimize_llm(model)
|
||||||
apply_horizontal_fusion(model)
|
|
||||||
if not qtype.is_floating_point:
|
if not qtype.is_floating_point:
|
||||||
model = quantize_model(model, qtype)
|
model = quantize_model(model, qtype)
|
||||||
create_npu_kernels(model)
|
create_npu_kernels(model)
|
||||||
|
|
|
||||||
32
python/llm/src/ipex_llm/transformers/npu_models/common.py
Normal file
32
python/llm/src/ipex_llm/transformers/npu_models/common.py
Normal file
|
|
@ -0,0 +1,32 @@
|
||||||
|
#
|
||||||
|
# Copyright 2016 The BigDL Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
|
||||||
|
def merge_linear(linears: List[torch.nn.Linear]) -> torch.nn.Linear:
|
||||||
|
new_weight = torch.cat(list(linear.weight.data for linear in linears), dim=0)
|
||||||
|
if linears[0].bias is not None:
|
||||||
|
new_linear = torch.nn.Linear(0, 0, bias=True)
|
||||||
|
new_bias = torch.cat(list(linear.bias.data for linear in linears), dim=0)
|
||||||
|
new_linear.bias = torch.nn.Parameter(new_bias, requires_grad=False)
|
||||||
|
else:
|
||||||
|
new_linear = torch.nn.Linear(0, 0, bias=False)
|
||||||
|
new_linear.weight = torch.nn.Parameter(new_weight, requires_grad=False)
|
||||||
|
new_linear.in_features = new_weight.size(1)
|
||||||
|
new_linear.out_features = new_weight.size(0)
|
||||||
|
return new_linear
|
||||||
|
|
@ -29,6 +29,11 @@ def optimize_llm(model: torch.nn.Module):
|
||||||
if model.config.model_type == "llama":
|
if model.config.model_type == "llama":
|
||||||
from ipex_llm.transformers.npu_models.llama import merge_qkv
|
from ipex_llm.transformers.npu_models.llama import merge_qkv
|
||||||
model.apply(merge_qkv)
|
model.apply(merge_qkv)
|
||||||
|
from ipex_llm.transformers.npu_models.llama import merge_mlp
|
||||||
|
model.apply(merge_mlp)
|
||||||
from ipex_llm.transformers.npu_models.llama import llama_attention_forward
|
from ipex_llm.transformers.npu_models.llama import llama_attention_forward
|
||||||
from transformers.models.llama.modeling_llama import LlamaAttention
|
from transformers.models.llama.modeling_llama import LlamaAttention
|
||||||
convert_forward(model, LlamaAttention, llama_attention_forward)
|
convert_forward(model, LlamaAttention, llama_attention_forward)
|
||||||
|
from ipex_llm.transformers.npu_models.llama import llama_mlp_forward
|
||||||
|
from transformers.models.llama.modeling_llama import LlamaMLP
|
||||||
|
convert_forward(model, LlamaMLP, llama_mlp_forward)
|
||||||
|
|
|
||||||
|
|
@ -36,35 +36,33 @@ from typing import Optional, Tuple
|
||||||
from transformers.cache_utils import Cache
|
from transformers.cache_utils import Cache
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from transformers.models.llama.modeling_llama import LlamaAttention, repeat_kv, apply_rotary_pos_emb
|
from transformers.models.llama.modeling_llama import repeat_kv, apply_rotary_pos_emb
|
||||||
|
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaMLP
|
||||||
|
|
||||||
|
from ipex_llm.transformers.npu_models.common import merge_linear
|
||||||
|
|
||||||
|
|
||||||
def merge_qkv(module: torch.nn.Module):
|
def merge_qkv(module: torch.nn.Module):
|
||||||
if isinstance(module, LlamaAttention):
|
if isinstance(module, LlamaAttention):
|
||||||
new_weight = torch.cat([
|
qkv_proj = merge_linear([
|
||||||
module.q_proj.weight.data,
|
module.q_proj,
|
||||||
module.k_proj.weight.data,
|
module.k_proj,
|
||||||
module.v_proj.weight.data,
|
module.v_proj,
|
||||||
], dim=0)
|
])
|
||||||
|
|
||||||
if module.q_proj.bias is not None:
|
|
||||||
qkv_proj = torch.nn.Linear(0, 0, bias=True)
|
|
||||||
new_bias = torch.cat([
|
|
||||||
module.q_proj.bias.data,
|
|
||||||
module.k_proj.bias.data,
|
|
||||||
module.v_proj.bias.data,
|
|
||||||
], dim=0)
|
|
||||||
qkv_proj.bias = torch.nn.Parameter(new_bias, requires_grad=False)
|
|
||||||
else:
|
|
||||||
qkv_proj = torch.nn.Linear(0, 0, bias=False)
|
|
||||||
qkv_proj.weight = torch.nn.Parameter(new_weight, requires_grad=False)
|
|
||||||
qkv_proj.in_features = new_weight.size(1)
|
|
||||||
qkv_proj.out_features = new_weight.size(0)
|
|
||||||
module.qkv_proj = qkv_proj
|
module.qkv_proj = qkv_proj
|
||||||
|
|
||||||
del module.q_proj, module.k_proj, module.v_proj
|
del module.q_proj, module.k_proj, module.v_proj
|
||||||
|
|
||||||
|
|
||||||
|
def merge_mlp(module: torch.nn.Module):
|
||||||
|
if isinstance(module, LlamaMLP):
|
||||||
|
gate_up_proj = merge_linear([
|
||||||
|
module.gate_proj,
|
||||||
|
module.up_proj,
|
||||||
|
])
|
||||||
|
module.gate_up_proj = gate_up_proj
|
||||||
|
del module.gate_proj, module.up_proj
|
||||||
|
|
||||||
|
|
||||||
def llama_attention_forward(
|
def llama_attention_forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
|
|
@ -121,3 +119,10 @@ def llama_attention_forward(
|
||||||
attn_weights = None
|
attn_weights = None
|
||||||
|
|
||||||
return attn_output, attn_weights, past_key_value
|
return attn_output, attn_weights, past_key_value
|
||||||
|
|
||||||
|
|
||||||
|
def llama_mlp_forward(self, x):
|
||||||
|
gate_up_proj = self.gate_up_proj(x)
|
||||||
|
gate_proj, up_proj = gate_up_proj.chunk(2, dim=-1)
|
||||||
|
down_proj = self.down_proj(self.act_fn(gate_proj) * up_proj)
|
||||||
|
return down_proj
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue