Add initial save/load low bit support for NPU(now only fp16 is supported) (#11359)

This commit is contained in:
Yishuo Wang 2024-06-20 10:49:39 +08:00 committed by GitHub
parent ed4c439497
commit a5e7d93242
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -14,6 +14,8 @@
# limitations under the License.
#
import os
import types
import warnings
import torch
import transformers
@ -24,6 +26,7 @@ from transformers.dynamic_module_utils import get_imports
import intel_npu_acceleration_library as npu_lib
from ipex_llm.utils.common.log4Error import invalidInputError
from ipex_llm.transformers.utils import logger
def patch_flash_attn_import(filename: str) -> List[str]:
@ -107,10 +110,43 @@ class _BaseAutoModelClass:
ignore_argument(kwargs, "pipeline_parallel_stages")
model = cls.HF_Model.from_pretrained(*args, **kwargs)
logger.info(f"Converting model, it may takes up to several minutes ...")
model = npu_lib.compile(model, qtype, False)
# add save_low_bit to pretrained model dynamically
model.save_low_bit = types.MethodType(cls.save_low_bit, model)
return model
@staticmethod
def save_low_bit(self, model_dir: str, *args, **kwargs):
os.makedirs(model_dir, exist_ok=True)
model_name = "pytorch_npu_model.pt"
model_path = os.path.join(model_dir, model_name)
del self.save_low_bit # workaround a bug
torch.save(self, model_path)
@staticmethod
@patch("transformers.dynamic_module_utils.get_imports", patch_flash_attn_import)
def load_low_bit(model_dir: str, *args, **kwargs):
if kwargs.pop('torch_dtype', None) not in [None, 'auto', torch.float]:
warnings.warn("`torch_dtype` will be ignored, `torch.float` will be used")
# ignore following arguments
ignore_argument(kwargs, "model_hub")
ignore_argument(kwargs, "lightweight_bmm")
ignore_argument(kwargs, "cpu_embedding")
ignore_argument(kwargs, "embedding_qtype")
ignore_argument(kwargs, "optimize_model")
ignore_argument(kwargs, "modules_to_not_convert")
ignore_argument(kwargs, "speculative")
ignore_argument(kwargs, "pipeline_parallel_stages")
model_name = "pytorch_npu_model.pt"
model_path = os.path.join(model_dir, model_name)
return torch.load(model_path)
class AutoModelForCausalLM(_BaseAutoModelClass):
HF_Model = transformers.AutoModelForCausalLM