From 8ef8e251786b9f9d526de1be9b302d1d833334a7 Mon Sep 17 00:00:00 2001 From: binbin Deng <108676127+plusbang@users.noreply.github.com> Date: Wed, 1 Nov 2023 10:30:44 +0800 Subject: [PATCH] LLM: improve response speed in multi-turn chat (#9299) * update * fix stop word and add chatglm2 support * remove system prompt --- python/llm/portable-zip/chat.py | 180 +++++++++++++++++++--------- python/llm/portable-zip/kv_cache.py | 158 ++++++++++++++++++++++++ 2 files changed, 283 insertions(+), 55 deletions(-) create mode 100644 python/llm/portable-zip/kv_cache.py diff --git a/python/llm/portable-zip/chat.py b/python/llm/portable-zip/chat.py index 8a282a97..cecf9700 100644 --- a/python/llm/portable-zip/chat.py +++ b/python/llm/portable-zip/chat.py @@ -13,6 +13,31 @@ # See the License for the specific language governing permissions and # limitations under the License. # +# Some parts of this file is adapted from +# https://github.com/mit-han-lab/streaming-llm/blob/main/examples/run_streaming_llama.py +# which is licensed under the MIT license: +# +# MIT License +# +# Copyright (c) 2023 MIT HAN Lab +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. import torch import argparse @@ -27,46 +52,98 @@ from transformers.generation.stopping_criteria import StoppingCriteriaList from colorama import Fore from bigdl.llm import optimize_model +from kv_cache import StartRecentKVCache -SYSTEM_PROMPT = "A chat between a curious human and an artificial intelligence assistant .\ -The assistant gives helpful, detailed, and polite answers to the human's questions." HUMAN_ID = "" BOT_ID = "" -# chat_history formated in [(iput_str, output_str)] -def format_prompt(input_str, - chat_history): - prompt = [f"{SYSTEM_PROMPT}\n"] - for history_input_str, history_output_str in chat_history: - prompt.append(f"{HUMAN_ID} {history_input_str}\n{BOT_ID} {history_output_str}\n") - prompt.append(f"{HUMAN_ID} {input_str}\n{BOT_ID} ") - - return "".join(prompt) - -def stream_chat(model, - tokenizer, - stopping_criteria, - input_str, - chat_history): - prompt = format_prompt(input_str, chat_history) - # print(prompt) - input_ids = tokenizer([prompt], return_tensors="pt") - streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) - generate_kwargs = dict(input_ids, streamer=streamer, max_new_tokens=512, stopping_criteria=stopping_criteria) - - from threading import Thread - # to ensure non-blocking access to the generated text, generation process should be ran in a separate thread - thread = Thread(target=model.generate, kwargs=generate_kwargs) - thread.start() - - output_str = [] +@torch.no_grad() +def greedy_generate(model, tokenizer, input_ids, past_key_values, max_gen_len): print(Fore.BLUE+"BigDL-LLM: "+Fore.RESET, end="") - for partial_output_str in streamer: - output_str.append(partial_output_str) - # remove the last HUMAN_ID if exists - print(partial_output_str.replace(f"{HUMAN_ID}", ""), end="") + outputs = model( + input_ids=input_ids, + past_key_values=past_key_values, + use_cache=True, + ) + past_key_values = outputs.past_key_values + pred_token_idx = outputs.logits[:, -1, :].argmax(dim=-1).unsqueeze(1) + generated_ids = [pred_token_idx.item()] + pos = 0 + for _ in range(max_gen_len - 1): + outputs = model( + input_ids=pred_token_idx, + past_key_values=past_key_values, + use_cache=True, + ) + past_key_values = outputs.past_key_values + pred_token_idx = outputs.logits[:, -1, :].argmax(dim=-1).unsqueeze(1) + generated_ids.append(pred_token_idx.item()) + generated_text = tokenizer.decode(generated_ids, skip_special_tokens=True, + clean_up_tokenization_spaces=True, + spaces_between_special_tokens=False) - chat_history.append((input_str, "".join(output_str).replace(f"{HUMAN_ID}", "").rstrip())) + now = len(generated_text) - 1 + if now > pos: + if '\n<' in generated_text: + break + else: + print("".join(generated_text[pos:now]), end="", flush=True) + pos = now + + if pred_token_idx == tokenizer.eos_token_id: + break + print(" ".join(generated_text[pos:]).strip('\n<'), flush=True) + return past_key_values + +@torch.no_grad() +def stream_chat(model, tokenizer, kv_cache=None, max_gen_len=512): + past_key_values = None + while True: + user_input = input(Fore.GREEN+"\nHuman: "+Fore.RESET) + if user_input == "stop": # let's stop the conversation when user input "stop" + break + prompt = f"{HUMAN_ID} {user_input}\n{BOT_ID} " + input_ids = tokenizer(prompt, return_tensors="pt").input_ids + seq_len = input_ids.shape[1] + if kv_cache is not None: + space_needed = seq_len + max_gen_len + past_key_values = kv_cache.evict_for_space(past_key_values, space_needed) + + past_key_values = greedy_generate( + model, tokenizer, input_ids, past_key_values, max_gen_len=max_gen_len + ) + +@torch.no_grad() +def chatglm2_stream_chat(model, tokenizer): + chat_history = [] + past_key_values = None + current_length = 0 + stopping_criteria = StoppingCriteriaList([StopSequenceCriteria(HUMAN_ID, tokenizer)]) + max_past_length = 2048 + + while True: + user_input = input(Fore.GREEN+"\nHuman: "+Fore.RESET) + if user_input == "stop": # let's stop the conversation when user input "stop" + break + print(Fore.BLUE+"BigDL-LLM: "+Fore.RESET, end="") + prompt = f"问:{user_input}\n答:" + for response, chat_history, past_key_values in model.stream_chat(tokenizer, prompt, + history=chat_history, + stopping_criteria=stopping_criteria, + past_key_values=past_key_values, + return_past_key_values=True): + print(response[current_length:], end="", flush=True) + current_length = len(response) + if past_key_values[0][0].shape[0] > max_past_length: + # To avoid out of memory, only keep recent key_values + new_values_list = [] + for i in range(len(past_key_values)): + new_value = [] + for val in past_key_values[i]: + new_v = val[-max_past_length:] + new_value.append(new_v) + new_values_list.append(tuple(new_value)) + past_key_values = tuple(new_values_list) def auto_select_model(model_name): try: @@ -89,28 +166,21 @@ def auto_select_model(model_name): return model if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("--model-path", type=str, help="path to an llm") - args = parser.parse_args() + parser = argparse.ArgumentParser() + parser.add_argument("--model-path", type=str, help="path to an llm") + args = parser.parse_args() - model_path = args.model_path - - model = auto_select_model(model_path) - model = optimize_model(model) + model_path = args.model_path - tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + model = auto_select_model(model_path) + model = optimize_model(model) - stopping_criteria = StoppingCriteriaList([StopSequenceCriteria(HUMAN_ID, tokenizer)]) + tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) - chat_history = [] - - while True: - with torch.inference_mode(): - user_input = input(Fore.GREEN+"\nHuman: "+Fore.RESET) - if user_input == "stop": # let's stop the conversation when user input "stop" - break - stream_chat(model=model, - tokenizer=tokenizer, - stopping_criteria=stopping_criteria, - input_str=user_input, - chat_history=chat_history) \ No newline at end of file + if model.config.architectures is not None and model.config.architectures[0] == "ChatGLMModel": + chatglm2_stream_chat(model=model, tokenizer=tokenizer) + else: + kv_cache = StartRecentKVCache() + stream_chat(model=model, + tokenizer=tokenizer, + kv_cache=kv_cache) diff --git a/python/llm/portable-zip/kv_cache.py b/python/llm/portable-zip/kv_cache.py new file mode 100644 index 00000000..d372cbce --- /dev/null +++ b/python/llm/portable-zip/kv_cache.py @@ -0,0 +1,158 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Some parts of this file is adapted from +# https://github.com/mit-han-lab/streaming-llm/blob/main/streaming_llm/kv_cache.py +# which is licensed under the MIT license: +# +# MIT License +# +# Copyright (c) 2023 MIT HAN Lab +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + + +import torch + +def slice1d(x, start, end): + return x[:, start:end, ...] + +def slice2d(x, start, end): + return x[:, :, start:end, ...] + +def slice3d(x, start, end): + return x[:, :, :, start:end, ...] + + +DIM_TO_SLICE = { + 1: slice1d, + 2: slice2d, + 3: slice3d, +} + + +class StartRecentKVCache: + def __init__( + self, + start_size=4, + recent_size=512, + k_seq_dim=2, + v_seq_dim=2, + ): + print(f"StartRecentKVCache: {start_size}, {recent_size}") + self.start_size = start_size + self.recent_size = recent_size + self.cache_size = start_size + recent_size + self.k_seq_dim = k_seq_dim + self.v_seq_dim = v_seq_dim + self.k_slice = DIM_TO_SLICE[k_seq_dim] + self.v_slice = DIM_TO_SLICE[v_seq_dim] + + def __call__(self, past_key_values): + if past_key_values is None: + return None + seq_len = past_key_values[0][0].size(self.k_seq_dim) + if seq_len <= self.cache_size: + return past_key_values + return [ + [ + torch.cat( + [ + self.k_slice(k, 0, self.start_size), + self.k_slice(k, seq_len - self.recent_size, seq_len), + ], + dim=self.k_seq_dim, + ), + torch.cat( + [ + self.v_slice(v, 0, self.start_size), + self.v_slice(v, seq_len - self.recent_size, seq_len), + ], + dim=self.v_seq_dim, + ), + ] + for k, v in past_key_values + ] + + def evict_for_space(self, past_key_values, num_coming): + if past_key_values is None: + return None + seq_len = past_key_values[0][0].size(self.k_seq_dim) + if seq_len + num_coming <= self.cache_size: + return past_key_values + return [ + [ + torch.cat( + [ + self.k_slice(k, 0, self.start_size), + self.k_slice( + k, seq_len - self.recent_size + num_coming, seq_len + ), + ], + dim=self.k_seq_dim, + ), + torch.cat( + [ + self.v_slice(v, 0, self.start_size), + self.v_slice( + v, seq_len - self.recent_size + num_coming, seq_len + ), + ], + dim=self.v_seq_dim, + ), + ] + for k, v in past_key_values + ] + + def evict_range(self, past_key_values, start, end): + if past_key_values is None: + return None + seq_len = past_key_values[0][0].size(self.k_seq_dim) + assert start <= end and end <= seq_len + return [ + [ + torch.cat( + [ + self.k_slice(k, 0, start), + self.k_slice(k, end, seq_len), + ], + dim=self.k_seq_dim, + ), + torch.cat( + [ + self.v_slice(v, 0, start), + self.v_slice(v, end, seq_len), + ], + dim=self.v_seq_dim, + ), + ] + for k, v in past_key_values + ]