[LLM] Correct chat format of llama and add llama_stream_chat in chat.py

* correct chat format of llama
* add llama_stream_chat
This commit is contained in:
Ziteng Zhang 2023-12-15 16:36:46 +08:00 committed by GitHub
parent 0d41b7ba7b
commit 67cc155771

View file

@ -201,6 +201,31 @@ def qwen_stream_chat(model, tokenizer, kv_cache=None, max_gen_len=512, stop_word
model, tokenizer, input_ids, past_key_values, max_gen_len=max_gen_len, stop_words=stop_words
)
@torch.no_grad()
def llama_stream_chat(model, tokenizer, kv_cache=None, max_gen_len=512, stop_words=[]):
past_key_values = None
while True:
user_input = input(Fore.GREEN+"\nHuman: "+Fore.RESET)
# let's stop the conversation when user input "stop"
if user_input == "stop":
break
# https://huggingface.co/TheBloke/Llama-2-70B-Chat-GGML#prompt-template-llama-2-chat
prompt = f"""
[INST] <<SYS>>
You are a helpful assistant.
<</SYS>>
{user_input}[/INST]
"""
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
seq_len = input_ids.shape[1]
if kv_cache is not None:
space_needed = seq_len + max_gen_len
past_key_values = kv_cache.evict_for_space(past_key_values, space_needed)
past_key_values = greedy_generate(
model, tokenizer, input_ids, past_key_values, max_gen_len=max_gen_len, stop_words=stop_words
)
def auto_select_model(model_name):
try:
try:
@ -241,6 +266,9 @@ if __name__ == "__main__":
qwen_stream_chat(model=model, tokenizer=tokenizer,kv_cache=kv_cache, stop_words=stop_words)
elif model.config.architectures is not None and model.config.architectures[0] == "ChatGLMModel":
chatglm2_stream_chat(model=model, tokenizer=tokenizer)
elif model.config.architectures is not None and model.config.architectures[0] == "LlamaForCausalLM":
kv_cache = StartRecentKVCache(start_size=start_size)
llama_stream_chat(model=model, tokenizer=tokenizer,kv_cache=kv_cache)
else:
kv_cache = StartRecentKVCache(start_size=start_size)
stream_chat(model=model,