[LLM] Support Yi model in chat.py (#9778)
* Suppot Yi model * code style& add reference link
This commit is contained in:
parent
11d883301b
commit
87b4100054
1 changed files with 37 additions and 2 deletions
|
|
@ -61,6 +61,9 @@ def get_stop_words_ids(chat_format, tokenizer):
|
|||
# https://github.com/QwenLM/Qwen/blob/main/examples/vllm_wrapper.py#L23
|
||||
if chat_format == "Qwen":
|
||||
stop_words_ids = [[tokenizer.im_end_id], [tokenizer.im_start_id], [tokenizer.eod_id]]
|
||||
# https://huggingface.co/01-ai/Yi-6B-Chat/blob/main/tokenizer_config.json#L38
|
||||
elif chat_format == "Yi":
|
||||
stop_words_ids = [tokenizer.encode("<|im_end|>")]
|
||||
else:
|
||||
raise NotImplementedError(f"Unknown chat format {chat_format!r}")
|
||||
return stop_words_ids
|
||||
|
|
@ -226,6 +229,34 @@ def llama_stream_chat(model, tokenizer, kv_cache=None, max_gen_len=512, stop_wor
|
|||
model, tokenizer, input_ids, past_key_values, max_gen_len=max_gen_len, stop_words=stop_words
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def yi_stream_chat(model, tokenizer, kv_cache=None, max_gen_len=512, stop_words=[]):
|
||||
past_key_values = None
|
||||
while True:
|
||||
user_input = input(Fore.GREEN+"\nHuman: "+Fore.RESET)
|
||||
# let's stop the conversation when user input "stop"
|
||||
if user_input == "stop":
|
||||
break
|
||||
# https://huggingface.co/01-ai/Yi-6B-Chat#31-use-the-chat-model
|
||||
prompt = f"""
|
||||
<|im_start|>system
|
||||
You are a helpful assistant. If you don't understand what the user means, ask the user to provide more information.
|
||||
<|im_end|>
|
||||
<|im_start|>user
|
||||
{user_input}
|
||||
<|im_end|>
|
||||
<|im_start|>assistant
|
||||
"""
|
||||
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
|
||||
seq_len = input_ids.shape[1]
|
||||
if kv_cache is not None:
|
||||
space_needed = seq_len + max_gen_len
|
||||
past_key_values = kv_cache.evict_for_space(past_key_values, space_needed)
|
||||
|
||||
past_key_values = greedy_generate(
|
||||
model, tokenizer, input_ids, past_key_values, max_gen_len=max_gen_len, stop_words=stop_words
|
||||
)
|
||||
|
||||
def auto_select_model(model_name):
|
||||
try:
|
||||
try:
|
||||
|
|
@ -265,9 +296,13 @@ if __name__ == "__main__":
|
|||
kv_cache = StartRecentKVCache(start_size=start_size)
|
||||
qwen_stream_chat(model=model, tokenizer=tokenizer,kv_cache=kv_cache, stop_words=stop_words)
|
||||
elif model.config.architectures is not None and model.config.architectures[0] == "ChatGLMModel":
|
||||
chatglm2_stream_chat(model=model, tokenizer=tokenizer)
|
||||
chatglm3_stream_chat(model=model, tokenizer=tokenizer)
|
||||
elif model.config.architectures is not None and model.config.architectures[0] == "LlamaForCausalLM":
|
||||
kv_cache = StartRecentKVCache(start_size=start_size)
|
||||
if "yi" in model_path.lower():
|
||||
stop_words = get_stop_words_ids("Yi", tokenizer=tokenizer)
|
||||
yi_stream_chat(model=model, tokenizer=tokenizer, kv_cache=kv_cache, stop_words=stop_words)
|
||||
else:
|
||||
llama_stream_chat(model=model, tokenizer=tokenizer, kv_cache=kv_cache)
|
||||
else:
|
||||
kv_cache = StartRecentKVCache(start_size=start_size)
|
||||
|
|
|
|||
Loading…
Reference in a new issue