LLM: fix langchain native int4 voiceasistant example (#8750)

This commit is contained in:
binbin Deng 2023-08-14 17:23:33 +08:00 committed by GitHub
parent d28ad8f7db
commit be2ae6eb7c
2 changed files with 11 additions and 3 deletions

View file

@ -58,13 +58,14 @@ pip install soundfile
```
```bash
python native_int4/voiceassistant.py -x MODEL_FAMILY -m CONVERTED_MODEL_PATH -t THREAD_NUM
python native_int4/voiceassistant.py -x MODEL_FAMILY -m CONVERTED_MODEL_PATH -t THREAD_NUM -c CONTEXT_SIZE
```
arguments info:
- `-m CONVERTED_MODEL_PATH`: **required**, path to the converted model
- `-x MODEL_FAMILY`: **required**, the model family of the model specified in `-m`, available options are `llama`, `gptneox` and `bloom`
- `-t THREAD_NUM`: specify the number of threads to use for inference. Default is `2`.
- `-c CONTEXT_SIZE`: specify maximum context size. Default to be 512.
When you see output says
> listening now...

View file

@ -37,10 +37,13 @@ def prepare_chain(args):
model_path = args.model_path
model_family = args.model_family
n_threads = args.thread_num
n_ctx = args.context_size
# Use a easy prompt could bring good-enough result
# You could tune the prompt based on your own model to perform better
template = """
{history}
Q: {human_input}
A:"""
prompt = PromptTemplate(input_variables=["history", "human_input"], template=template)
@ -51,8 +54,10 @@ def prepare_chain(args):
model_path=model_path,
model_family=model_family,
n_threads=n_threads,
callback_manager=callback_manager,
verbose=True
callback_manager=callback_manager,
verbose=True,
n_ctx=n_ctx,
stop=['\n\n'] # You could tune the stop words based on your own model to perform better
)
# Following code are complete the same as the use-case
@ -116,6 +121,8 @@ if __name__ == '__main__':
help='the path to the converted llm model')
parser.add_argument('-t','--thread-num', type=int, default=2,
help='Number of threads to use for inference')
parser.add_argument('-c','--context-size', type=int, default=512,
help='Maximum context size')
args = parser.parse_args()
main(args)