parent
14ce058004
commit
f7e957aaf9
2 changed files with 27 additions and 51 deletions
|
|
@ -332,6 +332,11 @@ class _BaseAutoModelClass:
|
||||||
else:
|
else:
|
||||||
kwargs["pretraining_tp"] = 1
|
kwargs["pretraining_tp"] = 1
|
||||||
q_k = load_in_low_bit if load_in_low_bit else "sym_int4"
|
q_k = load_in_low_bit if load_in_low_bit else "sym_int4"
|
||||||
|
|
||||||
|
invalidInputError(q_k not in ["sym_int4_rtn", "sym_int8_rtn"],
|
||||||
|
f"The dtype {q_k} is specified for NPU"
|
||||||
|
"and cannot be used on CPU and GPU")
|
||||||
|
|
||||||
imatrix_file = kwargs.pop("imatrix", None)
|
imatrix_file = kwargs.pop("imatrix", None)
|
||||||
if q_k in ["gguf_iq2_xxs", "gguf_iq2_xs", "gguf_iq1_s"]:
|
if q_k in ["gguf_iq2_xxs", "gguf_iq2_xs", "gguf_iq1_s"]:
|
||||||
invalidInputError(imatrix_file is not None,
|
invalidInputError(imatrix_file is not None,
|
||||||
|
|
|
||||||
|
|
@ -25,8 +25,6 @@ from unittest.mock import patch
|
||||||
from transformers.dynamic_module_utils import get_imports
|
from transformers.dynamic_module_utils import get_imports
|
||||||
from transformers.configuration_utils import PretrainedConfig
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
|
|
||||||
import intel_npu_acceleration_library as npu_lib
|
|
||||||
|
|
||||||
from ipex_llm.utils.common.log4Error import invalidInputError
|
from ipex_llm.utils.common.log4Error import invalidInputError
|
||||||
from ipex_llm.transformers.utils import logger
|
from ipex_llm.transformers.utils import logger
|
||||||
from ipex_llm.transformers.npu_models.convert import optimize_llm
|
from ipex_llm.transformers.npu_models.convert import optimize_llm
|
||||||
|
|
@ -90,23 +88,12 @@ class _BaseAutoModelClass:
|
||||||
warnings.warn("`torch_dtype` will be ignored, `torch.float` will be used")
|
warnings.warn("`torch_dtype` will be ignored, `torch.float` will be used")
|
||||||
kwargs['torch_dtype'] = torch.float
|
kwargs['torch_dtype'] = torch.float
|
||||||
|
|
||||||
low_bit = kwargs.pop('load_in_low_bit', 'fp32')
|
low_bit = kwargs.pop('load_in_low_bit', 'sym_int4')
|
||||||
try:
|
qtype_map = {
|
||||||
# for intel_npu_acceleration_library >= 1.1.0
|
'sym_int4': "sym_int4_rtn",
|
||||||
from intel_npu_acceleration_library.dtypes import int8, int4
|
'sym_int8': "sym_int8_rtn",
|
||||||
qtype_map = {
|
}
|
||||||
'sym_int4': "sym_int4_rtn",
|
|
||||||
'sym_int8': "sym_int8_rtn",
|
|
||||||
'fp16': torch.half,
|
|
||||||
'fp32': torch.float,
|
|
||||||
}
|
|
||||||
except ImportError as _e:
|
|
||||||
# for intel_npu_acceleration_library < 1.1.0
|
|
||||||
qtype_map = {
|
|
||||||
'sym_int8': torch.int8,
|
|
||||||
'fp16': torch.half,
|
|
||||||
'fp32': torch.float,
|
|
||||||
}
|
|
||||||
invalidInputError(low_bit in qtype_map.keys(),
|
invalidInputError(low_bit in qtype_map.keys(),
|
||||||
f"unsupported low_bit: {low_bit}, "
|
f"unsupported low_bit: {low_bit}, "
|
||||||
f"only {list(qtype_map.keys())} are supported")
|
f"only {list(qtype_map.keys())} are supported")
|
||||||
|
|
@ -143,22 +130,15 @@ class _BaseAutoModelClass:
|
||||||
model.config.update({"bigdl_lcmu_enabled": False})
|
model.config.update({"bigdl_lcmu_enabled": False})
|
||||||
|
|
||||||
logger.info(f"Converting model, it may takes up to several minutes ...")
|
logger.info(f"Converting model, it may takes up to several minutes ...")
|
||||||
try:
|
|
||||||
# for intel_npu_acceleration_library >= 1.1.0
|
from intel_npu_acceleration_library.compiler import create_npu_kernels
|
||||||
from intel_npu_acceleration_library.quantization import quantize_model
|
with torch.no_grad():
|
||||||
from intel_npu_acceleration_library.compiler import create_npu_kernels
|
optimize_llm(model)
|
||||||
with torch.no_grad():
|
cls.load_convert(qtype, model, 'cpu', *args, **kwargs)
|
||||||
optimize_llm(model)
|
create_npu_kernels(model)
|
||||||
if qtype in ["sym_int8_rtn", "sym_int4_rtn"]:
|
|
||||||
cls.load_convert(qtype, model, 'cpu', *args, **kwargs)
|
model = model.eval()
|
||||||
else:
|
|
||||||
if not qtype.is_floating_point:
|
|
||||||
model = quantize_model(model, qtype)
|
|
||||||
create_npu_kernels(model)
|
|
||||||
model = model.eval()
|
|
||||||
except ImportError as _e:
|
|
||||||
# for intel_npu_acceleration_library < 1.1.0
|
|
||||||
model = npu_lib.compile(model, qtype, False)
|
|
||||||
logger.info(f"Finish to convert model")
|
logger.info(f"Finish to convert model")
|
||||||
|
|
||||||
model.config.update({"bigdl_transformers_low_bit": qtype})
|
model.config.update({"bigdl_transformers_low_bit": qtype})
|
||||||
|
|
@ -313,22 +293,13 @@ class _BaseAutoModelClass:
|
||||||
# Loading args may differ based on their usage
|
# Loading args may differ based on their usage
|
||||||
quant_device = "meta" if bigdl_lcmu_enabled else "cpu"
|
quant_device = "meta" if bigdl_lcmu_enabled else "cpu"
|
||||||
logger.info(f"Converting model, it may takes up to several minutes ...")
|
logger.info(f"Converting model, it may takes up to several minutes ...")
|
||||||
try:
|
from intel_npu_acceleration_library.compiler import create_npu_kernels
|
||||||
# for intel_npu_acceleration_library >= 1.1.0
|
with torch.no_grad():
|
||||||
from intel_npu_acceleration_library.quantization import quantize_model
|
optimize_llm(model)
|
||||||
from intel_npu_acceleration_library.compiler import create_npu_kernels
|
cls.load_convert(qtype, model, quant_device, *model_args, **kwargs)
|
||||||
with torch.no_grad():
|
create_npu_kernels(model)
|
||||||
optimize_llm(model)
|
|
||||||
if qtype in ["sym_int8_rtn", "sym_int4_rtn"]:
|
model = model.eval()
|
||||||
cls.load_convert(qtype, model, quant_device, *model_args, **kwargs)
|
|
||||||
else:
|
|
||||||
if not qtype.is_floating_point:
|
|
||||||
model = quantize_model(model, qtype)
|
|
||||||
create_npu_kernels(model)
|
|
||||||
model = model.eval()
|
|
||||||
except ImportError as _e:
|
|
||||||
# for intel_npu_acceleration_library < 1.1.0
|
|
||||||
model = npu_lib.compile(model, qtype, False)
|
|
||||||
|
|
||||||
if is_sharded:
|
if is_sharded:
|
||||||
loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"]
|
loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"]
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue