add fp16 NPU Linear support and fix intel_npu_acceleration_library version 1.0 support (#11352)
This commit is contained in:
		
							parent
							
								
									c44b1942ed
								
							
						
					
					
						commit
						ae7b662ed2
					
				
					 1 changed files with 28 additions and 17 deletions
				
			
		| 
						 | 
				
			
			@ -22,7 +22,6 @@ from unittest.mock import patch
 | 
			
		|||
from transformers.dynamic_module_utils import get_imports
 | 
			
		||||
 | 
			
		||||
import intel_npu_acceleration_library as npu_lib
 | 
			
		||||
from intel_npu_acceleration_library.dtypes import int8, int4
 | 
			
		||||
 | 
			
		||||
from ipex_llm.utils.common.log4Error import invalidInputError
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -55,7 +54,8 @@ class _BaseAutoModelClass:
 | 
			
		|||
        The loaded model will run supported OPs on NPU, then run other OPs on CPU.
 | 
			
		||||
 | 
			
		||||
        Three new arguments are added to extend Hugging Face's from_pretrained method as follows:
 | 
			
		||||
        :param load_in_low_bit: str value, options are ``'sym_int4'``, ``'sym_int8'``, ``'fp32'``.
 | 
			
		||||
        :param load_in_low_bit: str value, options are ``'sym_int4'``, ``'sym_int8'``,
 | 
			
		||||
                                ``'fp16'``, ``'fp32'``.
 | 
			
		||||
                                Relevant low bit optimizations will be applied to the model.
 | 
			
		||||
        :return: a model instance
 | 
			
		||||
        """
 | 
			
		||||
| 
						 | 
				
			
			@ -63,20 +63,31 @@ class _BaseAutoModelClass:
 | 
			
		|||
            warnings.warn("`device_map` will be ignored")
 | 
			
		||||
        kwargs['device_map'] = 'cpu'
 | 
			
		||||
 | 
			
		||||
        low_bit = kwargs.pop('load_in_low_bit', None)
 | 
			
		||||
        low_bit_to_dtype_map = {
 | 
			
		||||
        if kwargs.get('torch_dtype', None) not in [None, 'auto', torch.float]:
 | 
			
		||||
            warnings.warn("`torch_dtype` will be ignored, `torch.float` will be used")
 | 
			
		||||
        kwargs['torch_dtype'] = torch.float
 | 
			
		||||
 | 
			
		||||
        low_bit = kwargs.pop('load_in_low_bit', 'fp32')
 | 
			
		||||
        try:
 | 
			
		||||
            # for intel_npu_acceleration_library >= 1.1.0
 | 
			
		||||
            from intel_npu_acceleration_library.dtypes import int8, int4
 | 
			
		||||
            qtype_map = {
 | 
			
		||||
                'sym_int4': int4,
 | 
			
		||||
                'sym_int8': int8,
 | 
			
		||||
                'fp16': torch.half,
 | 
			
		||||
                'fp32': torch.float,
 | 
			
		||||
            }
 | 
			
		||||
        if low_bit is not None:
 | 
			
		||||
            dtype = low_bit_to_dtype_map[low_bit]
 | 
			
		||||
        else:
 | 
			
		||||
            dtype = kwargs.get('torch_dtype', torch.float)
 | 
			
		||||
            dtype = torch.float if dtype == 'auto' else dtype
 | 
			
		||||
        invalidInputError(dtype in low_bit_to_dtype_map.values(),
 | 
			
		||||
                          f"unsupported dtype: {dtype}, "
 | 
			
		||||
                          "only `sym_int4`, `sym_int8`, `fp32` are supported")
 | 
			
		||||
        except ImportError as _e:
 | 
			
		||||
            # for intel_npu_acceleration_library < 1.1.0
 | 
			
		||||
            qtype_map = {
 | 
			
		||||
                'sym_int8': torch.int8,
 | 
			
		||||
                'fp16': torch.half,
 | 
			
		||||
                'fp32': torch.float,
 | 
			
		||||
            }
 | 
			
		||||
        invalidInputError(low_bit in qtype_map.keys(),
 | 
			
		||||
                          f"unsupported low_bit: {low_bit}, "
 | 
			
		||||
                          f"only {list(qtype_map.keys())} are supported")
 | 
			
		||||
        qtype = qtype_map[low_bit]
 | 
			
		||||
 | 
			
		||||
        kwargs["low_cpu_mem_usage"] = True
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -96,7 +107,7 @@ class _BaseAutoModelClass:
 | 
			
		|||
        ignore_argument(kwargs, "pipeline_parallel_stages")
 | 
			
		||||
 | 
			
		||||
        model = cls.HF_Model.from_pretrained(*args, **kwargs)
 | 
			
		||||
        model = npu_lib.compile(model, dtype, False)
 | 
			
		||||
        model = npu_lib.compile(model, qtype, False)
 | 
			
		||||
 | 
			
		||||
        return model
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue