diff --git a/python/llm/portable-zip/chat.py b/python/llm/portable-zip/chat.py index cecf9700..5e463cf8 100644 --- a/python/llm/portable-zip/chat.py +++ b/python/llm/portable-zip/chat.py @@ -57,8 +57,16 @@ from kv_cache import StartRecentKVCache HUMAN_ID = "" BOT_ID = "" +def get_stop_words_ids(chat_format, tokenizer): + # https://github.com/QwenLM/Qwen/blob/main/examples/vllm_wrapper.py#L23 + if chat_format == "Qwen": + stop_words_ids = [[tokenizer.im_end_id], [tokenizer.im_start_id], [tokenizer.eod_id]] + else: + raise NotImplementedError(f"Unknown chat format {chat_format!r}") + return stop_words_ids + @torch.no_grad() -def greedy_generate(model, tokenizer, input_ids, past_key_values, max_gen_len): +def greedy_generate(model, tokenizer, input_ids, past_key_values, max_gen_len, stop_words=[]): print(Fore.BLUE+"BigDL-LLM: "+Fore.RESET, end="") outputs = model( input_ids=input_ids, @@ -69,6 +77,7 @@ def greedy_generate(model, tokenizer, input_ids, past_key_values, max_gen_len): pred_token_idx = outputs.logits[:, -1, :].argmax(dim=-1).unsqueeze(1) generated_ids = [pred_token_idx.item()] pos = 0 + stop = False for _ in range(max_gen_len - 1): outputs = model( input_ids=pred_token_idx, @@ -78,6 +87,15 @@ def greedy_generate(model, tokenizer, input_ids, past_key_values, max_gen_len): past_key_values = outputs.past_key_values pred_token_idx = outputs.logits[:, -1, :].argmax(dim=-1).unsqueeze(1) generated_ids.append(pred_token_idx.item()) + + if stop_words is not None: + for stop_str in stop_words: + if generated_ids[-1 * len(stop_str):] == stop_str: + stop = True + break + if stop: + break + generated_text = tokenizer.decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True, spaces_between_special_tokens=False) @@ -96,11 +114,12 @@ def greedy_generate(model, tokenizer, input_ids, past_key_values, max_gen_len): return past_key_values @torch.no_grad() -def stream_chat(model, tokenizer, kv_cache=None, max_gen_len=512): +def stream_chat(model, tokenizer, kv_cache=None, max_gen_len=512, stop_words=[]): past_key_values = None while True: user_input = input(Fore.GREEN+"\nHuman: "+Fore.RESET) - if user_input == "stop": # let's stop the conversation when user input "stop" + # let's stop the conversation when user input "stop" + if user_input == "stop": break prompt = f"{HUMAN_ID} {user_input}\n{BOT_ID} " input_ids = tokenizer(prompt, return_tensors="pt").input_ids @@ -110,7 +129,7 @@ def stream_chat(model, tokenizer, kv_cache=None, max_gen_len=512): past_key_values = kv_cache.evict_for_space(past_key_values, space_needed) past_key_values = greedy_generate( - model, tokenizer, input_ids, past_key_values, max_gen_len=max_gen_len + model, tokenizer, input_ids, past_key_values, max_gen_len=max_gen_len, stop_words=stop_words ) @torch.no_grad() @@ -123,7 +142,8 @@ def chatglm2_stream_chat(model, tokenizer): while True: user_input = input(Fore.GREEN+"\nHuman: "+Fore.RESET) - if user_input == "stop": # let's stop the conversation when user input "stop" + # let's stop the conversation when user input "stop" + if user_input == "stop": break print(Fore.BLUE+"BigDL-LLM: "+Fore.RESET, end="") prompt = f"问:{user_input}\n答:" @@ -145,6 +165,34 @@ def chatglm2_stream_chat(model, tokenizer): new_values_list.append(tuple(new_value)) past_key_values = tuple(new_values_list) +@torch.no_grad() +def qwen_stream_chat(model, tokenizer, kv_cache=None, max_gen_len=512, stop_words=[]): + past_key_values = None + while True: + user_input = input(Fore.GREEN+"\nHuman: "+Fore.RESET) + # let's stop the conversation when user input "stop" + if user_input == "stop": + break + # https://huggingface.co/Qwen/Qwen-7B-Chat/blob/main/generation_config.json#L2 + prompt = f""" + <|im_start|>system + You are a helpful assistant. + <|im_end|> + <|im_start|>user + {user_input} + <|im_end|> + <|im_start|>assistant + """ + input_ids = tokenizer(prompt, return_tensors="pt").input_ids + seq_len = input_ids.shape[1] + if kv_cache is not None: + space_needed = seq_len + max_gen_len + past_key_values = kv_cache.evict_for_space(past_key_values, space_needed) + + past_key_values = greedy_generate( + model, tokenizer, input_ids, past_key_values, max_gen_len=max_gen_len, stop_words=stop_words + ) + def auto_select_model(model_name): try: try: @@ -168,19 +216,25 @@ def auto_select_model(model_name): if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--model-path", type=str, help="path to an llm") + parser.add_argument("--start-size", type=int, default=4, help="start_size of kv_cahce") args = parser.parse_args() model_path = args.model_path + start_size = args.start_size model = auto_select_model(model_path) model = optimize_model(model) tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) - if model.config.architectures is not None and model.config.architectures[0] == "ChatGLMModel": + if model.config.architectures is not None and model.config.architectures[0] == "QWenLMHeadModel": + stop_words = get_stop_words_ids("Qwen", tokenizer=tokenizer) + kv_cache = StartRecentKVCache(start_size=start_size) + qwen_stream_chat(model=model, tokenizer=tokenizer,kv_cache=kv_cache, stop_words=stop_words) + elif model.config.architectures is not None and model.config.architectures[0] == "ChatGLMModel": chatglm2_stream_chat(model=model, tokenizer=tokenizer) else: - kv_cache = StartRecentKVCache() + kv_cache = StartRecentKVCache(start_size=start_size) stream_chat(model=model, tokenizer=tokenizer, kv_cache=kv_cache)