[NPU] Support save npu quantized model without npu dependency (#12647)
* support save awq * load quantized model & save npu compiled model * fix style * update * fix dll load issue * update error message * fix style
This commit is contained in:
parent
502461d836
commit
fae73eee79
5 changed files with 203 additions and 144 deletions
|
|
@ -27,7 +27,7 @@ from transformers.configuration_utils import PretrainedConfig
|
||||||
|
|
||||||
from ipex_llm.utils.common.log4Error import invalidInputError
|
from ipex_llm.utils.common.log4Error import invalidInputError
|
||||||
from ipex_llm.transformers.utils import logger, load_imatrix_data
|
from ipex_llm.transformers.utils import logger, load_imatrix_data
|
||||||
from ipex_llm.transformers.npu_models.convert import optimize_llm, optimize_llm_post
|
from ipex_llm.transformers.npu_models.convert import optimize_llm
|
||||||
|
|
||||||
|
|
||||||
def patch_flash_attn_import(filename: str) -> List[str]:
|
def patch_flash_attn_import(filename: str) -> List[str]:
|
||||||
|
|
@ -207,8 +207,6 @@ class _BaseAutoModelClass:
|
||||||
model = model.eval()
|
model = model.eval()
|
||||||
logger.info(f"Finish to convert model")
|
logger.info(f"Finish to convert model")
|
||||||
else:
|
else:
|
||||||
from intel_npu_acceleration_library.compiler import create_npu_kernels
|
|
||||||
|
|
||||||
if optimize_model:
|
if optimize_model:
|
||||||
invalidInputError(
|
invalidInputError(
|
||||||
max_prompt_len < max_context_len,
|
max_prompt_len < max_context_len,
|
||||||
|
|
@ -232,11 +230,14 @@ class _BaseAutoModelClass:
|
||||||
"convert_model": convert_model,
|
"convert_model": convert_model,
|
||||||
"save_directory": save_directory,
|
"save_directory": save_directory,
|
||||||
"fuse_layers": fuse_layers,
|
"fuse_layers": fuse_layers,
|
||||||
"imatrix_data": imatrix_data
|
"imatrix_data": imatrix_data,
|
||||||
|
"skip_npu_logic": mock_device == "dummy",
|
||||||
}
|
}
|
||||||
|
# Dummy will skip npu related logic and save the quantized model
|
||||||
|
if mock_device == "dummy":
|
||||||
|
model.save_low_bit = types.MethodType(save_low_bit, model)
|
||||||
model = cls.optimize_npu_model(*args, **optimize_kwargs)
|
model = cls.optimize_npu_model(*args, **optimize_kwargs)
|
||||||
else:
|
else:
|
||||||
from ipex_llm.transformers.npu_models.convert import optimize_llm
|
|
||||||
optimize_llm(model)
|
optimize_llm(model)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
cls.load_convert(qtype, model, "cpu", modules_to_not_convert,
|
cls.load_convert(qtype, model, "cpu", modules_to_not_convert,
|
||||||
|
|
@ -258,7 +259,6 @@ class _BaseAutoModelClass:
|
||||||
def optimize_npu_model(cls, *args, **kwargs):
|
def optimize_npu_model(cls, *args, **kwargs):
|
||||||
|
|
||||||
from ipex_llm.transformers.npu_models.convert_mp import optimize_llm_pre, optimize_llm
|
from ipex_llm.transformers.npu_models.convert_mp import optimize_llm_pre, optimize_llm
|
||||||
from intel_npu_acceleration_library.compiler import create_npu_kernels
|
|
||||||
|
|
||||||
model = kwargs.pop("model")
|
model = kwargs.pop("model")
|
||||||
qtype = kwargs.pop("qtype", "sym_int4_rtn")
|
qtype = kwargs.pop("qtype", "sym_int4_rtn")
|
||||||
|
|
@ -275,6 +275,7 @@ class _BaseAutoModelClass:
|
||||||
save_directory = kwargs.pop('save_directory', None)
|
save_directory = kwargs.pop('save_directory', None)
|
||||||
fuse_layers = kwargs.pop('fuse_layers', None)
|
fuse_layers = kwargs.pop('fuse_layers', None)
|
||||||
imatrix_data = kwargs.pop('imatrix_data', None)
|
imatrix_data = kwargs.pop('imatrix_data', None)
|
||||||
|
skip_npu_logic = kwargs.pop("skip_npu_logic", False)
|
||||||
invalidInputError(save_directory is not None,
|
invalidInputError(save_directory is not None,
|
||||||
"Please provide the path to save converted model "
|
"Please provide the path to save converted model "
|
||||||
"through `save_directory`.")
|
"through `save_directory`.")
|
||||||
|
|
@ -294,51 +295,58 @@ class _BaseAutoModelClass:
|
||||||
cls.load_convert(qtype, model, "cpu", modules_to_not_convert,
|
cls.load_convert(qtype, model, "cpu", modules_to_not_convert,
|
||||||
quantization_group_size, imatrix_data,
|
quantization_group_size, imatrix_data,
|
||||||
*args, **kwargs)
|
*args, **kwargs)
|
||||||
create_npu_kernels(llm)
|
if not skip_npu_logic:
|
||||||
|
from intel_npu_acceleration_library.compiler import create_npu_kernels
|
||||||
|
create_npu_kernels(llm)
|
||||||
model = model.eval()
|
model = model.eval()
|
||||||
logger.info(f"Finish to convert model")
|
logger.info(f"Finish to convert model")
|
||||||
model.config.update({"bigdl_transformers_low_bit": qtype})
|
model.config.update({"bigdl_transformers_low_bit": qtype})
|
||||||
model.share_memory()
|
|
||||||
|
|
||||||
if not pipeline:
|
if skip_npu_logic:
|
||||||
if model.config.model_type in ["qwen2", "llama", "minicpm"]:
|
model.save_low_bit(model_dir=save_directory)
|
||||||
from ipex_llm.transformers.npu_models.convert import optimize_llm_single_process
|
|
||||||
optimize_llm_single_process(
|
|
||||||
llm,
|
|
||||||
kv_len=max_context_len,
|
|
||||||
max_prompt_len=max_prompt_len,
|
|
||||||
transpose_value_cache=transpose_value_cache,
|
|
||||||
group_size=quantization_group_size,
|
|
||||||
qtype=qtype,
|
|
||||||
save_directory=save_directory,
|
|
||||||
fuse_layers=fuse_layers,
|
|
||||||
has_llm=hasattr(model, "llm")
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
optimize_llm(
|
|
||||||
llm,
|
|
||||||
max_context_len=max_context_len,
|
|
||||||
max_prompt_len=max_prompt_len,
|
|
||||||
inter_pp=inter_pp,
|
|
||||||
intra_pp=intra_pp,
|
|
||||||
transpose_value_cache=transpose_value_cache,
|
|
||||||
group_size=quantization_group_size
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
from ipex_llm.transformers.npu_pipeline_model.convert_pipeline \
|
model.share_memory()
|
||||||
import convert_llm
|
|
||||||
convert_llm(llm,
|
if not pipeline:
|
||||||
|
if model.config.model_type in ["qwen2", "llama", "minicpm"]:
|
||||||
|
from ipex_llm.transformers.npu_models.convert import optimize_llm_single_process
|
||||||
|
optimize_llm_single_process(
|
||||||
|
llm,
|
||||||
kv_len=max_context_len,
|
kv_len=max_context_len,
|
||||||
max_prompt_len=max_prompt_len,
|
max_prompt_len=max_prompt_len,
|
||||||
transpose_value_cache=transpose_value_cache,
|
transpose_value_cache=transpose_value_cache,
|
||||||
group_size=quantization_group_size,
|
group_size=quantization_group_size,
|
||||||
qtype=qtype,
|
qtype=qtype,
|
||||||
convert_model=convert_model,
|
|
||||||
save_directory=save_directory,
|
save_directory=save_directory,
|
||||||
fuse_layers=fuse_layers)
|
fuse_layers=fuse_layers,
|
||||||
model.save_low_bit = types.MethodType(save_low_bit, model)
|
has_llm=hasattr(model, "llm")
|
||||||
model.save_low_bit(save_directory)
|
)
|
||||||
logger.info(f"Converted model has already saved to {save_directory}.")
|
else:
|
||||||
|
optimize_llm(
|
||||||
|
llm,
|
||||||
|
max_context_len=max_context_len,
|
||||||
|
max_prompt_len=max_prompt_len,
|
||||||
|
inter_pp=inter_pp,
|
||||||
|
intra_pp=intra_pp,
|
||||||
|
transpose_value_cache=transpose_value_cache,
|
||||||
|
group_size=quantization_group_size
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
from ipex_llm.transformers.npu_pipeline_model.convert_pipeline \
|
||||||
|
import convert_llm
|
||||||
|
convert_llm(llm,
|
||||||
|
kv_len=max_context_len,
|
||||||
|
max_prompt_len=max_prompt_len,
|
||||||
|
transpose_value_cache=transpose_value_cache,
|
||||||
|
group_size=quantization_group_size,
|
||||||
|
qtype=qtype,
|
||||||
|
convert_model=convert_model,
|
||||||
|
save_directory=save_directory,
|
||||||
|
fuse_layers=fuse_layers)
|
||||||
|
model.save_low_bit = types.MethodType(save_low_bit, model)
|
||||||
|
model.save_low_bit(save_directory)
|
||||||
|
logger.info(f"Converted model has already saved to {save_directory}.")
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
@ -379,6 +387,7 @@ class _BaseAutoModelClass:
|
||||||
intra_pp = kwargs.pop("intra_pp", None)
|
intra_pp = kwargs.pop("intra_pp", None)
|
||||||
transpose_value_cache = kwargs.pop("transpose_value_cache", True)
|
transpose_value_cache = kwargs.pop("transpose_value_cache", True)
|
||||||
modules_to_not_convert = kwargs.pop("modules_to_not_convert", [])
|
modules_to_not_convert = kwargs.pop("modules_to_not_convert", [])
|
||||||
|
save_directory = kwargs.pop('save_directory', None)
|
||||||
|
|
||||||
from transformers.models.auto.configuration_auto import AutoConfig
|
from transformers.models.auto.configuration_auto import AutoConfig
|
||||||
from transformers.modeling_utils import no_init_weights, get_state_dict_dtype
|
from transformers.modeling_utils import no_init_weights, get_state_dict_dtype
|
||||||
|
|
@ -650,16 +659,37 @@ class _BaseAutoModelClass:
|
||||||
param.requires_grad_(False)
|
param.requires_grad_(False)
|
||||||
|
|
||||||
if optimize_model and not pipeline:
|
if optimize_model and not pipeline:
|
||||||
from ipex_llm.transformers.npu_models.convert_mp import optimize_llm
|
if model.config.model_type in ["qwen2", "llama", "minicpm"]:
|
||||||
optimize_llm(
|
from ipex_llm.transformers.npu_models.convert import optimize_llm_single_process
|
||||||
llm,
|
if save_directory is None:
|
||||||
max_context_len=max_context_len,
|
invalidInputError(False,
|
||||||
max_prompt_len=max_prompt_len,
|
"Please specify the save_directory, the path of folder " +
|
||||||
inter_pp=inter_pp,
|
"to save the compiled NPU model. If path not exists, " +
|
||||||
intra_pp=intra_pp,
|
"the compiled NPU model will be saved there. " +
|
||||||
transpose_value_cache=transpose_value_cache,
|
"Else, program will exit.")
|
||||||
group_size=quantization_group_size
|
|
||||||
)
|
optimize_llm_single_process(
|
||||||
|
llm,
|
||||||
|
kv_len=max_context_len,
|
||||||
|
max_prompt_len=max_prompt_len,
|
||||||
|
transpose_value_cache=transpose_value_cache,
|
||||||
|
group_size=quantization_group_size,
|
||||||
|
qtype=qtype,
|
||||||
|
save_directory=save_directory,
|
||||||
|
fuse_layers=None,
|
||||||
|
has_llm=hasattr(model, "llm")
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
from ipex_llm.transformers.npu_models.convert_mp import optimize_llm
|
||||||
|
optimize_llm(
|
||||||
|
llm,
|
||||||
|
max_context_len=max_context_len,
|
||||||
|
max_prompt_len=max_prompt_len,
|
||||||
|
inter_pp=inter_pp,
|
||||||
|
intra_pp=intra_pp,
|
||||||
|
transpose_value_cache=transpose_value_cache,
|
||||||
|
group_size=quantization_group_size
|
||||||
|
)
|
||||||
elif optimize_model and pipeline:
|
elif optimize_model and pipeline:
|
||||||
from ipex_llm.transformers.npu_pipeline_model.convert_pipeline \
|
from ipex_llm.transformers.npu_pipeline_model.convert_pipeline \
|
||||||
import convert_llm
|
import convert_llm
|
||||||
|
|
|
||||||
|
|
@ -18,7 +18,7 @@ import torch
|
||||||
import importlib
|
import importlib
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from ipex_llm.transformers.low_bit_linear import LowBitLinear, FP4Params
|
from ipex_llm.transformers.low_bit_linear import LowBitLinear, FP4Params
|
||||||
from ipex_llm.transformers.npu_models.lm_head import LMHeadLinear, SlicedLMHead
|
from ipex_llm.transformers.npu_models.lm_head import SlicedLMHead
|
||||||
from ipex_llm.utils.common.log4Error import invalidInputError
|
from ipex_llm.utils.common.log4Error import invalidInputError
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -21,16 +21,25 @@
|
||||||
# SPDX-License-Identifier: Apache 2.0
|
# SPDX-License-Identifier: Apache 2.0
|
||||||
#
|
#
|
||||||
|
|
||||||
from intel_npu_acceleration_library.quantization import quantize_tensor, compress_to_i4
|
|
||||||
from intel_npu_acceleration_library.dtypes import NPUDtype
|
|
||||||
import os
|
import os
|
||||||
import torch
|
import torch
|
||||||
from torch.nn import Parameter
|
from torch.nn import Parameter
|
||||||
import uuid
|
import uuid
|
||||||
import math
|
import math
|
||||||
from intel_npu_acceleration_library.backend import run_matmul
|
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
from ipex_llm.utils.common import invalidInputError
|
from ipex_llm.utils.common import invalidInputError
|
||||||
|
import importlib
|
||||||
|
|
||||||
|
|
||||||
|
def is_acclib_available():
|
||||||
|
return importlib.util.find_spec("intel_npu_acceleration_library") is not None
|
||||||
|
|
||||||
|
|
||||||
|
if is_acclib_available():
|
||||||
|
from intel_npu_acceleration_library.quantization import quantize_tensor, compress_to_i4
|
||||||
|
from intel_npu_acceleration_library.dtypes import NPUDtype
|
||||||
|
from intel_npu_acceleration_library.backend import run_matmul
|
||||||
|
|
||||||
|
|
||||||
class Linear(torch.nn.Module):
|
class Linear(torch.nn.Module):
|
||||||
|
|
@ -63,6 +72,7 @@ class Linear(torch.nn.Module):
|
||||||
if self.training:
|
if self.training:
|
||||||
out = self._mm(x, self.weight, None)
|
out = self._mm(x, self.weight, None)
|
||||||
else:
|
else:
|
||||||
|
from intel_npu_acceleration_library.backend import run_matmul
|
||||||
out = run_matmul(x, self.weight, None, self.op_id)
|
out = run_matmul(x, self.weight, None, self.op_id)
|
||||||
|
|
||||||
if self.bias is None:
|
if self.bias is None:
|
||||||
|
|
@ -105,6 +115,8 @@ class Linear(torch.nn.Module):
|
||||||
Returns:
|
Returns:
|
||||||
Union[Linear, QuantizedLinear]: A NPU linear layer
|
Union[Linear, QuantizedLinear]: A NPU linear layer
|
||||||
"""
|
"""
|
||||||
|
from intel_npu_acceleration_library.quantization import quantize_tensor, compress_to_i4
|
||||||
|
from intel_npu_acceleration_library.dtypes import NPUDtype
|
||||||
if dtype.is_floating_point:
|
if dtype.is_floating_point:
|
||||||
if bias is None:
|
if bias is None:
|
||||||
return Linear(weight.to(dtype), None)
|
return Linear(weight.to(dtype), None)
|
||||||
|
|
|
||||||
|
|
@ -16,96 +16,6 @@
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from filelock import FileLock
|
|
||||||
from intel_npu_acceleration_library.backend import NNFactory
|
|
||||||
from intel_npu_acceleration_library.backend.bindings import lib as backend_lib
|
|
||||||
|
|
||||||
|
|
||||||
class LMHeadLinear(NNFactory):
|
|
||||||
"""Quantized Linear class for sliced lm_head, computing a matrix matrix multiplication
|
|
||||||
with weights prefetching."""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
inC: int,
|
|
||||||
outC: int,
|
|
||||||
batch: int,
|
|
||||||
split_num: int = 2,
|
|
||||||
profile: bool = False,
|
|
||||||
device: str = "NPU",
|
|
||||||
dtype: np.dtype = np.int8,
|
|
||||||
use_split: bool = False,
|
|
||||||
group_size: int = 0,
|
|
||||||
asym: bool = False,
|
|
||||||
):
|
|
||||||
"""Initialize the LMHeadLinear class.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
inC (int): input channels
|
|
||||||
outC (int): output channels
|
|
||||||
batch (int): batch
|
|
||||||
split_num (int): split in_features of lm_head to how many parts
|
|
||||||
profile (bool): Enable/Disable profiling. Defaults to False.
|
|
||||||
device (str): Target device, default to "NPU".
|
|
||||||
dtype (np.dtype): weights datatype. Defaults to np.int8.
|
|
||||||
|
|
||||||
"""
|
|
||||||
super().__init__(profile, device)
|
|
||||||
self.inC, self.outC = inC, outC
|
|
||||||
self.batch = batch
|
|
||||||
|
|
||||||
self.split_num = split_num
|
|
||||||
if use_split:
|
|
||||||
input = self.parameter((1, self.batch, self.inC))
|
|
||||||
res = self.dq_split_linear(input, self.split_num, self.outC, self.inC, wt_dtype=dtype,
|
|
||||||
scale_factor=(group_size == 0), asym=asym)
|
|
||||||
else:
|
|
||||||
input = self.parameter((self.batch, self.inC))
|
|
||||||
split_size = self.inC // split_num // 2 * 2
|
|
||||||
|
|
||||||
for i in range(self.split_num):
|
|
||||||
start_idx = i * split_size
|
|
||||||
end_idx = (i + 1) * split_size if i < self.split_num - 1 else self.inC
|
|
||||||
input_slice = self.slice(input, begin=[0, start_idx],
|
|
||||||
end=[self.batch, end_idx])
|
|
||||||
linear_slice = self.linear(input_slice, outC, split_size, bias=False,
|
|
||||||
wt_dtype=dtype, asym=asym)
|
|
||||||
if i == 0:
|
|
||||||
res = linear_slice
|
|
||||||
else:
|
|
||||||
res += linear_slice
|
|
||||||
|
|
||||||
print("start compiling lm_head")
|
|
||||||
self.compile()
|
|
||||||
print("end compiling lm_head")
|
|
||||||
|
|
||||||
def set_weights(self, op_id, weights):
|
|
||||||
self.set_weights_async(op_id, weights)
|
|
||||||
with FileLock(f"lmhead_run.lock"):
|
|
||||||
backend_lib.run(self._mm)
|
|
||||||
|
|
||||||
def set_weights_async(self, op_id, weights):
|
|
||||||
self.setWeights(1, op_id, *weights)
|
|
||||||
|
|
||||||
def run(
|
|
||||||
self, X: np.ndarray
|
|
||||||
) -> np.ndarray:
|
|
||||||
"""Run the layer: $X * (W * S)^T$ .
|
|
||||||
|
|
||||||
Args:
|
|
||||||
X (np.ndarray): activation
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
RuntimeError: Input, weights or scale shape mismatch
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
np.ndarray: result
|
|
||||||
"""
|
|
||||||
self.set_input_tensor(X, 0)
|
|
||||||
self.elapsed = backend_lib.run(self._mm)
|
|
||||||
if len(self.out) == 1:
|
|
||||||
return self.out[0]
|
|
||||||
return self.out
|
|
||||||
|
|
||||||
|
|
||||||
class SlicedLMHead(nn.Module):
|
class SlicedLMHead(nn.Module):
|
||||||
|
|
@ -160,6 +70,7 @@ class SlicedLMHead(nn.Module):
|
||||||
return self.lm_heads[0].weight.dtype
|
return self.lm_heads[0].weight.dtype
|
||||||
|
|
||||||
def get_fused_lm_head(self):
|
def get_fused_lm_head(self):
|
||||||
|
from ipex_llm.transformers.npu_models.lm_head_linear import LMHeadLinear
|
||||||
np_dtype = np.uint8 if self.get_weight_dtype() == torch.uint8 else np.int8
|
np_dtype = np.uint8 if self.get_weight_dtype() == torch.uint8 else np.int8
|
||||||
self.fused_lm_head = LMHeadLinear(self.inC, self.outC, 1, self.split_num,
|
self.fused_lm_head = LMHeadLinear(self.inC, self.outC, 1, self.split_num,
|
||||||
False, "NPU", dtype=np_dtype, use_split=self.use_split,
|
False, "NPU", dtype=np_dtype, use_split=self.use_split,
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,106 @@
|
||||||
|
#
|
||||||
|
# Copyright 2016 The BigDL Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from filelock import FileLock
|
||||||
|
from intel_npu_acceleration_library.backend import NNFactory
|
||||||
|
from intel_npu_acceleration_library.backend.bindings import lib as backend_lib
|
||||||
|
|
||||||
|
|
||||||
|
class LMHeadLinear(NNFactory):
|
||||||
|
"""Quantized Linear class for sliced lm_head, computing a matrix matrix multiplication
|
||||||
|
with weights prefetching."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
inC: int,
|
||||||
|
outC: int,
|
||||||
|
batch: int,
|
||||||
|
split_num: int = 2,
|
||||||
|
profile: bool = False,
|
||||||
|
device: str = "NPU",
|
||||||
|
dtype: np.dtype = np.int8,
|
||||||
|
use_split: bool = False,
|
||||||
|
group_size: int = 0,
|
||||||
|
asym: bool = False,
|
||||||
|
):
|
||||||
|
"""Initialize the LMHeadLinear class.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inC (int): input channels
|
||||||
|
outC (int): output channels
|
||||||
|
batch (int): batch
|
||||||
|
split_num (int): split in_features of lm_head to how many parts
|
||||||
|
profile (bool): Enable/Disable profiling. Defaults to False.
|
||||||
|
device (str): Target device, default to "NPU".
|
||||||
|
dtype (np.dtype): weights datatype. Defaults to np.int8.
|
||||||
|
|
||||||
|
"""
|
||||||
|
super().__init__(profile, device)
|
||||||
|
self.inC, self.outC = inC, outC
|
||||||
|
self.batch = batch
|
||||||
|
|
||||||
|
self.split_num = split_num
|
||||||
|
if use_split:
|
||||||
|
input = self.parameter((1, self.batch, self.inC))
|
||||||
|
res = self.dq_split_linear(input, self.split_num, self.outC, self.inC, wt_dtype=dtype,
|
||||||
|
scale_factor=(group_size == 0), asym=asym)
|
||||||
|
else:
|
||||||
|
input = self.parameter((self.batch, self.inC))
|
||||||
|
split_size = self.inC // split_num // 2 * 2
|
||||||
|
|
||||||
|
for i in range(self.split_num):
|
||||||
|
start_idx = i * split_size
|
||||||
|
end_idx = (i + 1) * split_size if i < self.split_num - 1 else self.inC
|
||||||
|
input_slice = self.slice(input, begin=[0, start_idx],
|
||||||
|
end=[self.batch, end_idx])
|
||||||
|
linear_slice = self.linear(input_slice, outC, split_size, bias=False,
|
||||||
|
wt_dtype=dtype, asym=asym)
|
||||||
|
if i == 0:
|
||||||
|
res = linear_slice
|
||||||
|
else:
|
||||||
|
res += linear_slice
|
||||||
|
|
||||||
|
print("start compiling lm_head")
|
||||||
|
self.compile()
|
||||||
|
print("end compiling lm_head")
|
||||||
|
|
||||||
|
def set_weights(self, op_id, weights):
|
||||||
|
self.set_weights_async(op_id, weights)
|
||||||
|
with FileLock(f"lmhead_run.lock"):
|
||||||
|
backend_lib.run(self._mm)
|
||||||
|
|
||||||
|
def set_weights_async(self, op_id, weights):
|
||||||
|
self.setWeights(1, op_id, *weights)
|
||||||
|
|
||||||
|
def run(
|
||||||
|
self, X: np.ndarray
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""Run the layer: $X * (W * S)^T$ .
|
||||||
|
|
||||||
|
Args:
|
||||||
|
X (np.ndarray): activation
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: Input, weights or scale shape mismatch
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
np.ndarray: result
|
||||||
|
"""
|
||||||
|
self.set_input_tensor(X, 0)
|
||||||
|
self.elapsed = backend_lib.run(self._mm)
|
||||||
|
if len(self.out) == 1:
|
||||||
|
return self.out[0]
|
||||||
|
return self.out
|
||||||
Loading…
Reference in a new issue