LLM: improve response speed in multi-turn chat (#9299)
* update * fix stop word and add chatglm2 support * remove system prompt
This commit is contained in:
		
							parent
							
								
									d4ab5904ef
								
							
						
					
					
						commit
						8ef8e25178
					
				
					 2 changed files with 283 additions and 55 deletions
				
			
		| 
						 | 
					@ -13,6 +13,31 @@
 | 
				
			||||||
# See the License for the specific language governing permissions and
 | 
					# See the License for the specific language governing permissions and
 | 
				
			||||||
# limitations under the License.
 | 
					# limitations under the License.
 | 
				
			||||||
#
 | 
					#
 | 
				
			||||||
 | 
					# Some parts of this file is adapted from
 | 
				
			||||||
 | 
					# https://github.com/mit-han-lab/streaming-llm/blob/main/examples/run_streaming_llama.py
 | 
				
			||||||
 | 
					# which is licensed under the MIT license:
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# MIT License
 | 
				
			||||||
 | 
					# 
 | 
				
			||||||
 | 
					# Copyright (c) 2023 MIT HAN Lab
 | 
				
			||||||
 | 
					# 
 | 
				
			||||||
 | 
					# Permission is hereby granted, free of charge, to any person obtaining a copy
 | 
				
			||||||
 | 
					# of this software and associated documentation files (the "Software"), to deal
 | 
				
			||||||
 | 
					# in the Software without restriction, including without limitation the rights
 | 
				
			||||||
 | 
					# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 | 
				
			||||||
 | 
					# copies of the Software, and to permit persons to whom the Software is
 | 
				
			||||||
 | 
					# furnished to do so, subject to the following conditions:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# The above copyright notice and this permission notice shall be included in all
 | 
				
			||||||
 | 
					# copies or substantial portions of the Software.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 | 
				
			||||||
 | 
					# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 | 
				
			||||||
 | 
					# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
 | 
				
			||||||
 | 
					# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 | 
				
			||||||
 | 
					# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 | 
				
			||||||
 | 
					# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
 | 
				
			||||||
 | 
					# SOFTWARE.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
import argparse
 | 
					import argparse
 | 
				
			||||||
| 
						 | 
					@ -27,46 +52,98 @@ from transformers.generation.stopping_criteria import StoppingCriteriaList
 | 
				
			||||||
from colorama import Fore
 | 
					from colorama import Fore
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from bigdl.llm import optimize_model
 | 
					from bigdl.llm import optimize_model
 | 
				
			||||||
 | 
					from kv_cache import StartRecentKVCache
 | 
				
			||||||
 | 
					
 | 
				
			||||||
SYSTEM_PROMPT = "A chat between a curious human <human> and an artificial intelligence assistant <bot>.\
 | 
					 | 
				
			||||||
The assistant gives helpful, detailed, and polite answers to the human's questions."
 | 
					 | 
				
			||||||
HUMAN_ID = "<human>"
 | 
					HUMAN_ID = "<human>"
 | 
				
			||||||
BOT_ID = "<bot>"
 | 
					BOT_ID = "<bot>"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# chat_history formated in [(iput_str, output_str)]
 | 
					@torch.no_grad()
 | 
				
			||||||
def format_prompt(input_str,
 | 
					def greedy_generate(model, tokenizer, input_ids, past_key_values, max_gen_len):
 | 
				
			||||||
                  chat_history):
 | 
					 | 
				
			||||||
    prompt = [f"{SYSTEM_PROMPT}\n"]
 | 
					 | 
				
			||||||
    for history_input_str, history_output_str in chat_history:
 | 
					 | 
				
			||||||
        prompt.append(f"{HUMAN_ID} {history_input_str}\n{BOT_ID} {history_output_str}\n")
 | 
					 | 
				
			||||||
    prompt.append(f"{HUMAN_ID} {input_str}\n{BOT_ID} ")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    return "".join(prompt)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def stream_chat(model,
 | 
					 | 
				
			||||||
                tokenizer,
 | 
					 | 
				
			||||||
                stopping_criteria,
 | 
					 | 
				
			||||||
                input_str,
 | 
					 | 
				
			||||||
                chat_history):
 | 
					 | 
				
			||||||
    prompt = format_prompt(input_str, chat_history)
 | 
					 | 
				
			||||||
    # print(prompt)
 | 
					 | 
				
			||||||
    input_ids = tokenizer([prompt], return_tensors="pt")
 | 
					 | 
				
			||||||
    streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
 | 
					 | 
				
			||||||
    generate_kwargs = dict(input_ids, streamer=streamer, max_new_tokens=512, stopping_criteria=stopping_criteria)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    from threading import Thread
 | 
					 | 
				
			||||||
    # to ensure non-blocking access to the generated text, generation process should be ran in a separate thread
 | 
					 | 
				
			||||||
    thread = Thread(target=model.generate, kwargs=generate_kwargs)
 | 
					 | 
				
			||||||
    thread.start()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    output_str = []
 | 
					 | 
				
			||||||
    print(Fore.BLUE+"BigDL-LLM: "+Fore.RESET, end="")
 | 
					    print(Fore.BLUE+"BigDL-LLM: "+Fore.RESET, end="")
 | 
				
			||||||
    for partial_output_str in streamer:
 | 
					    outputs = model(
 | 
				
			||||||
        output_str.append(partial_output_str)
 | 
					        input_ids=input_ids,
 | 
				
			||||||
        # remove the last HUMAN_ID if exists
 | 
					        past_key_values=past_key_values,
 | 
				
			||||||
        print(partial_output_str.replace(f"{HUMAN_ID}", ""), end="")
 | 
					        use_cache=True,
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    past_key_values = outputs.past_key_values
 | 
				
			||||||
 | 
					    pred_token_idx = outputs.logits[:, -1, :].argmax(dim=-1).unsqueeze(1)
 | 
				
			||||||
 | 
					    generated_ids = [pred_token_idx.item()]
 | 
				
			||||||
 | 
					    pos = 0
 | 
				
			||||||
 | 
					    for _ in range(max_gen_len - 1):
 | 
				
			||||||
 | 
					        outputs = model(
 | 
				
			||||||
 | 
					            input_ids=pred_token_idx,
 | 
				
			||||||
 | 
					            past_key_values=past_key_values,
 | 
				
			||||||
 | 
					            use_cache=True,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        past_key_values = outputs.past_key_values
 | 
				
			||||||
 | 
					        pred_token_idx = outputs.logits[:, -1, :].argmax(dim=-1).unsqueeze(1)
 | 
				
			||||||
 | 
					        generated_ids.append(pred_token_idx.item())
 | 
				
			||||||
 | 
					        generated_text = tokenizer.decode(generated_ids, skip_special_tokens=True,
 | 
				
			||||||
 | 
					                                          clean_up_tokenization_spaces=True,
 | 
				
			||||||
 | 
					                                          spaces_between_special_tokens=False)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    chat_history.append((input_str, "".join(output_str).replace(f"{HUMAN_ID}", "").rstrip()))
 | 
					        now = len(generated_text) - 1
 | 
				
			||||||
 | 
					        if now > pos:
 | 
				
			||||||
 | 
					            if '\n<' in generated_text:
 | 
				
			||||||
 | 
					                break
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                print("".join(generated_text[pos:now]), end="", flush=True)
 | 
				
			||||||
 | 
					                pos = now
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if pred_token_idx == tokenizer.eos_token_id:
 | 
				
			||||||
 | 
					            break
 | 
				
			||||||
 | 
					    print(" ".join(generated_text[pos:]).strip('\n<'), flush=True)
 | 
				
			||||||
 | 
					    return past_key_values
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@torch.no_grad()
 | 
				
			||||||
 | 
					def stream_chat(model, tokenizer, kv_cache=None, max_gen_len=512):
 | 
				
			||||||
 | 
					    past_key_values = None
 | 
				
			||||||
 | 
					    while True:
 | 
				
			||||||
 | 
					        user_input = input(Fore.GREEN+"\nHuman: "+Fore.RESET)
 | 
				
			||||||
 | 
					        if user_input == "stop": # let's stop the conversation when user input "stop"
 | 
				
			||||||
 | 
					            break
 | 
				
			||||||
 | 
					        prompt = f"{HUMAN_ID} {user_input}\n{BOT_ID} "
 | 
				
			||||||
 | 
					        input_ids = tokenizer(prompt, return_tensors="pt").input_ids
 | 
				
			||||||
 | 
					        seq_len = input_ids.shape[1]
 | 
				
			||||||
 | 
					        if kv_cache is not None:
 | 
				
			||||||
 | 
					            space_needed = seq_len + max_gen_len
 | 
				
			||||||
 | 
					            past_key_values = kv_cache.evict_for_space(past_key_values, space_needed)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        past_key_values = greedy_generate(
 | 
				
			||||||
 | 
					            model, tokenizer, input_ids, past_key_values, max_gen_len=max_gen_len
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@torch.no_grad()
 | 
				
			||||||
 | 
					def chatglm2_stream_chat(model, tokenizer):
 | 
				
			||||||
 | 
					    chat_history = []
 | 
				
			||||||
 | 
					    past_key_values = None
 | 
				
			||||||
 | 
					    current_length = 0
 | 
				
			||||||
 | 
					    stopping_criteria = StoppingCriteriaList([StopSequenceCriteria(HUMAN_ID, tokenizer)])
 | 
				
			||||||
 | 
					    max_past_length = 2048
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    while True:
 | 
				
			||||||
 | 
					        user_input = input(Fore.GREEN+"\nHuman: "+Fore.RESET)
 | 
				
			||||||
 | 
					        if user_input == "stop": # let's stop the conversation when user input "stop"
 | 
				
			||||||
 | 
					            break
 | 
				
			||||||
 | 
					        print(Fore.BLUE+"BigDL-LLM: "+Fore.RESET, end="")
 | 
				
			||||||
 | 
					        prompt = f"问:{user_input}\n答:"
 | 
				
			||||||
 | 
					        for response, chat_history, past_key_values in model.stream_chat(tokenizer, prompt,
 | 
				
			||||||
 | 
					                                                                         history=chat_history,
 | 
				
			||||||
 | 
					                                                                         stopping_criteria=stopping_criteria,
 | 
				
			||||||
 | 
					                                                                         past_key_values=past_key_values,
 | 
				
			||||||
 | 
					                                                                         return_past_key_values=True):
 | 
				
			||||||
 | 
					            print(response[current_length:], end="", flush=True)
 | 
				
			||||||
 | 
					            current_length = len(response)
 | 
				
			||||||
 | 
					            if past_key_values[0][0].shape[0] > max_past_length:
 | 
				
			||||||
 | 
					                # To avoid out of memory, only keep recent key_values
 | 
				
			||||||
 | 
					                new_values_list = []
 | 
				
			||||||
 | 
					                for i in range(len(past_key_values)):
 | 
				
			||||||
 | 
					                    new_value = []
 | 
				
			||||||
 | 
					                    for val in past_key_values[i]:
 | 
				
			||||||
 | 
					                        new_v = val[-max_past_length:]
 | 
				
			||||||
 | 
					                        new_value.append(new_v)
 | 
				
			||||||
 | 
					                    new_values_list.append(tuple(new_value))
 | 
				
			||||||
 | 
					                past_key_values = tuple(new_values_list)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def auto_select_model(model_name):
 | 
					def auto_select_model(model_name):
 | 
				
			||||||
    try:
 | 
					    try:
 | 
				
			||||||
| 
						 | 
					@ -89,28 +166,21 @@ def auto_select_model(model_name):
 | 
				
			||||||
    return model
 | 
					    return model
 | 
				
			||||||
  
 | 
					  
 | 
				
			||||||
if __name__ == "__main__":
 | 
					if __name__ == "__main__":
 | 
				
			||||||
  parser = argparse.ArgumentParser()
 | 
					    parser = argparse.ArgumentParser()
 | 
				
			||||||
  parser.add_argument("--model-path", type=str, help="path to an llm")
 | 
					    parser.add_argument("--model-path", type=str, help="path to an llm")
 | 
				
			||||||
  args = parser.parse_args()
 | 
					    args = parser.parse_args()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  model_path = args.model_path
 | 
					    model_path = args.model_path
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  model = auto_select_model(model_path)
 | 
					    model = auto_select_model(model_path)
 | 
				
			||||||
  model = optimize_model(model)
 | 
					    model = optimize_model(model)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
 | 
					    tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  stopping_criteria = StoppingCriteriaList([StopSequenceCriteria(HUMAN_ID, tokenizer)])
 | 
					    if model.config.architectures is not None and model.config.architectures[0] == "ChatGLMModel":
 | 
				
			||||||
 | 
					        chatglm2_stream_chat(model=model, tokenizer=tokenizer)
 | 
				
			||||||
  chat_history = []
 | 
					    else:
 | 
				
			||||||
 | 
					        kv_cache = StartRecentKVCache()
 | 
				
			||||||
  while True:
 | 
					        stream_chat(model=model,
 | 
				
			||||||
      with torch.inference_mode():
 | 
					                    tokenizer=tokenizer,
 | 
				
			||||||
          user_input = input(Fore.GREEN+"\nHuman: "+Fore.RESET)
 | 
					                    kv_cache=kv_cache)
 | 
				
			||||||
          if user_input == "stop": # let's stop the conversation when user input "stop"
 | 
					 | 
				
			||||||
              break
 | 
					 | 
				
			||||||
          stream_chat(model=model,
 | 
					 | 
				
			||||||
                      tokenizer=tokenizer,
 | 
					 | 
				
			||||||
                      stopping_criteria=stopping_criteria,
 | 
					 | 
				
			||||||
                      input_str=user_input,
 | 
					 | 
				
			||||||
                      chat_history=chat_history)
 | 
					 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
							
								
								
									
										158
									
								
								python/llm/portable-zip/kv_cache.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										158
									
								
								python/llm/portable-zip/kv_cache.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,158 @@
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Copyright 2016 The BigDL Authors.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					# You may obtain a copy of the License at
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#     http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					# distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					# See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					# limitations under the License.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Some parts of this file is adapted from
 | 
				
			||||||
 | 
					# https://github.com/mit-han-lab/streaming-llm/blob/main/streaming_llm/kv_cache.py
 | 
				
			||||||
 | 
					# which is licensed under the MIT license:
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# MIT License
 | 
				
			||||||
 | 
					# 
 | 
				
			||||||
 | 
					# Copyright (c) 2023 MIT HAN Lab
 | 
				
			||||||
 | 
					# 
 | 
				
			||||||
 | 
					# Permission is hereby granted, free of charge, to any person obtaining a copy
 | 
				
			||||||
 | 
					# of this software and associated documentation files (the "Software"), to deal
 | 
				
			||||||
 | 
					# in the Software without restriction, including without limitation the rights
 | 
				
			||||||
 | 
					# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 | 
				
			||||||
 | 
					# copies of the Software, and to permit persons to whom the Software is
 | 
				
			||||||
 | 
					# furnished to do so, subject to the following conditions:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# The above copyright notice and this permission notice shall be included in all
 | 
				
			||||||
 | 
					# copies or substantial portions of the Software.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 | 
				
			||||||
 | 
					# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 | 
				
			||||||
 | 
					# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
 | 
				
			||||||
 | 
					# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 | 
				
			||||||
 | 
					# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 | 
				
			||||||
 | 
					# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
 | 
				
			||||||
 | 
					# SOFTWARE.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def slice1d(x, start, end):
 | 
				
			||||||
 | 
					    return x[:, start:end, ...]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def slice2d(x, start, end):
 | 
				
			||||||
 | 
					    return x[:, :, start:end, ...]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def slice3d(x, start, end):
 | 
				
			||||||
 | 
					    return x[:, :, :, start:end, ...]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					DIM_TO_SLICE = {
 | 
				
			||||||
 | 
					    1: slice1d,
 | 
				
			||||||
 | 
					    2: slice2d,
 | 
				
			||||||
 | 
					    3: slice3d,
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class StartRecentKVCache:
 | 
				
			||||||
 | 
					    def __init__(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        start_size=4,
 | 
				
			||||||
 | 
					        recent_size=512,
 | 
				
			||||||
 | 
					        k_seq_dim=2,
 | 
				
			||||||
 | 
					        v_seq_dim=2,
 | 
				
			||||||
 | 
					    ):
 | 
				
			||||||
 | 
					        print(f"StartRecentKVCache: {start_size}, {recent_size}")
 | 
				
			||||||
 | 
					        self.start_size = start_size
 | 
				
			||||||
 | 
					        self.recent_size = recent_size
 | 
				
			||||||
 | 
					        self.cache_size = start_size + recent_size
 | 
				
			||||||
 | 
					        self.k_seq_dim = k_seq_dim
 | 
				
			||||||
 | 
					        self.v_seq_dim = v_seq_dim
 | 
				
			||||||
 | 
					        self.k_slice = DIM_TO_SLICE[k_seq_dim]
 | 
				
			||||||
 | 
					        self.v_slice = DIM_TO_SLICE[v_seq_dim]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __call__(self, past_key_values):
 | 
				
			||||||
 | 
					        if past_key_values is None:
 | 
				
			||||||
 | 
					            return None
 | 
				
			||||||
 | 
					        seq_len = past_key_values[0][0].size(self.k_seq_dim)
 | 
				
			||||||
 | 
					        if seq_len <= self.cache_size:
 | 
				
			||||||
 | 
					            return past_key_values
 | 
				
			||||||
 | 
					        return [
 | 
				
			||||||
 | 
					            [
 | 
				
			||||||
 | 
					                torch.cat(
 | 
				
			||||||
 | 
					                    [
 | 
				
			||||||
 | 
					                        self.k_slice(k, 0, self.start_size),
 | 
				
			||||||
 | 
					                        self.k_slice(k, seq_len - self.recent_size, seq_len),
 | 
				
			||||||
 | 
					                    ],
 | 
				
			||||||
 | 
					                    dim=self.k_seq_dim,
 | 
				
			||||||
 | 
					                ),
 | 
				
			||||||
 | 
					                torch.cat(
 | 
				
			||||||
 | 
					                    [
 | 
				
			||||||
 | 
					                        self.v_slice(v, 0, self.start_size),
 | 
				
			||||||
 | 
					                        self.v_slice(v, seq_len - self.recent_size, seq_len),
 | 
				
			||||||
 | 
					                    ],
 | 
				
			||||||
 | 
					                    dim=self.v_seq_dim,
 | 
				
			||||||
 | 
					                ),
 | 
				
			||||||
 | 
					            ]
 | 
				
			||||||
 | 
					            for k, v in past_key_values
 | 
				
			||||||
 | 
					        ]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def evict_for_space(self, past_key_values, num_coming):
 | 
				
			||||||
 | 
					        if past_key_values is None:
 | 
				
			||||||
 | 
					            return None
 | 
				
			||||||
 | 
					        seq_len = past_key_values[0][0].size(self.k_seq_dim)
 | 
				
			||||||
 | 
					        if seq_len + num_coming <= self.cache_size:
 | 
				
			||||||
 | 
					            return past_key_values
 | 
				
			||||||
 | 
					        return [
 | 
				
			||||||
 | 
					            [
 | 
				
			||||||
 | 
					                torch.cat(
 | 
				
			||||||
 | 
					                    [
 | 
				
			||||||
 | 
					                        self.k_slice(k, 0, self.start_size),
 | 
				
			||||||
 | 
					                        self.k_slice(
 | 
				
			||||||
 | 
					                            k, seq_len - self.recent_size + num_coming, seq_len
 | 
				
			||||||
 | 
					                        ),
 | 
				
			||||||
 | 
					                    ],
 | 
				
			||||||
 | 
					                    dim=self.k_seq_dim,
 | 
				
			||||||
 | 
					                ),
 | 
				
			||||||
 | 
					                torch.cat(
 | 
				
			||||||
 | 
					                    [
 | 
				
			||||||
 | 
					                        self.v_slice(v, 0, self.start_size),
 | 
				
			||||||
 | 
					                        self.v_slice(
 | 
				
			||||||
 | 
					                            v, seq_len - self.recent_size + num_coming, seq_len
 | 
				
			||||||
 | 
					                        ),
 | 
				
			||||||
 | 
					                    ],
 | 
				
			||||||
 | 
					                    dim=self.v_seq_dim,
 | 
				
			||||||
 | 
					                ),
 | 
				
			||||||
 | 
					            ]
 | 
				
			||||||
 | 
					            for k, v in past_key_values
 | 
				
			||||||
 | 
					        ]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def evict_range(self, past_key_values, start, end):
 | 
				
			||||||
 | 
					        if past_key_values is None:
 | 
				
			||||||
 | 
					            return None
 | 
				
			||||||
 | 
					        seq_len = past_key_values[0][0].size(self.k_seq_dim)
 | 
				
			||||||
 | 
					        assert start <= end and end <= seq_len
 | 
				
			||||||
 | 
					        return [
 | 
				
			||||||
 | 
					            [
 | 
				
			||||||
 | 
					                torch.cat(
 | 
				
			||||||
 | 
					                    [
 | 
				
			||||||
 | 
					                        self.k_slice(k, 0, start),
 | 
				
			||||||
 | 
					                        self.k_slice(k, end, seq_len),
 | 
				
			||||||
 | 
					                    ],
 | 
				
			||||||
 | 
					                    dim=self.k_seq_dim,
 | 
				
			||||||
 | 
					                ),
 | 
				
			||||||
 | 
					                torch.cat(
 | 
				
			||||||
 | 
					                    [
 | 
				
			||||||
 | 
					                        self.v_slice(v, 0, start),
 | 
				
			||||||
 | 
					                        self.v_slice(v, end, seq_len),
 | 
				
			||||||
 | 
					                    ],
 | 
				
			||||||
 | 
					                    dim=self.v_seq_dim,
 | 
				
			||||||
 | 
					                ),
 | 
				
			||||||
 | 
					            ]
 | 
				
			||||||
 | 
					            for k, v in past_key_values
 | 
				
			||||||
 | 
					        ]
 | 
				
			||||||
		Loading…
	
		Reference in a new issue