[LLM] Correct chat format of llama and add llama_stream_chat in chat.py
* correct chat format of llama * add llama_stream_chat
This commit is contained in:
parent
0d41b7ba7b
commit
67cc155771
1 changed files with 28 additions and 0 deletions
|
|
@ -201,6 +201,31 @@ def qwen_stream_chat(model, tokenizer, kv_cache=None, max_gen_len=512, stop_word
|
||||||
model, tokenizer, input_ids, past_key_values, max_gen_len=max_gen_len, stop_words=stop_words
|
model, tokenizer, input_ids, past_key_values, max_gen_len=max_gen_len, stop_words=stop_words
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def llama_stream_chat(model, tokenizer, kv_cache=None, max_gen_len=512, stop_words=[]):
|
||||||
|
past_key_values = None
|
||||||
|
while True:
|
||||||
|
user_input = input(Fore.GREEN+"\nHuman: "+Fore.RESET)
|
||||||
|
# let's stop the conversation when user input "stop"
|
||||||
|
if user_input == "stop":
|
||||||
|
break
|
||||||
|
# https://huggingface.co/TheBloke/Llama-2-70B-Chat-GGML#prompt-template-llama-2-chat
|
||||||
|
prompt = f"""
|
||||||
|
[INST] <<SYS>>
|
||||||
|
You are a helpful assistant.
|
||||||
|
<</SYS>>
|
||||||
|
{user_input}[/INST]
|
||||||
|
"""
|
||||||
|
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
|
||||||
|
seq_len = input_ids.shape[1]
|
||||||
|
if kv_cache is not None:
|
||||||
|
space_needed = seq_len + max_gen_len
|
||||||
|
past_key_values = kv_cache.evict_for_space(past_key_values, space_needed)
|
||||||
|
|
||||||
|
past_key_values = greedy_generate(
|
||||||
|
model, tokenizer, input_ids, past_key_values, max_gen_len=max_gen_len, stop_words=stop_words
|
||||||
|
)
|
||||||
|
|
||||||
def auto_select_model(model_name):
|
def auto_select_model(model_name):
|
||||||
try:
|
try:
|
||||||
try:
|
try:
|
||||||
|
|
@ -241,6 +266,9 @@ if __name__ == "__main__":
|
||||||
qwen_stream_chat(model=model, tokenizer=tokenizer,kv_cache=kv_cache, stop_words=stop_words)
|
qwen_stream_chat(model=model, tokenizer=tokenizer,kv_cache=kv_cache, stop_words=stop_words)
|
||||||
elif model.config.architectures is not None and model.config.architectures[0] == "ChatGLMModel":
|
elif model.config.architectures is not None and model.config.architectures[0] == "ChatGLMModel":
|
||||||
chatglm2_stream_chat(model=model, tokenizer=tokenizer)
|
chatglm2_stream_chat(model=model, tokenizer=tokenizer)
|
||||||
|
elif model.config.architectures is not None and model.config.architectures[0] == "LlamaForCausalLM":
|
||||||
|
kv_cache = StartRecentKVCache(start_size=start_size)
|
||||||
|
llama_stream_chat(model=model, tokenizer=tokenizer,kv_cache=kv_cache)
|
||||||
else:
|
else:
|
||||||
kv_cache = StartRecentKVCache(start_size=start_size)
|
kv_cache = StartRecentKVCache(start_size=start_size)
|
||||||
stream_chat(model=model,
|
stream_chat(model=model,
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue