diff --git a/python/llm/src/ipex_llm/transformers/npu_model.py b/python/llm/src/ipex_llm/transformers/npu_model.py index 78b83f88..771dff62 100644 --- a/python/llm/src/ipex_llm/transformers/npu_model.py +++ b/python/llm/src/ipex_llm/transformers/npu_model.py @@ -14,6 +14,8 @@ # limitations under the License. # +import os +import types import warnings import torch import transformers @@ -24,6 +26,7 @@ from transformers.dynamic_module_utils import get_imports import intel_npu_acceleration_library as npu_lib from ipex_llm.utils.common.log4Error import invalidInputError +from ipex_llm.transformers.utils import logger def patch_flash_attn_import(filename: str) -> List[str]: @@ -107,10 +110,43 @@ class _BaseAutoModelClass: ignore_argument(kwargs, "pipeline_parallel_stages") model = cls.HF_Model.from_pretrained(*args, **kwargs) + + logger.info(f"Converting model, it may takes up to several minutes ...") model = npu_lib.compile(model, qtype, False) + # add save_low_bit to pretrained model dynamically + model.save_low_bit = types.MethodType(cls.save_low_bit, model) + return model + @staticmethod + def save_low_bit(self, model_dir: str, *args, **kwargs): + os.makedirs(model_dir, exist_ok=True) + model_name = "pytorch_npu_model.pt" + model_path = os.path.join(model_dir, model_name) + del self.save_low_bit # workaround a bug + torch.save(self, model_path) + + @staticmethod + @patch("transformers.dynamic_module_utils.get_imports", patch_flash_attn_import) + def load_low_bit(model_dir: str, *args, **kwargs): + if kwargs.pop('torch_dtype', None) not in [None, 'auto', torch.float]: + warnings.warn("`torch_dtype` will be ignored, `torch.float` will be used") + + # ignore following arguments + ignore_argument(kwargs, "model_hub") + ignore_argument(kwargs, "lightweight_bmm") + ignore_argument(kwargs, "cpu_embedding") + ignore_argument(kwargs, "embedding_qtype") + ignore_argument(kwargs, "optimize_model") + ignore_argument(kwargs, "modules_to_not_convert") + ignore_argument(kwargs, "speculative") + ignore_argument(kwargs, "pipeline_parallel_stages") + + model_name = "pytorch_npu_model.pt" + model_path = os.path.join(model_dir, model_name) + return torch.load(model_path) + class AutoModelForCausalLM(_BaseAutoModelClass): HF_Model = transformers.AutoModelForCausalLM