Add initial save/load low bit support for NPU(now only fp16 is supported) (#11359)
This commit is contained in:
parent
ed4c439497
commit
a5e7d93242
1 changed files with 36 additions and 0 deletions
|
|
@ -14,6 +14,8 @@
|
|||
# limitations under the License.
|
||||
#
|
||||
|
||||
import os
|
||||
import types
|
||||
import warnings
|
||||
import torch
|
||||
import transformers
|
||||
|
|
@ -24,6 +26,7 @@ from transformers.dynamic_module_utils import get_imports
|
|||
import intel_npu_acceleration_library as npu_lib
|
||||
|
||||
from ipex_llm.utils.common.log4Error import invalidInputError
|
||||
from ipex_llm.transformers.utils import logger
|
||||
|
||||
|
||||
def patch_flash_attn_import(filename: str) -> List[str]:
|
||||
|
|
@ -107,10 +110,43 @@ class _BaseAutoModelClass:
|
|||
ignore_argument(kwargs, "pipeline_parallel_stages")
|
||||
|
||||
model = cls.HF_Model.from_pretrained(*args, **kwargs)
|
||||
|
||||
logger.info(f"Converting model, it may takes up to several minutes ...")
|
||||
model = npu_lib.compile(model, qtype, False)
|
||||
|
||||
# add save_low_bit to pretrained model dynamically
|
||||
model.save_low_bit = types.MethodType(cls.save_low_bit, model)
|
||||
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
def save_low_bit(self, model_dir: str, *args, **kwargs):
|
||||
os.makedirs(model_dir, exist_ok=True)
|
||||
model_name = "pytorch_npu_model.pt"
|
||||
model_path = os.path.join(model_dir, model_name)
|
||||
del self.save_low_bit # workaround a bug
|
||||
torch.save(self, model_path)
|
||||
|
||||
@staticmethod
|
||||
@patch("transformers.dynamic_module_utils.get_imports", patch_flash_attn_import)
|
||||
def load_low_bit(model_dir: str, *args, **kwargs):
|
||||
if kwargs.pop('torch_dtype', None) not in [None, 'auto', torch.float]:
|
||||
warnings.warn("`torch_dtype` will be ignored, `torch.float` will be used")
|
||||
|
||||
# ignore following arguments
|
||||
ignore_argument(kwargs, "model_hub")
|
||||
ignore_argument(kwargs, "lightweight_bmm")
|
||||
ignore_argument(kwargs, "cpu_embedding")
|
||||
ignore_argument(kwargs, "embedding_qtype")
|
||||
ignore_argument(kwargs, "optimize_model")
|
||||
ignore_argument(kwargs, "modules_to_not_convert")
|
||||
ignore_argument(kwargs, "speculative")
|
||||
ignore_argument(kwargs, "pipeline_parallel_stages")
|
||||
|
||||
model_name = "pytorch_npu_model.pt"
|
||||
model_path = os.path.join(model_dir, model_name)
|
||||
return torch.load(model_path)
|
||||
|
||||
|
||||
class AutoModelForCausalLM(_BaseAutoModelClass):
|
||||
HF_Model = transformers.AutoModelForCausalLM
|
||||
|
|
|
|||
Loading…
Reference in a new issue