LLM: Optimize Langchain Pipeline (#8561)
* LLM: Optimize Langchain Pipeline * load in low bit
This commit is contained in:
parent
616b7cb0a2
commit
e680af45ea
2 changed files with 73 additions and 9 deletions
|
|
@ -33,6 +33,23 @@ import pyttsx3
|
||||||
import argparse
|
import argparse
|
||||||
import time
|
import time
|
||||||
|
|
||||||
|
english_template = """
|
||||||
|
{history}
|
||||||
|
Q: {human_input}
|
||||||
|
A:"""
|
||||||
|
|
||||||
|
chinese_template = """{history}\n\n问:{human_input}\n\n答:"""
|
||||||
|
|
||||||
|
|
||||||
|
template_dict = {
|
||||||
|
"english": english_template,
|
||||||
|
"chinese": chinese_template
|
||||||
|
}
|
||||||
|
|
||||||
|
llm_load_methods = (
|
||||||
|
TransformersLLM.from_model_id,
|
||||||
|
TransformersLLM.from_model_id_low_bit,
|
||||||
|
)
|
||||||
|
|
||||||
def prepare_chain(args):
|
def prepare_chain(args):
|
||||||
|
|
||||||
|
|
@ -41,16 +58,13 @@ def prepare_chain(args):
|
||||||
# Use a easy prompt could bring good-enough result
|
# Use a easy prompt could bring good-enough result
|
||||||
# For Chinese Prompt
|
# For Chinese Prompt
|
||||||
# template = """{history}\n\n问:{human_input}\n\n答:"""
|
# template = """{history}\n\n问:{human_input}\n\n答:"""
|
||||||
template = """
|
template = template_dict[args.language]
|
||||||
{history}
|
|
||||||
Q: {human_input}
|
|
||||||
A:"""
|
|
||||||
prompt = PromptTemplate(input_variables=["history", "human_input"], template=template)
|
prompt = PromptTemplate(input_variables=["history", "human_input"], template=template)
|
||||||
|
|
||||||
llm = TransformersLLM.from_model_id(
|
method_index = 1 if args.directly else 0
|
||||||
|
llm = llm_load_methods[method_index](
|
||||||
model_id=llm_model_path,
|
model_id=llm_model_path,
|
||||||
model_kwargs={"temperature": 0,
|
model_kwargs={"temperature": 0,
|
||||||
"max_length": args.max_length,
|
|
||||||
"trust_remote_code": True},
|
"trust_remote_code": True},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -59,6 +73,7 @@ def prepare_chain(args):
|
||||||
llm=llm,
|
llm=llm,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
verbose=True,
|
verbose=True,
|
||||||
|
llm_kwargs={"max_new_tokens":args.max_new_tokens},
|
||||||
memory=ConversationBufferWindowMemory(k=2),
|
memory=ConversationBufferWindowMemory(k=2),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -126,10 +141,12 @@ if __name__ == '__main__':
|
||||||
help="the path to the huggingface speech recognition model")
|
help="the path to the huggingface speech recognition model")
|
||||||
parser.add_argument('-m','--llm-model-path', type=str, required=True,
|
parser.add_argument('-m','--llm-model-path', type=str, required=True,
|
||||||
help='the path to the huggingface llm model')
|
help='the path to the huggingface llm model')
|
||||||
parser.add_argument('-x','--max-length', type=int, default=256,
|
parser.add_argument('-x','--max-new-tokens', type=int, default=32,
|
||||||
help='the max length of model tokens input')
|
help='the max new tokens of model tokens input')
|
||||||
parser.add_argument('-l', '--language', type=str, default="english",
|
parser.add_argument('-l', '--language', type=str, default="english",
|
||||||
help='language to be transcribed')
|
help='the language to be transcribed')
|
||||||
|
parser.add_argument('-d', '--directly', action='store_true',
|
||||||
|
help='whether to load low bit model directly')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
main(args)
|
main(args)
|
||||||
|
|
@ -130,6 +130,53 @@ class TransformersLLM(LLM):
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_model_id_low_bit(
|
||||||
|
cls,
|
||||||
|
model_id: str,
|
||||||
|
model_kwargs: Optional[dict] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> LLM:
|
||||||
|
"""Construct object from model_id"""
|
||||||
|
try:
|
||||||
|
from bigdl.llm.transformers import (
|
||||||
|
AutoModel,
|
||||||
|
AutoModelForCausalLM,
|
||||||
|
)
|
||||||
|
from transformers import AutoTokenizer, LlamaTokenizer
|
||||||
|
|
||||||
|
except ImportError:
|
||||||
|
raise ValueError(
|
||||||
|
"Could not import transformers python package. "
|
||||||
|
"Please install it with `pip install transformers`."
|
||||||
|
)
|
||||||
|
|
||||||
|
_model_kwargs = model_kwargs or {}
|
||||||
|
# TODO: may refactore this code in the future
|
||||||
|
try:
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_id, **_model_kwargs)
|
||||||
|
except:
|
||||||
|
tokenizer = LlamaTokenizer.from_pretrained(model_id, **_model_kwargs)
|
||||||
|
|
||||||
|
# TODO: may refactore this code in the future
|
||||||
|
try:
|
||||||
|
model = AutoModelForCausalLM.load_low_bit(model_id, **_model_kwargs)
|
||||||
|
except:
|
||||||
|
model = AutoModel.load_low_bit(model_id, **_model_kwargs)
|
||||||
|
|
||||||
|
if "trust_remote_code" in _model_kwargs:
|
||||||
|
_model_kwargs = {
|
||||||
|
k: v for k, v in _model_kwargs.items() if k != "trust_remote_code"
|
||||||
|
}
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
model_id=model_id,
|
||||||
|
model=model,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
model_kwargs=_model_kwargs,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _identifying_params(self) -> Mapping[str, Any]:
|
def _identifying_params(self) -> Mapping[str, Any]:
|
||||||
"""Get the identifying parameters."""
|
"""Get the identifying parameters."""
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue