[NPU] Support save npu quantized model without npu dependency (#12647)

* support save awq

* load quantized model & save npu compiled model

* fix style

* update

* fix dll load issue

* update error message

* fix style
This commit is contained in:
Yina Chen 2025-01-06 12:06:22 +02:00 committed by GitHub
parent 502461d836
commit fae73eee79
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 203 additions and 144 deletions

View file

@ -27,7 +27,7 @@ from transformers.configuration_utils import PretrainedConfig
from ipex_llm.utils.common.log4Error import invalidInputError
from ipex_llm.transformers.utils import logger, load_imatrix_data
from ipex_llm.transformers.npu_models.convert import optimize_llm, optimize_llm_post
from ipex_llm.transformers.npu_models.convert import optimize_llm
def patch_flash_attn_import(filename: str) -> List[str]:
@ -207,8 +207,6 @@ class _BaseAutoModelClass:
model = model.eval()
logger.info(f"Finish to convert model")
else:
from intel_npu_acceleration_library.compiler import create_npu_kernels
if optimize_model:
invalidInputError(
max_prompt_len < max_context_len,
@ -232,11 +230,14 @@ class _BaseAutoModelClass:
"convert_model": convert_model,
"save_directory": save_directory,
"fuse_layers": fuse_layers,
"imatrix_data": imatrix_data
"imatrix_data": imatrix_data,
"skip_npu_logic": mock_device == "dummy",
}
# Dummy will skip npu related logic and save the quantized model
if mock_device == "dummy":
model.save_low_bit = types.MethodType(save_low_bit, model)
model = cls.optimize_npu_model(*args, **optimize_kwargs)
else:
from ipex_llm.transformers.npu_models.convert import optimize_llm
optimize_llm(model)
with torch.no_grad():
cls.load_convert(qtype, model, "cpu", modules_to_not_convert,
@ -258,7 +259,6 @@ class _BaseAutoModelClass:
def optimize_npu_model(cls, *args, **kwargs):
from ipex_llm.transformers.npu_models.convert_mp import optimize_llm_pre, optimize_llm
from intel_npu_acceleration_library.compiler import create_npu_kernels
model = kwargs.pop("model")
qtype = kwargs.pop("qtype", "sym_int4_rtn")
@ -275,6 +275,7 @@ class _BaseAutoModelClass:
save_directory = kwargs.pop('save_directory', None)
fuse_layers = kwargs.pop('fuse_layers', None)
imatrix_data = kwargs.pop('imatrix_data', None)
skip_npu_logic = kwargs.pop("skip_npu_logic", False)
invalidInputError(save_directory is not None,
"Please provide the path to save converted model "
"through `save_directory`.")
@ -294,51 +295,58 @@ class _BaseAutoModelClass:
cls.load_convert(qtype, model, "cpu", modules_to_not_convert,
quantization_group_size, imatrix_data,
*args, **kwargs)
create_npu_kernels(llm)
if not skip_npu_logic:
from intel_npu_acceleration_library.compiler import create_npu_kernels
create_npu_kernels(llm)
model = model.eval()
logger.info(f"Finish to convert model")
model.config.update({"bigdl_transformers_low_bit": qtype})
model.share_memory()
if not pipeline:
if model.config.model_type in ["qwen2", "llama", "minicpm"]:
from ipex_llm.transformers.npu_models.convert import optimize_llm_single_process
optimize_llm_single_process(
llm,
kv_len=max_context_len,
max_prompt_len=max_prompt_len,
transpose_value_cache=transpose_value_cache,
group_size=quantization_group_size,
qtype=qtype,
save_directory=save_directory,
fuse_layers=fuse_layers,
has_llm=hasattr(model, "llm")
)
else:
optimize_llm(
llm,
max_context_len=max_context_len,
max_prompt_len=max_prompt_len,
inter_pp=inter_pp,
intra_pp=intra_pp,
transpose_value_cache=transpose_value_cache,
group_size=quantization_group_size
)
if skip_npu_logic:
model.save_low_bit(model_dir=save_directory)
else:
from ipex_llm.transformers.npu_pipeline_model.convert_pipeline \
import convert_llm
convert_llm(llm,
model.share_memory()
if not pipeline:
if model.config.model_type in ["qwen2", "llama", "minicpm"]:
from ipex_llm.transformers.npu_models.convert import optimize_llm_single_process
optimize_llm_single_process(
llm,
kv_len=max_context_len,
max_prompt_len=max_prompt_len,
transpose_value_cache=transpose_value_cache,
group_size=quantization_group_size,
qtype=qtype,
convert_model=convert_model,
save_directory=save_directory,
fuse_layers=fuse_layers)
model.save_low_bit = types.MethodType(save_low_bit, model)
model.save_low_bit(save_directory)
logger.info(f"Converted model has already saved to {save_directory}.")
fuse_layers=fuse_layers,
has_llm=hasattr(model, "llm")
)
else:
optimize_llm(
llm,
max_context_len=max_context_len,
max_prompt_len=max_prompt_len,
inter_pp=inter_pp,
intra_pp=intra_pp,
transpose_value_cache=transpose_value_cache,
group_size=quantization_group_size
)
else:
from ipex_llm.transformers.npu_pipeline_model.convert_pipeline \
import convert_llm
convert_llm(llm,
kv_len=max_context_len,
max_prompt_len=max_prompt_len,
transpose_value_cache=transpose_value_cache,
group_size=quantization_group_size,
qtype=qtype,
convert_model=convert_model,
save_directory=save_directory,
fuse_layers=fuse_layers)
model.save_low_bit = types.MethodType(save_low_bit, model)
model.save_low_bit(save_directory)
logger.info(f"Converted model has already saved to {save_directory}.")
return model
@classmethod
@ -379,6 +387,7 @@ class _BaseAutoModelClass:
intra_pp = kwargs.pop("intra_pp", None)
transpose_value_cache = kwargs.pop("transpose_value_cache", True)
modules_to_not_convert = kwargs.pop("modules_to_not_convert", [])
save_directory = kwargs.pop('save_directory', None)
from transformers.models.auto.configuration_auto import AutoConfig
from transformers.modeling_utils import no_init_weights, get_state_dict_dtype
@ -650,16 +659,37 @@ class _BaseAutoModelClass:
param.requires_grad_(False)
if optimize_model and not pipeline:
from ipex_llm.transformers.npu_models.convert_mp import optimize_llm
optimize_llm(
llm,
max_context_len=max_context_len,
max_prompt_len=max_prompt_len,
inter_pp=inter_pp,
intra_pp=intra_pp,
transpose_value_cache=transpose_value_cache,
group_size=quantization_group_size
)
if model.config.model_type in ["qwen2", "llama", "minicpm"]:
from ipex_llm.transformers.npu_models.convert import optimize_llm_single_process
if save_directory is None:
invalidInputError(False,
"Please specify the save_directory, the path of folder " +
"to save the compiled NPU model. If path not exists, " +
"the compiled NPU model will be saved there. " +
"Else, program will exit.")
optimize_llm_single_process(
llm,
kv_len=max_context_len,
max_prompt_len=max_prompt_len,
transpose_value_cache=transpose_value_cache,
group_size=quantization_group_size,
qtype=qtype,
save_directory=save_directory,
fuse_layers=None,
has_llm=hasattr(model, "llm")
)
else:
from ipex_llm.transformers.npu_models.convert_mp import optimize_llm
optimize_llm(
llm,
max_context_len=max_context_len,
max_prompt_len=max_prompt_len,
inter_pp=inter_pp,
intra_pp=intra_pp,
transpose_value_cache=transpose_value_cache,
group_size=quantization_group_size
)
elif optimize_model and pipeline:
from ipex_llm.transformers.npu_pipeline_model.convert_pipeline \
import convert_llm

View file

@ -18,7 +18,7 @@ import torch
import importlib
import numpy as np
from ipex_llm.transformers.low_bit_linear import LowBitLinear, FP4Params
from ipex_llm.transformers.npu_models.lm_head import LMHeadLinear, SlicedLMHead
from ipex_llm.transformers.npu_models.lm_head import SlicedLMHead
from ipex_llm.utils.common.log4Error import invalidInputError

View file

@ -21,16 +21,25 @@
# SPDX-License-Identifier: Apache 2.0
#
from intel_npu_acceleration_library.quantization import quantize_tensor, compress_to_i4
from intel_npu_acceleration_library.dtypes import NPUDtype
import os
import torch
from torch.nn import Parameter
import uuid
import math
from intel_npu_acceleration_library.backend import run_matmul
from typing import Optional, Union
from ipex_llm.utils.common import invalidInputError
import importlib
def is_acclib_available():
return importlib.util.find_spec("intel_npu_acceleration_library") is not None
if is_acclib_available():
from intel_npu_acceleration_library.quantization import quantize_tensor, compress_to_i4
from intel_npu_acceleration_library.dtypes import NPUDtype
from intel_npu_acceleration_library.backend import run_matmul
class Linear(torch.nn.Module):
@ -63,6 +72,7 @@ class Linear(torch.nn.Module):
if self.training:
out = self._mm(x, self.weight, None)
else:
from intel_npu_acceleration_library.backend import run_matmul
out = run_matmul(x, self.weight, None, self.op_id)
if self.bias is None:
@ -105,6 +115,8 @@ class Linear(torch.nn.Module):
Returns:
Union[Linear, QuantizedLinear]: A NPU linear layer
"""
from intel_npu_acceleration_library.quantization import quantize_tensor, compress_to_i4
from intel_npu_acceleration_library.dtypes import NPUDtype
if dtype.is_floating_point:
if bias is None:
return Linear(weight.to(dtype), None)

View file

@ -16,96 +16,6 @@
import torch
from torch import nn
import numpy as np
from filelock import FileLock
from intel_npu_acceleration_library.backend import NNFactory
from intel_npu_acceleration_library.backend.bindings import lib as backend_lib
class LMHeadLinear(NNFactory):
"""Quantized Linear class for sliced lm_head, computing a matrix matrix multiplication
with weights prefetching."""
def __init__(
self,
inC: int,
outC: int,
batch: int,
split_num: int = 2,
profile: bool = False,
device: str = "NPU",
dtype: np.dtype = np.int8,
use_split: bool = False,
group_size: int = 0,
asym: bool = False,
):
"""Initialize the LMHeadLinear class.
Args:
inC (int): input channels
outC (int): output channels
batch (int): batch
split_num (int): split in_features of lm_head to how many parts
profile (bool): Enable/Disable profiling. Defaults to False.
device (str): Target device, default to "NPU".
dtype (np.dtype): weights datatype. Defaults to np.int8.
"""
super().__init__(profile, device)
self.inC, self.outC = inC, outC
self.batch = batch
self.split_num = split_num
if use_split:
input = self.parameter((1, self.batch, self.inC))
res = self.dq_split_linear(input, self.split_num, self.outC, self.inC, wt_dtype=dtype,
scale_factor=(group_size == 0), asym=asym)
else:
input = self.parameter((self.batch, self.inC))
split_size = self.inC // split_num // 2 * 2
for i in range(self.split_num):
start_idx = i * split_size
end_idx = (i + 1) * split_size if i < self.split_num - 1 else self.inC
input_slice = self.slice(input, begin=[0, start_idx],
end=[self.batch, end_idx])
linear_slice = self.linear(input_slice, outC, split_size, bias=False,
wt_dtype=dtype, asym=asym)
if i == 0:
res = linear_slice
else:
res += linear_slice
print("start compiling lm_head")
self.compile()
print("end compiling lm_head")
def set_weights(self, op_id, weights):
self.set_weights_async(op_id, weights)
with FileLock(f"lmhead_run.lock"):
backend_lib.run(self._mm)
def set_weights_async(self, op_id, weights):
self.setWeights(1, op_id, *weights)
def run(
self, X: np.ndarray
) -> np.ndarray:
"""Run the layer: $X * (W * S)^T$ .
Args:
X (np.ndarray): activation
Raises:
RuntimeError: Input, weights or scale shape mismatch
Returns:
np.ndarray: result
"""
self.set_input_tensor(X, 0)
self.elapsed = backend_lib.run(self._mm)
if len(self.out) == 1:
return self.out[0]
return self.out
class SlicedLMHead(nn.Module):
@ -160,6 +70,7 @@ class SlicedLMHead(nn.Module):
return self.lm_heads[0].weight.dtype
def get_fused_lm_head(self):
from ipex_llm.transformers.npu_models.lm_head_linear import LMHeadLinear
np_dtype = np.uint8 if self.get_weight_dtype() == torch.uint8 else np.int8
self.fused_lm_head = LMHeadLinear(self.inC, self.outC, 1, self.split_num,
False, "NPU", dtype=np_dtype, use_split=self.use_split,

View file

@ -0,0 +1,106 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from filelock import FileLock
from intel_npu_acceleration_library.backend import NNFactory
from intel_npu_acceleration_library.backend.bindings import lib as backend_lib
class LMHeadLinear(NNFactory):
"""Quantized Linear class for sliced lm_head, computing a matrix matrix multiplication
with weights prefetching."""
def __init__(
self,
inC: int,
outC: int,
batch: int,
split_num: int = 2,
profile: bool = False,
device: str = "NPU",
dtype: np.dtype = np.int8,
use_split: bool = False,
group_size: int = 0,
asym: bool = False,
):
"""Initialize the LMHeadLinear class.
Args:
inC (int): input channels
outC (int): output channels
batch (int): batch
split_num (int): split in_features of lm_head to how many parts
profile (bool): Enable/Disable profiling. Defaults to False.
device (str): Target device, default to "NPU".
dtype (np.dtype): weights datatype. Defaults to np.int8.
"""
super().__init__(profile, device)
self.inC, self.outC = inC, outC
self.batch = batch
self.split_num = split_num
if use_split:
input = self.parameter((1, self.batch, self.inC))
res = self.dq_split_linear(input, self.split_num, self.outC, self.inC, wt_dtype=dtype,
scale_factor=(group_size == 0), asym=asym)
else:
input = self.parameter((self.batch, self.inC))
split_size = self.inC // split_num // 2 * 2
for i in range(self.split_num):
start_idx = i * split_size
end_idx = (i + 1) * split_size if i < self.split_num - 1 else self.inC
input_slice = self.slice(input, begin=[0, start_idx],
end=[self.batch, end_idx])
linear_slice = self.linear(input_slice, outC, split_size, bias=False,
wt_dtype=dtype, asym=asym)
if i == 0:
res = linear_slice
else:
res += linear_slice
print("start compiling lm_head")
self.compile()
print("end compiling lm_head")
def set_weights(self, op_id, weights):
self.set_weights_async(op_id, weights)
with FileLock(f"lmhead_run.lock"):
backend_lib.run(self._mm)
def set_weights_async(self, op_id, weights):
self.setWeights(1, op_id, *weights)
def run(
self, X: np.ndarray
) -> np.ndarray:
"""Run the layer: $X * (W * S)^T$ .
Args:
X (np.ndarray): activation
Raises:
RuntimeError: Input, weights or scale shape mismatch
Returns:
np.ndarray: result
"""
self.set_input_tensor(X, 0)
self.elapsed = backend_lib.run(self._mm)
if len(self.out) == 1:
return self.out[0]
return self.out