add initial support for intel npu acceleration library (#11347)
This commit is contained in:
		
							parent
							
								
									694912698e
								
							
						
					
					
						commit
						83082e5cc7
					
				
					 2 changed files with 143 additions and 0 deletions
				
			
		| 
						 | 
					@ -27,7 +27,9 @@ import sys
 | 
				
			||||||
import types
 | 
					import types
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# Default is false, set to true to auto importing Intel Extension for PyTorch.
 | 
					# Default is false, set to true to auto importing Intel Extension for PyTorch.
 | 
				
			||||||
 | 
					USE_NPU = os.getenv("BIGDL_USE_NPU", 'False').lower() in ('true', '1', 't')
 | 
				
			||||||
BIGDL_IMPORT_IPEX = os.getenv("BIGDL_IMPORT_IPEX", 'True').lower() in ('true', '1', 't')
 | 
					BIGDL_IMPORT_IPEX = os.getenv("BIGDL_IMPORT_IPEX", 'True').lower() in ('true', '1', 't')
 | 
				
			||||||
 | 
					BIGDL_IMPORT_IPEX = not USE_NPU and BIGDL_IMPORT_IPEX
 | 
				
			||||||
if BIGDL_IMPORT_IPEX:
 | 
					if BIGDL_IMPORT_IPEX:
 | 
				
			||||||
    # Import Intel Extension for PyTorch as ipex if XPU version is installed
 | 
					    # Import Intel Extension for PyTorch as ipex if XPU version is installed
 | 
				
			||||||
    from .utils.ipex_importer import ipex_importer
 | 
					    from .utils.ipex_importer import ipex_importer
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
							
								
								
									
										141
									
								
								python/llm/src/ipex_llm/transformers/npu_model.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										141
									
								
								python/llm/src/ipex_llm/transformers/npu_model.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,141 @@
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Copyright 2016 The BigDL Authors.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					# You may obtain a copy of the License at
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#     http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					# distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					# See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					# limitations under the License.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import warnings
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					import transformers
 | 
				
			||||||
 | 
					from typing import List
 | 
				
			||||||
 | 
					from unittest.mock import patch
 | 
				
			||||||
 | 
					from transformers.dynamic_module_utils import get_imports
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import intel_npu_acceleration_library as npu_lib
 | 
				
			||||||
 | 
					from intel_npu_acceleration_library.dtypes import int8, int4
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from ipex_llm.utils.common.log4Error import invalidInputError
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def patch_flash_attn_import(filename: str) -> List[str]:
 | 
				
			||||||
 | 
					    """Work around for https://huggingface.co/microsoft/phi-1_5/discussions/72."""
 | 
				
			||||||
 | 
					    imports = get_imports(filename)
 | 
				
			||||||
 | 
					    if "flash_attn" in imports:
 | 
				
			||||||
 | 
					        imports.remove("flash_attn")
 | 
				
			||||||
 | 
					    return imports
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def ignore_argument(kwargs: dict, key: 'str'):
 | 
				
			||||||
 | 
					    arg = kwargs.pop(key, None)
 | 
				
			||||||
 | 
					    if arg is not None:
 | 
				
			||||||
 | 
					        warnings.warn(f"argument `{key}={arg}` will be ignored")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class _BaseAutoModelClass:
 | 
				
			||||||
 | 
					    HF_MODEL = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @classmethod
 | 
				
			||||||
 | 
					    @patch("transformers.dynamic_module_utils.get_imports", patch_flash_attn_import)
 | 
				
			||||||
 | 
					    def from_pretrained(cls,
 | 
				
			||||||
 | 
					                        *args,
 | 
				
			||||||
 | 
					                        **kwargs):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Load a model from a directory or the HF Hub. Use load_in_low_bit parameter to convert
 | 
				
			||||||
 | 
					        model to low-bit format, like int4 and int8.
 | 
				
			||||||
 | 
					        The loaded model will run supported OPs on NPU, then run other OPs on CPU.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Three new arguments are added to extend Hugging Face's from_pretrained method as follows:
 | 
				
			||||||
 | 
					        :param load_in_low_bit: str value, options are ``'sym_int4'``, ``'sym_int8'``, ``'fp32'``.
 | 
				
			||||||
 | 
					                                Relevant low bit optimizations will be applied to the model.
 | 
				
			||||||
 | 
					        :return: a model instance
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        if kwargs.get('device_map', None) not in [None, 'cpu', 'auto']:
 | 
				
			||||||
 | 
					            warnings.warn("`device_map` will be ignored")
 | 
				
			||||||
 | 
					        kwargs['device_map'] = 'cpu'
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        low_bit = kwargs.pop('load_in_low_bit', None)
 | 
				
			||||||
 | 
					        low_bit_to_dtype_map = {
 | 
				
			||||||
 | 
					            'sym_int4': int4,
 | 
				
			||||||
 | 
					            'sym_int8': int8,
 | 
				
			||||||
 | 
					            'fp32': torch.float,
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					        if low_bit is not None:
 | 
				
			||||||
 | 
					            dtype = low_bit_to_dtype_map[low_bit]
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            dtype = kwargs.get('torch_dtype', torch.float)
 | 
				
			||||||
 | 
					            dtype = torch.float if dtype == 'auto' else dtype
 | 
				
			||||||
 | 
					        invalidInputError(dtype in low_bit_to_dtype_map.values(),
 | 
				
			||||||
 | 
					                          f"unsupported dtype: {dtype}, "
 | 
				
			||||||
 | 
					                          "only `sym_int4`, `sym_int8`, `fp32` are supported")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        kwargs["low_cpu_mem_usage"] = True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # ignore following arguments
 | 
				
			||||||
 | 
					        ignore_argument(kwargs, "model_hub")
 | 
				
			||||||
 | 
					        ignore_argument(kwargs, "lightweight_bmm")
 | 
				
			||||||
 | 
					        ignore_argument(kwargs, "load_in_4bit")
 | 
				
			||||||
 | 
					        ignore_argument(kwargs, "load_in_8bit")
 | 
				
			||||||
 | 
					        ignore_argument(kwargs, "imatrix")
 | 
				
			||||||
 | 
					        ignore_argument(kwargs, "mixed_precision")
 | 
				
			||||||
 | 
					        ignore_argument(kwargs, "cpu_embedding")
 | 
				
			||||||
 | 
					        ignore_argument(kwargs, "embedding_qtype")
 | 
				
			||||||
 | 
					        ignore_argument(kwargs, "optimize_model")
 | 
				
			||||||
 | 
					        ignore_argument(kwargs, "modules_to_not_convert")
 | 
				
			||||||
 | 
					        ignore_argument(kwargs, "quantization_config")
 | 
				
			||||||
 | 
					        ignore_argument(kwargs, "speculative")
 | 
				
			||||||
 | 
					        ignore_argument(kwargs, "pipeline_parallel_stages")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        model = cls.HF_Model.from_pretrained(*args, **kwargs)
 | 
				
			||||||
 | 
					        model = npu_lib.compile(model, dtype, False)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return model
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class AutoModelForCausalLM(_BaseAutoModelClass):
 | 
				
			||||||
 | 
					    HF_Model = transformers.AutoModelForCausalLM
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class AutoModel(_BaseAutoModelClass):
 | 
				
			||||||
 | 
					    HF_Model = transformers.AutoModel
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class AutoModelForSpeechSeq2Seq(_BaseAutoModelClass):
 | 
				
			||||||
 | 
					    HF_Model = transformers.AutoModelForSpeechSeq2Seq
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class AutoModelForSeq2SeqLM(_BaseAutoModelClass):
 | 
				
			||||||
 | 
					    HF_Model = transformers.AutoModelForSeq2SeqLM
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class AutoModelForSequenceClassification(_BaseAutoModelClass):
 | 
				
			||||||
 | 
					    HF_Model = transformers.AutoModelForSequenceClassification
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class AutoModelForMaskedLM(_BaseAutoModelClass):
 | 
				
			||||||
 | 
					    HF_Model = transformers.AutoModelForMaskedLM
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class AutoModelForQuestionAnswering(_BaseAutoModelClass):
 | 
				
			||||||
 | 
					    HF_Model = transformers.AutoModelForQuestionAnswering
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class AutoModelForNextSentencePrediction(_BaseAutoModelClass):
 | 
				
			||||||
 | 
					    HF_Model = transformers.AutoModelForNextSentencePrediction
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class AutoModelForMultipleChoice(_BaseAutoModelClass):
 | 
				
			||||||
 | 
					    HF_Model = transformers.AutoModelForMultipleChoice
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class AutoModelForTokenClassification(_BaseAutoModelClass):
 | 
				
			||||||
 | 
					    HF_Model = transformers.AutoModelForTokenClassification
 | 
				
			||||||
		Loading…
	
		Reference in a new issue