Fix baichuan-13b issue on portable zip under transformers 4.36 (#10746)
* fix baichuan-13b issue * update * update
This commit is contained in:
		
							parent
							
								
									9e668a5bf0
								
							
						
					
					
						commit
						a9a6b6b7af
					
				
					 2 changed files with 63 additions and 5 deletions
				
			
		| 
						 | 
					@ -11,12 +11,17 @@ This portable zip includes everything you need to run an LLM with IPEX-LLM optim
 | 
				
			||||||
</p>
 | 
					</p>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
### Verified Models
 | 
					### Verified Models
 | 
				
			||||||
 | 
					- Llama-2-7b-chat-hf
 | 
				
			||||||
 | 
					- Yi-6B-Chat
 | 
				
			||||||
 | 
					- Mixtral-8x7B-Instruct-v0.1
 | 
				
			||||||
 | 
					- Mistral-7B-Instruct-v0
 | 
				
			||||||
- ChatGLM2-6b
 | 
					- ChatGLM2-6b
 | 
				
			||||||
 | 
					- ChatGLM3-6b
 | 
				
			||||||
- Baichuan-13B-Chat
 | 
					- Baichuan-13B-Chat
 | 
				
			||||||
- Baichuan2-7B-Chat
 | 
					- Baichuan2-7B-Chat
 | 
				
			||||||
- internlm-chat-7b
 | 
					- internlm-chat-7b
 | 
				
			||||||
- Llama-2-7b-chat-hf
 | 
					- internlm2-chat-7b
 | 
				
			||||||
 | 
					- Qwen-7B-Chat
 | 
				
			||||||
 | 
					
 | 
				
			||||||
## How to use
 | 
					## How to use
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -252,6 +252,51 @@ def yi_stream_chat(model, tokenizer, kv_cache=None, max_gen_len=512, stop_words=
 | 
				
			||||||
            model, tokenizer, input_ids, past_key_values, max_gen_len=max_gen_len, stop_words=stop_words
 | 
					            model, tokenizer, input_ids, past_key_values, max_gen_len=max_gen_len, stop_words=stop_words
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def format_prompt_with_history(input_str,
 | 
				
			||||||
 | 
					                  chat_history):
 | 
				
			||||||
 | 
					    SYSTEM_PROMPT = "A chat between a curious human <human> and an artificial intelligence assistant <bot>.\
 | 
				
			||||||
 | 
					    The assistant gives helpful, detailed, and polite answers to the human's questions."
 | 
				
			||||||
 | 
					    prompt = [f"{SYSTEM_PROMPT}\n"]
 | 
				
			||||||
 | 
					    # prompt = []
 | 
				
			||||||
 | 
					    for history_input_str, history_output_str in chat_history:
 | 
				
			||||||
 | 
					        prompt.append(f"{HUMAN_ID} {history_input_str}\n{BOT_ID} {history_output_str}\n")
 | 
				
			||||||
 | 
					    prompt.append(f"{HUMAN_ID} {input_str}\n{BOT_ID} ")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return "".join(prompt)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def stream_chat_with_history(model, tokenizer):
 | 
				
			||||||
 | 
					    stopping_criteria = StoppingCriteriaList([StopSequenceCriteria(HUMAN_ID, tokenizer)])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    chat_history = []
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    while True:
 | 
				
			||||||
 | 
					        with torch.inference_mode():
 | 
				
			||||||
 | 
					            user_input = input(Fore.GREEN + "\nHuman: " + Fore.RESET)
 | 
				
			||||||
 | 
					            if user_input == "stop":  # let's stop the conversation when user input "stop"
 | 
				
			||||||
 | 
					                break
 | 
				
			||||||
 | 
					            prompt = format_prompt_with_history(user_input, chat_history)
 | 
				
			||||||
 | 
					            # print(prompt)
 | 
				
			||||||
 | 
					            input_ids = tokenizer([prompt], return_tensors="pt")
 | 
				
			||||||
 | 
					            streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
 | 
				
			||||||
 | 
					            generate_kwargs = dict(input_ids, streamer=streamer, max_new_tokens=512,
 | 
				
			||||||
 | 
					                                   stopping_criteria=stopping_criteria)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            from threading import Thread
 | 
				
			||||||
 | 
					            # to ensure non-blocking access to the generated text, generation process should be ran in a separate thread
 | 
				
			||||||
 | 
					            thread = Thread(target=model.generate, kwargs=generate_kwargs)
 | 
				
			||||||
 | 
					            thread.start()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            output_str = []
 | 
				
			||||||
 | 
					            print(Fore.BLUE + "IPEX-LLM: " + Fore.RESET, end="")
 | 
				
			||||||
 | 
					            for partial_output_str in streamer:
 | 
				
			||||||
 | 
					                output_str.append(partial_output_str)
 | 
				
			||||||
 | 
					                # remove the last HUMAN_ID if exists
 | 
				
			||||||
 | 
					                print(partial_output_str.replace(f"{HUMAN_ID}", ""), end="")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            chat_history.append((user_input, "".join(output_str).replace(f"{HUMAN_ID}", "").rstrip()))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def auto_select_model(model_name):
 | 
					def auto_select_model(model_name):
 | 
				
			||||||
    try:
 | 
					    try:
 | 
				
			||||||
        try:
 | 
					        try:
 | 
				
			||||||
| 
						 | 
					@ -276,10 +321,12 @@ if __name__ == "__main__":
 | 
				
			||||||
    parser = argparse.ArgumentParser()
 | 
					    parser = argparse.ArgumentParser()
 | 
				
			||||||
    parser.add_argument("--model-path", type=str, help="path to an llm")
 | 
					    parser.add_argument("--model-path", type=str, help="path to an llm")
 | 
				
			||||||
    parser.add_argument("--start-size", type=int, default=4, help="start_size of kv_cahce")
 | 
					    parser.add_argument("--start-size", type=int, default=4, help="start_size of kv_cahce")
 | 
				
			||||||
 | 
					    parser.add_argument("--recent-size", type=int, default=2000)
 | 
				
			||||||
    args = parser.parse_args()
 | 
					    args = parser.parse_args()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    model_path = args.model_path
 | 
					    model_path = args.model_path
 | 
				
			||||||
    start_size = args.start_size
 | 
					    start_size = args.start_size
 | 
				
			||||||
 | 
					    recent_size = args.recent_size
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    model = auto_select_model(model_path)
 | 
					    model = auto_select_model(model_path)
 | 
				
			||||||
    model = optimize_model(model)
 | 
					    model = optimize_model(model)
 | 
				
			||||||
| 
						 | 
					@ -288,19 +335,25 @@ if __name__ == "__main__":
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if model.config.architectures is not None and model.config.architectures[0] == "QWenLMHeadModel":
 | 
					    if model.config.architectures is not None and model.config.architectures[0] == "QWenLMHeadModel":
 | 
				
			||||||
        stop_words = get_stop_words_ids("Qwen", tokenizer=tokenizer)
 | 
					        stop_words = get_stop_words_ids("Qwen", tokenizer=tokenizer)
 | 
				
			||||||
        kv_cache = StartRecentKVCache(start_size=start_size, k_seq_dim=1, v_seq_dim=1)
 | 
					        kv_cache = StartRecentKVCache(start_size=start_size,
 | 
				
			||||||
 | 
					                                      k_seq_dim=1,
 | 
				
			||||||
 | 
					                                      v_seq_dim=1,
 | 
				
			||||||
 | 
					                                      recent_size=recent_size)
 | 
				
			||||||
        qwen_stream_chat(model=model, tokenizer=tokenizer,kv_cache=kv_cache, stop_words=stop_words)
 | 
					        qwen_stream_chat(model=model, tokenizer=tokenizer,kv_cache=kv_cache, stop_words=stop_words)
 | 
				
			||||||
    elif model.config.architectures is not None and model.config.architectures[0] == "ChatGLMModel":
 | 
					    elif model.config.architectures is not None and model.config.architectures[0] == "ChatGLMModel":
 | 
				
			||||||
        chatglm3_stream_chat(model=model, tokenizer=tokenizer)
 | 
					        chatglm3_stream_chat(model=model, tokenizer=tokenizer)
 | 
				
			||||||
    elif model.config.architectures is not None and model.config.architectures[0] == "LlamaForCausalLM":
 | 
					    elif model.config.architectures is not None and model.config.architectures[0] == "LlamaForCausalLM":
 | 
				
			||||||
        kv_cache = StartRecentKVCache(start_size=start_size)
 | 
					        kv_cache = StartRecentKVCache(start_size=start_size, recent_size=recent_size)
 | 
				
			||||||
        if "yi" in model_path.lower():
 | 
					        if "yi" in model_path.lower():
 | 
				
			||||||
            stop_words = get_stop_words_ids("Yi", tokenizer=tokenizer)
 | 
					            stop_words = get_stop_words_ids("Yi", tokenizer=tokenizer)
 | 
				
			||||||
            yi_stream_chat(model=model, tokenizer=tokenizer, kv_cache=kv_cache, stop_words=stop_words)
 | 
					            yi_stream_chat(model=model, tokenizer=tokenizer, kv_cache=kv_cache, stop_words=stop_words)
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            llama_stream_chat(model=model, tokenizer=tokenizer, kv_cache=kv_cache)
 | 
					            llama_stream_chat(model=model, tokenizer=tokenizer, kv_cache=kv_cache)
 | 
				
			||||||
 | 
					    elif model.config.architectures[0] == "BaichuanForCausalLM" and model.config.vocab_size == 64000:
 | 
				
			||||||
 | 
					        # Baichuan-13B-Chat
 | 
				
			||||||
 | 
					        stream_chat_with_history(model=model, tokenizer=tokenizer)
 | 
				
			||||||
    else:
 | 
					    else:
 | 
				
			||||||
        kv_cache = StartRecentKVCache(start_size=start_size)
 | 
					        kv_cache = StartRecentKVCache(start_size=start_size, recent_size=recent_size)
 | 
				
			||||||
        stream_chat(model=model,
 | 
					        stream_chat(model=model,
 | 
				
			||||||
                    tokenizer=tokenizer,
 | 
					                    tokenizer=tokenizer,
 | 
				
			||||||
                    kv_cache=kv_cache)
 | 
					                    kv_cache=kv_cache)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue