[NPU] Support save npu quantized model without npu dependency (#12647)
* support save awq * load quantized model & save npu compiled model * fix style * update * fix dll load issue * update error message * fix style
This commit is contained in:
parent
502461d836
commit
fae73eee79
5 changed files with 203 additions and 144 deletions
|
|
@ -27,7 +27,7 @@ from transformers.configuration_utils import PretrainedConfig
|
|||
|
||||
from ipex_llm.utils.common.log4Error import invalidInputError
|
||||
from ipex_llm.transformers.utils import logger, load_imatrix_data
|
||||
from ipex_llm.transformers.npu_models.convert import optimize_llm, optimize_llm_post
|
||||
from ipex_llm.transformers.npu_models.convert import optimize_llm
|
||||
|
||||
|
||||
def patch_flash_attn_import(filename: str) -> List[str]:
|
||||
|
|
@ -207,8 +207,6 @@ class _BaseAutoModelClass:
|
|||
model = model.eval()
|
||||
logger.info(f"Finish to convert model")
|
||||
else:
|
||||
from intel_npu_acceleration_library.compiler import create_npu_kernels
|
||||
|
||||
if optimize_model:
|
||||
invalidInputError(
|
||||
max_prompt_len < max_context_len,
|
||||
|
|
@ -232,11 +230,14 @@ class _BaseAutoModelClass:
|
|||
"convert_model": convert_model,
|
||||
"save_directory": save_directory,
|
||||
"fuse_layers": fuse_layers,
|
||||
"imatrix_data": imatrix_data
|
||||
"imatrix_data": imatrix_data,
|
||||
"skip_npu_logic": mock_device == "dummy",
|
||||
}
|
||||
# Dummy will skip npu related logic and save the quantized model
|
||||
if mock_device == "dummy":
|
||||
model.save_low_bit = types.MethodType(save_low_bit, model)
|
||||
model = cls.optimize_npu_model(*args, **optimize_kwargs)
|
||||
else:
|
||||
from ipex_llm.transformers.npu_models.convert import optimize_llm
|
||||
optimize_llm(model)
|
||||
with torch.no_grad():
|
||||
cls.load_convert(qtype, model, "cpu", modules_to_not_convert,
|
||||
|
|
@ -258,7 +259,6 @@ class _BaseAutoModelClass:
|
|||
def optimize_npu_model(cls, *args, **kwargs):
|
||||
|
||||
from ipex_llm.transformers.npu_models.convert_mp import optimize_llm_pre, optimize_llm
|
||||
from intel_npu_acceleration_library.compiler import create_npu_kernels
|
||||
|
||||
model = kwargs.pop("model")
|
||||
qtype = kwargs.pop("qtype", "sym_int4_rtn")
|
||||
|
|
@ -275,6 +275,7 @@ class _BaseAutoModelClass:
|
|||
save_directory = kwargs.pop('save_directory', None)
|
||||
fuse_layers = kwargs.pop('fuse_layers', None)
|
||||
imatrix_data = kwargs.pop('imatrix_data', None)
|
||||
skip_npu_logic = kwargs.pop("skip_npu_logic", False)
|
||||
invalidInputError(save_directory is not None,
|
||||
"Please provide the path to save converted model "
|
||||
"through `save_directory`.")
|
||||
|
|
@ -294,51 +295,58 @@ class _BaseAutoModelClass:
|
|||
cls.load_convert(qtype, model, "cpu", modules_to_not_convert,
|
||||
quantization_group_size, imatrix_data,
|
||||
*args, **kwargs)
|
||||
create_npu_kernels(llm)
|
||||
if not skip_npu_logic:
|
||||
from intel_npu_acceleration_library.compiler import create_npu_kernels
|
||||
create_npu_kernels(llm)
|
||||
model = model.eval()
|
||||
logger.info(f"Finish to convert model")
|
||||
model.config.update({"bigdl_transformers_low_bit": qtype})
|
||||
model.share_memory()
|
||||
|
||||
if not pipeline:
|
||||
if model.config.model_type in ["qwen2", "llama", "minicpm"]:
|
||||
from ipex_llm.transformers.npu_models.convert import optimize_llm_single_process
|
||||
optimize_llm_single_process(
|
||||
llm,
|
||||
kv_len=max_context_len,
|
||||
max_prompt_len=max_prompt_len,
|
||||
transpose_value_cache=transpose_value_cache,
|
||||
group_size=quantization_group_size,
|
||||
qtype=qtype,
|
||||
save_directory=save_directory,
|
||||
fuse_layers=fuse_layers,
|
||||
has_llm=hasattr(model, "llm")
|
||||
)
|
||||
else:
|
||||
optimize_llm(
|
||||
llm,
|
||||
max_context_len=max_context_len,
|
||||
max_prompt_len=max_prompt_len,
|
||||
inter_pp=inter_pp,
|
||||
intra_pp=intra_pp,
|
||||
transpose_value_cache=transpose_value_cache,
|
||||
group_size=quantization_group_size
|
||||
)
|
||||
if skip_npu_logic:
|
||||
model.save_low_bit(model_dir=save_directory)
|
||||
else:
|
||||
from ipex_llm.transformers.npu_pipeline_model.convert_pipeline \
|
||||
import convert_llm
|
||||
convert_llm(llm,
|
||||
model.share_memory()
|
||||
|
||||
if not pipeline:
|
||||
if model.config.model_type in ["qwen2", "llama", "minicpm"]:
|
||||
from ipex_llm.transformers.npu_models.convert import optimize_llm_single_process
|
||||
optimize_llm_single_process(
|
||||
llm,
|
||||
kv_len=max_context_len,
|
||||
max_prompt_len=max_prompt_len,
|
||||
transpose_value_cache=transpose_value_cache,
|
||||
group_size=quantization_group_size,
|
||||
qtype=qtype,
|
||||
convert_model=convert_model,
|
||||
save_directory=save_directory,
|
||||
fuse_layers=fuse_layers)
|
||||
model.save_low_bit = types.MethodType(save_low_bit, model)
|
||||
model.save_low_bit(save_directory)
|
||||
logger.info(f"Converted model has already saved to {save_directory}.")
|
||||
fuse_layers=fuse_layers,
|
||||
has_llm=hasattr(model, "llm")
|
||||
)
|
||||
else:
|
||||
optimize_llm(
|
||||
llm,
|
||||
max_context_len=max_context_len,
|
||||
max_prompt_len=max_prompt_len,
|
||||
inter_pp=inter_pp,
|
||||
intra_pp=intra_pp,
|
||||
transpose_value_cache=transpose_value_cache,
|
||||
group_size=quantization_group_size
|
||||
)
|
||||
else:
|
||||
from ipex_llm.transformers.npu_pipeline_model.convert_pipeline \
|
||||
import convert_llm
|
||||
convert_llm(llm,
|
||||
kv_len=max_context_len,
|
||||
max_prompt_len=max_prompt_len,
|
||||
transpose_value_cache=transpose_value_cache,
|
||||
group_size=quantization_group_size,
|
||||
qtype=qtype,
|
||||
convert_model=convert_model,
|
||||
save_directory=save_directory,
|
||||
fuse_layers=fuse_layers)
|
||||
model.save_low_bit = types.MethodType(save_low_bit, model)
|
||||
model.save_low_bit(save_directory)
|
||||
logger.info(f"Converted model has already saved to {save_directory}.")
|
||||
|
||||
return model
|
||||
|
||||
@classmethod
|
||||
|
|
@ -379,6 +387,7 @@ class _BaseAutoModelClass:
|
|||
intra_pp = kwargs.pop("intra_pp", None)
|
||||
transpose_value_cache = kwargs.pop("transpose_value_cache", True)
|
||||
modules_to_not_convert = kwargs.pop("modules_to_not_convert", [])
|
||||
save_directory = kwargs.pop('save_directory', None)
|
||||
|
||||
from transformers.models.auto.configuration_auto import AutoConfig
|
||||
from transformers.modeling_utils import no_init_weights, get_state_dict_dtype
|
||||
|
|
@ -650,16 +659,37 @@ class _BaseAutoModelClass:
|
|||
param.requires_grad_(False)
|
||||
|
||||
if optimize_model and not pipeline:
|
||||
from ipex_llm.transformers.npu_models.convert_mp import optimize_llm
|
||||
optimize_llm(
|
||||
llm,
|
||||
max_context_len=max_context_len,
|
||||
max_prompt_len=max_prompt_len,
|
||||
inter_pp=inter_pp,
|
||||
intra_pp=intra_pp,
|
||||
transpose_value_cache=transpose_value_cache,
|
||||
group_size=quantization_group_size
|
||||
)
|
||||
if model.config.model_type in ["qwen2", "llama", "minicpm"]:
|
||||
from ipex_llm.transformers.npu_models.convert import optimize_llm_single_process
|
||||
if save_directory is None:
|
||||
invalidInputError(False,
|
||||
"Please specify the save_directory, the path of folder " +
|
||||
"to save the compiled NPU model. If path not exists, " +
|
||||
"the compiled NPU model will be saved there. " +
|
||||
"Else, program will exit.")
|
||||
|
||||
optimize_llm_single_process(
|
||||
llm,
|
||||
kv_len=max_context_len,
|
||||
max_prompt_len=max_prompt_len,
|
||||
transpose_value_cache=transpose_value_cache,
|
||||
group_size=quantization_group_size,
|
||||
qtype=qtype,
|
||||
save_directory=save_directory,
|
||||
fuse_layers=None,
|
||||
has_llm=hasattr(model, "llm")
|
||||
)
|
||||
else:
|
||||
from ipex_llm.transformers.npu_models.convert_mp import optimize_llm
|
||||
optimize_llm(
|
||||
llm,
|
||||
max_context_len=max_context_len,
|
||||
max_prompt_len=max_prompt_len,
|
||||
inter_pp=inter_pp,
|
||||
intra_pp=intra_pp,
|
||||
transpose_value_cache=transpose_value_cache,
|
||||
group_size=quantization_group_size
|
||||
)
|
||||
elif optimize_model and pipeline:
|
||||
from ipex_llm.transformers.npu_pipeline_model.convert_pipeline \
|
||||
import convert_llm
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ import torch
|
|||
import importlib
|
||||
import numpy as np
|
||||
from ipex_llm.transformers.low_bit_linear import LowBitLinear, FP4Params
|
||||
from ipex_llm.transformers.npu_models.lm_head import LMHeadLinear, SlicedLMHead
|
||||
from ipex_llm.transformers.npu_models.lm_head import SlicedLMHead
|
||||
from ipex_llm.utils.common.log4Error import invalidInputError
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -21,16 +21,25 @@
|
|||
# SPDX-License-Identifier: Apache 2.0
|
||||
#
|
||||
|
||||
from intel_npu_acceleration_library.quantization import quantize_tensor, compress_to_i4
|
||||
from intel_npu_acceleration_library.dtypes import NPUDtype
|
||||
|
||||
import os
|
||||
import torch
|
||||
from torch.nn import Parameter
|
||||
import uuid
|
||||
import math
|
||||
from intel_npu_acceleration_library.backend import run_matmul
|
||||
from typing import Optional, Union
|
||||
from ipex_llm.utils.common import invalidInputError
|
||||
import importlib
|
||||
|
||||
|
||||
def is_acclib_available():
|
||||
return importlib.util.find_spec("intel_npu_acceleration_library") is not None
|
||||
|
||||
|
||||
if is_acclib_available():
|
||||
from intel_npu_acceleration_library.quantization import quantize_tensor, compress_to_i4
|
||||
from intel_npu_acceleration_library.dtypes import NPUDtype
|
||||
from intel_npu_acceleration_library.backend import run_matmul
|
||||
|
||||
|
||||
class Linear(torch.nn.Module):
|
||||
|
|
@ -63,6 +72,7 @@ class Linear(torch.nn.Module):
|
|||
if self.training:
|
||||
out = self._mm(x, self.weight, None)
|
||||
else:
|
||||
from intel_npu_acceleration_library.backend import run_matmul
|
||||
out = run_matmul(x, self.weight, None, self.op_id)
|
||||
|
||||
if self.bias is None:
|
||||
|
|
@ -105,6 +115,8 @@ class Linear(torch.nn.Module):
|
|||
Returns:
|
||||
Union[Linear, QuantizedLinear]: A NPU linear layer
|
||||
"""
|
||||
from intel_npu_acceleration_library.quantization import quantize_tensor, compress_to_i4
|
||||
from intel_npu_acceleration_library.dtypes import NPUDtype
|
||||
if dtype.is_floating_point:
|
||||
if bias is None:
|
||||
return Linear(weight.to(dtype), None)
|
||||
|
|
|
|||
|
|
@ -16,96 +16,6 @@
|
|||
import torch
|
||||
from torch import nn
|
||||
import numpy as np
|
||||
from filelock import FileLock
|
||||
from intel_npu_acceleration_library.backend import NNFactory
|
||||
from intel_npu_acceleration_library.backend.bindings import lib as backend_lib
|
||||
|
||||
|
||||
class LMHeadLinear(NNFactory):
|
||||
"""Quantized Linear class for sliced lm_head, computing a matrix matrix multiplication
|
||||
with weights prefetching."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
inC: int,
|
||||
outC: int,
|
||||
batch: int,
|
||||
split_num: int = 2,
|
||||
profile: bool = False,
|
||||
device: str = "NPU",
|
||||
dtype: np.dtype = np.int8,
|
||||
use_split: bool = False,
|
||||
group_size: int = 0,
|
||||
asym: bool = False,
|
||||
):
|
||||
"""Initialize the LMHeadLinear class.
|
||||
|
||||
Args:
|
||||
inC (int): input channels
|
||||
outC (int): output channels
|
||||
batch (int): batch
|
||||
split_num (int): split in_features of lm_head to how many parts
|
||||
profile (bool): Enable/Disable profiling. Defaults to False.
|
||||
device (str): Target device, default to "NPU".
|
||||
dtype (np.dtype): weights datatype. Defaults to np.int8.
|
||||
|
||||
"""
|
||||
super().__init__(profile, device)
|
||||
self.inC, self.outC = inC, outC
|
||||
self.batch = batch
|
||||
|
||||
self.split_num = split_num
|
||||
if use_split:
|
||||
input = self.parameter((1, self.batch, self.inC))
|
||||
res = self.dq_split_linear(input, self.split_num, self.outC, self.inC, wt_dtype=dtype,
|
||||
scale_factor=(group_size == 0), asym=asym)
|
||||
else:
|
||||
input = self.parameter((self.batch, self.inC))
|
||||
split_size = self.inC // split_num // 2 * 2
|
||||
|
||||
for i in range(self.split_num):
|
||||
start_idx = i * split_size
|
||||
end_idx = (i + 1) * split_size if i < self.split_num - 1 else self.inC
|
||||
input_slice = self.slice(input, begin=[0, start_idx],
|
||||
end=[self.batch, end_idx])
|
||||
linear_slice = self.linear(input_slice, outC, split_size, bias=False,
|
||||
wt_dtype=dtype, asym=asym)
|
||||
if i == 0:
|
||||
res = linear_slice
|
||||
else:
|
||||
res += linear_slice
|
||||
|
||||
print("start compiling lm_head")
|
||||
self.compile()
|
||||
print("end compiling lm_head")
|
||||
|
||||
def set_weights(self, op_id, weights):
|
||||
self.set_weights_async(op_id, weights)
|
||||
with FileLock(f"lmhead_run.lock"):
|
||||
backend_lib.run(self._mm)
|
||||
|
||||
def set_weights_async(self, op_id, weights):
|
||||
self.setWeights(1, op_id, *weights)
|
||||
|
||||
def run(
|
||||
self, X: np.ndarray
|
||||
) -> np.ndarray:
|
||||
"""Run the layer: $X * (W * S)^T$ .
|
||||
|
||||
Args:
|
||||
X (np.ndarray): activation
|
||||
|
||||
Raises:
|
||||
RuntimeError: Input, weights or scale shape mismatch
|
||||
|
||||
Returns:
|
||||
np.ndarray: result
|
||||
"""
|
||||
self.set_input_tensor(X, 0)
|
||||
self.elapsed = backend_lib.run(self._mm)
|
||||
if len(self.out) == 1:
|
||||
return self.out[0]
|
||||
return self.out
|
||||
|
||||
|
||||
class SlicedLMHead(nn.Module):
|
||||
|
|
@ -160,6 +70,7 @@ class SlicedLMHead(nn.Module):
|
|||
return self.lm_heads[0].weight.dtype
|
||||
|
||||
def get_fused_lm_head(self):
|
||||
from ipex_llm.transformers.npu_models.lm_head_linear import LMHeadLinear
|
||||
np_dtype = np.uint8 if self.get_weight_dtype() == torch.uint8 else np.int8
|
||||
self.fused_lm_head = LMHeadLinear(self.inC, self.outC, 1, self.split_num,
|
||||
False, "NPU", dtype=np_dtype, use_split=self.use_split,
|
||||
|
|
|
|||
|
|
@ -0,0 +1,106 @@
|
|||
#
|
||||
# Copyright 2016 The BigDL Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
from filelock import FileLock
|
||||
from intel_npu_acceleration_library.backend import NNFactory
|
||||
from intel_npu_acceleration_library.backend.bindings import lib as backend_lib
|
||||
|
||||
|
||||
class LMHeadLinear(NNFactory):
|
||||
"""Quantized Linear class for sliced lm_head, computing a matrix matrix multiplication
|
||||
with weights prefetching."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
inC: int,
|
||||
outC: int,
|
||||
batch: int,
|
||||
split_num: int = 2,
|
||||
profile: bool = False,
|
||||
device: str = "NPU",
|
||||
dtype: np.dtype = np.int8,
|
||||
use_split: bool = False,
|
||||
group_size: int = 0,
|
||||
asym: bool = False,
|
||||
):
|
||||
"""Initialize the LMHeadLinear class.
|
||||
|
||||
Args:
|
||||
inC (int): input channels
|
||||
outC (int): output channels
|
||||
batch (int): batch
|
||||
split_num (int): split in_features of lm_head to how many parts
|
||||
profile (bool): Enable/Disable profiling. Defaults to False.
|
||||
device (str): Target device, default to "NPU".
|
||||
dtype (np.dtype): weights datatype. Defaults to np.int8.
|
||||
|
||||
"""
|
||||
super().__init__(profile, device)
|
||||
self.inC, self.outC = inC, outC
|
||||
self.batch = batch
|
||||
|
||||
self.split_num = split_num
|
||||
if use_split:
|
||||
input = self.parameter((1, self.batch, self.inC))
|
||||
res = self.dq_split_linear(input, self.split_num, self.outC, self.inC, wt_dtype=dtype,
|
||||
scale_factor=(group_size == 0), asym=asym)
|
||||
else:
|
||||
input = self.parameter((self.batch, self.inC))
|
||||
split_size = self.inC // split_num // 2 * 2
|
||||
|
||||
for i in range(self.split_num):
|
||||
start_idx = i * split_size
|
||||
end_idx = (i + 1) * split_size if i < self.split_num - 1 else self.inC
|
||||
input_slice = self.slice(input, begin=[0, start_idx],
|
||||
end=[self.batch, end_idx])
|
||||
linear_slice = self.linear(input_slice, outC, split_size, bias=False,
|
||||
wt_dtype=dtype, asym=asym)
|
||||
if i == 0:
|
||||
res = linear_slice
|
||||
else:
|
||||
res += linear_slice
|
||||
|
||||
print("start compiling lm_head")
|
||||
self.compile()
|
||||
print("end compiling lm_head")
|
||||
|
||||
def set_weights(self, op_id, weights):
|
||||
self.set_weights_async(op_id, weights)
|
||||
with FileLock(f"lmhead_run.lock"):
|
||||
backend_lib.run(self._mm)
|
||||
|
||||
def set_weights_async(self, op_id, weights):
|
||||
self.setWeights(1, op_id, *weights)
|
||||
|
||||
def run(
|
||||
self, X: np.ndarray
|
||||
) -> np.ndarray:
|
||||
"""Run the layer: $X * (W * S)^T$ .
|
||||
|
||||
Args:
|
||||
X (np.ndarray): activation
|
||||
|
||||
Raises:
|
||||
RuntimeError: Input, weights or scale shape mismatch
|
||||
|
||||
Returns:
|
||||
np.ndarray: result
|
||||
"""
|
||||
self.set_input_tensor(X, 0)
|
||||
self.elapsed = backend_lib.run(self._mm)
|
||||
if len(self.out) == 1:
|
||||
return self.out[0]
|
||||
return self.out
|
||||
Loading…
Reference in a new issue