[LLM] Unify Transformers and Native API (#8713)
* re-open pr to run on latest runner * re-add examples and ut * rename ut and move deprecate to warning instead of raising an error info * ut fix
This commit is contained in:
		
							parent
							
								
									1cb8f5abbd
								
							
						
					
					
						commit
						aceea4dc29
					
				
					 5 changed files with 133 additions and 8 deletions
				
			
		| 
						 | 
					@ -16,6 +16,7 @@
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import time
 | 
					import time
 | 
				
			||||||
import argparse
 | 
					import argparse
 | 
				
			||||||
 | 
					from bigdl.llm.transformers import *
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def convert(repo_id_or_model_path, model_family, tmp_path):
 | 
					def convert(repo_id_or_model_path, model_family, tmp_path):
 | 
				
			||||||
| 
						 | 
					@ -31,17 +32,30 @@ def convert(repo_id_or_model_path, model_family, tmp_path):
 | 
				
			||||||
    return bigdl_llm_path
 | 
					    return bigdl_llm_path
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def load(model_path, model_family, n_threads):
 | 
					def load(model_path, model_family, n_threads):
 | 
				
			||||||
    from bigdl.llm.transformers import BigdlNativeForCausalLM
 | 
					    model_family_to_class = {
 | 
				
			||||||
    llm = BigdlNativeForCausalLM.from_pretrained(
 | 
					        "llama": LlamaForCausalLM,
 | 
				
			||||||
 | 
					        "gptneox": GptneoxForCausalLM,
 | 
				
			||||||
 | 
					        "bloom": BloomForCausalLM,
 | 
				
			||||||
 | 
					        "starcoder": StarcoderForCausalLM,
 | 
				
			||||||
 | 
					        "chatglm": ChatGLMForCausalLM
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if model_family in model_family_to_class:
 | 
				
			||||||
 | 
					        llm_causal = model_family_to_class[model_family]
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        raise ValueError(f"Unknown model family: {model_family}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    llm = llm_causal.from_pretrained(
 | 
				
			||||||
        pretrained_model_name_or_path=model_path,
 | 
					        pretrained_model_name_or_path=model_path,
 | 
				
			||||||
        model_family=model_family,
 | 
					        native=True,
 | 
				
			||||||
 | 
					        dtype="int4",
 | 
				
			||||||
        n_threads=n_threads)
 | 
					        n_threads=n_threads)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    return llm
 | 
					    return llm
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def inference(llm, repo_id_or_model_path, model_family, prompt):
 | 
					def inference(llm, repo_id_or_model_path, model_family, prompt):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if model_family in ['llama', 'gptneox', 'bloom', 'starcoder']:
 | 
					    if model_family in ['llama', 'gptneox', 'bloom', 'starcoder', 'chatglm']:
 | 
				
			||||||
        # ------ Option 1: Use bigdl-llm based tokenizer
 | 
					        # ------ Option 1: Use bigdl-llm based tokenizer
 | 
				
			||||||
        print('-'*20, ' bigdl-llm based tokenizer ', '-'*20)
 | 
					        print('-'*20, ' bigdl-llm based tokenizer ', '-'*20)
 | 
				
			||||||
        st = time.time()
 | 
					        st = time.time()
 | 
				
			||||||
| 
						 | 
					@ -95,9 +109,9 @@ def main():
 | 
				
			||||||
    parser.add_argument('--thread-num', type=int, default=2, required=True,
 | 
					    parser.add_argument('--thread-num', type=int, default=2, required=True,
 | 
				
			||||||
                        help='Number of threads to use for inference')
 | 
					                        help='Number of threads to use for inference')
 | 
				
			||||||
    parser.add_argument('--model-family', type=str, default='llama', required=True,
 | 
					    parser.add_argument('--model-family', type=str, default='llama', required=True,
 | 
				
			||||||
                        choices=["llama", "llama2", "bloom", "gptneox", "starcoder"],
 | 
					                        choices=["llama", "llama2", "bloom", "gptneox", "starcoder", "chatglm"],
 | 
				
			||||||
                        help="The model family of the large language model (supported option: 'llama', 'llama2', "
 | 
					                        help="The model family of the large language model (supported option: 'llama', 'llama2', "
 | 
				
			||||||
                             "'gptneox', 'bloom', 'starcoder')")
 | 
					                             "'gptneox', 'bloom', 'starcoder', 'chatglm')")
 | 
				
			||||||
    parser.add_argument('--repo-id-or-model-path', type=str, required=True,
 | 
					    parser.add_argument('--repo-id-or-model-path', type=str, required=True,
 | 
				
			||||||
                        help='The path to the huggingface checkpoint folder')
 | 
					                        help='The path to the huggingface checkpoint folder')
 | 
				
			||||||
    parser.add_argument('--prompt', type=str, default='Once upon a time, there existed a little girl who liked to have adventures. ',
 | 
					    parser.add_argument('--prompt', type=str, default='Once upon a time, there existed a little girl who liked to have adventures. ',
 | 
				
			||||||
| 
						 | 
					@ -117,7 +131,6 @@ def main():
 | 
				
			||||||
                             model_family=args.model_family,
 | 
					                             model_family=args.model_family,
 | 
				
			||||||
                             tmp_path=args.tmp_path)
 | 
					                             tmp_path=args.tmp_path)
 | 
				
			||||||
    
 | 
					    
 | 
				
			||||||
    
 | 
					 | 
				
			||||||
    # Step 2: load int4 model
 | 
					    # Step 2: load int4 model
 | 
				
			||||||
    llm = load(model_path=bigdl_llm_path,
 | 
					    llm = load(model_path=bigdl_llm_path,
 | 
				
			||||||
               model_family=args.model_family,
 | 
					               model_family=args.model_family,
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -16,4 +16,4 @@
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from .convert import ggml_convert_quant
 | 
					from .convert import ggml_convert_quant
 | 
				
			||||||
from .model import AutoModelForCausalLM, AutoModel, AutoModelForSeq2SeqLM, AutoModelForSpeechSeq2Seq
 | 
					from .model import AutoModelForCausalLM, AutoModel, AutoModelForSeq2SeqLM, AutoModelForSpeechSeq2Seq
 | 
				
			||||||
from .modelling_bigdl import BigdlNativeForCausalLM
 | 
					from .modelling_bigdl import *
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -19,7 +19,9 @@
 | 
				
			||||||
# Otherwise there would be module not found error in non-pip's setting as Python would
 | 
					# Otherwise there would be module not found error in non-pip's setting as Python would
 | 
				
			||||||
# only search the first bigdl package and end up finding only one sub-package.
 | 
					# only search the first bigdl package and end up finding only one sub-package.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import logging
 | 
				
			||||||
from bigdl.llm.utils.common import invalidInputError
 | 
					from bigdl.llm.utils.common import invalidInputError
 | 
				
			||||||
 | 
					from .model import *
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class BigdlNativeForCausalLM:
 | 
					class BigdlNativeForCausalLM:
 | 
				
			||||||
| 
						 | 
					@ -52,6 +54,8 @@ class BigdlNativeForCausalLM:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        :return: a model instance
 | 
					        :return: a model instance
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
 | 
					        logging.warning("BigdlNativeForCausalLM has been deprecated, "
 | 
				
			||||||
 | 
					                        "please switch to the new CausalLM API for sepcific models.")
 | 
				
			||||||
        invalidInputError(model_family in ['llama', 'gptneox', 'bloom', 'starcoder', 'chatglm'],
 | 
					        invalidInputError(model_family in ['llama', 'gptneox', 'bloom', 'starcoder', 'chatglm'],
 | 
				
			||||||
                          "Now we only support model family: 'llama', 'gptneox', 'bloom',"
 | 
					                          "Now we only support model family: 'llama', 'gptneox', 'bloom',"
 | 
				
			||||||
                          " 'starcoder', 'chatglm', '{}' is not in the list.".format(model_family))
 | 
					                          " 'starcoder', 'chatglm', '{}' is not in the list.".format(model_family))
 | 
				
			||||||
| 
						 | 
					@ -75,3 +79,70 @@ class BigdlNativeForCausalLM:
 | 
				
			||||||
        elif model_family == 'chatglm':
 | 
					        elif model_family == 'chatglm':
 | 
				
			||||||
            from bigdl.llm.ggml.model.chatglm import ChatGLM
 | 
					            from bigdl.llm.ggml.model.chatglm import ChatGLM
 | 
				
			||||||
            return ChatGLM(model_path=ggml_model_path, **kwargs)
 | 
					            return ChatGLM(model_path=ggml_model_path, **kwargs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class _BaseGGMLClass:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    GGML_Model = None
 | 
				
			||||||
 | 
					    HF_Class = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @classmethod
 | 
				
			||||||
 | 
					    def from_pretrained(cls,
 | 
				
			||||||
 | 
					                        pretrained_model_name_or_path: str,
 | 
				
			||||||
 | 
					                        native: bool = True,
 | 
				
			||||||
 | 
					                        dtype: str = "int4",
 | 
				
			||||||
 | 
					                        *args,
 | 
				
			||||||
 | 
					                        **kwargs):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        :param pretrained_model_name_or_path: Path for model checkpoint.
 | 
				
			||||||
 | 
					               If running with ``native int4``, the path should be converted BigDL-LLM optimized
 | 
				
			||||||
 | 
					               ggml binary checkpoint, which should be converted by ``bigdl.llm.llm_convert``.
 | 
				
			||||||
 | 
					               If running with ``transformers int4``, the path should be the huggingface repo id
 | 
				
			||||||
 | 
					               to be downloaded or the huggingface checkpoint folder.
 | 
				
			||||||
 | 
					        :param native: Load model to either BigDL-LLM optimized Transformer or Native (ggml) int4.
 | 
				
			||||||
 | 
					        :param dtype: Which quantized precision will be converted.
 | 
				
			||||||
 | 
					               Now only `int4` and `int8` are supported, and `int8` only works for `llama`
 | 
				
			||||||
 | 
					               , `gptneox` and `starcoder`.
 | 
				
			||||||
 | 
					        :param kwargs: keyword arguments which will be passed to the model instance.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        :return: a model instance
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        if native:
 | 
				
			||||||
 | 
					            invalidInputError(dtype.lower() in ['int4', 'int8'],
 | 
				
			||||||
 | 
					                              "Now we only support int4 and int8 as date type for weight")
 | 
				
			||||||
 | 
					            ggml_model_path = pretrained_model_name_or_path
 | 
				
			||||||
 | 
					            return cls.GGML_Model(model_path=ggml_model_path,
 | 
				
			||||||
 | 
					                                  **kwargs)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            return cls.HF_Class.from_pretrained(pretrained_model_name_or_path,
 | 
				
			||||||
 | 
					                                                *args, **kwargs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class LlamaForCausalLM(_BaseGGMLClass):
 | 
				
			||||||
 | 
					    from bigdl.llm.ggml.model.llama import Llama
 | 
				
			||||||
 | 
					    GGML_Model = Llama
 | 
				
			||||||
 | 
					    HF_Class = AutoModelForCausalLM
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class ChatGLMForCausalLM(_BaseGGMLClass):
 | 
				
			||||||
 | 
					    from bigdl.llm.ggml.model.chatglm import ChatGLM
 | 
				
			||||||
 | 
					    GGML_Model = ChatGLM
 | 
				
			||||||
 | 
					    HF_Class = AutoModel
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class GptneoxForCausalLM(_BaseGGMLClass):
 | 
				
			||||||
 | 
					    from bigdl.llm.ggml.model.gptneox import Gptneox
 | 
				
			||||||
 | 
					    GGML_Model = Gptneox
 | 
				
			||||||
 | 
					    HF_Class = AutoModelForCausalLM
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class BloomForCausalLM(_BaseGGMLClass):
 | 
				
			||||||
 | 
					    from bigdl.llm.ggml.model.bloom import Bloom
 | 
				
			||||||
 | 
					    GGML_Model = Bloom
 | 
				
			||||||
 | 
					    HF_Class = AutoModelForCausalLM
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class StarcoderForCausalLM(_BaseGGMLClass):
 | 
				
			||||||
 | 
					    from bigdl.llm.ggml.model.starcoder import Starcoder
 | 
				
			||||||
 | 
					    GGML_Model = Starcoder
 | 
				
			||||||
 | 
					    HF_Class = AutoModelForCausalLM
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -16,6 +16,8 @@
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from bigdl.llm.models import Llama, Bloom, Gptneox, Starcoder
 | 
					from bigdl.llm.models import Llama, Bloom, Gptneox, Starcoder
 | 
				
			||||||
 | 
					from bigdl.llm.transformers import LlamaForCausalLM, BloomForCausalLM, \
 | 
				
			||||||
 | 
					    GptneoxForCausalLM, StarcoderForCausalLM
 | 
				
			||||||
import pytest
 | 
					import pytest
 | 
				
			||||||
from unittest import TestCase
 | 
					from unittest import TestCase
 | 
				
			||||||
import os
 | 
					import os
 | 
				
			||||||
| 
						 | 
					@ -43,6 +45,11 @@ class Test_Models_Basics(TestCase):
 | 
				
			||||||
        llm = Llama(self.llama_model_path, n_threads=self.n_threads)
 | 
					        llm = Llama(self.llama_model_path, n_threads=self.n_threads)
 | 
				
			||||||
        output = llm("What is the capital of France?", max_tokens=32, stream=True)
 | 
					        output = llm("What is the capital of France?", max_tokens=32, stream=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_llama_for_causallm(self):
 | 
				
			||||||
 | 
					        llm = LlamaForCausalLM.from_pretrained(self.llama_model_path, native=True,
 | 
				
			||||||
 | 
					                                               n_threads=self.n_threads)
 | 
				
			||||||
 | 
					        output = llm("What is the capital of France?", max_tokens=32, stream=False)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def test_bloom_completion_success(self):
 | 
					    def test_bloom_completion_success(self):
 | 
				
			||||||
        llm = Bloom(self.bloom_model_path, n_threads=self.n_threads)
 | 
					        llm = Bloom(self.bloom_model_path, n_threads=self.n_threads)
 | 
				
			||||||
        output = llm("What is the capital of France?", max_tokens=32, stream=False)
 | 
					        output = llm("What is the capital of France?", max_tokens=32, stream=False)
 | 
				
			||||||
| 
						 | 
					@ -55,6 +62,11 @@ class Test_Models_Basics(TestCase):
 | 
				
			||||||
        llm = Bloom(self.bloom_model_path, n_threads=self.n_threads)
 | 
					        llm = Bloom(self.bloom_model_path, n_threads=self.n_threads)
 | 
				
			||||||
        output = llm("What is the capital of France?", max_tokens=32, stream=True)
 | 
					        output = llm("What is the capital of France?", max_tokens=32, stream=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_bloom_for_causallm(self):
 | 
				
			||||||
 | 
					        llm = BloomForCausalLM.from_pretrained(self.bloom_model_path, native=True,
 | 
				
			||||||
 | 
					                                               n_threads=self.n_threads)
 | 
				
			||||||
 | 
					        output = llm("What is the capital of France?", max_tokens=32, stream=False)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def test_gptneox_completion_success(self):
 | 
					    def test_gptneox_completion_success(self):
 | 
				
			||||||
        llm = Gptneox(self.gptneox_model_path, n_threads=self.n_threads)
 | 
					        llm = Gptneox(self.gptneox_model_path, n_threads=self.n_threads)
 | 
				
			||||||
        output = llm("Q: What is the capital of France? A:", max_tokens=32, stream=False)
 | 
					        output = llm("Q: What is the capital of France? A:", max_tokens=32, stream=False)
 | 
				
			||||||
| 
						 | 
					@ -64,6 +76,11 @@ class Test_Models_Basics(TestCase):
 | 
				
			||||||
        llm = Gptneox(self.gptneox_model_path, n_threads=self.n_threads)
 | 
					        llm = Gptneox(self.gptneox_model_path, n_threads=self.n_threads)
 | 
				
			||||||
        output = llm("Q: What is the capital of France? A:", max_tokens=32, stream=True)
 | 
					        output = llm("Q: What is the capital of France? A:", max_tokens=32, stream=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_getneox_for_causallm(self):
 | 
				
			||||||
 | 
					        llm = GptneoxForCausalLM.from_pretrained(self.gptneox_model_path, native=True,
 | 
				
			||||||
 | 
					                                                 n_threads=self.n_threads)
 | 
				
			||||||
 | 
					        output = llm("Q: What is the capital of France? A:", max_tokens=32, stream=False)
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
    def test_starcoder_completion_success(self):
 | 
					    def test_starcoder_completion_success(self):
 | 
				
			||||||
        llm = Starcoder(self.starcoder_model_path, n_threads=self.n_threads)
 | 
					        llm = Starcoder(self.starcoder_model_path, n_threads=self.n_threads)
 | 
				
			||||||
        output = llm("def print_hello_world(", max_tokens=32, stream=False)
 | 
					        output = llm("def print_hello_world(", max_tokens=32, stream=False)
 | 
				
			||||||
| 
						 | 
					@ -73,6 +90,11 @@ class Test_Models_Basics(TestCase):
 | 
				
			||||||
        llm = Starcoder(self.starcoder_model_path, n_threads=self.n_threads)
 | 
					        llm = Starcoder(self.starcoder_model_path, n_threads=self.n_threads)
 | 
				
			||||||
        output = llm("def print_hello_world(", max_tokens=32, stream=True)
 | 
					        output = llm("def print_hello_world(", max_tokens=32, stream=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_starcoder_for_causallm(self):
 | 
				
			||||||
 | 
					        llm = StarcoderForCausalLM.from_pretrained(self.starcoder_model_path, native=True,
 | 
				
			||||||
 | 
					                                                   n_threads=self.n_threads)
 | 
				
			||||||
 | 
					        output = llm("def print_hello_world(", max_tokens=32, stream=False)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
if __name__ == '__main__':
 | 
					if __name__ == '__main__':
 | 
				
			||||||
    pytest.main([__file__])
 | 
					    pytest.main([__file__])
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -90,5 +90,24 @@ class TestTransformersAPI(unittest.TestCase):
 | 
				
			||||||
        res = 'Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.' in transcription[0]
 | 
					        res = 'Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.' in transcription[0]
 | 
				
			||||||
        self.assertTrue(res)
 | 
					        self.assertTrue(res)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_transformers_chatglm_for_causallm(self):
 | 
				
			||||||
 | 
					        from bigdl.llm.transformers import ChatGLMForCausalLM
 | 
				
			||||||
 | 
					        model_path = os.environ.get('ORIGINAL_CHATGLM2_6B_PATH')
 | 
				
			||||||
 | 
					        model = ChatGLMForCausalLM.from_pretrained(model_path, native=False, trust_remote_code=True, load_in_4bit=True)
 | 
				
			||||||
 | 
					        tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
 | 
				
			||||||
 | 
					        input_str = "Tell me the capital of France.\n\n"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        with torch.inference_mode():
 | 
				
			||||||
 | 
					            st = time.time()
 | 
				
			||||||
 | 
					            input_ids = tokenizer.encode(input_str, return_tensors="pt")
 | 
				
			||||||
 | 
					            output = model.generate(input_ids, do_sample=False, max_new_tokens=32)
 | 
				
			||||||
 | 
					            output_str = tokenizer.decode(output[0], skip_special_tokens=True)
 | 
				
			||||||
 | 
					            end = time.time()
 | 
				
			||||||
 | 
					        print('Prompt:', input_str)
 | 
				
			||||||
 | 
					        print('Output:', output_str)
 | 
				
			||||||
 | 
					        print(f'Inference time: {end-st} s')
 | 
				
			||||||
 | 
					        res = 'Paris' in output_str        
 | 
				
			||||||
 | 
					        self.assertTrue(res)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
if __name__ == '__main__':
 | 
					if __name__ == '__main__':
 | 
				
			||||||
    pytest.main([__file__])
 | 
					    pytest.main([__file__])
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue