[NPU] Optimize Qwen2 lm_head to use INT4 (#12072)
* temp save * update * fix * fix * Split lm_head into 7 parts & remove int8 for lm_head when sym_int4 * Simlify and add condition to code * Small fix * refactor some code * fix style * fix style * fix style * fix * fix * temp sav e * refactor * fix style * further refactor * simplify code * meet code review * fix style --------- Co-authored-by: Yuwen Hu <yuwen.hu@intel.com>
This commit is contained in:
parent
18714ceac7
commit
081af41def
3 changed files with 171 additions and 8 deletions
|
|
@ -16,7 +16,9 @@
|
||||||
import os
|
import os
|
||||||
import torch
|
import torch
|
||||||
import importlib
|
import importlib
|
||||||
|
import numpy as np
|
||||||
from ipex_llm.transformers.low_bit_linear import LowBitLinear, FP4Params
|
from ipex_llm.transformers.low_bit_linear import LowBitLinear, FP4Params
|
||||||
|
from ipex_llm.transformers.npu_models.lm_head import LMHeadLinear, SlicedLMHead
|
||||||
|
|
||||||
|
|
||||||
def convert_forward(m, target_m, new_forward):
|
def convert_forward(m, target_m, new_forward):
|
||||||
|
|
@ -85,9 +87,16 @@ def optimize_llm_pre(model: torch.nn.Module, qtype):
|
||||||
|
|
||||||
if model.config.model_type == "qwen2":
|
if model.config.model_type == "qwen2":
|
||||||
from ipex_llm.transformers.npu_models.qwen2_mp import split_mlp_down_proj
|
from ipex_llm.transformers.npu_models.qwen2_mp import split_mlp_down_proj
|
||||||
from ipex_llm.transformers.npu_models.qwen2_mp import split_mlp_forward
|
|
||||||
model.apply(split_mlp_down_proj)
|
model.apply(split_mlp_down_proj)
|
||||||
|
|
||||||
|
# for Qwen2-7B-Insturct, divide lm_head into 7 parts
|
||||||
|
if model.config.hidden_size == 3584 and model.config.vocab_size == 152064 and \
|
||||||
|
not cpu_lm_head:
|
||||||
|
new_lm_head = SlicedLMHead(model.lm_head.weight, split_num=7,
|
||||||
|
bias=model.lm_head.bias)
|
||||||
|
del model.lm_head
|
||||||
|
model.lm_head = new_lm_head
|
||||||
|
|
||||||
# lm_head to cpu optimization
|
# lm_head to cpu optimization
|
||||||
if cpu_lm_head:
|
if cpu_lm_head:
|
||||||
# disable the optimization by default
|
# disable the optimization by default
|
||||||
|
|
@ -182,6 +191,11 @@ def optimize_llm(
|
||||||
from transformers.models.qwen2.modeling_qwen2 import Qwen2ForCausalLM
|
from transformers.models.qwen2.modeling_qwen2 import Qwen2ForCausalLM
|
||||||
from ipex_llm.transformers.npu_models.qwen2_mp import qwen2_casullm_forward
|
from ipex_llm.transformers.npu_models.qwen2_mp import qwen2_casullm_forward
|
||||||
convert_forward(model, Qwen2ForCausalLM, qwen2_casullm_forward)
|
convert_forward(model, Qwen2ForCausalLM, qwen2_casullm_forward)
|
||||||
|
|
||||||
|
# for Qwen2-7B-Insturct, divide lm_head into 7 parts
|
||||||
|
if model.config.hidden_size == 3584 and model.config.vocab_size == 152064 and \
|
||||||
|
isinstance(model.lm_head, SlicedLMHead):
|
||||||
|
model.lm_head.get_fused_lm_head()
|
||||||
elif model.config.model_type == "minicpm":
|
elif model.config.model_type == "minicpm":
|
||||||
# for minicpm-1b
|
# for minicpm-1b
|
||||||
if intra_pp is None:
|
if intra_pp is None:
|
||||||
|
|
|
||||||
|
|
@ -22,16 +22,14 @@
|
||||||
#
|
#
|
||||||
|
|
||||||
from intel_npu_acceleration_library.quantization import quantize_tensor, compress_to_i4
|
from intel_npu_acceleration_library.quantization import quantize_tensor, compress_to_i4
|
||||||
from intel_npu_acceleration_library.nn.autograd import AutogradMatMul
|
|
||||||
from intel_npu_acceleration_library.backend import run_matmul
|
|
||||||
from intel_npu_acceleration_library.dtypes import NPUDtype
|
from intel_npu_acceleration_library.dtypes import NPUDtype
|
||||||
from typing import Optional, Union
|
|
||||||
import os
|
import os
|
||||||
import torch
|
import torch
|
||||||
from torch.nn import Parameter
|
from torch.nn import Parameter
|
||||||
import uuid
|
import uuid
|
||||||
import math
|
import math
|
||||||
|
from intel_npu_acceleration_library.backend import run_matmul
|
||||||
|
from typing import Optional, Union
|
||||||
from ipex_llm.utils.common import invalidInputError
|
from ipex_llm.utils.common import invalidInputError
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -52,7 +50,6 @@ class Linear(torch.nn.Module):
|
||||||
self.bias = torch.nn.Parameter(bias) if isinstance(bias, torch.Tensor) else None
|
self.bias = torch.nn.Parameter(bias) if isinstance(bias, torch.Tensor) else None
|
||||||
self.outC, self.inC = self.weight.shape
|
self.outC, self.inC = self.weight.shape
|
||||||
self.op_id = str(uuid.uuid4())
|
self.op_id = str(uuid.uuid4())
|
||||||
self._mm = AutogradMatMul.apply
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
"""Torch module forward method.
|
"""Torch module forward method.
|
||||||
|
|
@ -147,7 +144,7 @@ class QuantizedLinear(torch.nn.Module):
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.weight = Parameter(weight, requires_grad=False)
|
self.weight = Parameter(weight, requires_grad=False).contiguous()
|
||||||
if self.weight.dtype not in (torch.int8, torch.uint8):
|
if self.weight.dtype not in (torch.int8, torch.uint8):
|
||||||
invalidInputError(
|
invalidInputError(
|
||||||
False,
|
False,
|
||||||
|
|
@ -163,7 +160,6 @@ class QuantizedLinear(torch.nn.Module):
|
||||||
self.scale = Parameter(scale * math.sqrt(self.inC), requires_grad=False)
|
self.scale = Parameter(scale * math.sqrt(self.inC), requires_grad=False)
|
||||||
self.bias = bias
|
self.bias = bias
|
||||||
self.op_id = str(uuid.uuid4())
|
self.op_id = str(uuid.uuid4())
|
||||||
self._mm = AutogradMatMul.apply
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
"""Torch module forward method.
|
"""Torch module forward method.
|
||||||
|
|
@ -194,6 +190,7 @@ class QuantizedLinear(torch.nn.Module):
|
||||||
"Use `.eval()` to do inference only"
|
"Use `.eval()` to do inference only"
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
out = run_matmul(x, self.weight.data, self.scale.data, self.op_id)
|
out = run_matmul(x, self.weight.data, self.scale.data, self.op_id)
|
||||||
|
|
||||||
if self.bias is None:
|
if self.bias is None:
|
||||||
|
|
|
||||||
152
python/llm/src/ipex_llm/transformers/npu_models/lm_head.py
Normal file
152
python/llm/src/ipex_llm/transformers/npu_models/lm_head.py
Normal file
|
|
@ -0,0 +1,152 @@
|
||||||
|
#
|
||||||
|
# Copyright 2016 The BigDL Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
import numpy as np
|
||||||
|
from intel_npu_acceleration_library.backend import NNFactory
|
||||||
|
from intel_npu_acceleration_library.backend.bindings import lib as backend_lib
|
||||||
|
|
||||||
|
|
||||||
|
class LMHeadLinear(NNFactory):
|
||||||
|
"""Quantized Linear class for sliced lm_head, computing a matrix matrix multiplication
|
||||||
|
with weights prefetching."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
inC: int,
|
||||||
|
outC: int,
|
||||||
|
batch: int,
|
||||||
|
split_num: int = 2,
|
||||||
|
profile: bool = False,
|
||||||
|
device: str = "NPU",
|
||||||
|
dtype: np.dtype = np.int8,
|
||||||
|
):
|
||||||
|
"""Initialize the LMHeadLinear class.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inC (int): input channels
|
||||||
|
outC (int): output channels
|
||||||
|
batch (int): batch
|
||||||
|
split_num (int): split in_features of lm_head to how many parts
|
||||||
|
profile (bool): Enable/Disable profiling. Defaults to False.
|
||||||
|
device (str): Target device, default to "NPU".
|
||||||
|
dtype (np.dtype): weights datatype. Defaults to np.int8.
|
||||||
|
|
||||||
|
"""
|
||||||
|
super().__init__(profile, device)
|
||||||
|
self.inC, self.outC = inC, outC
|
||||||
|
self.batch = batch
|
||||||
|
|
||||||
|
input = self.parameter((self.batch, self.inC))
|
||||||
|
|
||||||
|
self.split_num = split_num
|
||||||
|
split_size = self.inC // split_num // 2 * 2
|
||||||
|
|
||||||
|
for i in range(self.split_num):
|
||||||
|
start_idx = i * split_size
|
||||||
|
end_idx = (i + 1) * split_size if i < self.split_num - 1 else self.inC
|
||||||
|
input_slice = self.slice(input, begin=[0, start_idx],
|
||||||
|
end=[self.batch, end_idx])
|
||||||
|
linear_slice = self.linear(input_slice, outC, split_size, bias=False, wt_dtype=dtype)
|
||||||
|
if i == 0:
|
||||||
|
res = linear_slice
|
||||||
|
else:
|
||||||
|
res += linear_slice
|
||||||
|
|
||||||
|
print("start compiling lm_head")
|
||||||
|
self.compile()
|
||||||
|
print("end compiling lm_head")
|
||||||
|
|
||||||
|
def run(
|
||||||
|
self, X: np.ndarray
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""Run the layer: $X * (W * S)^T$ .
|
||||||
|
|
||||||
|
Args:
|
||||||
|
X (np.ndarray): activation
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: Input, weights or scale shape mismatch
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
np.ndarray: result
|
||||||
|
"""
|
||||||
|
self.prefetchWeights(1, verify_size=False)
|
||||||
|
self.set_input_tensor(X, 0)
|
||||||
|
self.elapsed = backend_lib.run(self._mm)
|
||||||
|
if len(self.out) == 1:
|
||||||
|
return self.out[0]
|
||||||
|
return self.out
|
||||||
|
|
||||||
|
|
||||||
|
class SlicedLMHead(nn.Module):
|
||||||
|
def __init__(self, weight, bias, split_num):
|
||||||
|
super().__init__()
|
||||||
|
self.split_num = split_num
|
||||||
|
self.outC, self.inC = weight.shape
|
||||||
|
split_size = weight.size(1) // split_num // 2 * 2
|
||||||
|
self.lm_heads = nn.Sequential()
|
||||||
|
for i in range(split_num):
|
||||||
|
new_linear = torch.nn.Linear(0, 0, bias=False)
|
||||||
|
start_idx = i * split_size
|
||||||
|
end_idx = (i + 1) * split_size if i < split_num - 1 else weight.size(1)
|
||||||
|
new_weight = torch.nn.Parameter(weight[:, start_idx:end_idx],
|
||||||
|
requires_grad=False)
|
||||||
|
new_linear.weight = new_weight
|
||||||
|
new_linear.in_features = new_weight.size(1)
|
||||||
|
new_linear.out_features = new_weight.size(0)
|
||||||
|
self.lm_heads.append(new_linear)
|
||||||
|
self.bias = bias
|
||||||
|
|
||||||
|
def forward(self, hidden_states):
|
||||||
|
if hidden_states.size(0) * hidden_states.size(1) == 1:
|
||||||
|
original_shape = hidden_states.shape
|
||||||
|
x_2d = hidden_states.view(-1, hidden_states.shape[-1])
|
||||||
|
target_shape = tuple(list(original_shape[:-1]) + [self.outC])
|
||||||
|
|
||||||
|
out = self.fused_lm_head.run(x_2d.numpy())
|
||||||
|
logits = torch.from_numpy(out)
|
||||||
|
logits = logits.view(target_shape)
|
||||||
|
else:
|
||||||
|
split_size = hidden_states.size(-1) // self.split_num // 2 * 2
|
||||||
|
logits = None
|
||||||
|
for i in range(self.split_num):
|
||||||
|
start_idx = i * split_size
|
||||||
|
end_idx = (i + 1) * split_size if i < self.split_num - 1 else self.inC
|
||||||
|
hidden_states_slice = hidden_states[:, :, start_idx:end_idx]
|
||||||
|
logits_slice = self.lm_heads[i](hidden_states_slice)
|
||||||
|
if logits is None:
|
||||||
|
logits = logits_slice
|
||||||
|
else:
|
||||||
|
logits += logits_slice
|
||||||
|
|
||||||
|
if self.bias is None:
|
||||||
|
return logits
|
||||||
|
return logits + self.bias
|
||||||
|
|
||||||
|
def get_weight_dtype(self):
|
||||||
|
return self.lm_heads[0].weight.dtype
|
||||||
|
|
||||||
|
def get_fused_lm_head(self):
|
||||||
|
np_dtype = np.uint8 if self.get_weight_dtype() == torch.uint8 else np.int8
|
||||||
|
self.fused_lm_head = LMHeadLinear(self.inC, self.outC, 1, self.split_num,
|
||||||
|
False, "NPU", dtype=np_dtype)
|
||||||
|
fused_lm_head_weights = [(self.lm_heads[i].weight.data.numpy(),
|
||||||
|
self.lm_heads[i].scale.data.numpy())
|
||||||
|
for i in range(self.split_num)]
|
||||||
|
self.fused_lm_head.setWeights(1, self.lm_heads[0].op_id,
|
||||||
|
*fused_lm_head_weights)
|
||||||
Loading…
Reference in a new issue