parent
f07937945f
commit
57b8adb189
4 changed files with 440 additions and 28 deletions
|
|
@ -225,6 +225,7 @@ def ggml_convert_qtype(tensor: torch.Tensor, qtype: int,
|
||||||
if qtype in [SYM_INT8_RTN, SYM_INT4_RTN]:
|
if qtype in [SYM_INT8_RTN, SYM_INT4_RTN]:
|
||||||
dst_tensor = torch.empty(dst_size, dtype=RTN_DTYPE[qtype],
|
dst_tensor = torch.empty(dst_size, dtype=RTN_DTYPE[qtype],
|
||||||
device=device)
|
device=device)
|
||||||
|
dst_tensor = dst_tensor.reshape(tensor.shape[0], tensor.shape[-1] // QK)
|
||||||
scale = torch.empty(n // k, dtype=torch.float32,
|
scale = torch.empty(n // k, dtype=torch.float32,
|
||||||
device=device)
|
device=device)
|
||||||
else:
|
else:
|
||||||
|
|
@ -239,7 +240,6 @@ def ggml_convert_qtype(tensor: torch.Tensor, qtype: int,
|
||||||
scale_ptr = ctypes.cast(scale.data.data_ptr(), ctypes.POINTER(ctypes.c_float))
|
scale_ptr = ctypes.cast(scale.data.data_ptr(), ctypes.POINTER(ctypes.c_float))
|
||||||
ggml.ggml_quantize_tensor_rtn(src, dst, scale_ptr, qtype, n,
|
ggml.ggml_quantize_tensor_rtn(src, dst, scale_ptr, qtype, n,
|
||||||
k, hist, enable_scale_search)
|
k, hist, enable_scale_search)
|
||||||
dst_tensor = dst_tensor.reshape(tensor.shape[0], tensor.shape[-1] // QK)
|
|
||||||
return dst_tensor, scale.type(torch.float16)
|
return dst_tensor, scale.type(torch.float16)
|
||||||
else:
|
else:
|
||||||
ggml.ggml_quantize_tensor(src, dst, qtype, n, k, hist, enable_scale_search)
|
ggml.ggml_quantize_tensor(src, dst, qtype, n, k, hist, enable_scale_search)
|
||||||
|
|
@ -252,7 +252,10 @@ def ggml_convert_qtype(tensor: torch.Tensor, qtype: int,
|
||||||
ggml.ggml_quantize_tensor_with_weights(src, dst, qtype,
|
ggml.ggml_quantize_tensor_with_weights(src, dst, qtype,
|
||||||
n // in_features, in_features,
|
n // in_features, in_features,
|
||||||
hist, imatrix)
|
hist, imatrix)
|
||||||
return dst_tensor
|
if qtype in [SYM_INT8_RTN, SYM_INT4_RTN]:
|
||||||
|
return dst_tensor, scale.type(torch.float16)
|
||||||
|
else:
|
||||||
|
return dst_tensor
|
||||||
|
|
||||||
|
|
||||||
def ggml_q_format_convet_cpu2xpu(tensor: torch.Tensor, num_elem: int, qtype: int):
|
def ggml_q_format_convet_cpu2xpu(tensor: torch.Tensor, num_elem: int, qtype: int):
|
||||||
|
|
|
||||||
|
|
@ -15,6 +15,7 @@
|
||||||
#
|
#
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
import copy
|
||||||
import types
|
import types
|
||||||
import warnings
|
import warnings
|
||||||
import torch
|
import torch
|
||||||
|
|
@ -22,6 +23,7 @@ import transformers
|
||||||
from typing import List
|
from typing import List
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
from transformers.dynamic_module_utils import get_imports
|
from transformers.dynamic_module_utils import get_imports
|
||||||
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
|
|
||||||
import intel_npu_acceleration_library as npu_lib
|
import intel_npu_acceleration_library as npu_lib
|
||||||
|
|
||||||
|
|
@ -44,6 +46,23 @@ def ignore_argument(kwargs: dict, key: 'str'):
|
||||||
warnings.warn(f"argument `{key}={arg}` will be ignored")
|
warnings.warn(f"argument `{key}={arg}` will be ignored")
|
||||||
|
|
||||||
|
|
||||||
|
def save_low_bit(self, model_dir: str, *args, **kwargs):
|
||||||
|
origin_device = self.device
|
||||||
|
kwargs['safe_serialization'] = False
|
||||||
|
self.save_pretrained(model_dir, *args, **kwargs)
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
# We conveniently save all the keys of the model to have them on hand,
|
||||||
|
# so that when using 'low_cpumem load',
|
||||||
|
# it's not necessary to load the entire model to extract its keys
|
||||||
|
# and we can avoid gc not triggered potentially.
|
||||||
|
load_keys = {"all_checkpoint_keys": list(self.state_dict().keys())}
|
||||||
|
with open(os.path.join(model_dir, "load_keys.json"), "w") as json_file:
|
||||||
|
json.dump(load_keys, json_file)
|
||||||
|
if origin_device != 'cpu':
|
||||||
|
self.to(origin_device)
|
||||||
|
|
||||||
|
|
||||||
class _BaseAutoModelClass:
|
class _BaseAutoModelClass:
|
||||||
HF_MODEL = None
|
HF_MODEL = None
|
||||||
|
|
||||||
|
|
@ -110,7 +129,18 @@ class _BaseAutoModelClass:
|
||||||
ignore_argument(kwargs, "speculative")
|
ignore_argument(kwargs, "speculative")
|
||||||
ignore_argument(kwargs, "pipeline_parallel_stages")
|
ignore_argument(kwargs, "pipeline_parallel_stages")
|
||||||
|
|
||||||
model = cls.HF_Model.from_pretrained(*args, **kwargs)
|
_args = copy.deepcopy(args)
|
||||||
|
_kwargs = copy.deepcopy(kwargs)
|
||||||
|
try:
|
||||||
|
# To handle the input CUDA setting (such as 'device_map={"":0}'), ignore it
|
||||||
|
kwargs.pop('device_map', None)
|
||||||
|
model = cls.HF_Model.from_pretrained(*args, **kwargs)
|
||||||
|
except NotImplementedError:
|
||||||
|
logger.info("Failed to load models with `low_cpu_mem_usage` specified, "
|
||||||
|
"will fall to traditional load method with higher memory consumption.")
|
||||||
|
_kwargs["low_cpu_mem_usage"] = False
|
||||||
|
model = cls.HF_Model.from_pretrained(*_args, **_kwargs)
|
||||||
|
model.config.update({"bigdl_lcmu_enabled": False})
|
||||||
|
|
||||||
logger.info(f"Converting model, it may takes up to several minutes ...")
|
logger.info(f"Converting model, it may takes up to several minutes ...")
|
||||||
try:
|
try:
|
||||||
|
|
@ -120,7 +150,7 @@ class _BaseAutoModelClass:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
optimize_llm(model)
|
optimize_llm(model)
|
||||||
if qtype in ["sym_int8_rtn", "sym_int4_rtn"]:
|
if qtype in ["sym_int8_rtn", "sym_int4_rtn"]:
|
||||||
cls.load_convert(qtype, model, *args, **kwargs)
|
cls.load_convert(qtype, model, 'cpu', *args, **kwargs)
|
||||||
else:
|
else:
|
||||||
if not qtype.is_floating_point:
|
if not qtype.is_floating_point:
|
||||||
model = quantize_model(model, qtype)
|
model = quantize_model(model, qtype)
|
||||||
|
|
@ -131,27 +161,21 @@ class _BaseAutoModelClass:
|
||||||
model = npu_lib.compile(model, qtype, False)
|
model = npu_lib.compile(model, qtype, False)
|
||||||
logger.info(f"Finish to convert model")
|
logger.info(f"Finish to convert model")
|
||||||
|
|
||||||
|
model.config.update({"bigdl_transformers_low_bit": qtype})
|
||||||
|
|
||||||
# add save_low_bit to pretrained model dynamically
|
# add save_low_bit to pretrained model dynamically
|
||||||
model.save_low_bit = types.MethodType(cls.save_low_bit, model)
|
model.save_low_bit = types.MethodType(save_low_bit, model)
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load_convert(cls, q_k, optimize_model, *arg, **kwarg):
|
def load_convert(cls, q_k, optimize_model, device, *arg, **kwarg):
|
||||||
from ipex_llm.transformers.npu_models.convert import replace_with_QuantizedLinear
|
from ipex_llm.transformers.npu_models.convert import replace_with_QuantizedLinear
|
||||||
replace_with_QuantizedLinear(optimize_model, q_k)
|
replace_with_QuantizedLinear(optimize_model, q_k, device=device)
|
||||||
|
|
||||||
@staticmethod
|
@classmethod
|
||||||
def save_low_bit(self, model_dir: str, *args, **kwargs):
|
|
||||||
os.makedirs(model_dir, exist_ok=True)
|
|
||||||
model_name = "pytorch_npu_model.pt"
|
|
||||||
model_path = os.path.join(model_dir, model_name)
|
|
||||||
del self.save_low_bit # workaround a bug
|
|
||||||
torch.save(self, model_path)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
@patch("transformers.dynamic_module_utils.get_imports", patch_flash_attn_import)
|
@patch("transformers.dynamic_module_utils.get_imports", patch_flash_attn_import)
|
||||||
def load_low_bit(model_dir: str, *args, **kwargs):
|
def load_low_bit(cls, pretrained_model_name_or_path: str, *model_args, **kwargs):
|
||||||
if kwargs.pop('torch_dtype', None) not in [None, 'auto', torch.float]:
|
if kwargs.pop('torch_dtype', None) not in [None, 'auto', torch.float]:
|
||||||
warnings.warn("`torch_dtype` will be ignored, `torch.float` will be used")
|
warnings.warn("`torch_dtype` will be ignored, `torch.float` will be used")
|
||||||
|
|
||||||
|
|
@ -165,9 +189,203 @@ class _BaseAutoModelClass:
|
||||||
ignore_argument(kwargs, "speculative")
|
ignore_argument(kwargs, "speculative")
|
||||||
ignore_argument(kwargs, "pipeline_parallel_stages")
|
ignore_argument(kwargs, "pipeline_parallel_stages")
|
||||||
|
|
||||||
model_name = "pytorch_npu_model.pt"
|
from transformers.models.auto.configuration_auto import AutoConfig
|
||||||
model_path = os.path.join(model_dir, model_name)
|
from transformers.modeling_utils import no_init_weights, get_state_dict_dtype
|
||||||
return torch.load(model_path)
|
from transformers.dynamic_module_utils import resolve_trust_remote_code, \
|
||||||
|
get_class_from_dynamic_module
|
||||||
|
from transformers.models.auto.auto_factory import _get_model_class
|
||||||
|
from transformers.utils.generic import ContextManagers
|
||||||
|
from transformers.generation.configuration_utils import GenerationConfig
|
||||||
|
from ipex_llm.transformers.utils import extract_local_archive_file, get_local_shard_files, \
|
||||||
|
load_state_dict
|
||||||
|
from accelerate.big_modeling import init_empty_weights
|
||||||
|
|
||||||
|
trust_remote_code = kwargs.pop("trust_remote_code", None)
|
||||||
|
kwargs_orig = copy.deepcopy(kwargs)
|
||||||
|
|
||||||
|
config, kwargs = AutoConfig.from_pretrained(
|
||||||
|
pretrained_model_name_or_path,
|
||||||
|
return_unused_kwargs=True,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
# if torch_dtype=auto was passed here, ensure to pass it on
|
||||||
|
if kwargs_orig.get("torch_dtype", None) == "auto":
|
||||||
|
kwargs["torch_dtype"] = "auto"
|
||||||
|
|
||||||
|
# Maybe needed when extract_local_archive_file
|
||||||
|
subfolder = kwargs.get("subfolder", "")
|
||||||
|
variant = kwargs.get("variant", None)
|
||||||
|
offload_folder = kwargs.pop("offload_folder", None)
|
||||||
|
offload_state_dict = kwargs.pop("offload_state_dict", False)
|
||||||
|
torch_dtype = kwargs.pop("torch_dtype", "auto")
|
||||||
|
sharded_metadata = None
|
||||||
|
|
||||||
|
config_dict, _ = PretrainedConfig.get_config_dict(pretrained_model_name_or_path)
|
||||||
|
qtype = config_dict.pop("bigdl_transformers_low_bit", False)
|
||||||
|
bigdl_lcmu_enabled = config_dict.pop("bigdl_lcmu_enabled", True)
|
||||||
|
|
||||||
|
invalidInputError(qtype,
|
||||||
|
"Detect this model is not a low-bit model, Please use from_pretrained"
|
||||||
|
" with load_in_4bit or load_in_low_bit to get a low-bit model , and "
|
||||||
|
" serialize the model using save_low_bit first.")
|
||||||
|
|
||||||
|
invalidInputError(qtype in ["sym_int8_rtn", "sym_int4_rtn"],
|
||||||
|
f"Unknown bigdl_transformers_low_bit value: {qtype},"
|
||||||
|
f" expected: sym_int4, asym_int4, sym_int5, asym_int5 or sym_int8.")
|
||||||
|
|
||||||
|
has_remote_code = hasattr(config, "auto_map") and cls.HF_Model.__name__ in config.auto_map
|
||||||
|
has_local_code = type(config) in cls.HF_Model._model_mapping.keys()
|
||||||
|
trust_remote_code = resolve_trust_remote_code(
|
||||||
|
trust_remote_code, pretrained_model_name_or_path, has_local_code, has_remote_code
|
||||||
|
)
|
||||||
|
if has_remote_code and trust_remote_code:
|
||||||
|
class_ref = config.auto_map[cls.HF_Model.__name__]
|
||||||
|
model_class = get_class_from_dynamic_module(
|
||||||
|
class_ref, pretrained_model_name_or_path, **kwargs
|
||||||
|
)
|
||||||
|
if os.path.isdir(pretrained_model_name_or_path):
|
||||||
|
model_class.register_for_auto_class(cls.HF_Model.__name__)
|
||||||
|
else:
|
||||||
|
cls.HF_Model.register(config.__class__, model_class, exist_ok=True)
|
||||||
|
elif type(config) in cls.HF_Model._model_mapping.keys():
|
||||||
|
model_class = _get_model_class(config, cls.HF_Model._model_mapping)
|
||||||
|
|
||||||
|
resolved_archive_file, is_sharded = extract_local_archive_file(
|
||||||
|
pretrained_model_name_or_path,
|
||||||
|
subfolder,
|
||||||
|
variant)
|
||||||
|
|
||||||
|
if is_sharded:
|
||||||
|
resolved_archive_file, sharded_metadata = \
|
||||||
|
get_local_shard_files(pretrained_model_name_or_path,
|
||||||
|
resolved_archive_file,
|
||||||
|
subfolder=subfolder)
|
||||||
|
|
||||||
|
# set dtype to instantiate the model under:
|
||||||
|
# 1. If torch_dtype is not None, we use that dtype
|
||||||
|
# 2. If torch_dtype is "auto", we auto-detect dtype from the loaded state_dict,
|
||||||
|
# by checking its first weights entry that is of a floating type
|
||||||
|
# - we assume all floating dtype weights are of the same dtype
|
||||||
|
# we also may have config.torch_dtype available, but we won't rely on it till v5
|
||||||
|
dtype_orig = None
|
||||||
|
|
||||||
|
if torch_dtype is not None:
|
||||||
|
if isinstance(torch_dtype, str):
|
||||||
|
if torch_dtype == "auto":
|
||||||
|
if hasattr(config, "torch_dtype") and config.torch_dtype is not None:
|
||||||
|
torch_dtype = config.torch_dtype
|
||||||
|
|
||||||
|
else:
|
||||||
|
if is_sharded and "dtype" in sharded_metadata:
|
||||||
|
torch_dtype = sharded_metadata["dtype"]
|
||||||
|
else:
|
||||||
|
one_state_dict = load_state_dict(resolved_archive_file[0])
|
||||||
|
torch_dtype = get_state_dict_dtype(one_state_dict)
|
||||||
|
del one_state_dict # free CPU memory
|
||||||
|
else:
|
||||||
|
invalidInputError(False,
|
||||||
|
f'`torch_dtype` can be either `torch.dtype` or `"auto"`,'
|
||||||
|
'but received {torch_dtype}')
|
||||||
|
dtype_orig = model_class._set_default_torch_dtype(torch_dtype)
|
||||||
|
|
||||||
|
# Pretrained Model
|
||||||
|
_fast_init = kwargs.pop("_fast_init", True)
|
||||||
|
init_contexts = [no_init_weights(_enable=_fast_init)]
|
||||||
|
init_contexts.append(init_empty_weights())
|
||||||
|
|
||||||
|
if bigdl_lcmu_enabled:
|
||||||
|
with ContextManagers(init_contexts):
|
||||||
|
if config.architectures is not None and config.architectures[0] in \
|
||||||
|
["ChatGLMModel", "ChatGLMForConditionalGeneration"]:
|
||||||
|
|
||||||
|
"""
|
||||||
|
ChatGLMModel uses skip_init by default, which will force modules placed on cpu
|
||||||
|
if the device is not specified. This will further cause replaced linear
|
||||||
|
allocating memory on cpu.
|
||||||
|
"""
|
||||||
|
kwargs["device"] = "meta"
|
||||||
|
model = model_class(config, *model_args, **kwargs)
|
||||||
|
else:
|
||||||
|
model = model_class(config, *model_args, **kwargs)
|
||||||
|
|
||||||
|
# Loading args may differ based on their usage
|
||||||
|
quant_device = "meta" if bigdl_lcmu_enabled else "cpu"
|
||||||
|
logger.info(f"Converting model, it may takes up to several minutes ...")
|
||||||
|
try:
|
||||||
|
# for intel_npu_acceleration_library >= 1.1.0
|
||||||
|
from intel_npu_acceleration_library.quantization import quantize_model
|
||||||
|
from intel_npu_acceleration_library.compiler import create_npu_kernels
|
||||||
|
with torch.no_grad():
|
||||||
|
optimize_llm(model)
|
||||||
|
if qtype in ["sym_int8_rtn", "sym_int4_rtn"]:
|
||||||
|
cls.load_convert(qtype, model, quant_device, *model_args, **kwargs)
|
||||||
|
else:
|
||||||
|
if not qtype.is_floating_point:
|
||||||
|
model = quantize_model(model, qtype)
|
||||||
|
create_npu_kernels(model)
|
||||||
|
model = model.eval()
|
||||||
|
except ImportError as _e:
|
||||||
|
# for intel_npu_acceleration_library < 1.1.0
|
||||||
|
model = npu_lib.compile(model, qtype, False)
|
||||||
|
|
||||||
|
if is_sharded:
|
||||||
|
loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"]
|
||||||
|
else:
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
with open(os.path.join(pretrained_model_name_or_path,
|
||||||
|
"load_keys.json"), "r") as json_file:
|
||||||
|
loaded_data = json.load(json_file)
|
||||||
|
loaded_state_dict_keys = loaded_data["all_checkpoint_keys"]
|
||||||
|
|
||||||
|
# restore default dtype
|
||||||
|
if dtype_orig is not None:
|
||||||
|
torch.set_default_dtype(dtype_orig)
|
||||||
|
|
||||||
|
(
|
||||||
|
model,
|
||||||
|
missing_keys,
|
||||||
|
unexpected_keys,
|
||||||
|
mismatched_keys,
|
||||||
|
offload_index,
|
||||||
|
error_msgs,
|
||||||
|
) = model_class._load_pretrained_model(
|
||||||
|
model,
|
||||||
|
None,
|
||||||
|
loaded_state_dict_keys, # XXX: rename?
|
||||||
|
resolved_archive_file,
|
||||||
|
pretrained_model_name_or_path,
|
||||||
|
sharded_metadata=sharded_metadata,
|
||||||
|
_fast_init=False, # always false to avoid pre-init behaviors
|
||||||
|
low_cpu_mem_usage=bigdl_lcmu_enabled,
|
||||||
|
offload_folder=offload_folder,
|
||||||
|
offload_state_dict=offload_state_dict,
|
||||||
|
dtype=torch_dtype,
|
||||||
|
keep_in_fp32_modules=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
# make sure token embedding weights are still tied if needed
|
||||||
|
model.tie_weights()
|
||||||
|
|
||||||
|
# Set model in evaluation mode to deactivate DropOut modules by default
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
# If it is a model with generation capabilities, attempt to load the generation config
|
||||||
|
if model.can_generate():
|
||||||
|
try:
|
||||||
|
model.generation_config = GenerationConfig.from_pretrained(
|
||||||
|
pretrained_model_name_or_path,
|
||||||
|
subfolder=subfolder,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
except (OSError, TypeError):
|
||||||
|
pass
|
||||||
|
for param in model.parameters():
|
||||||
|
param.requires_grad_(False)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
class AutoModelForCausalLM(_BaseAutoModelClass):
|
class AutoModelForCausalLM(_BaseAutoModelClass):
|
||||||
|
|
|
||||||
|
|
@ -15,8 +15,7 @@
|
||||||
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import importlib
|
from ipex_llm.transformers.npu_models.linear import QuantizedLinear
|
||||||
from intel_npu_acceleration_library.nn import QuantizedLinear
|
|
||||||
|
|
||||||
|
|
||||||
def module_optimization(func) -> torch.nn.Module:
|
def module_optimization(func) -> torch.nn.Module:
|
||||||
|
|
@ -31,7 +30,7 @@ def module_optimization(func) -> torch.nn.Module:
|
||||||
torch.nn.Module: optimized module
|
torch.nn.Module: optimized module
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def wrapper(model: torch.nn.Module, qtype, *args, **kwargs):
|
def wrapper(model: torch.nn.Module, qtype, device, *args, **kwargs):
|
||||||
"""Recursively apply the optimization function.
|
"""Recursively apply the optimization function.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
@ -41,23 +40,23 @@ def module_optimization(func) -> torch.nn.Module:
|
||||||
|
|
||||||
"""
|
"""
|
||||||
for name, layer in model.named_children():
|
for name, layer in model.named_children():
|
||||||
new_layer = func(layer, qtype, *args, **kwargs)
|
new_layer = func(layer, qtype, device, *args, **kwargs)
|
||||||
if new_layer:
|
if new_layer:
|
||||||
model.add_module(name, new_layer)
|
model.add_module(name, new_layer)
|
||||||
wrapper(new_layer, qtype, *args, **kwargs)
|
wrapper(new_layer, qtype, device, *args, **kwargs)
|
||||||
else:
|
else:
|
||||||
wrapper(layer, qtype, *args, **kwargs)
|
wrapper(layer, qtype, device, *args, **kwargs)
|
||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
@module_optimization
|
@module_optimization
|
||||||
def replace_with_QuantizedLinear(layer, qtype):
|
def replace_with_QuantizedLinear(layer, qtype, device):
|
||||||
from ipex_llm.transformers.low_bit_linear import ggml_convert_qtype
|
from ipex_llm.transformers.low_bit_linear import ggml_convert_qtype
|
||||||
from ipex_llm.ggml.quantize import ggml_tensor_qtype
|
from ipex_llm.ggml.quantize import ggml_tensor_qtype
|
||||||
iqtype = ggml_tensor_qtype[qtype]
|
iqtype = ggml_tensor_qtype[qtype]
|
||||||
if isinstance(layer, torch.nn.Linear):
|
if isinstance(layer, torch.nn.Linear):
|
||||||
qweights, scale = ggml_convert_qtype(layer.weight.data, iqtype, 'cpu')
|
qweights, scale = ggml_convert_qtype(layer.weight.data, iqtype, device=device)
|
||||||
return QuantizedLinear(qweights, scale, layer.bias)
|
return QuantizedLinear(qweights, scale, layer.bias)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
192
python/llm/src/ipex_llm/transformers/npu_models/linear.py
Normal file
192
python/llm/src/ipex_llm/transformers/npu_models/linear.py
Normal file
|
|
@ -0,0 +1,192 @@
|
||||||
|
#
|
||||||
|
# Copyright 2016 The BigDL Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
# This file is adapted from
|
||||||
|
# https://github.com/intel/intel-npu-acceleration-library/blob/main/intel_npu_acceleration_library/nn/linear.py
|
||||||
|
|
||||||
|
#
|
||||||
|
# Copyright © 2024 Intel Corporation
|
||||||
|
# SPDX-License-Identifier: Apache 2.0
|
||||||
|
#
|
||||||
|
|
||||||
|
from intel_npu_acceleration_library.quantization import quantize_tensor, compress_to_i4
|
||||||
|
from intel_npu_acceleration_library.nn.autograd import AutogradMatMul
|
||||||
|
from intel_npu_acceleration_library.backend import run_matmul
|
||||||
|
from intel_npu_acceleration_library.dtypes import NPUDtype
|
||||||
|
from typing import Optional, Union
|
||||||
|
import torch
|
||||||
|
from torch.nn import Parameter
|
||||||
|
import uuid
|
||||||
|
import math
|
||||||
|
|
||||||
|
from ipex_llm.utils.common import invalidInputError
|
||||||
|
|
||||||
|
|
||||||
|
class Linear(torch.nn.Module):
|
||||||
|
"""Torch Linear operation NPU backend."""
|
||||||
|
|
||||||
|
def __init__(self, weight: torch.Tensor, bias: Optional[torch.Tensor] = None):
|
||||||
|
"""Initialize the Linear class.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
weight (torch.Tensor): Linear operation weight
|
||||||
|
bias (Optional[torch.Tensor], optional): Linear operation optional bias.
|
||||||
|
Defaults to None.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.weight = torch.nn.Parameter(weight)
|
||||||
|
self.bias = torch.nn.Parameter(bias) if isinstance(bias, torch.Tensor) else None
|
||||||
|
self.outC, self.inC = self.weight.shape
|
||||||
|
self.op_id = str(uuid.uuid4())
|
||||||
|
self._mm = AutogradMatMul.apply
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Torch module forward method.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (torch.Tensor): Input tensor
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: result
|
||||||
|
"""
|
||||||
|
if self.training:
|
||||||
|
out = self._mm(x, self.weight, None)
|
||||||
|
else:
|
||||||
|
out = run_matmul(x, self.weight, None, self.op_id)
|
||||||
|
|
||||||
|
if self.bias is None:
|
||||||
|
return out
|
||||||
|
return out + self.bias
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def fromTorch(
|
||||||
|
layer: torch.nn.Linear, dtype: torch.dtype = torch.float16
|
||||||
|
) -> Union["Linear", "QuantizedLinear"]:
|
||||||
|
"""Generate a NPU Linear layer from a torch one.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
layer (torch.nn.Linear): the original torch.nn.Linear model to run on the NPU
|
||||||
|
dtype (torch.dtype): the desired datatype
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Union[Linear, QuantizedLinear]: A NPU linear layer
|
||||||
|
"""
|
||||||
|
if any(dim > 2**17 for dim in layer.weight.shape):
|
||||||
|
return layer
|
||||||
|
return Linear.fromTensor(layer.weight, getattr(layer, "bias", None), dtype)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def fromTensor(
|
||||||
|
weight: torch.Tensor,
|
||||||
|
bias: Optional[torch.Tensor],
|
||||||
|
dtype: torch.dtype = torch.float16,
|
||||||
|
) -> Union["Linear", "QuantizedLinear"]:
|
||||||
|
"""Generate a NPU Linear layer from a torch one.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
weight (torch.Tensor): the original weight tensor
|
||||||
|
bias (Optional[torch.Tensor]): the original bias tensor
|
||||||
|
dtype (torch.dtype): the desired datatype
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: dtype not supported
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Union[Linear, QuantizedLinear]: A NPU linear layer
|
||||||
|
"""
|
||||||
|
if dtype.is_floating_point:
|
||||||
|
if bias is None:
|
||||||
|
return Linear(weight.to(dtype), None)
|
||||||
|
return Linear(weight.to(dtype), bias.to(dtype))
|
||||||
|
elif isinstance(dtype, NPUDtype):
|
||||||
|
weights_quant, scale = quantize_tensor(weight, (dtype.min, dtype.max))
|
||||||
|
if dtype.bits == 4:
|
||||||
|
weights_quant = compress_to_i4(weights_quant)
|
||||||
|
return QuantizedLinear(weights_quant, scale, bias)
|
||||||
|
elif dtype == torch.int8:
|
||||||
|
weights_quant, scale = quantize_tensor(weight)
|
||||||
|
return QuantizedLinear(weights_quant, scale, bias)
|
||||||
|
else:
|
||||||
|
invalidInputError(False,
|
||||||
|
f"NPU do not support yet the requeste datatype: {dtype}")
|
||||||
|
|
||||||
|
|
||||||
|
class QuantizedLinear(torch.nn.Module):
|
||||||
|
"""Torch Quantized Linear operation NPU backend."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
weight: torch.Tensor,
|
||||||
|
scale: torch.Tensor,
|
||||||
|
bias: Optional[torch.Tensor] = None,
|
||||||
|
):
|
||||||
|
"""Initialize the QuantizedLinear class.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
weight (torch.Tensor): Linear operation weight
|
||||||
|
scale (torch.Tensor): Quantization scale
|
||||||
|
bias (Optional[torch.Tensor], optional): Linear operation optional bias.
|
||||||
|
Defaults to None.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: Quantized weight must be in torch.int8 format
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.weight = Parameter(weight, requires_grad=False)
|
||||||
|
if self.weight.dtype not in (torch.int8, torch.uint8):
|
||||||
|
invalidInputError(
|
||||||
|
False,
|
||||||
|
(
|
||||||
|
f"Quantized weight must be in torch.(u)int8"
|
||||||
|
" dtype instead of {self.weight.dtype}"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.outC, self.inC = self.weight.shape
|
||||||
|
if self.weight.dtype == torch.uint8:
|
||||||
|
# In case is Int4 we need to double the input channels because weights are compressed
|
||||||
|
self.inC *= 2
|
||||||
|
self.scale = Parameter(scale * math.sqrt(self.inC), requires_grad=False)
|
||||||
|
self.bias = bias
|
||||||
|
self.op_id = str(uuid.uuid4())
|
||||||
|
self._mm = AutogradMatMul.apply
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Torch module forward method.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (torch.Tensor): Input tensor
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: Training is not supported for QuantizedLinear layer.
|
||||||
|
Use `.eval()` to do inference only
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: result
|
||||||
|
"""
|
||||||
|
if self.training:
|
||||||
|
invalidInputError(
|
||||||
|
False,
|
||||||
|
(
|
||||||
|
"Training is not supported for QuantizedLinear layer."
|
||||||
|
"Use `.eval()` to do inference only"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
out = run_matmul(x, self.weight.data, self.scale.data, self.op_id)
|
||||||
|
|
||||||
|
if self.bias is None:
|
||||||
|
return out
|
||||||
|
return out + self.bias
|
||||||
Loading…
Reference in a new issue