Add initial save/load low bit support for NPU(now only fp16 is supported) (#11359)
This commit is contained in:
		
							parent
							
								
									ed4c439497
								
							
						
					
					
						commit
						a5e7d93242
					
				
					 1 changed files with 36 additions and 0 deletions
				
			
		| 
						 | 
				
			
			@ -14,6 +14,8 @@
 | 
			
		|||
# limitations under the License.
 | 
			
		||||
#
 | 
			
		||||
 | 
			
		||||
import os
 | 
			
		||||
import types
 | 
			
		||||
import warnings
 | 
			
		||||
import torch
 | 
			
		||||
import transformers
 | 
			
		||||
| 
						 | 
				
			
			@ -24,6 +26,7 @@ from transformers.dynamic_module_utils import get_imports
 | 
			
		|||
import intel_npu_acceleration_library as npu_lib
 | 
			
		||||
 | 
			
		||||
from ipex_llm.utils.common.log4Error import invalidInputError
 | 
			
		||||
from ipex_llm.transformers.utils import logger
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def patch_flash_attn_import(filename: str) -> List[str]:
 | 
			
		||||
| 
						 | 
				
			
			@ -107,10 +110,43 @@ class _BaseAutoModelClass:
 | 
			
		|||
        ignore_argument(kwargs, "pipeline_parallel_stages")
 | 
			
		||||
 | 
			
		||||
        model = cls.HF_Model.from_pretrained(*args, **kwargs)
 | 
			
		||||
 | 
			
		||||
        logger.info(f"Converting model, it may takes up to several minutes ...")
 | 
			
		||||
        model = npu_lib.compile(model, qtype, False)
 | 
			
		||||
 | 
			
		||||
        # add save_low_bit to pretrained model dynamically
 | 
			
		||||
        model.save_low_bit = types.MethodType(cls.save_low_bit, model)
 | 
			
		||||
 | 
			
		||||
        return model
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def save_low_bit(self, model_dir: str, *args, **kwargs):
 | 
			
		||||
        os.makedirs(model_dir, exist_ok=True)
 | 
			
		||||
        model_name = "pytorch_npu_model.pt"
 | 
			
		||||
        model_path = os.path.join(model_dir, model_name)
 | 
			
		||||
        del self.save_low_bit   # workaround a bug
 | 
			
		||||
        torch.save(self, model_path)
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    @patch("transformers.dynamic_module_utils.get_imports", patch_flash_attn_import)
 | 
			
		||||
    def load_low_bit(model_dir: str, *args, **kwargs):
 | 
			
		||||
        if kwargs.pop('torch_dtype', None) not in [None, 'auto', torch.float]:
 | 
			
		||||
            warnings.warn("`torch_dtype` will be ignored, `torch.float` will be used")
 | 
			
		||||
 | 
			
		||||
        # ignore following arguments
 | 
			
		||||
        ignore_argument(kwargs, "model_hub")
 | 
			
		||||
        ignore_argument(kwargs, "lightweight_bmm")
 | 
			
		||||
        ignore_argument(kwargs, "cpu_embedding")
 | 
			
		||||
        ignore_argument(kwargs, "embedding_qtype")
 | 
			
		||||
        ignore_argument(kwargs, "optimize_model")
 | 
			
		||||
        ignore_argument(kwargs, "modules_to_not_convert")
 | 
			
		||||
        ignore_argument(kwargs, "speculative")
 | 
			
		||||
        ignore_argument(kwargs, "pipeline_parallel_stages")
 | 
			
		||||
 | 
			
		||||
        model_name = "pytorch_npu_model.pt"
 | 
			
		||||
        model_path = os.path.join(model_dir, model_name)
 | 
			
		||||
        return torch.load(model_path)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class AutoModelForCausalLM(_BaseAutoModelClass):
 | 
			
		||||
    HF_Model = transformers.AutoModelForCausalLM
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue