[LLM] Unify Transformers and Native API (#8713)
* re-open pr to run on latest runner * re-add examples and ut * rename ut and move deprecate to warning instead of raising an error info * ut fix
This commit is contained in:
parent
1cb8f5abbd
commit
aceea4dc29
5 changed files with 133 additions and 8 deletions
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
import time
|
||||
import argparse
|
||||
from bigdl.llm.transformers import *
|
||||
|
||||
|
||||
def convert(repo_id_or_model_path, model_family, tmp_path):
|
||||
|
|
@ -31,17 +32,30 @@ def convert(repo_id_or_model_path, model_family, tmp_path):
|
|||
return bigdl_llm_path
|
||||
|
||||
def load(model_path, model_family, n_threads):
|
||||
from bigdl.llm.transformers import BigdlNativeForCausalLM
|
||||
llm = BigdlNativeForCausalLM.from_pretrained(
|
||||
model_family_to_class = {
|
||||
"llama": LlamaForCausalLM,
|
||||
"gptneox": GptneoxForCausalLM,
|
||||
"bloom": BloomForCausalLM,
|
||||
"starcoder": StarcoderForCausalLM,
|
||||
"chatglm": ChatGLMForCausalLM
|
||||
}
|
||||
|
||||
if model_family in model_family_to_class:
|
||||
llm_causal = model_family_to_class[model_family]
|
||||
else:
|
||||
raise ValueError(f"Unknown model family: {model_family}")
|
||||
|
||||
llm = llm_causal.from_pretrained(
|
||||
pretrained_model_name_or_path=model_path,
|
||||
model_family=model_family,
|
||||
native=True,
|
||||
dtype="int4",
|
||||
n_threads=n_threads)
|
||||
|
||||
return llm
|
||||
|
||||
def inference(llm, repo_id_or_model_path, model_family, prompt):
|
||||
|
||||
if model_family in ['llama', 'gptneox', 'bloom', 'starcoder']:
|
||||
if model_family in ['llama', 'gptneox', 'bloom', 'starcoder', 'chatglm']:
|
||||
# ------ Option 1: Use bigdl-llm based tokenizer
|
||||
print('-'*20, ' bigdl-llm based tokenizer ', '-'*20)
|
||||
st = time.time()
|
||||
|
|
@ -95,9 +109,9 @@ def main():
|
|||
parser.add_argument('--thread-num', type=int, default=2, required=True,
|
||||
help='Number of threads to use for inference')
|
||||
parser.add_argument('--model-family', type=str, default='llama', required=True,
|
||||
choices=["llama", "llama2", "bloom", "gptneox", "starcoder"],
|
||||
choices=["llama", "llama2", "bloom", "gptneox", "starcoder", "chatglm"],
|
||||
help="The model family of the large language model (supported option: 'llama', 'llama2', "
|
||||
"'gptneox', 'bloom', 'starcoder')")
|
||||
"'gptneox', 'bloom', 'starcoder', 'chatglm')")
|
||||
parser.add_argument('--repo-id-or-model-path', type=str, required=True,
|
||||
help='The path to the huggingface checkpoint folder')
|
||||
parser.add_argument('--prompt', type=str, default='Once upon a time, there existed a little girl who liked to have adventures. ',
|
||||
|
|
@ -117,7 +131,6 @@ def main():
|
|||
model_family=args.model_family,
|
||||
tmp_path=args.tmp_path)
|
||||
|
||||
|
||||
# Step 2: load int4 model
|
||||
llm = load(model_path=bigdl_llm_path,
|
||||
model_family=args.model_family,
|
||||
|
|
|
|||
|
|
@ -16,4 +16,4 @@
|
|||
|
||||
from .convert import ggml_convert_quant
|
||||
from .model import AutoModelForCausalLM, AutoModel, AutoModelForSeq2SeqLM, AutoModelForSpeechSeq2Seq
|
||||
from .modelling_bigdl import BigdlNativeForCausalLM
|
||||
from .modelling_bigdl import *
|
||||
|
|
|
|||
|
|
@ -19,7 +19,9 @@
|
|||
# Otherwise there would be module not found error in non-pip's setting as Python would
|
||||
# only search the first bigdl package and end up finding only one sub-package.
|
||||
|
||||
import logging
|
||||
from bigdl.llm.utils.common import invalidInputError
|
||||
from .model import *
|
||||
|
||||
|
||||
class BigdlNativeForCausalLM:
|
||||
|
|
@ -52,6 +54,8 @@ class BigdlNativeForCausalLM:
|
|||
|
||||
:return: a model instance
|
||||
"""
|
||||
logging.warning("BigdlNativeForCausalLM has been deprecated, "
|
||||
"please switch to the new CausalLM API for sepcific models.")
|
||||
invalidInputError(model_family in ['llama', 'gptneox', 'bloom', 'starcoder', 'chatglm'],
|
||||
"Now we only support model family: 'llama', 'gptneox', 'bloom',"
|
||||
" 'starcoder', 'chatglm', '{}' is not in the list.".format(model_family))
|
||||
|
|
@ -75,3 +79,70 @@ class BigdlNativeForCausalLM:
|
|||
elif model_family == 'chatglm':
|
||||
from bigdl.llm.ggml.model.chatglm import ChatGLM
|
||||
return ChatGLM(model_path=ggml_model_path, **kwargs)
|
||||
|
||||
|
||||
class _BaseGGMLClass:
|
||||
|
||||
GGML_Model = None
|
||||
HF_Class = None
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls,
|
||||
pretrained_model_name_or_path: str,
|
||||
native: bool = True,
|
||||
dtype: str = "int4",
|
||||
*args,
|
||||
**kwargs):
|
||||
"""
|
||||
:param pretrained_model_name_or_path: Path for model checkpoint.
|
||||
If running with ``native int4``, the path should be converted BigDL-LLM optimized
|
||||
ggml binary checkpoint, which should be converted by ``bigdl.llm.llm_convert``.
|
||||
If running with ``transformers int4``, the path should be the huggingface repo id
|
||||
to be downloaded or the huggingface checkpoint folder.
|
||||
:param native: Load model to either BigDL-LLM optimized Transformer or Native (ggml) int4.
|
||||
:param dtype: Which quantized precision will be converted.
|
||||
Now only `int4` and `int8` are supported, and `int8` only works for `llama`
|
||||
, `gptneox` and `starcoder`.
|
||||
:param kwargs: keyword arguments which will be passed to the model instance.
|
||||
|
||||
:return: a model instance
|
||||
"""
|
||||
if native:
|
||||
invalidInputError(dtype.lower() in ['int4', 'int8'],
|
||||
"Now we only support int4 and int8 as date type for weight")
|
||||
ggml_model_path = pretrained_model_name_or_path
|
||||
return cls.GGML_Model(model_path=ggml_model_path,
|
||||
**kwargs)
|
||||
else:
|
||||
return cls.HF_Class.from_pretrained(pretrained_model_name_or_path,
|
||||
*args, **kwargs)
|
||||
|
||||
|
||||
class LlamaForCausalLM(_BaseGGMLClass):
|
||||
from bigdl.llm.ggml.model.llama import Llama
|
||||
GGML_Model = Llama
|
||||
HF_Class = AutoModelForCausalLM
|
||||
|
||||
|
||||
class ChatGLMForCausalLM(_BaseGGMLClass):
|
||||
from bigdl.llm.ggml.model.chatglm import ChatGLM
|
||||
GGML_Model = ChatGLM
|
||||
HF_Class = AutoModel
|
||||
|
||||
|
||||
class GptneoxForCausalLM(_BaseGGMLClass):
|
||||
from bigdl.llm.ggml.model.gptneox import Gptneox
|
||||
GGML_Model = Gptneox
|
||||
HF_Class = AutoModelForCausalLM
|
||||
|
||||
|
||||
class BloomForCausalLM(_BaseGGMLClass):
|
||||
from bigdl.llm.ggml.model.bloom import Bloom
|
||||
GGML_Model = Bloom
|
||||
HF_Class = AutoModelForCausalLM
|
||||
|
||||
|
||||
class StarcoderForCausalLM(_BaseGGMLClass):
|
||||
from bigdl.llm.ggml.model.starcoder import Starcoder
|
||||
GGML_Model = Starcoder
|
||||
HF_Class = AutoModelForCausalLM
|
||||
|
|
|
|||
|
|
@ -16,6 +16,8 @@
|
|||
|
||||
|
||||
from bigdl.llm.models import Llama, Bloom, Gptneox, Starcoder
|
||||
from bigdl.llm.transformers import LlamaForCausalLM, BloomForCausalLM, \
|
||||
GptneoxForCausalLM, StarcoderForCausalLM
|
||||
import pytest
|
||||
from unittest import TestCase
|
||||
import os
|
||||
|
|
@ -43,6 +45,11 @@ class Test_Models_Basics(TestCase):
|
|||
llm = Llama(self.llama_model_path, n_threads=self.n_threads)
|
||||
output = llm("What is the capital of France?", max_tokens=32, stream=True)
|
||||
|
||||
def test_llama_for_causallm(self):
|
||||
llm = LlamaForCausalLM.from_pretrained(self.llama_model_path, native=True,
|
||||
n_threads=self.n_threads)
|
||||
output = llm("What is the capital of France?", max_tokens=32, stream=False)
|
||||
|
||||
def test_bloom_completion_success(self):
|
||||
llm = Bloom(self.bloom_model_path, n_threads=self.n_threads)
|
||||
output = llm("What is the capital of France?", max_tokens=32, stream=False)
|
||||
|
|
@ -55,6 +62,11 @@ class Test_Models_Basics(TestCase):
|
|||
llm = Bloom(self.bloom_model_path, n_threads=self.n_threads)
|
||||
output = llm("What is the capital of France?", max_tokens=32, stream=True)
|
||||
|
||||
def test_bloom_for_causallm(self):
|
||||
llm = BloomForCausalLM.from_pretrained(self.bloom_model_path, native=True,
|
||||
n_threads=self.n_threads)
|
||||
output = llm("What is the capital of France?", max_tokens=32, stream=False)
|
||||
|
||||
def test_gptneox_completion_success(self):
|
||||
llm = Gptneox(self.gptneox_model_path, n_threads=self.n_threads)
|
||||
output = llm("Q: What is the capital of France? A:", max_tokens=32, stream=False)
|
||||
|
|
@ -63,6 +75,11 @@ class Test_Models_Basics(TestCase):
|
|||
def test_gptneox_completion_with_stream_success(self):
|
||||
llm = Gptneox(self.gptneox_model_path, n_threads=self.n_threads)
|
||||
output = llm("Q: What is the capital of France? A:", max_tokens=32, stream=True)
|
||||
|
||||
def test_getneox_for_causallm(self):
|
||||
llm = GptneoxForCausalLM.from_pretrained(self.gptneox_model_path, native=True,
|
||||
n_threads=self.n_threads)
|
||||
output = llm("Q: What is the capital of France? A:", max_tokens=32, stream=False)
|
||||
|
||||
def test_starcoder_completion_success(self):
|
||||
llm = Starcoder(self.starcoder_model_path, n_threads=self.n_threads)
|
||||
|
|
@ -73,6 +90,11 @@ class Test_Models_Basics(TestCase):
|
|||
llm = Starcoder(self.starcoder_model_path, n_threads=self.n_threads)
|
||||
output = llm("def print_hello_world(", max_tokens=32, stream=True)
|
||||
|
||||
def test_starcoder_for_causallm(self):
|
||||
llm = StarcoderForCausalLM.from_pretrained(self.starcoder_model_path, native=True,
|
||||
n_threads=self.n_threads)
|
||||
output = llm("def print_hello_world(", max_tokens=32, stream=False)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
|
|
|
|||
|
|
@ -90,5 +90,24 @@ class TestTransformersAPI(unittest.TestCase):
|
|||
res = 'Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.' in transcription[0]
|
||||
self.assertTrue(res)
|
||||
|
||||
def test_transformers_chatglm_for_causallm(self):
|
||||
from bigdl.llm.transformers import ChatGLMForCausalLM
|
||||
model_path = os.environ.get('ORIGINAL_CHATGLM2_6B_PATH')
|
||||
model = ChatGLMForCausalLM.from_pretrained(model_path, native=False, trust_remote_code=True, load_in_4bit=True)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
||||
input_str = "Tell me the capital of France.\n\n"
|
||||
|
||||
with torch.inference_mode():
|
||||
st = time.time()
|
||||
input_ids = tokenizer.encode(input_str, return_tensors="pt")
|
||||
output = model.generate(input_ids, do_sample=False, max_new_tokens=32)
|
||||
output_str = tokenizer.decode(output[0], skip_special_tokens=True)
|
||||
end = time.time()
|
||||
print('Prompt:', input_str)
|
||||
print('Output:', output_str)
|
||||
print(f'Inference time: {end-st} s')
|
||||
res = 'Paris' in output_str
|
||||
self.assertTrue(res)
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
|
|
|
|||
Loading…
Reference in a new issue