[LLM] Unify Transformers and Native API (#8713)

* re-open pr to run on latest runner

* re-add examples and ut

* rename ut and move deprecate to warning instead of raising an error info

* ut fix
This commit is contained in:
SONG Ge 2023-08-11 19:45:47 +08:00 committed by GitHub
parent 1cb8f5abbd
commit aceea4dc29
5 changed files with 133 additions and 8 deletions

View file

@ -16,6 +16,7 @@
import time
import argparse
from bigdl.llm.transformers import *
def convert(repo_id_or_model_path, model_family, tmp_path):
@ -31,17 +32,30 @@ def convert(repo_id_or_model_path, model_family, tmp_path):
return bigdl_llm_path
def load(model_path, model_family, n_threads):
from bigdl.llm.transformers import BigdlNativeForCausalLM
llm = BigdlNativeForCausalLM.from_pretrained(
model_family_to_class = {
"llama": LlamaForCausalLM,
"gptneox": GptneoxForCausalLM,
"bloom": BloomForCausalLM,
"starcoder": StarcoderForCausalLM,
"chatglm": ChatGLMForCausalLM
}
if model_family in model_family_to_class:
llm_causal = model_family_to_class[model_family]
else:
raise ValueError(f"Unknown model family: {model_family}")
llm = llm_causal.from_pretrained(
pretrained_model_name_or_path=model_path,
model_family=model_family,
native=True,
dtype="int4",
n_threads=n_threads)
return llm
def inference(llm, repo_id_or_model_path, model_family, prompt):
if model_family in ['llama', 'gptneox', 'bloom', 'starcoder']:
if model_family in ['llama', 'gptneox', 'bloom', 'starcoder', 'chatglm']:
# ------ Option 1: Use bigdl-llm based tokenizer
print('-'*20, ' bigdl-llm based tokenizer ', '-'*20)
st = time.time()
@ -95,9 +109,9 @@ def main():
parser.add_argument('--thread-num', type=int, default=2, required=True,
help='Number of threads to use for inference')
parser.add_argument('--model-family', type=str, default='llama', required=True,
choices=["llama", "llama2", "bloom", "gptneox", "starcoder"],
choices=["llama", "llama2", "bloom", "gptneox", "starcoder", "chatglm"],
help="The model family of the large language model (supported option: 'llama', 'llama2', "
"'gptneox', 'bloom', 'starcoder')")
"'gptneox', 'bloom', 'starcoder', 'chatglm')")
parser.add_argument('--repo-id-or-model-path', type=str, required=True,
help='The path to the huggingface checkpoint folder')
parser.add_argument('--prompt', type=str, default='Once upon a time, there existed a little girl who liked to have adventures. ',
@ -117,7 +131,6 @@ def main():
model_family=args.model_family,
tmp_path=args.tmp_path)
# Step 2: load int4 model
llm = load(model_path=bigdl_llm_path,
model_family=args.model_family,

View file

@ -16,4 +16,4 @@
from .convert import ggml_convert_quant
from .model import AutoModelForCausalLM, AutoModel, AutoModelForSeq2SeqLM, AutoModelForSpeechSeq2Seq
from .modelling_bigdl import BigdlNativeForCausalLM
from .modelling_bigdl import *

View file

@ -19,7 +19,9 @@
# Otherwise there would be module not found error in non-pip's setting as Python would
# only search the first bigdl package and end up finding only one sub-package.
import logging
from bigdl.llm.utils.common import invalidInputError
from .model import *
class BigdlNativeForCausalLM:
@ -52,6 +54,8 @@ class BigdlNativeForCausalLM:
:return: a model instance
"""
logging.warning("BigdlNativeForCausalLM has been deprecated, "
"please switch to the new CausalLM API for sepcific models.")
invalidInputError(model_family in ['llama', 'gptneox', 'bloom', 'starcoder', 'chatglm'],
"Now we only support model family: 'llama', 'gptneox', 'bloom',"
" 'starcoder', 'chatglm', '{}' is not in the list.".format(model_family))
@ -75,3 +79,70 @@ class BigdlNativeForCausalLM:
elif model_family == 'chatglm':
from bigdl.llm.ggml.model.chatglm import ChatGLM
return ChatGLM(model_path=ggml_model_path, **kwargs)
class _BaseGGMLClass:
GGML_Model = None
HF_Class = None
@classmethod
def from_pretrained(cls,
pretrained_model_name_or_path: str,
native: bool = True,
dtype: str = "int4",
*args,
**kwargs):
"""
:param pretrained_model_name_or_path: Path for model checkpoint.
If running with ``native int4``, the path should be converted BigDL-LLM optimized
ggml binary checkpoint, which should be converted by ``bigdl.llm.llm_convert``.
If running with ``transformers int4``, the path should be the huggingface repo id
to be downloaded or the huggingface checkpoint folder.
:param native: Load model to either BigDL-LLM optimized Transformer or Native (ggml) int4.
:param dtype: Which quantized precision will be converted.
Now only `int4` and `int8` are supported, and `int8` only works for `llama`
, `gptneox` and `starcoder`.
:param kwargs: keyword arguments which will be passed to the model instance.
:return: a model instance
"""
if native:
invalidInputError(dtype.lower() in ['int4', 'int8'],
"Now we only support int4 and int8 as date type for weight")
ggml_model_path = pretrained_model_name_or_path
return cls.GGML_Model(model_path=ggml_model_path,
**kwargs)
else:
return cls.HF_Class.from_pretrained(pretrained_model_name_or_path,
*args, **kwargs)
class LlamaForCausalLM(_BaseGGMLClass):
from bigdl.llm.ggml.model.llama import Llama
GGML_Model = Llama
HF_Class = AutoModelForCausalLM
class ChatGLMForCausalLM(_BaseGGMLClass):
from bigdl.llm.ggml.model.chatglm import ChatGLM
GGML_Model = ChatGLM
HF_Class = AutoModel
class GptneoxForCausalLM(_BaseGGMLClass):
from bigdl.llm.ggml.model.gptneox import Gptneox
GGML_Model = Gptneox
HF_Class = AutoModelForCausalLM
class BloomForCausalLM(_BaseGGMLClass):
from bigdl.llm.ggml.model.bloom import Bloom
GGML_Model = Bloom
HF_Class = AutoModelForCausalLM
class StarcoderForCausalLM(_BaseGGMLClass):
from bigdl.llm.ggml.model.starcoder import Starcoder
GGML_Model = Starcoder
HF_Class = AutoModelForCausalLM

View file

@ -16,6 +16,8 @@
from bigdl.llm.models import Llama, Bloom, Gptneox, Starcoder
from bigdl.llm.transformers import LlamaForCausalLM, BloomForCausalLM, \
GptneoxForCausalLM, StarcoderForCausalLM
import pytest
from unittest import TestCase
import os
@ -43,6 +45,11 @@ class Test_Models_Basics(TestCase):
llm = Llama(self.llama_model_path, n_threads=self.n_threads)
output = llm("What is the capital of France?", max_tokens=32, stream=True)
def test_llama_for_causallm(self):
llm = LlamaForCausalLM.from_pretrained(self.llama_model_path, native=True,
n_threads=self.n_threads)
output = llm("What is the capital of France?", max_tokens=32, stream=False)
def test_bloom_completion_success(self):
llm = Bloom(self.bloom_model_path, n_threads=self.n_threads)
output = llm("What is the capital of France?", max_tokens=32, stream=False)
@ -55,6 +62,11 @@ class Test_Models_Basics(TestCase):
llm = Bloom(self.bloom_model_path, n_threads=self.n_threads)
output = llm("What is the capital of France?", max_tokens=32, stream=True)
def test_bloom_for_causallm(self):
llm = BloomForCausalLM.from_pretrained(self.bloom_model_path, native=True,
n_threads=self.n_threads)
output = llm("What is the capital of France?", max_tokens=32, stream=False)
def test_gptneox_completion_success(self):
llm = Gptneox(self.gptneox_model_path, n_threads=self.n_threads)
output = llm("Q: What is the capital of France? A:", max_tokens=32, stream=False)
@ -63,6 +75,11 @@ class Test_Models_Basics(TestCase):
def test_gptneox_completion_with_stream_success(self):
llm = Gptneox(self.gptneox_model_path, n_threads=self.n_threads)
output = llm("Q: What is the capital of France? A:", max_tokens=32, stream=True)
def test_getneox_for_causallm(self):
llm = GptneoxForCausalLM.from_pretrained(self.gptneox_model_path, native=True,
n_threads=self.n_threads)
output = llm("Q: What is the capital of France? A:", max_tokens=32, stream=False)
def test_starcoder_completion_success(self):
llm = Starcoder(self.starcoder_model_path, n_threads=self.n_threads)
@ -73,6 +90,11 @@ class Test_Models_Basics(TestCase):
llm = Starcoder(self.starcoder_model_path, n_threads=self.n_threads)
output = llm("def print_hello_world(", max_tokens=32, stream=True)
def test_starcoder_for_causallm(self):
llm = StarcoderForCausalLM.from_pretrained(self.starcoder_model_path, native=True,
n_threads=self.n_threads)
output = llm("def print_hello_world(", max_tokens=32, stream=False)
if __name__ == '__main__':
pytest.main([__file__])

View file

@ -90,5 +90,24 @@ class TestTransformersAPI(unittest.TestCase):
res = 'Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.' in transcription[0]
self.assertTrue(res)
def test_transformers_chatglm_for_causallm(self):
from bigdl.llm.transformers import ChatGLMForCausalLM
model_path = os.environ.get('ORIGINAL_CHATGLM2_6B_PATH')
model = ChatGLMForCausalLM.from_pretrained(model_path, native=False, trust_remote_code=True, load_in_4bit=True)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
input_str = "Tell me the capital of France.\n\n"
with torch.inference_mode():
st = time.time()
input_ids = tokenizer.encode(input_str, return_tensors="pt")
output = model.generate(input_ids, do_sample=False, max_new_tokens=32)
output_str = tokenizer.decode(output[0], skip_special_tokens=True)
end = time.time()
print('Prompt:', input_str)
print('Output:', output_str)
print(f'Inference time: {end-st} s')
res = 'Paris' in output_str
self.assertTrue(res)
if __name__ == '__main__':
pytest.main([__file__])