[LLM] Unify Transformers and Native API (#8713)
* re-open pr to run on latest runner * re-add examples and ut * rename ut and move deprecate to warning instead of raising an error info * ut fix
This commit is contained in:
parent
1cb8f5abbd
commit
aceea4dc29
5 changed files with 133 additions and 8 deletions
|
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
import time
|
import time
|
||||||
import argparse
|
import argparse
|
||||||
|
from bigdl.llm.transformers import *
|
||||||
|
|
||||||
|
|
||||||
def convert(repo_id_or_model_path, model_family, tmp_path):
|
def convert(repo_id_or_model_path, model_family, tmp_path):
|
||||||
|
|
@ -31,17 +32,30 @@ def convert(repo_id_or_model_path, model_family, tmp_path):
|
||||||
return bigdl_llm_path
|
return bigdl_llm_path
|
||||||
|
|
||||||
def load(model_path, model_family, n_threads):
|
def load(model_path, model_family, n_threads):
|
||||||
from bigdl.llm.transformers import BigdlNativeForCausalLM
|
model_family_to_class = {
|
||||||
llm = BigdlNativeForCausalLM.from_pretrained(
|
"llama": LlamaForCausalLM,
|
||||||
|
"gptneox": GptneoxForCausalLM,
|
||||||
|
"bloom": BloomForCausalLM,
|
||||||
|
"starcoder": StarcoderForCausalLM,
|
||||||
|
"chatglm": ChatGLMForCausalLM
|
||||||
|
}
|
||||||
|
|
||||||
|
if model_family in model_family_to_class:
|
||||||
|
llm_causal = model_family_to_class[model_family]
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown model family: {model_family}")
|
||||||
|
|
||||||
|
llm = llm_causal.from_pretrained(
|
||||||
pretrained_model_name_or_path=model_path,
|
pretrained_model_name_or_path=model_path,
|
||||||
model_family=model_family,
|
native=True,
|
||||||
|
dtype="int4",
|
||||||
n_threads=n_threads)
|
n_threads=n_threads)
|
||||||
|
|
||||||
return llm
|
return llm
|
||||||
|
|
||||||
def inference(llm, repo_id_or_model_path, model_family, prompt):
|
def inference(llm, repo_id_or_model_path, model_family, prompt):
|
||||||
|
|
||||||
if model_family in ['llama', 'gptneox', 'bloom', 'starcoder']:
|
if model_family in ['llama', 'gptneox', 'bloom', 'starcoder', 'chatglm']:
|
||||||
# ------ Option 1: Use bigdl-llm based tokenizer
|
# ------ Option 1: Use bigdl-llm based tokenizer
|
||||||
print('-'*20, ' bigdl-llm based tokenizer ', '-'*20)
|
print('-'*20, ' bigdl-llm based tokenizer ', '-'*20)
|
||||||
st = time.time()
|
st = time.time()
|
||||||
|
|
@ -95,9 +109,9 @@ def main():
|
||||||
parser.add_argument('--thread-num', type=int, default=2, required=True,
|
parser.add_argument('--thread-num', type=int, default=2, required=True,
|
||||||
help='Number of threads to use for inference')
|
help='Number of threads to use for inference')
|
||||||
parser.add_argument('--model-family', type=str, default='llama', required=True,
|
parser.add_argument('--model-family', type=str, default='llama', required=True,
|
||||||
choices=["llama", "llama2", "bloom", "gptneox", "starcoder"],
|
choices=["llama", "llama2", "bloom", "gptneox", "starcoder", "chatglm"],
|
||||||
help="The model family of the large language model (supported option: 'llama', 'llama2', "
|
help="The model family of the large language model (supported option: 'llama', 'llama2', "
|
||||||
"'gptneox', 'bloom', 'starcoder')")
|
"'gptneox', 'bloom', 'starcoder', 'chatglm')")
|
||||||
parser.add_argument('--repo-id-or-model-path', type=str, required=True,
|
parser.add_argument('--repo-id-or-model-path', type=str, required=True,
|
||||||
help='The path to the huggingface checkpoint folder')
|
help='The path to the huggingface checkpoint folder')
|
||||||
parser.add_argument('--prompt', type=str, default='Once upon a time, there existed a little girl who liked to have adventures. ',
|
parser.add_argument('--prompt', type=str, default='Once upon a time, there existed a little girl who liked to have adventures. ',
|
||||||
|
|
@ -117,7 +131,6 @@ def main():
|
||||||
model_family=args.model_family,
|
model_family=args.model_family,
|
||||||
tmp_path=args.tmp_path)
|
tmp_path=args.tmp_path)
|
||||||
|
|
||||||
|
|
||||||
# Step 2: load int4 model
|
# Step 2: load int4 model
|
||||||
llm = load(model_path=bigdl_llm_path,
|
llm = load(model_path=bigdl_llm_path,
|
||||||
model_family=args.model_family,
|
model_family=args.model_family,
|
||||||
|
|
|
||||||
|
|
@ -16,4 +16,4 @@
|
||||||
|
|
||||||
from .convert import ggml_convert_quant
|
from .convert import ggml_convert_quant
|
||||||
from .model import AutoModelForCausalLM, AutoModel, AutoModelForSeq2SeqLM, AutoModelForSpeechSeq2Seq
|
from .model import AutoModelForCausalLM, AutoModel, AutoModelForSeq2SeqLM, AutoModelForSpeechSeq2Seq
|
||||||
from .modelling_bigdl import BigdlNativeForCausalLM
|
from .modelling_bigdl import *
|
||||||
|
|
|
||||||
|
|
@ -19,7 +19,9 @@
|
||||||
# Otherwise there would be module not found error in non-pip's setting as Python would
|
# Otherwise there would be module not found error in non-pip's setting as Python would
|
||||||
# only search the first bigdl package and end up finding only one sub-package.
|
# only search the first bigdl package and end up finding only one sub-package.
|
||||||
|
|
||||||
|
import logging
|
||||||
from bigdl.llm.utils.common import invalidInputError
|
from bigdl.llm.utils.common import invalidInputError
|
||||||
|
from .model import *
|
||||||
|
|
||||||
|
|
||||||
class BigdlNativeForCausalLM:
|
class BigdlNativeForCausalLM:
|
||||||
|
|
@ -52,6 +54,8 @@ class BigdlNativeForCausalLM:
|
||||||
|
|
||||||
:return: a model instance
|
:return: a model instance
|
||||||
"""
|
"""
|
||||||
|
logging.warning("BigdlNativeForCausalLM has been deprecated, "
|
||||||
|
"please switch to the new CausalLM API for sepcific models.")
|
||||||
invalidInputError(model_family in ['llama', 'gptneox', 'bloom', 'starcoder', 'chatglm'],
|
invalidInputError(model_family in ['llama', 'gptneox', 'bloom', 'starcoder', 'chatglm'],
|
||||||
"Now we only support model family: 'llama', 'gptneox', 'bloom',"
|
"Now we only support model family: 'llama', 'gptneox', 'bloom',"
|
||||||
" 'starcoder', 'chatglm', '{}' is not in the list.".format(model_family))
|
" 'starcoder', 'chatglm', '{}' is not in the list.".format(model_family))
|
||||||
|
|
@ -75,3 +79,70 @@ class BigdlNativeForCausalLM:
|
||||||
elif model_family == 'chatglm':
|
elif model_family == 'chatglm':
|
||||||
from bigdl.llm.ggml.model.chatglm import ChatGLM
|
from bigdl.llm.ggml.model.chatglm import ChatGLM
|
||||||
return ChatGLM(model_path=ggml_model_path, **kwargs)
|
return ChatGLM(model_path=ggml_model_path, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class _BaseGGMLClass:
|
||||||
|
|
||||||
|
GGML_Model = None
|
||||||
|
HF_Class = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls,
|
||||||
|
pretrained_model_name_or_path: str,
|
||||||
|
native: bool = True,
|
||||||
|
dtype: str = "int4",
|
||||||
|
*args,
|
||||||
|
**kwargs):
|
||||||
|
"""
|
||||||
|
:param pretrained_model_name_or_path: Path for model checkpoint.
|
||||||
|
If running with ``native int4``, the path should be converted BigDL-LLM optimized
|
||||||
|
ggml binary checkpoint, which should be converted by ``bigdl.llm.llm_convert``.
|
||||||
|
If running with ``transformers int4``, the path should be the huggingface repo id
|
||||||
|
to be downloaded or the huggingface checkpoint folder.
|
||||||
|
:param native: Load model to either BigDL-LLM optimized Transformer or Native (ggml) int4.
|
||||||
|
:param dtype: Which quantized precision will be converted.
|
||||||
|
Now only `int4` and `int8` are supported, and `int8` only works for `llama`
|
||||||
|
, `gptneox` and `starcoder`.
|
||||||
|
:param kwargs: keyword arguments which will be passed to the model instance.
|
||||||
|
|
||||||
|
:return: a model instance
|
||||||
|
"""
|
||||||
|
if native:
|
||||||
|
invalidInputError(dtype.lower() in ['int4', 'int8'],
|
||||||
|
"Now we only support int4 and int8 as date type for weight")
|
||||||
|
ggml_model_path = pretrained_model_name_or_path
|
||||||
|
return cls.GGML_Model(model_path=ggml_model_path,
|
||||||
|
**kwargs)
|
||||||
|
else:
|
||||||
|
return cls.HF_Class.from_pretrained(pretrained_model_name_or_path,
|
||||||
|
*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class LlamaForCausalLM(_BaseGGMLClass):
|
||||||
|
from bigdl.llm.ggml.model.llama import Llama
|
||||||
|
GGML_Model = Llama
|
||||||
|
HF_Class = AutoModelForCausalLM
|
||||||
|
|
||||||
|
|
||||||
|
class ChatGLMForCausalLM(_BaseGGMLClass):
|
||||||
|
from bigdl.llm.ggml.model.chatglm import ChatGLM
|
||||||
|
GGML_Model = ChatGLM
|
||||||
|
HF_Class = AutoModel
|
||||||
|
|
||||||
|
|
||||||
|
class GptneoxForCausalLM(_BaseGGMLClass):
|
||||||
|
from bigdl.llm.ggml.model.gptneox import Gptneox
|
||||||
|
GGML_Model = Gptneox
|
||||||
|
HF_Class = AutoModelForCausalLM
|
||||||
|
|
||||||
|
|
||||||
|
class BloomForCausalLM(_BaseGGMLClass):
|
||||||
|
from bigdl.llm.ggml.model.bloom import Bloom
|
||||||
|
GGML_Model = Bloom
|
||||||
|
HF_Class = AutoModelForCausalLM
|
||||||
|
|
||||||
|
|
||||||
|
class StarcoderForCausalLM(_BaseGGMLClass):
|
||||||
|
from bigdl.llm.ggml.model.starcoder import Starcoder
|
||||||
|
GGML_Model = Starcoder
|
||||||
|
HF_Class = AutoModelForCausalLM
|
||||||
|
|
|
||||||
|
|
@ -16,6 +16,8 @@
|
||||||
|
|
||||||
|
|
||||||
from bigdl.llm.models import Llama, Bloom, Gptneox, Starcoder
|
from bigdl.llm.models import Llama, Bloom, Gptneox, Starcoder
|
||||||
|
from bigdl.llm.transformers import LlamaForCausalLM, BloomForCausalLM, \
|
||||||
|
GptneoxForCausalLM, StarcoderForCausalLM
|
||||||
import pytest
|
import pytest
|
||||||
from unittest import TestCase
|
from unittest import TestCase
|
||||||
import os
|
import os
|
||||||
|
|
@ -43,6 +45,11 @@ class Test_Models_Basics(TestCase):
|
||||||
llm = Llama(self.llama_model_path, n_threads=self.n_threads)
|
llm = Llama(self.llama_model_path, n_threads=self.n_threads)
|
||||||
output = llm("What is the capital of France?", max_tokens=32, stream=True)
|
output = llm("What is the capital of France?", max_tokens=32, stream=True)
|
||||||
|
|
||||||
|
def test_llama_for_causallm(self):
|
||||||
|
llm = LlamaForCausalLM.from_pretrained(self.llama_model_path, native=True,
|
||||||
|
n_threads=self.n_threads)
|
||||||
|
output = llm("What is the capital of France?", max_tokens=32, stream=False)
|
||||||
|
|
||||||
def test_bloom_completion_success(self):
|
def test_bloom_completion_success(self):
|
||||||
llm = Bloom(self.bloom_model_path, n_threads=self.n_threads)
|
llm = Bloom(self.bloom_model_path, n_threads=self.n_threads)
|
||||||
output = llm("What is the capital of France?", max_tokens=32, stream=False)
|
output = llm("What is the capital of France?", max_tokens=32, stream=False)
|
||||||
|
|
@ -55,6 +62,11 @@ class Test_Models_Basics(TestCase):
|
||||||
llm = Bloom(self.bloom_model_path, n_threads=self.n_threads)
|
llm = Bloom(self.bloom_model_path, n_threads=self.n_threads)
|
||||||
output = llm("What is the capital of France?", max_tokens=32, stream=True)
|
output = llm("What is the capital of France?", max_tokens=32, stream=True)
|
||||||
|
|
||||||
|
def test_bloom_for_causallm(self):
|
||||||
|
llm = BloomForCausalLM.from_pretrained(self.bloom_model_path, native=True,
|
||||||
|
n_threads=self.n_threads)
|
||||||
|
output = llm("What is the capital of France?", max_tokens=32, stream=False)
|
||||||
|
|
||||||
def test_gptneox_completion_success(self):
|
def test_gptneox_completion_success(self):
|
||||||
llm = Gptneox(self.gptneox_model_path, n_threads=self.n_threads)
|
llm = Gptneox(self.gptneox_model_path, n_threads=self.n_threads)
|
||||||
output = llm("Q: What is the capital of France? A:", max_tokens=32, stream=False)
|
output = llm("Q: What is the capital of France? A:", max_tokens=32, stream=False)
|
||||||
|
|
@ -64,6 +76,11 @@ class Test_Models_Basics(TestCase):
|
||||||
llm = Gptneox(self.gptneox_model_path, n_threads=self.n_threads)
|
llm = Gptneox(self.gptneox_model_path, n_threads=self.n_threads)
|
||||||
output = llm("Q: What is the capital of France? A:", max_tokens=32, stream=True)
|
output = llm("Q: What is the capital of France? A:", max_tokens=32, stream=True)
|
||||||
|
|
||||||
|
def test_getneox_for_causallm(self):
|
||||||
|
llm = GptneoxForCausalLM.from_pretrained(self.gptneox_model_path, native=True,
|
||||||
|
n_threads=self.n_threads)
|
||||||
|
output = llm("Q: What is the capital of France? A:", max_tokens=32, stream=False)
|
||||||
|
|
||||||
def test_starcoder_completion_success(self):
|
def test_starcoder_completion_success(self):
|
||||||
llm = Starcoder(self.starcoder_model_path, n_threads=self.n_threads)
|
llm = Starcoder(self.starcoder_model_path, n_threads=self.n_threads)
|
||||||
output = llm("def print_hello_world(", max_tokens=32, stream=False)
|
output = llm("def print_hello_world(", max_tokens=32, stream=False)
|
||||||
|
|
@ -73,6 +90,11 @@ class Test_Models_Basics(TestCase):
|
||||||
llm = Starcoder(self.starcoder_model_path, n_threads=self.n_threads)
|
llm = Starcoder(self.starcoder_model_path, n_threads=self.n_threads)
|
||||||
output = llm("def print_hello_world(", max_tokens=32, stream=True)
|
output = llm("def print_hello_world(", max_tokens=32, stream=True)
|
||||||
|
|
||||||
|
def test_starcoder_for_causallm(self):
|
||||||
|
llm = StarcoderForCausalLM.from_pretrained(self.starcoder_model_path, native=True,
|
||||||
|
n_threads=self.n_threads)
|
||||||
|
output = llm("def print_hello_world(", max_tokens=32, stream=False)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
pytest.main([__file__])
|
pytest.main([__file__])
|
||||||
|
|
|
||||||
|
|
@ -90,5 +90,24 @@ class TestTransformersAPI(unittest.TestCase):
|
||||||
res = 'Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.' in transcription[0]
|
res = 'Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.' in transcription[0]
|
||||||
self.assertTrue(res)
|
self.assertTrue(res)
|
||||||
|
|
||||||
|
def test_transformers_chatglm_for_causallm(self):
|
||||||
|
from bigdl.llm.transformers import ChatGLMForCausalLM
|
||||||
|
model_path = os.environ.get('ORIGINAL_CHATGLM2_6B_PATH')
|
||||||
|
model = ChatGLMForCausalLM.from_pretrained(model_path, native=False, trust_remote_code=True, load_in_4bit=True)
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
||||||
|
input_str = "Tell me the capital of France.\n\n"
|
||||||
|
|
||||||
|
with torch.inference_mode():
|
||||||
|
st = time.time()
|
||||||
|
input_ids = tokenizer.encode(input_str, return_tensors="pt")
|
||||||
|
output = model.generate(input_ids, do_sample=False, max_new_tokens=32)
|
||||||
|
output_str = tokenizer.decode(output[0], skip_special_tokens=True)
|
||||||
|
end = time.time()
|
||||||
|
print('Prompt:', input_str)
|
||||||
|
print('Output:', output_str)
|
||||||
|
print(f'Inference time: {end-st} s')
|
||||||
|
res = 'Paris' in output_str
|
||||||
|
self.assertTrue(res)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
pytest.main([__file__])
|
pytest.main([__file__])
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue