[LLM] Add transformers-like API from_pretrained (#8271)
* Init commit for bigdl.llm.transformers.AutoModelForCausalLM * Temp change to avoid name conflicts with external transformers lib * Support downloading model from huggingface * Small python style fix * Change location of transformers to avoid library conflicts * Add return value for converted ggml binary ckpt path for convert_model * Avoid repeated loading of shared library and adding some comments * Small fix * Path type fix anddocstring fix * Small fix * Small fix * Change cache dir to pwd
This commit is contained in:
		
							parent
							
								
									2ed5842448
								
							
						
					
					
						commit
						64bc123dd3
					
				
					 4 changed files with 137 additions and 4 deletions
				
			
		| 
						 | 
				
			
			@ -39,6 +39,8 @@ def convert_model(input_path: str,
 | 
			
		|||
    :param dtype: Which quantized precision will be converted.
 | 
			
		||||
            Now only int4 supported.
 | 
			
		||||
    :param tmp_path: Which path to store the intermediate model during the conversion process.
 | 
			
		||||
 | 
			
		||||
    :return: the path str to the converted lower precision checkpoint
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    dtype = dtype.lower()
 | 
			
		||||
| 
						 | 
				
			
			@ -54,7 +56,7 @@ def convert_model(input_path: str,
 | 
			
		|||
 | 
			
		||||
    tmp_ggml_file_path = next(Path(tmp_ggml_file_path).iterdir())
 | 
			
		||||
 | 
			
		||||
    quantize(input_path=tmp_ggml_file_path,
 | 
			
		||||
    return quantize(input_path=tmp_ggml_file_path,
 | 
			
		||||
                    output_path=output_path,
 | 
			
		||||
                    model_family=model_family,
 | 
			
		||||
                    dtype=dtype)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -60,6 +60,8 @@ def quantize(input_path: str, output_path: str=None,
 | 
			
		|||
            llama : "q4_0", "q4_1", "q4_2"
 | 
			
		||||
            bloom : "q4_0", "q4_1"
 | 
			
		||||
            gptneox : "q4_0", "q4_1", "q4_2", "q5_0", "q5_1", "q8_0"
 | 
			
		||||
 | 
			
		||||
    :return: the path str to the converted ggml binary checkpoint
 | 
			
		||||
    """
 | 
			
		||||
    invalidInputError(model_family in ['llama', 'bloom', 'gptneox'],
 | 
			
		||||
                      "Now we only support quantization of model \
 | 
			
		||||
| 
						 | 
				
			
			@ -92,3 +94,4 @@ def quantize(input_path: str, output_path: str=None,
 | 
			
		|||
    p.communicate()
 | 
			
		||||
    invalidInputError(not p.returncode,
 | 
			
		||||
                      "Fail to quantize {}.".format(str(input_path)))
 | 
			
		||||
    return str(output_path)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
							
								
								
									
										22
									
								
								python/llm/src/bigdl/llm/ggml/transformers/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										22
									
								
								python/llm/src/bigdl/llm/ggml/transformers/__init__.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,22 @@
 | 
			
		|||
#
 | 
			
		||||
# Copyright 2016 The BigDL Authors.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
#
 | 
			
		||||
 | 
			
		||||
# This would makes sure Python is aware there is more than one sub-package within bigdl,
 | 
			
		||||
# physically located elsewhere.
 | 
			
		||||
# Otherwise there would be module not found error in non-pip's setting as Python would
 | 
			
		||||
# only search the first bigdl package and end up finding only one sub-package.
 | 
			
		||||
 | 
			
		||||
from .model import AutoModelForCausalLM
 | 
			
		||||
							
								
								
									
										106
									
								
								python/llm/src/bigdl/llm/ggml/transformers/model.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										106
									
								
								python/llm/src/bigdl/llm/ggml/transformers/model.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,106 @@
 | 
			
		|||
#
 | 
			
		||||
# Copyright 2016 The BigDL Authors.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
#
 | 
			
		||||
 | 
			
		||||
# This would makes sure Python is aware there is more than one sub-package within bigdl,
 | 
			
		||||
# physically located elsewhere.
 | 
			
		||||
# Otherwise there would be module not found error in non-pip's setting as Python would
 | 
			
		||||
# only search the first bigdl package and end up finding only one sub-package.
 | 
			
		||||
 | 
			
		||||
import os
 | 
			
		||||
import traceback
 | 
			
		||||
from huggingface_hub import snapshot_download
 | 
			
		||||
from bigdl.llm.utils.common import invalidInputError
 | 
			
		||||
from bigdl.llm.ggml import convert_model
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class AutoModelForCausalLM:
 | 
			
		||||
    """
 | 
			
		||||
    A generic model class that mimics the behavior of
 | 
			
		||||
    ``transformers.AutoModelForCausalLM.from_pretrained`` API
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def from_pretrained(cls,
 | 
			
		||||
                        pretrained_model_name_or_path: str,
 | 
			
		||||
                        model_family: str = 'llama',
 | 
			
		||||
                        dtype: str = 'int4',
 | 
			
		||||
                        cache_dir: str = './',
 | 
			
		||||
                        **kwargs):
 | 
			
		||||
        """
 | 
			
		||||
        :param pretrained_model_name_or_path: We support 3 kinds of pretrained model checkpoint
 | 
			
		||||
 | 
			
		||||
               1. path for huggingface checkpoint that are directly pulled from hugginface hub.
 | 
			
		||||
                  This should be a dir path that contains: weight bin, tokenizer config,
 | 
			
		||||
                  tokenizer.model (required for llama) and added_tokens.json (if applied).
 | 
			
		||||
                  For lora fine tuned model, the path should be pointed to a merged weight.
 | 
			
		||||
               2. path for converted ggml binary checkpoint. The checkpoint should be converted by
 | 
			
		||||
                  ``bigdl.llm.ggml.convert_model``.
 | 
			
		||||
               3. a str for huggingface hub repo id.
 | 
			
		||||
 | 
			
		||||
        :param model_family: the model family of the pretrained checkpoint.
 | 
			
		||||
               Currently we support ``"llama"``, ``"bloom"``, ``"gptneox"``.
 | 
			
		||||
        :param dtype: (optional) the data type for weight. Currently we only support ``"int4"``
 | 
			
		||||
        :param cache_dir: (optional) this parameter will only be used when
 | 
			
		||||
               ``pretrained_model_name_or_path`` is a hugginface checkpoint or hub repo id.
 | 
			
		||||
               It indicates the saving path for the converted low precision model.
 | 
			
		||||
        :param **kwargs: keyword arguments which will be passed to the model instance
 | 
			
		||||
 | 
			
		||||
        :return: a model instance
 | 
			
		||||
        """
 | 
			
		||||
        invalidInputError(model_family in ['llama', 'gptneox', 'bloom'],
 | 
			
		||||
                          "Now we only support model family: 'llama', 'gptneox', 'bloom', "
 | 
			
		||||
                          "'{}' is not in the list.".format(model_family))
 | 
			
		||||
        invalidInputError(dtype.lower() == 'int4',
 | 
			
		||||
                          "Now we only support int4 as date type for weight")
 | 
			
		||||
 | 
			
		||||
        # check whether pretrained_model_name_or_path exists.
 | 
			
		||||
        # if not, it is likely that the user wants to pass in the repo id.
 | 
			
		||||
        if not os.path.exists(pretrained_model_name_or_path):
 | 
			
		||||
            try:
 | 
			
		||||
                # download from huggingface based on repo id
 | 
			
		||||
                pretrained_model_name_or_path = snapshot_download(
 | 
			
		||||
                    repo_id=pretrained_model_name_or_path)
 | 
			
		||||
            except Exception as e:
 | 
			
		||||
                traceback.print_exc()
 | 
			
		||||
                # if downloading fails, it could be the case that repo id is invalid,
 | 
			
		||||
                # or the user pass in the wrong path for checkpoint
 | 
			
		||||
                invalidInputError(False,
 | 
			
		||||
                                  "Downloadng from huggingface repo id {} failed. "
 | 
			
		||||
                                  "Please input valid huggingface hub repo id, "
 | 
			
		||||
                                  "or provide the valid path to huggingface / "
 | 
			
		||||
                                  "ggml binary checkpoint, for pretrained_model_name_or_path"
 | 
			
		||||
                                  .format(pretrained_model_name_or_path))
 | 
			
		||||
 | 
			
		||||
        ggml_model_path = pretrained_model_name_or_path
 | 
			
		||||
        # check whether pretrained_model_name_or_path is a file.
 | 
			
		||||
        # if not, it is likely that pretrained_model_name_or_path
 | 
			
		||||
        # points to a huggingface checkpoint
 | 
			
		||||
        if not os.path.isfile(pretrained_model_name_or_path):
 | 
			
		||||
            # huggingface checkpoint
 | 
			
		||||
            ggml_model_path = convert_model(input_path=pretrained_model_name_or_path,
 | 
			
		||||
                                            output_path=cache_dir,
 | 
			
		||||
                                            model_family=model_family,
 | 
			
		||||
                                            dtype=dtype)
 | 
			
		||||
 | 
			
		||||
        if model_family == 'llama':
 | 
			
		||||
            from bigdl.llm.ggml.model.llama import Llama
 | 
			
		||||
            return Llama(model_path=ggml_model_path, **kwargs)
 | 
			
		||||
        elif model_family == 'gptneox':
 | 
			
		||||
            from bigdl.llm.ggml.model.gptneox import Gptneox
 | 
			
		||||
            return Gptneox(model_path=ggml_model_path, **kwargs)
 | 
			
		||||
        elif model_family == 'bloom':
 | 
			
		||||
            from bigdl.llm.ggml.model.bloom import Bloom
 | 
			
		||||
            return Bloom(model_path=ggml_model_path, **kwargs)
 | 
			
		||||
		Loading…
	
		Reference in a new issue