diff --git a/python/llm/example/langchain/transformers_int4/voiceassistant.py b/python/llm/example/langchain/transformers_int4/voiceassistant.py index b9fdf956..7c649fdb 100644 --- a/python/llm/example/langchain/transformers_int4/voiceassistant.py +++ b/python/llm/example/langchain/transformers_int4/voiceassistant.py @@ -33,6 +33,23 @@ import pyttsx3 import argparse import time +english_template = """ +{history} +Q: {human_input} +A:""" + +chinese_template = """{history}\n\n问:{human_input}\n\n答:""" + + +template_dict = { + "english": english_template, + "chinese": chinese_template +} + +llm_load_methods = ( + TransformersLLM.from_model_id, + TransformersLLM.from_model_id_low_bit, +) def prepare_chain(args): @@ -41,16 +58,13 @@ def prepare_chain(args): # Use a easy prompt could bring good-enough result # For Chinese Prompt # template = """{history}\n\n问:{human_input}\n\n答:""" - template = """ - {history} - Q: {human_input} - A:""" + template = template_dict[args.language] prompt = PromptTemplate(input_variables=["history", "human_input"], template=template) - llm = TransformersLLM.from_model_id( + method_index = 1 if args.directly else 0 + llm = llm_load_methods[method_index]( model_id=llm_model_path, model_kwargs={"temperature": 0, - "max_length": args.max_length, "trust_remote_code": True}, ) @@ -59,6 +73,7 @@ def prepare_chain(args): llm=llm, prompt=prompt, verbose=True, + llm_kwargs={"max_new_tokens":args.max_new_tokens}, memory=ConversationBufferWindowMemory(k=2), ) @@ -126,10 +141,12 @@ if __name__ == '__main__': help="the path to the huggingface speech recognition model") parser.add_argument('-m','--llm-model-path', type=str, required=True, help='the path to the huggingface llm model') - parser.add_argument('-x','--max-length', type=int, default=256, - help='the max length of model tokens input') + parser.add_argument('-x','--max-new-tokens', type=int, default=32, + help='the max new tokens of model tokens input') parser.add_argument('-l', '--language', type=str, default="english", - help='language to be transcribed') + help='the language to be transcribed') + parser.add_argument('-d', '--directly', action='store_true', + help='whether to load low bit model directly') args = parser.parse_args() main(args) \ No newline at end of file diff --git a/python/llm/src/bigdl/llm/langchain/llms/transformersllm.py b/python/llm/src/bigdl/llm/langchain/llms/transformersllm.py index d0801041..ade3cd01 100644 --- a/python/llm/src/bigdl/llm/langchain/llms/transformersllm.py +++ b/python/llm/src/bigdl/llm/langchain/llms/transformersllm.py @@ -130,6 +130,53 @@ class TransformersLLM(LLM): **kwargs, ) + @classmethod + def from_model_id_low_bit( + cls, + model_id: str, + model_kwargs: Optional[dict] = None, + **kwargs: Any, + ) -> LLM: + """Construct object from model_id""" + try: + from bigdl.llm.transformers import ( + AutoModel, + AutoModelForCausalLM, + ) + from transformers import AutoTokenizer, LlamaTokenizer + + except ImportError: + raise ValueError( + "Could not import transformers python package. " + "Please install it with `pip install transformers`." + ) + + _model_kwargs = model_kwargs or {} + # TODO: may refactore this code in the future + try: + tokenizer = AutoTokenizer.from_pretrained(model_id, **_model_kwargs) + except: + tokenizer = LlamaTokenizer.from_pretrained(model_id, **_model_kwargs) + + # TODO: may refactore this code in the future + try: + model = AutoModelForCausalLM.load_low_bit(model_id, **_model_kwargs) + except: + model = AutoModel.load_low_bit(model_id, **_model_kwargs) + + if "trust_remote_code" in _model_kwargs: + _model_kwargs = { + k: v for k, v in _model_kwargs.items() if k != "trust_remote_code" + } + + return cls( + model_id=model_id, + model=model, + tokenizer=tokenizer, + model_kwargs=_model_kwargs, + **kwargs, + ) + @property def _identifying_params(self) -> Mapping[str, Any]: """Get the identifying parameters."""