diff --git a/python/llm/portable-zip/chat.py b/python/llm/portable-zip/chat.py index 06005bd7..94e454d1 100644 --- a/python/llm/portable-zip/chat.py +++ b/python/llm/portable-zip/chat.py @@ -61,6 +61,9 @@ def get_stop_words_ids(chat_format, tokenizer): # https://github.com/QwenLM/Qwen/blob/main/examples/vllm_wrapper.py#L23 if chat_format == "Qwen": stop_words_ids = [[tokenizer.im_end_id], [tokenizer.im_start_id], [tokenizer.eod_id]] + # https://huggingface.co/01-ai/Yi-6B-Chat/blob/main/tokenizer_config.json#L38 + elif chat_format == "Yi": + stop_words_ids = [tokenizer.encode("<|im_end|>")] else: raise NotImplementedError(f"Unknown chat format {chat_format!r}") return stop_words_ids @@ -226,6 +229,34 @@ def llama_stream_chat(model, tokenizer, kv_cache=None, max_gen_len=512, stop_wor model, tokenizer, input_ids, past_key_values, max_gen_len=max_gen_len, stop_words=stop_words ) +@torch.no_grad() +def yi_stream_chat(model, tokenizer, kv_cache=None, max_gen_len=512, stop_words=[]): + past_key_values = None + while True: + user_input = input(Fore.GREEN+"\nHuman: "+Fore.RESET) + # let's stop the conversation when user input "stop" + if user_input == "stop": + break + # https://huggingface.co/01-ai/Yi-6B-Chat#31-use-the-chat-model + prompt = f""" + <|im_start|>system + You are a helpful assistant. If you don't understand what the user means, ask the user to provide more information. + <|im_end|> + <|im_start|>user + {user_input} + <|im_end|> + <|im_start|>assistant + """ + input_ids = tokenizer(prompt, return_tensors="pt").input_ids + seq_len = input_ids.shape[1] + if kv_cache is not None: + space_needed = seq_len + max_gen_len + past_key_values = kv_cache.evict_for_space(past_key_values, space_needed) + + past_key_values = greedy_generate( + model, tokenizer, input_ids, past_key_values, max_gen_len=max_gen_len, stop_words=stop_words + ) + def auto_select_model(model_name): try: try: @@ -265,10 +296,14 @@ if __name__ == "__main__": kv_cache = StartRecentKVCache(start_size=start_size) qwen_stream_chat(model=model, tokenizer=tokenizer,kv_cache=kv_cache, stop_words=stop_words) elif model.config.architectures is not None and model.config.architectures[0] == "ChatGLMModel": - chatglm2_stream_chat(model=model, tokenizer=tokenizer) + chatglm3_stream_chat(model=model, tokenizer=tokenizer) elif model.config.architectures is not None and model.config.architectures[0] == "LlamaForCausalLM": kv_cache = StartRecentKVCache(start_size=start_size) - llama_stream_chat(model=model, tokenizer=tokenizer,kv_cache=kv_cache) + if "yi" in model_path.lower(): + stop_words = get_stop_words_ids("Yi", tokenizer=tokenizer) + yi_stream_chat(model=model, tokenizer=tokenizer, kv_cache=kv_cache, stop_words=stop_words) + else: + llama_stream_chat(model=model, tokenizer=tokenizer, kv_cache=kv_cache) else: kv_cache = StartRecentKVCache(start_size=start_size) stream_chat(model=model,