[LLM] Support Yi model in chat.py (#9778)

* Suppot Yi model

* code style& add reference link
This commit is contained in:
Ziteng Zhang 2023-12-26 10:03:39 +08:00 committed by GitHub
parent 11d883301b
commit 87b4100054

View file

@ -61,6 +61,9 @@ def get_stop_words_ids(chat_format, tokenizer):
# https://github.com/QwenLM/Qwen/blob/main/examples/vllm_wrapper.py#L23 # https://github.com/QwenLM/Qwen/blob/main/examples/vllm_wrapper.py#L23
if chat_format == "Qwen": if chat_format == "Qwen":
stop_words_ids = [[tokenizer.im_end_id], [tokenizer.im_start_id], [tokenizer.eod_id]] stop_words_ids = [[tokenizer.im_end_id], [tokenizer.im_start_id], [tokenizer.eod_id]]
# https://huggingface.co/01-ai/Yi-6B-Chat/blob/main/tokenizer_config.json#L38
elif chat_format == "Yi":
stop_words_ids = [tokenizer.encode("<|im_end|>")]
else: else:
raise NotImplementedError(f"Unknown chat format {chat_format!r}") raise NotImplementedError(f"Unknown chat format {chat_format!r}")
return stop_words_ids return stop_words_ids
@ -226,6 +229,34 @@ def llama_stream_chat(model, tokenizer, kv_cache=None, max_gen_len=512, stop_wor
model, tokenizer, input_ids, past_key_values, max_gen_len=max_gen_len, stop_words=stop_words model, tokenizer, input_ids, past_key_values, max_gen_len=max_gen_len, stop_words=stop_words
) )
@torch.no_grad()
def yi_stream_chat(model, tokenizer, kv_cache=None, max_gen_len=512, stop_words=[]):
past_key_values = None
while True:
user_input = input(Fore.GREEN+"\nHuman: "+Fore.RESET)
# let's stop the conversation when user input "stop"
if user_input == "stop":
break
# https://huggingface.co/01-ai/Yi-6B-Chat#31-use-the-chat-model
prompt = f"""
<|im_start|>system
You are a helpful assistant. If you don't understand what the user means, ask the user to provide more information.
<|im_end|>
<|im_start|>user
{user_input}
<|im_end|>
<|im_start|>assistant
"""
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
seq_len = input_ids.shape[1]
if kv_cache is not None:
space_needed = seq_len + max_gen_len
past_key_values = kv_cache.evict_for_space(past_key_values, space_needed)
past_key_values = greedy_generate(
model, tokenizer, input_ids, past_key_values, max_gen_len=max_gen_len, stop_words=stop_words
)
def auto_select_model(model_name): def auto_select_model(model_name):
try: try:
try: try:
@ -265,10 +296,14 @@ if __name__ == "__main__":
kv_cache = StartRecentKVCache(start_size=start_size) kv_cache = StartRecentKVCache(start_size=start_size)
qwen_stream_chat(model=model, tokenizer=tokenizer,kv_cache=kv_cache, stop_words=stop_words) qwen_stream_chat(model=model, tokenizer=tokenizer,kv_cache=kv_cache, stop_words=stop_words)
elif model.config.architectures is not None and model.config.architectures[0] == "ChatGLMModel": elif model.config.architectures is not None and model.config.architectures[0] == "ChatGLMModel":
chatglm2_stream_chat(model=model, tokenizer=tokenizer) chatglm3_stream_chat(model=model, tokenizer=tokenizer)
elif model.config.architectures is not None and model.config.architectures[0] == "LlamaForCausalLM": elif model.config.architectures is not None and model.config.architectures[0] == "LlamaForCausalLM":
kv_cache = StartRecentKVCache(start_size=start_size) kv_cache = StartRecentKVCache(start_size=start_size)
llama_stream_chat(model=model, tokenizer=tokenizer,kv_cache=kv_cache) if "yi" in model_path.lower():
stop_words = get_stop_words_ids("Yi", tokenizer=tokenizer)
yi_stream_chat(model=model, tokenizer=tokenizer, kv_cache=kv_cache, stop_words=stop_words)
else:
llama_stream_chat(model=model, tokenizer=tokenizer, kv_cache=kv_cache)
else: else:
kv_cache = StartRecentKVCache(start_size=start_size) kv_cache = StartRecentKVCache(start_size=start_size)
stream_chat(model=model, stream_chat(model=model,