LLM: Optimize Langchain Pipeline (#8561)
* LLM: Optimize Langchain Pipeline * load in low bit
This commit is contained in:
parent
616b7cb0a2
commit
e680af45ea
2 changed files with 73 additions and 9 deletions
|
|
@ -33,6 +33,23 @@ import pyttsx3
|
|||
import argparse
|
||||
import time
|
||||
|
||||
english_template = """
|
||||
{history}
|
||||
Q: {human_input}
|
||||
A:"""
|
||||
|
||||
chinese_template = """{history}\n\n问:{human_input}\n\n答:"""
|
||||
|
||||
|
||||
template_dict = {
|
||||
"english": english_template,
|
||||
"chinese": chinese_template
|
||||
}
|
||||
|
||||
llm_load_methods = (
|
||||
TransformersLLM.from_model_id,
|
||||
TransformersLLM.from_model_id_low_bit,
|
||||
)
|
||||
|
||||
def prepare_chain(args):
|
||||
|
||||
|
|
@ -41,16 +58,13 @@ def prepare_chain(args):
|
|||
# Use a easy prompt could bring good-enough result
|
||||
# For Chinese Prompt
|
||||
# template = """{history}\n\n问:{human_input}\n\n答:"""
|
||||
template = """
|
||||
{history}
|
||||
Q: {human_input}
|
||||
A:"""
|
||||
template = template_dict[args.language]
|
||||
prompt = PromptTemplate(input_variables=["history", "human_input"], template=template)
|
||||
|
||||
llm = TransformersLLM.from_model_id(
|
||||
method_index = 1 if args.directly else 0
|
||||
llm = llm_load_methods[method_index](
|
||||
model_id=llm_model_path,
|
||||
model_kwargs={"temperature": 0,
|
||||
"max_length": args.max_length,
|
||||
"trust_remote_code": True},
|
||||
)
|
||||
|
||||
|
|
@ -59,6 +73,7 @@ def prepare_chain(args):
|
|||
llm=llm,
|
||||
prompt=prompt,
|
||||
verbose=True,
|
||||
llm_kwargs={"max_new_tokens":args.max_new_tokens},
|
||||
memory=ConversationBufferWindowMemory(k=2),
|
||||
)
|
||||
|
||||
|
|
@ -126,10 +141,12 @@ if __name__ == '__main__':
|
|||
help="the path to the huggingface speech recognition model")
|
||||
parser.add_argument('-m','--llm-model-path', type=str, required=True,
|
||||
help='the path to the huggingface llm model')
|
||||
parser.add_argument('-x','--max-length', type=int, default=256,
|
||||
help='the max length of model tokens input')
|
||||
parser.add_argument('-x','--max-new-tokens', type=int, default=32,
|
||||
help='the max new tokens of model tokens input')
|
||||
parser.add_argument('-l', '--language', type=str, default="english",
|
||||
help='language to be transcribed')
|
||||
help='the language to be transcribed')
|
||||
parser.add_argument('-d', '--directly', action='store_true',
|
||||
help='whether to load low bit model directly')
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args)
|
||||
|
|
@ -130,6 +130,53 @@ class TransformersLLM(LLM):
|
|||
**kwargs,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_model_id_low_bit(
|
||||
cls,
|
||||
model_id: str,
|
||||
model_kwargs: Optional[dict] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLM:
|
||||
"""Construct object from model_id"""
|
||||
try:
|
||||
from bigdl.llm.transformers import (
|
||||
AutoModel,
|
||||
AutoModelForCausalLM,
|
||||
)
|
||||
from transformers import AutoTokenizer, LlamaTokenizer
|
||||
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Could not import transformers python package. "
|
||||
"Please install it with `pip install transformers`."
|
||||
)
|
||||
|
||||
_model_kwargs = model_kwargs or {}
|
||||
# TODO: may refactore this code in the future
|
||||
try:
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id, **_model_kwargs)
|
||||
except:
|
||||
tokenizer = LlamaTokenizer.from_pretrained(model_id, **_model_kwargs)
|
||||
|
||||
# TODO: may refactore this code in the future
|
||||
try:
|
||||
model = AutoModelForCausalLM.load_low_bit(model_id, **_model_kwargs)
|
||||
except:
|
||||
model = AutoModel.load_low_bit(model_id, **_model_kwargs)
|
||||
|
||||
if "trust_remote_code" in _model_kwargs:
|
||||
_model_kwargs = {
|
||||
k: v for k, v in _model_kwargs.items() if k != "trust_remote_code"
|
||||
}
|
||||
|
||||
return cls(
|
||||
model_id=model_id,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
model_kwargs=_model_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Mapping[str, Any]:
|
||||
"""Get the identifying parameters."""
|
||||
|
|
|
|||
Loading…
Reference in a new issue