[LLM] Support Yi model in chat.py (#9778)
* Suppot Yi model * code style& add reference link
This commit is contained in:
parent
11d883301b
commit
87b4100054
1 changed files with 37 additions and 2 deletions
|
|
@ -61,6 +61,9 @@ def get_stop_words_ids(chat_format, tokenizer):
|
||||||
# https://github.com/QwenLM/Qwen/blob/main/examples/vllm_wrapper.py#L23
|
# https://github.com/QwenLM/Qwen/blob/main/examples/vllm_wrapper.py#L23
|
||||||
if chat_format == "Qwen":
|
if chat_format == "Qwen":
|
||||||
stop_words_ids = [[tokenizer.im_end_id], [tokenizer.im_start_id], [tokenizer.eod_id]]
|
stop_words_ids = [[tokenizer.im_end_id], [tokenizer.im_start_id], [tokenizer.eod_id]]
|
||||||
|
# https://huggingface.co/01-ai/Yi-6B-Chat/blob/main/tokenizer_config.json#L38
|
||||||
|
elif chat_format == "Yi":
|
||||||
|
stop_words_ids = [tokenizer.encode("<|im_end|>")]
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f"Unknown chat format {chat_format!r}")
|
raise NotImplementedError(f"Unknown chat format {chat_format!r}")
|
||||||
return stop_words_ids
|
return stop_words_ids
|
||||||
|
|
@ -226,6 +229,34 @@ def llama_stream_chat(model, tokenizer, kv_cache=None, max_gen_len=512, stop_wor
|
||||||
model, tokenizer, input_ids, past_key_values, max_gen_len=max_gen_len, stop_words=stop_words
|
model, tokenizer, input_ids, past_key_values, max_gen_len=max_gen_len, stop_words=stop_words
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def yi_stream_chat(model, tokenizer, kv_cache=None, max_gen_len=512, stop_words=[]):
|
||||||
|
past_key_values = None
|
||||||
|
while True:
|
||||||
|
user_input = input(Fore.GREEN+"\nHuman: "+Fore.RESET)
|
||||||
|
# let's stop the conversation when user input "stop"
|
||||||
|
if user_input == "stop":
|
||||||
|
break
|
||||||
|
# https://huggingface.co/01-ai/Yi-6B-Chat#31-use-the-chat-model
|
||||||
|
prompt = f"""
|
||||||
|
<|im_start|>system
|
||||||
|
You are a helpful assistant. If you don't understand what the user means, ask the user to provide more information.
|
||||||
|
<|im_end|>
|
||||||
|
<|im_start|>user
|
||||||
|
{user_input}
|
||||||
|
<|im_end|>
|
||||||
|
<|im_start|>assistant
|
||||||
|
"""
|
||||||
|
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
|
||||||
|
seq_len = input_ids.shape[1]
|
||||||
|
if kv_cache is not None:
|
||||||
|
space_needed = seq_len + max_gen_len
|
||||||
|
past_key_values = kv_cache.evict_for_space(past_key_values, space_needed)
|
||||||
|
|
||||||
|
past_key_values = greedy_generate(
|
||||||
|
model, tokenizer, input_ids, past_key_values, max_gen_len=max_gen_len, stop_words=stop_words
|
||||||
|
)
|
||||||
|
|
||||||
def auto_select_model(model_name):
|
def auto_select_model(model_name):
|
||||||
try:
|
try:
|
||||||
try:
|
try:
|
||||||
|
|
@ -265,10 +296,14 @@ if __name__ == "__main__":
|
||||||
kv_cache = StartRecentKVCache(start_size=start_size)
|
kv_cache = StartRecentKVCache(start_size=start_size)
|
||||||
qwen_stream_chat(model=model, tokenizer=tokenizer,kv_cache=kv_cache, stop_words=stop_words)
|
qwen_stream_chat(model=model, tokenizer=tokenizer,kv_cache=kv_cache, stop_words=stop_words)
|
||||||
elif model.config.architectures is not None and model.config.architectures[0] == "ChatGLMModel":
|
elif model.config.architectures is not None and model.config.architectures[0] == "ChatGLMModel":
|
||||||
chatglm2_stream_chat(model=model, tokenizer=tokenizer)
|
chatglm3_stream_chat(model=model, tokenizer=tokenizer)
|
||||||
elif model.config.architectures is not None and model.config.architectures[0] == "LlamaForCausalLM":
|
elif model.config.architectures is not None and model.config.architectures[0] == "LlamaForCausalLM":
|
||||||
kv_cache = StartRecentKVCache(start_size=start_size)
|
kv_cache = StartRecentKVCache(start_size=start_size)
|
||||||
llama_stream_chat(model=model, tokenizer=tokenizer,kv_cache=kv_cache)
|
if "yi" in model_path.lower():
|
||||||
|
stop_words = get_stop_words_ids("Yi", tokenizer=tokenizer)
|
||||||
|
yi_stream_chat(model=model, tokenizer=tokenizer, kv_cache=kv_cache, stop_words=stop_words)
|
||||||
|
else:
|
||||||
|
llama_stream_chat(model=model, tokenizer=tokenizer, kv_cache=kv_cache)
|
||||||
else:
|
else:
|
||||||
kv_cache = StartRecentKVCache(start_size=start_size)
|
kv_cache = StartRecentKVCache(start_size=start_size)
|
||||||
stream_chat(model=model,
|
stream_chat(model=model,
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue