LLM: Optimize Langchain Pipeline (#8561)

* LLM: Optimize Langchain Pipeline

* load in low bit
This commit is contained in:
Zhao Changmin 2023-07-19 17:43:13 +08:00 committed by GitHub
parent 616b7cb0a2
commit e680af45ea
2 changed files with 73 additions and 9 deletions

View file

@ -33,6 +33,23 @@ import pyttsx3
import argparse
import time
english_template = """
{history}
Q: {human_input}
A:"""
chinese_template = """{history}\n\n问:{human_input}\n\n答:"""
template_dict = {
"english": english_template,
"chinese": chinese_template
}
llm_load_methods = (
TransformersLLM.from_model_id,
TransformersLLM.from_model_id_low_bit,
)
def prepare_chain(args):
@ -41,16 +58,13 @@ def prepare_chain(args):
# Use a easy prompt could bring good-enough result
# For Chinese Prompt
# template = """{history}\n\n问{human_input}\n\n答"""
template = """
{history}
Q: {human_input}
A:"""
template = template_dict[args.language]
prompt = PromptTemplate(input_variables=["history", "human_input"], template=template)
llm = TransformersLLM.from_model_id(
method_index = 1 if args.directly else 0
llm = llm_load_methods[method_index](
model_id=llm_model_path,
model_kwargs={"temperature": 0,
"max_length": args.max_length,
"trust_remote_code": True},
)
@ -59,6 +73,7 @@ def prepare_chain(args):
llm=llm,
prompt=prompt,
verbose=True,
llm_kwargs={"max_new_tokens":args.max_new_tokens},
memory=ConversationBufferWindowMemory(k=2),
)
@ -126,10 +141,12 @@ if __name__ == '__main__':
help="the path to the huggingface speech recognition model")
parser.add_argument('-m','--llm-model-path', type=str, required=True,
help='the path to the huggingface llm model')
parser.add_argument('-x','--max-length', type=int, default=256,
help='the max length of model tokens input')
parser.add_argument('-x','--max-new-tokens', type=int, default=32,
help='the max new tokens of model tokens input')
parser.add_argument('-l', '--language', type=str, default="english",
help='language to be transcribed')
help='the language to be transcribed')
parser.add_argument('-d', '--directly', action='store_true',
help='whether to load low bit model directly')
args = parser.parse_args()
main(args)

View file

@ -130,6 +130,53 @@ class TransformersLLM(LLM):
**kwargs,
)
@classmethod
def from_model_id_low_bit(
cls,
model_id: str,
model_kwargs: Optional[dict] = None,
**kwargs: Any,
) -> LLM:
"""Construct object from model_id"""
try:
from bigdl.llm.transformers import (
AutoModel,
AutoModelForCausalLM,
)
from transformers import AutoTokenizer, LlamaTokenizer
except ImportError:
raise ValueError(
"Could not import transformers python package. "
"Please install it with `pip install transformers`."
)
_model_kwargs = model_kwargs or {}
# TODO: may refactore this code in the future
try:
tokenizer = AutoTokenizer.from_pretrained(model_id, **_model_kwargs)
except:
tokenizer = LlamaTokenizer.from_pretrained(model_id, **_model_kwargs)
# TODO: may refactore this code in the future
try:
model = AutoModelForCausalLM.load_low_bit(model_id, **_model_kwargs)
except:
model = AutoModel.load_low_bit(model_id, **_model_kwargs)
if "trust_remote_code" in _model_kwargs:
_model_kwargs = {
k: v for k, v in _model_kwargs.items() if k != "trust_remote_code"
}
return cls(
model_id=model_id,
model=model,
tokenizer=tokenizer,
model_kwargs=_model_kwargs,
**kwargs,
)
@property
def _identifying_params(self) -> Mapping[str, Any]:
"""Get the identifying parameters."""