[LLM] Correct chat format of llama and add llama_stream_chat in chat.py
* correct chat format of llama * add llama_stream_chat
This commit is contained in:
parent
0d41b7ba7b
commit
67cc155771
1 changed files with 28 additions and 0 deletions
|
|
@ -201,6 +201,31 @@ def qwen_stream_chat(model, tokenizer, kv_cache=None, max_gen_len=512, stop_word
|
|||
model, tokenizer, input_ids, past_key_values, max_gen_len=max_gen_len, stop_words=stop_words
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def llama_stream_chat(model, tokenizer, kv_cache=None, max_gen_len=512, stop_words=[]):
|
||||
past_key_values = None
|
||||
while True:
|
||||
user_input = input(Fore.GREEN+"\nHuman: "+Fore.RESET)
|
||||
# let's stop the conversation when user input "stop"
|
||||
if user_input == "stop":
|
||||
break
|
||||
# https://huggingface.co/TheBloke/Llama-2-70B-Chat-GGML#prompt-template-llama-2-chat
|
||||
prompt = f"""
|
||||
[INST] <<SYS>>
|
||||
You are a helpful assistant.
|
||||
<</SYS>>
|
||||
{user_input}[/INST]
|
||||
"""
|
||||
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
|
||||
seq_len = input_ids.shape[1]
|
||||
if kv_cache is not None:
|
||||
space_needed = seq_len + max_gen_len
|
||||
past_key_values = kv_cache.evict_for_space(past_key_values, space_needed)
|
||||
|
||||
past_key_values = greedy_generate(
|
||||
model, tokenizer, input_ids, past_key_values, max_gen_len=max_gen_len, stop_words=stop_words
|
||||
)
|
||||
|
||||
def auto_select_model(model_name):
|
||||
try:
|
||||
try:
|
||||
|
|
@ -241,6 +266,9 @@ if __name__ == "__main__":
|
|||
qwen_stream_chat(model=model, tokenizer=tokenizer,kv_cache=kv_cache, stop_words=stop_words)
|
||||
elif model.config.architectures is not None and model.config.architectures[0] == "ChatGLMModel":
|
||||
chatglm2_stream_chat(model=model, tokenizer=tokenizer)
|
||||
elif model.config.architectures is not None and model.config.architectures[0] == "LlamaForCausalLM":
|
||||
kv_cache = StartRecentKVCache(start_size=start_size)
|
||||
llama_stream_chat(model=model, tokenizer=tokenizer,kv_cache=kv_cache)
|
||||
else:
|
||||
kv_cache = StartRecentKVCache(start_size=start_size)
|
||||
stream_chat(model=model,
|
||||
|
|
|
|||
Loading…
Reference in a new issue