ipex-llm/python/llm/src/ipex_llm/transformers/npu_model.py

415 lines
17 KiB
Python

#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
import copy
import types
import warnings
import torch
import transformers
from typing import List
from unittest.mock import patch
from transformers.dynamic_module_utils import get_imports
from transformers.configuration_utils import PretrainedConfig
from ipex_llm.utils.common.log4Error import invalidInputError
from ipex_llm.transformers.utils import logger
from ipex_llm.transformers.npu_models.convert import optimize_llm, optimize_llm_post
def patch_flash_attn_import(filename: str) -> List[str]:
"""Work around for https://huggingface.co/microsoft/phi-1_5/discussions/72."""
imports = get_imports(filename)
if "flash_attn" in imports:
imports.remove("flash_attn")
return imports
def ignore_argument(kwargs: dict, key: 'str'):
arg = kwargs.pop(key, None)
if arg is not None:
warnings.warn(f"argument `{key}={arg}` will be ignored")
def save_low_bit(self, model_dir: str, *args, **kwargs):
origin_device = self.device
kwargs['safe_serialization'] = False
self.save_pretrained(model_dir, *args, **kwargs)
import json
import os
# We conveniently save all the keys of the model to have them on hand,
# so that when using 'low_cpumem load',
# it's not necessary to load the entire model to extract its keys
# and we can avoid gc not triggered potentially.
load_keys = {"all_checkpoint_keys": list(self.state_dict().keys())}
with open(os.path.join(model_dir, "load_keys.json"), "w") as json_file:
json.dump(load_keys, json_file)
if origin_device != 'cpu':
self.to(origin_device)
class _BaseAutoModelClass:
HF_MODEL = None
@classmethod
@patch("transformers.dynamic_module_utils.get_imports", patch_flash_attn_import)
def from_pretrained(cls,
*args,
**kwargs):
"""
Load a model from a directory or the HF Hub. Use load_in_low_bit parameter to convert
model to low-bit format, like int4 and int8.
The loaded model will run supported OPs on NPU, then run other OPs on CPU.
Three new arguments are added to extend Hugging Face's from_pretrained method as follows:
:param load_in_low_bit: str value, options are ``'sym_int4'``, ``'sym_int8'``,
``'fp16'``, ``'fp32'``.
Relevant low bit optimizations will be applied to the model.
:return: a model instance
"""
if kwargs.get('device_map', None) not in [None, 'cpu', 'auto']:
warnings.warn("`device_map` will be ignored")
kwargs['device_map'] = 'cpu'
if kwargs.get('torch_dtype', None) not in [None, 'auto', torch.float, torch.float16]:
warnings.warn("`torch_dtype` will be ignored, `torch.float` will be used")
kwargs['torch_dtype'] = torch.float
low_bit = kwargs.pop('load_in_low_bit', 'sym_int4')
qtype_map = {
'sym_int4': "sym_int4_rtn",
'sym_int8': "sym_int8_rtn",
}
invalidInputError(low_bit in qtype_map.keys(),
f"unsupported low_bit: {low_bit}, "
f"only {list(qtype_map.keys())} are supported")
qtype = qtype_map[low_bit]
kwargs["low_cpu_mem_usage"] = True
# ignore following arguments
ignore_argument(kwargs, "model_hub")
ignore_argument(kwargs, "lightweight_bmm")
ignore_argument(kwargs, "load_in_4bit")
ignore_argument(kwargs, "load_in_8bit")
ignore_argument(kwargs, "imatrix")
ignore_argument(kwargs, "mixed_precision")
ignore_argument(kwargs, "cpu_embedding")
ignore_argument(kwargs, "embedding_qtype")
ignore_argument(kwargs, "optimize_model")
ignore_argument(kwargs, "modules_to_not_convert")
ignore_argument(kwargs, "quantization_config")
ignore_argument(kwargs, "speculative")
pipeline_parallel_stages = kwargs.pop("pipeline_parallel_stages", 1)
_args = copy.deepcopy(args)
_kwargs = copy.deepcopy(kwargs)
try:
# To handle the input CUDA setting (such as 'device_map={"":0}'), ignore it
kwargs.pop('device_map', None)
model = cls.HF_Model.from_pretrained(*args, **kwargs)
except NotImplementedError:
logger.info("Failed to load models with `low_cpu_mem_usage` specified, "
"will fall to traditional load method with higher memory consumption.")
_kwargs["low_cpu_mem_usage"] = False
model = cls.HF_Model.from_pretrained(*_args, **_kwargs)
model.config.update({"bigdl_lcmu_enabled": False})
logger.info(f"Converting model, it may takes up to several minutes ...")
if pipeline_parallel_stages > 1:
invalidInputError(torch.distributed.get_world_size() == pipeline_parallel_stages,
"Please make sure world size is same as `pipeline_parallel_stages`")
kwargs['torch_dtype'] = torch.float16
from .npu_models.pipeline_parallel import pipeline_parallel, pipeline_parallel_generate
model = pipeline_parallel(model, pipeline_parallel_stages,
kwargs["torch_dtype"], device="cpu")
# add pipeline_parallel_generate to pretrained model dynamically
model.pipeline_parallel_generate = types.MethodType(pipeline_parallel_generate,
model)
from intel_npu_acceleration_library.compiler import create_npu_kernels
with torch.no_grad():
optimize_llm(model)
if pipeline_parallel_stages == 1:
cls.load_convert(qtype, model, 'cpu', *args, **kwargs)
create_npu_kernels(model)
else:
cls.load_convert(qtype, model.model, 'cpu', *args, **kwargs)
create_npu_kernels(model.model)
optimize_llm_post(model)
model = model.eval()
logger.info(f"Finish to convert model")
model.config.update({"bigdl_transformers_low_bit": qtype})
# add save_low_bit to pretrained model dynamically
model.save_low_bit = types.MethodType(save_low_bit, model)
return model
@classmethod
def load_convert(cls, q_k, optimize_model, device, *arg, **kwarg):
from ipex_llm.transformers.npu_models.convert import replace_with_QuantizedLinear
replace_with_QuantizedLinear(optimize_model, q_k, device=device)
@classmethod
@patch("transformers.dynamic_module_utils.get_imports", patch_flash_attn_import)
def load_low_bit(cls, pretrained_model_name_or_path: str, *model_args, **kwargs):
if kwargs.pop('torch_dtype', None) not in [None, 'auto', torch.float]:
warnings.warn("`torch_dtype` will be ignored, `torch.float` will be used")
# ignore following arguments
ignore_argument(kwargs, "model_hub")
ignore_argument(kwargs, "lightweight_bmm")
ignore_argument(kwargs, "cpu_embedding")
ignore_argument(kwargs, "embedding_qtype")
ignore_argument(kwargs, "optimize_model")
ignore_argument(kwargs, "modules_to_not_convert")
ignore_argument(kwargs, "speculative")
ignore_argument(kwargs, "pipeline_parallel_stages")
from transformers.models.auto.configuration_auto import AutoConfig
from transformers.modeling_utils import no_init_weights, get_state_dict_dtype
from transformers.dynamic_module_utils import resolve_trust_remote_code, \
get_class_from_dynamic_module
from transformers.models.auto.auto_factory import _get_model_class
from transformers.utils.generic import ContextManagers
from transformers.generation.configuration_utils import GenerationConfig
from ipex_llm.transformers.utils import extract_local_archive_file, get_local_shard_files, \
load_state_dict
from accelerate.big_modeling import init_empty_weights
trust_remote_code = kwargs.pop("trust_remote_code", None)
kwargs_orig = copy.deepcopy(kwargs)
config, kwargs = AutoConfig.from_pretrained(
pretrained_model_name_or_path,
return_unused_kwargs=True,
trust_remote_code=trust_remote_code,
**kwargs,
)
# if torch_dtype=auto was passed here, ensure to pass it on
if kwargs_orig.get("torch_dtype", None) == "auto":
kwargs["torch_dtype"] = "auto"
# Maybe needed when extract_local_archive_file
subfolder = kwargs.get("subfolder", "")
variant = kwargs.get("variant", None)
offload_folder = kwargs.pop("offload_folder", None)
offload_state_dict = kwargs.pop("offload_state_dict", False)
torch_dtype = kwargs.pop("torch_dtype", "auto")
sharded_metadata = None
config_dict, _ = PretrainedConfig.get_config_dict(pretrained_model_name_or_path)
qtype = config_dict.pop("bigdl_transformers_low_bit", False)
bigdl_lcmu_enabled = config_dict.pop("bigdl_lcmu_enabled", True)
invalidInputError(qtype,
"Detect this model is not a low-bit model, Please use from_pretrained"
" with load_in_4bit or load_in_low_bit to get a low-bit model , and "
" serialize the model using save_low_bit first.")
invalidInputError(qtype in ["sym_int8_rtn", "sym_int4_rtn"],
f"Unknown bigdl_transformers_low_bit value: {qtype},"
f" expected: sym_int4, asym_int4, sym_int5, asym_int5 or sym_int8.")
has_remote_code = hasattr(config, "auto_map") and cls.HF_Model.__name__ in config.auto_map
has_local_code = type(config) in cls.HF_Model._model_mapping.keys()
trust_remote_code = resolve_trust_remote_code(
trust_remote_code, pretrained_model_name_or_path, has_local_code, has_remote_code
)
if has_remote_code and trust_remote_code:
class_ref = config.auto_map[cls.HF_Model.__name__]
model_class = get_class_from_dynamic_module(
class_ref, pretrained_model_name_or_path, **kwargs
)
if os.path.isdir(pretrained_model_name_or_path):
model_class.register_for_auto_class(cls.HF_Model.__name__)
else:
cls.HF_Model.register(config.__class__, model_class, exist_ok=True)
elif type(config) in cls.HF_Model._model_mapping.keys():
model_class = _get_model_class(config, cls.HF_Model._model_mapping)
resolved_archive_file, is_sharded = extract_local_archive_file(
pretrained_model_name_or_path,
subfolder,
variant)
if is_sharded:
resolved_archive_file, sharded_metadata = \
get_local_shard_files(pretrained_model_name_or_path,
resolved_archive_file,
subfolder=subfolder)
# set dtype to instantiate the model under:
# 1. If torch_dtype is not None, we use that dtype
# 2. If torch_dtype is "auto", we auto-detect dtype from the loaded state_dict,
# by checking its first weights entry that is of a floating type
# - we assume all floating dtype weights are of the same dtype
# we also may have config.torch_dtype available, but we won't rely on it till v5
dtype_orig = None
if torch_dtype is not None:
if isinstance(torch_dtype, str):
if torch_dtype == "auto":
if hasattr(config, "torch_dtype") and config.torch_dtype is not None:
torch_dtype = config.torch_dtype
else:
if is_sharded and "dtype" in sharded_metadata:
torch_dtype = sharded_metadata["dtype"]
else:
one_state_dict = load_state_dict(resolved_archive_file[0])
torch_dtype = get_state_dict_dtype(one_state_dict)
del one_state_dict # free CPU memory
else:
invalidInputError(False,
f'`torch_dtype` can be either `torch.dtype` or `"auto"`,'
'but received {torch_dtype}')
dtype_orig = model_class._set_default_torch_dtype(torch_dtype)
# Pretrained Model
_fast_init = kwargs.pop("_fast_init", True)
init_contexts = [no_init_weights(_enable=_fast_init)]
init_contexts.append(init_empty_weights())
if bigdl_lcmu_enabled:
with ContextManagers(init_contexts):
if config.architectures is not None and config.architectures[0] in \
["ChatGLMModel", "ChatGLMForConditionalGeneration"]:
"""
ChatGLMModel uses skip_init by default, which will force modules placed on cpu
if the device is not specified. This will further cause replaced linear
allocating memory on cpu.
"""
kwargs["device"] = "meta"
model = model_class(config, *model_args, **kwargs)
else:
model = model_class(config, *model_args, **kwargs)
# Loading args may differ based on their usage
quant_device = "meta" if bigdl_lcmu_enabled else "cpu"
logger.info(f"Converting model, it may takes up to several minutes ...")
from intel_npu_acceleration_library.compiler import create_npu_kernels
with torch.no_grad():
optimize_llm(model)
cls.load_convert(qtype, model, quant_device, *model_args, **kwargs)
create_npu_kernels(model)
model = model.eval()
if is_sharded:
loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"]
else:
import os
import json
with open(os.path.join(pretrained_model_name_or_path,
"load_keys.json"), "r") as json_file:
loaded_data = json.load(json_file)
loaded_state_dict_keys = loaded_data["all_checkpoint_keys"]
# restore default dtype
if dtype_orig is not None:
torch.set_default_dtype(dtype_orig)
(
model,
missing_keys,
unexpected_keys,
mismatched_keys,
offload_index,
error_msgs,
) = model_class._load_pretrained_model(
model,
None,
loaded_state_dict_keys, # XXX: rename?
resolved_archive_file,
pretrained_model_name_or_path,
sharded_metadata=sharded_metadata,
_fast_init=False, # always false to avoid pre-init behaviors
low_cpu_mem_usage=bigdl_lcmu_enabled,
offload_folder=offload_folder,
offload_state_dict=offload_state_dict,
dtype=torch_dtype,
keep_in_fp32_modules=[],
)
# make sure token embedding weights are still tied if needed
model.tie_weights()
# Set model in evaluation mode to deactivate DropOut modules by default
model.eval()
# If it is a model with generation capabilities, attempt to load the generation config
if model.can_generate():
try:
model.generation_config = GenerationConfig.from_pretrained(
pretrained_model_name_or_path,
subfolder=subfolder,
**kwargs,
)
except (OSError, TypeError):
pass
for param in model.parameters():
param.requires_grad_(False)
return model
class AutoModelForCausalLM(_BaseAutoModelClass):
HF_Model = transformers.AutoModelForCausalLM
class AutoModel(_BaseAutoModelClass):
HF_Model = transformers.AutoModel
class AutoModelForSpeechSeq2Seq(_BaseAutoModelClass):
HF_Model = transformers.AutoModelForSpeechSeq2Seq
class AutoModelForSeq2SeqLM(_BaseAutoModelClass):
HF_Model = transformers.AutoModelForSeq2SeqLM
class AutoModelForSequenceClassification(_BaseAutoModelClass):
HF_Model = transformers.AutoModelForSequenceClassification
class AutoModelForMaskedLM(_BaseAutoModelClass):
HF_Model = transformers.AutoModelForMaskedLM
class AutoModelForQuestionAnswering(_BaseAutoModelClass):
HF_Model = transformers.AutoModelForQuestionAnswering
class AutoModelForNextSentencePrediction(_BaseAutoModelClass):
HF_Model = transformers.AutoModelForNextSentencePrediction
class AutoModelForMultipleChoice(_BaseAutoModelClass):
HF_Model = transformers.AutoModelForMultipleChoice
class AutoModelForTokenClassification(_BaseAutoModelClass):
HF_Model = transformers.AutoModelForTokenClassification