[NPU] Optimize Qwen2 lm_head to use INT4 (#12072)

* temp save

* update

* fix

* fix

* Split lm_head into 7 parts & remove int8 for lm_head when sym_int4

* Simlify and add condition to code

* Small fix

* refactor some code

* fix style

* fix style

* fix style

* fix

* fix

* temp sav e

* refactor

* fix style

* further refactor

* simplify code

* meet code review

* fix style

---------

Co-authored-by: Yuwen Hu <yuwen.hu@intel.com>
This commit is contained in:
Ruonan Wang 2024-09-14 00:26:46 -07:00 committed by GitHub
parent 18714ceac7
commit 081af41def
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 171 additions and 8 deletions

View file

@ -16,7 +16,9 @@
import os import os
import torch import torch
import importlib import importlib
import numpy as np
from ipex_llm.transformers.low_bit_linear import LowBitLinear, FP4Params from ipex_llm.transformers.low_bit_linear import LowBitLinear, FP4Params
from ipex_llm.transformers.npu_models.lm_head import LMHeadLinear, SlicedLMHead
def convert_forward(m, target_m, new_forward): def convert_forward(m, target_m, new_forward):
@ -85,9 +87,16 @@ def optimize_llm_pre(model: torch.nn.Module, qtype):
if model.config.model_type == "qwen2": if model.config.model_type == "qwen2":
from ipex_llm.transformers.npu_models.qwen2_mp import split_mlp_down_proj from ipex_llm.transformers.npu_models.qwen2_mp import split_mlp_down_proj
from ipex_llm.transformers.npu_models.qwen2_mp import split_mlp_forward
model.apply(split_mlp_down_proj) model.apply(split_mlp_down_proj)
# for Qwen2-7B-Insturct, divide lm_head into 7 parts
if model.config.hidden_size == 3584 and model.config.vocab_size == 152064 and \
not cpu_lm_head:
new_lm_head = SlicedLMHead(model.lm_head.weight, split_num=7,
bias=model.lm_head.bias)
del model.lm_head
model.lm_head = new_lm_head
# lm_head to cpu optimization # lm_head to cpu optimization
if cpu_lm_head: if cpu_lm_head:
# disable the optimization by default # disable the optimization by default
@ -182,6 +191,11 @@ def optimize_llm(
from transformers.models.qwen2.modeling_qwen2 import Qwen2ForCausalLM from transformers.models.qwen2.modeling_qwen2 import Qwen2ForCausalLM
from ipex_llm.transformers.npu_models.qwen2_mp import qwen2_casullm_forward from ipex_llm.transformers.npu_models.qwen2_mp import qwen2_casullm_forward
convert_forward(model, Qwen2ForCausalLM, qwen2_casullm_forward) convert_forward(model, Qwen2ForCausalLM, qwen2_casullm_forward)
# for Qwen2-7B-Insturct, divide lm_head into 7 parts
if model.config.hidden_size == 3584 and model.config.vocab_size == 152064 and \
isinstance(model.lm_head, SlicedLMHead):
model.lm_head.get_fused_lm_head()
elif model.config.model_type == "minicpm": elif model.config.model_type == "minicpm":
# for minicpm-1b # for minicpm-1b
if intra_pp is None: if intra_pp is None:

View file

@ -22,16 +22,14 @@
# #
from intel_npu_acceleration_library.quantization import quantize_tensor, compress_to_i4 from intel_npu_acceleration_library.quantization import quantize_tensor, compress_to_i4
from intel_npu_acceleration_library.nn.autograd import AutogradMatMul
from intel_npu_acceleration_library.backend import run_matmul
from intel_npu_acceleration_library.dtypes import NPUDtype from intel_npu_acceleration_library.dtypes import NPUDtype
from typing import Optional, Union
import os import os
import torch import torch
from torch.nn import Parameter from torch.nn import Parameter
import uuid import uuid
import math import math
from intel_npu_acceleration_library.backend import run_matmul
from typing import Optional, Union
from ipex_llm.utils.common import invalidInputError from ipex_llm.utils.common import invalidInputError
@ -52,7 +50,6 @@ class Linear(torch.nn.Module):
self.bias = torch.nn.Parameter(bias) if isinstance(bias, torch.Tensor) else None self.bias = torch.nn.Parameter(bias) if isinstance(bias, torch.Tensor) else None
self.outC, self.inC = self.weight.shape self.outC, self.inC = self.weight.shape
self.op_id = str(uuid.uuid4()) self.op_id = str(uuid.uuid4())
self._mm = AutogradMatMul.apply
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Torch module forward method. """Torch module forward method.
@ -147,7 +144,7 @@ class QuantizedLinear(torch.nn.Module):
""" """
super().__init__() super().__init__()
self.weight = Parameter(weight, requires_grad=False) self.weight = Parameter(weight, requires_grad=False).contiguous()
if self.weight.dtype not in (torch.int8, torch.uint8): if self.weight.dtype not in (torch.int8, torch.uint8):
invalidInputError( invalidInputError(
False, False,
@ -163,7 +160,6 @@ class QuantizedLinear(torch.nn.Module):
self.scale = Parameter(scale * math.sqrt(self.inC), requires_grad=False) self.scale = Parameter(scale * math.sqrt(self.inC), requires_grad=False)
self.bias = bias self.bias = bias
self.op_id = str(uuid.uuid4()) self.op_id = str(uuid.uuid4())
self._mm = AutogradMatMul.apply
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Torch module forward method. """Torch module forward method.
@ -194,6 +190,7 @@ class QuantizedLinear(torch.nn.Module):
"Use `.eval()` to do inference only" "Use `.eval()` to do inference only"
) )
) )
out = run_matmul(x, self.weight.data, self.scale.data, self.op_id) out = run_matmul(x, self.weight.data, self.scale.data, self.op_id)
if self.bias is None: if self.bias is None:

View file

@ -0,0 +1,152 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from torch import nn
import numpy as np
from intel_npu_acceleration_library.backend import NNFactory
from intel_npu_acceleration_library.backend.bindings import lib as backend_lib
class LMHeadLinear(NNFactory):
"""Quantized Linear class for sliced lm_head, computing a matrix matrix multiplication
with weights prefetching."""
def __init__(
self,
inC: int,
outC: int,
batch: int,
split_num: int = 2,
profile: bool = False,
device: str = "NPU",
dtype: np.dtype = np.int8,
):
"""Initialize the LMHeadLinear class.
Args:
inC (int): input channels
outC (int): output channels
batch (int): batch
split_num (int): split in_features of lm_head to how many parts
profile (bool): Enable/Disable profiling. Defaults to False.
device (str): Target device, default to "NPU".
dtype (np.dtype): weights datatype. Defaults to np.int8.
"""
super().__init__(profile, device)
self.inC, self.outC = inC, outC
self.batch = batch
input = self.parameter((self.batch, self.inC))
self.split_num = split_num
split_size = self.inC // split_num // 2 * 2
for i in range(self.split_num):
start_idx = i * split_size
end_idx = (i + 1) * split_size if i < self.split_num - 1 else self.inC
input_slice = self.slice(input, begin=[0, start_idx],
end=[self.batch, end_idx])
linear_slice = self.linear(input_slice, outC, split_size, bias=False, wt_dtype=dtype)
if i == 0:
res = linear_slice
else:
res += linear_slice
print("start compiling lm_head")
self.compile()
print("end compiling lm_head")
def run(
self, X: np.ndarray
) -> np.ndarray:
"""Run the layer: $X * (W * S)^T$ .
Args:
X (np.ndarray): activation
Raises:
RuntimeError: Input, weights or scale shape mismatch
Returns:
np.ndarray: result
"""
self.prefetchWeights(1, verify_size=False)
self.set_input_tensor(X, 0)
self.elapsed = backend_lib.run(self._mm)
if len(self.out) == 1:
return self.out[0]
return self.out
class SlicedLMHead(nn.Module):
def __init__(self, weight, bias, split_num):
super().__init__()
self.split_num = split_num
self.outC, self.inC = weight.shape
split_size = weight.size(1) // split_num // 2 * 2
self.lm_heads = nn.Sequential()
for i in range(split_num):
new_linear = torch.nn.Linear(0, 0, bias=False)
start_idx = i * split_size
end_idx = (i + 1) * split_size if i < split_num - 1 else weight.size(1)
new_weight = torch.nn.Parameter(weight[:, start_idx:end_idx],
requires_grad=False)
new_linear.weight = new_weight
new_linear.in_features = new_weight.size(1)
new_linear.out_features = new_weight.size(0)
self.lm_heads.append(new_linear)
self.bias = bias
def forward(self, hidden_states):
if hidden_states.size(0) * hidden_states.size(1) == 1:
original_shape = hidden_states.shape
x_2d = hidden_states.view(-1, hidden_states.shape[-1])
target_shape = tuple(list(original_shape[:-1]) + [self.outC])
out = self.fused_lm_head.run(x_2d.numpy())
logits = torch.from_numpy(out)
logits = logits.view(target_shape)
else:
split_size = hidden_states.size(-1) // self.split_num // 2 * 2
logits = None
for i in range(self.split_num):
start_idx = i * split_size
end_idx = (i + 1) * split_size if i < self.split_num - 1 else self.inC
hidden_states_slice = hidden_states[:, :, start_idx:end_idx]
logits_slice = self.lm_heads[i](hidden_states_slice)
if logits is None:
logits = logits_slice
else:
logits += logits_slice
if self.bias is None:
return logits
return logits + self.bias
def get_weight_dtype(self):
return self.lm_heads[0].weight.dtype
def get_fused_lm_head(self):
np_dtype = np.uint8 if self.get_weight_dtype() == torch.uint8 else np.int8
self.fused_lm_head = LMHeadLinear(self.inC, self.outC, 1, self.split_num,
False, "NPU", dtype=np_dtype)
fused_lm_head_weights = [(self.lm_heads[i].weight.data.numpy(),
self.lm_heads[i].scale.data.numpy())
for i in range(self.split_num)]
self.fused_lm_head.setWeights(1, self.lm_heads[0].op_id,
*fused_lm_head_weights)