Add initial save/load low bit support for NPU(now only fp16 is supported) (#11359)
This commit is contained in:
parent
ed4c439497
commit
a5e7d93242
1 changed files with 36 additions and 0 deletions
|
|
@ -14,6 +14,8 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
|
||||||
|
import os
|
||||||
|
import types
|
||||||
import warnings
|
import warnings
|
||||||
import torch
|
import torch
|
||||||
import transformers
|
import transformers
|
||||||
|
|
@ -24,6 +26,7 @@ from transformers.dynamic_module_utils import get_imports
|
||||||
import intel_npu_acceleration_library as npu_lib
|
import intel_npu_acceleration_library as npu_lib
|
||||||
|
|
||||||
from ipex_llm.utils.common.log4Error import invalidInputError
|
from ipex_llm.utils.common.log4Error import invalidInputError
|
||||||
|
from ipex_llm.transformers.utils import logger
|
||||||
|
|
||||||
|
|
||||||
def patch_flash_attn_import(filename: str) -> List[str]:
|
def patch_flash_attn_import(filename: str) -> List[str]:
|
||||||
|
|
@ -107,10 +110,43 @@ class _BaseAutoModelClass:
|
||||||
ignore_argument(kwargs, "pipeline_parallel_stages")
|
ignore_argument(kwargs, "pipeline_parallel_stages")
|
||||||
|
|
||||||
model = cls.HF_Model.from_pretrained(*args, **kwargs)
|
model = cls.HF_Model.from_pretrained(*args, **kwargs)
|
||||||
|
|
||||||
|
logger.info(f"Converting model, it may takes up to several minutes ...")
|
||||||
model = npu_lib.compile(model, qtype, False)
|
model = npu_lib.compile(model, qtype, False)
|
||||||
|
|
||||||
|
# add save_low_bit to pretrained model dynamically
|
||||||
|
model.save_low_bit = types.MethodType(cls.save_low_bit, model)
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def save_low_bit(self, model_dir: str, *args, **kwargs):
|
||||||
|
os.makedirs(model_dir, exist_ok=True)
|
||||||
|
model_name = "pytorch_npu_model.pt"
|
||||||
|
model_path = os.path.join(model_dir, model_name)
|
||||||
|
del self.save_low_bit # workaround a bug
|
||||||
|
torch.save(self, model_path)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@patch("transformers.dynamic_module_utils.get_imports", patch_flash_attn_import)
|
||||||
|
def load_low_bit(model_dir: str, *args, **kwargs):
|
||||||
|
if kwargs.pop('torch_dtype', None) not in [None, 'auto', torch.float]:
|
||||||
|
warnings.warn("`torch_dtype` will be ignored, `torch.float` will be used")
|
||||||
|
|
||||||
|
# ignore following arguments
|
||||||
|
ignore_argument(kwargs, "model_hub")
|
||||||
|
ignore_argument(kwargs, "lightweight_bmm")
|
||||||
|
ignore_argument(kwargs, "cpu_embedding")
|
||||||
|
ignore_argument(kwargs, "embedding_qtype")
|
||||||
|
ignore_argument(kwargs, "optimize_model")
|
||||||
|
ignore_argument(kwargs, "modules_to_not_convert")
|
||||||
|
ignore_argument(kwargs, "speculative")
|
||||||
|
ignore_argument(kwargs, "pipeline_parallel_stages")
|
||||||
|
|
||||||
|
model_name = "pytorch_npu_model.pt"
|
||||||
|
model_path = os.path.join(model_dir, model_name)
|
||||||
|
return torch.load(model_path)
|
||||||
|
|
||||||
|
|
||||||
class AutoModelForCausalLM(_BaseAutoModelClass):
|
class AutoModelForCausalLM(_BaseAutoModelClass):
|
||||||
HF_Model = transformers.AutoModelForCausalLM
|
HF_Model = transformers.AutoModelForCausalLM
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue