Speculative Starcoder on CPU (#10138)
* Speculative Starcoder on CPU * enable kv-cache pre-allocation * refine codes * refine * fix style * fix style * fix style * refine * refine * Update speculative.py * Update gptbigcode.py * fix style * Update speculative.py * enable mixed-datatype layernorm on top of torch API * adaptive dtype * Update README.md
This commit is contained in:
parent
a47989c860
commit
36a9e88104
7 changed files with 304 additions and 5 deletions
|
|
@ -1,5 +1,5 @@
|
||||||
# Mistral
|
# Mistral
|
||||||
In this directory, you will find examples on how you could run Baichuan2 BF16 inference with self-speculative decoding using BigDL-LLM on [Intel CPUs](../README.md). For illustration purposes,we utilize the [mistralai/Mistral-7B-Instruct-v0.1](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1) and [mistralai/Mistral-7B-v0.1](https://huggingface.co/mistralai/Mistral-7B-v0.1) as reference Mistral models.
|
In this directory, you will find examples on how you could run Mistral BF16 inference with self-speculative decoding using BigDL-LLM on [Intel CPUs](../README.md). For illustration purposes,we utilize the [mistralai/Mistral-7B-Instruct-v0.1](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1) and [mistralai/Mistral-7B-v0.1](https://huggingface.co/mistralai/Mistral-7B-v0.1) as reference Mistral models.
|
||||||
|
|
||||||
## 0. Requirements
|
## 0. Requirements
|
||||||
To run these examples with BigDL-LLM on Intel CPUs, we have some recommended requirements for your machine, please refer to [here](../README.md#recommended-requirements) for more information.
|
To run these examples with BigDL-LLM on Intel CPUs, we have some recommended requirements for your machine, please refer to [here](../README.md#recommended-requirements) for more information.
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,97 @@
|
||||||
|
# Starcoder
|
||||||
|
In this directory, you will find examples on how you could run Starcoder BF16 inference with self-speculative decoding using BigDL-LLM on [Intel CPUs](../README.md). For illustration purposes,we utilize the [bigcode/starcoder](https://huggingface.co/bigcode/starcoder) and [bigcode/tiny_starcoder_py](https://huggingface.co/bigcode/tiny_starcoder_py) as reference Starcoder models.
|
||||||
|
|
||||||
|
## 0. Requirements
|
||||||
|
To run these examples with BigDL-LLM on Intel CPUs, we have some recommended requirements for your machine, please refer to [here](../README.md#recommended-requirements) for more information.
|
||||||
|
|
||||||
|
## Example: Predict Tokens using `generate()` API
|
||||||
|
In the example [speculative.py](./speculative.py), we show a basic use case for a Starcoder model to predict the next N tokens using `generate()` API, with BigDL-LLM speculative decoding optimizations on Intel CPUs.
|
||||||
|
### 1. Install
|
||||||
|
We suggest using conda to manage environment:
|
||||||
|
```bash
|
||||||
|
conda create -n llm python=3.9
|
||||||
|
conda activate llm
|
||||||
|
pip install --pre --upgrade bigdl-llm[all]
|
||||||
|
pip install intel_extension_for_pytorch==2.1.0
|
||||||
|
pip install transformers==4.31.0
|
||||||
|
```
|
||||||
|
### 2. Configures high-performing processor environment variables
|
||||||
|
```bash
|
||||||
|
source bigdl-llm-init -t
|
||||||
|
export OMP_NUM_THREADS=48 # you can change 48 here to #cores of one processor socket
|
||||||
|
```
|
||||||
|
### 3. Run
|
||||||
|
|
||||||
|
We recommend to use `numactl` to bind the program to a specified processor socket:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
numactl -C 0-47 -m 0 python ./speculative.py --repo-id-or-model-path REPO_ID_OR_MODEL_PATH --prompt PROMPT --n-predict N_PREDICT
|
||||||
|
```
|
||||||
|
|
||||||
|
For example, 0-47 means bind the python program to core list 0-47 for a 48-core socket.
|
||||||
|
|
||||||
|
Arguments info:
|
||||||
|
|
||||||
|
- `--repo-id-or-model-path REPO_ID_OR_MODEL_PATH`: argument defining the huggingface repo id for the Starcoder model (e.g. `bigcode/starcoder` and `bigcode/tiny_starcoder_py`) to be downloaded, or the path to the huggingface checkpoint folder. It is default to be `bigcode/starcoder`.
|
||||||
|
- `--prompt PROMPT`: argument defining the prompt to be infered (with integrated prompt format for chat). A default prompt is provided.
|
||||||
|
- `--n-predict N_PREDICT`: argument defining the max number of tokens to predict. It is default to be `128`.
|
||||||
|
|
||||||
|
#### Sample Output
|
||||||
|
#### [bigcode/starcoder](https://huggingface.co/bigcode/starcoder)
|
||||||
|
|
||||||
|
```log
|
||||||
|
def dfs_print_Fibonacci_sequence(n):
|
||||||
|
if n == 0:
|
||||||
|
return
|
||||||
|
elif n == 1:
|
||||||
|
print(0)
|
||||||
|
return
|
||||||
|
elif n == 2:
|
||||||
|
print(0)
|
||||||
|
print(1)
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
print(0)
|
||||||
|
print(1)
|
||||||
|
dfs_print_Fibonacci_sequence(n-2)
|
||||||
|
print(dfs_Fibonacci_sequence(n-1))
|
||||||
|
|
||||||
|
def dfs_Fibonacci_sequence(n):
|
||||||
|
if n == 0:
|
||||||
|
return 0
|
||||||
|
elif n == 1:
|
||||||
|
return 1
|
||||||
|
else:
|
||||||
|
return dfs_Fibonacci_sequence
|
||||||
|
Tokens generated 128
|
||||||
|
E2E Generation time xx.xxxxs
|
||||||
|
First token latency xx.xxxxs
|
||||||
|
```
|
||||||
|
|
||||||
|
#### [bigcode/tiny_starcoder_py](https://huggingface.co/bigcode/tiny_starcoder_py)
|
||||||
|
```log
|
||||||
|
def dfs_print_Fibonacci_sequence(n):
|
||||||
|
if n == 0:
|
||||||
|
return
|
||||||
|
print(n)
|
||||||
|
for i in range(2, n):
|
||||||
|
print(dfs_print_Fibonacci_sequence(i))
|
||||||
|
|
||||||
|
|
||||||
|
def dfs_print_Fibonacci_sequence_2(n):
|
||||||
|
if n == 0:
|
||||||
|
return
|
||||||
|
print(n)
|
||||||
|
for i in range(2, n):
|
||||||
|
print(dfs_print_Fibonacci_sequence_2(i))
|
||||||
|
|
||||||
|
|
||||||
|
def dfs_print_Fibonacci_sequence_3(n):
|
||||||
|
if n == 0:
|
||||||
|
return
|
||||||
|
print(n)
|
||||||
|
for i in
|
||||||
|
Tokens generated 128
|
||||||
|
E2E Generation time xx.xxxxs
|
||||||
|
First token latency xx.xxxxs
|
||||||
|
```
|
||||||
|
|
@ -0,0 +1,87 @@
|
||||||
|
#
|
||||||
|
# Copyright 2016 The BigDL Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from bigdl.llm.transformers import AutoModelForCausalLM
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
import argparse
|
||||||
|
import time
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
torch.nn.Linear.reset_parameters = lambda x: None
|
||||||
|
seed=42
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
np.random.seed(seed)
|
||||||
|
|
||||||
|
STARCODER_PROMPT_FORMAT = "{prompt}"
|
||||||
|
prompt = "def dfs_print_Fibonacci_sequence(n):"
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
parser = argparse.ArgumentParser(description='Predict Tokens using `generate()` API for Mistral model')
|
||||||
|
parser.add_argument('--repo-id-or-model-path', type=str, default="bigcode/starcoder",
|
||||||
|
help='The huggingface repo id for the Mistral (e.g. `bigcode/starcoder` and `bigcode/tiny_starcoder_py`) to be downloaded'
|
||||||
|
', or the path to the huggingface checkpoint folder')
|
||||||
|
parser.add_argument('--prompt', type=str, default=prompt,
|
||||||
|
help='Prompt to infer')
|
||||||
|
parser.add_argument('--n-predict', type=int, default=128,
|
||||||
|
help='Max tokens to predict')
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
model_path = args.repo_id_or_model_path
|
||||||
|
|
||||||
|
# Load model in optimized bf16 here.
|
||||||
|
# Set `speculative=True`` to enable speculative decoding,
|
||||||
|
# it only works when load_in_low_bit="fp16" on Intel GPU or load_in_low_bit="bf16" on latest Intel Xeon CPU
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(model_path,
|
||||||
|
optimize_model=True,
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
load_in_low_bit="bf16",
|
||||||
|
speculative=True,
|
||||||
|
torchscript=True,
|
||||||
|
trust_remote_code=True,
|
||||||
|
use_cache=True)
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
||||||
|
|
||||||
|
with torch.inference_mode():
|
||||||
|
prompt = STARCODER_PROMPT_FORMAT.format(prompt=args.prompt)
|
||||||
|
inputs = tokenizer(prompt, return_tensors='pt')
|
||||||
|
input_ids = inputs.input_ids.to(model.device)
|
||||||
|
actual_in_len = input_ids.shape[1]
|
||||||
|
print("actual input_ids length:" + str(actual_in_len))
|
||||||
|
attention_mask = inputs.attention_mask.to(model.device)
|
||||||
|
|
||||||
|
# warmup
|
||||||
|
output = model.generate(input_ids,
|
||||||
|
max_new_tokens=args.n_predict,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
do_sample=False)
|
||||||
|
output_str = tokenizer.decode(output[0])
|
||||||
|
|
||||||
|
# speculative decoding
|
||||||
|
st = time.perf_counter()
|
||||||
|
output = model.generate(input_ids,
|
||||||
|
max_new_tokens=args.n_predict,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
do_sample=False)
|
||||||
|
output_str = tokenizer.decode(output[0], skip_special_tokens=True)
|
||||||
|
end = time.perf_counter()
|
||||||
|
|
||||||
|
print(output_str)
|
||||||
|
print(f"Tokens generated {model.n_token_generated}")
|
||||||
|
print(f"E2E Generation time {(end - st):.4f}s")
|
||||||
|
print(f"First token latency {model.first_token_time:.4f}s")
|
||||||
|
|
@ -1173,6 +1173,10 @@ def _optimize_post(model, lightweight_bmm=False):
|
||||||
modeling_module_name = model.__class__.__module__
|
modeling_module_name = model.__class__.__module__
|
||||||
module = importlib.import_module(modeling_module_name)
|
module = importlib.import_module(modeling_module_name)
|
||||||
from bigdl.llm.transformers.models.gptbigcode import _attn_wrapper
|
from bigdl.llm.transformers.models.gptbigcode import _attn_wrapper
|
||||||
|
from bigdl.llm.transformers.models.gptbigcode import gptbigcode_attention_forward
|
||||||
|
convert_forward(model,
|
||||||
|
module.GPTBigCodeAttention,
|
||||||
|
gptbigcode_attention_forward)
|
||||||
_attn = _attn_wrapper(module.GPTBigCodeAttention._attn)
|
_attn = _attn_wrapper(module.GPTBigCodeAttention._attn)
|
||||||
replace_func(model,
|
replace_func(model,
|
||||||
module.GPTBigCodeAttention,
|
module.GPTBigCodeAttention,
|
||||||
|
|
|
||||||
|
|
@ -74,7 +74,10 @@ def bloom_layer_norm_forward(self, hidden_states):
|
||||||
# if nelement == 0, means fused norm failed, go back to python implement.
|
# if nelement == 0, means fused norm failed, go back to python implement.
|
||||||
if result.nelement != 0:
|
if result.nelement != 0:
|
||||||
return result
|
return result
|
||||||
return F.layer_norm(hidden_states, self.normalized_shape, self.weight, self.bias, self.eps)
|
input_dtype = hidden_states.dtype
|
||||||
|
result = F.layer_norm(hidden_states.to(self.weight.dtype),
|
||||||
|
self.normalized_shape, self.weight, self.bias, self.eps)
|
||||||
|
return result.to(input_dtype)
|
||||||
|
|
||||||
|
|
||||||
def bloom_attention_forward(
|
def bloom_attention_forward(
|
||||||
|
|
|
||||||
|
|
@ -15,10 +15,14 @@
|
||||||
#
|
#
|
||||||
|
|
||||||
|
|
||||||
|
from typing import Optional, Tuple, Union
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
def _attn_wrapper(origin_attn):
|
def _attn_wrapper(origin_attn):
|
||||||
def _attn(self, query, key, value, attention_mask=None, head_mask=None):
|
def _attn(self, query, key, value, attention_mask=None, head_mask=None):
|
||||||
attn_output, attn_weights = origin_attn(self,
|
attn_output, attn_weights = origin_attn(self,
|
||||||
query=query,
|
query=query.to(key.dtype),
|
||||||
key=key,
|
key=key,
|
||||||
value=value,
|
value=value,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
|
|
@ -27,3 +31,84 @@ def _attn_wrapper(origin_attn):
|
||||||
attn_output = attn_output.clone()
|
attn_output = attn_output.clone()
|
||||||
return attn_output, attn_weights
|
return attn_output, attn_weights
|
||||||
return _attn
|
return _attn
|
||||||
|
|
||||||
|
|
||||||
|
def gptbigcode_attention_forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
layer_past: Optional[torch.Tensor] = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
head_mask: Optional[torch.Tensor] = None,
|
||||||
|
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||||
|
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
use_cache: Optional[bool] = False,
|
||||||
|
output_attentions: Optional[bool] = False,
|
||||||
|
**kwargs):
|
||||||
|
if "padding_mask" in kwargs:
|
||||||
|
logger.warning_once(
|
||||||
|
"Passing `padding_mask` is deprecated and will be removed in v4.37." +
|
||||||
|
"Please make sure use `attention_mask` instead.`"
|
||||||
|
)
|
||||||
|
|
||||||
|
if encoder_hidden_states is not None:
|
||||||
|
if not hasattr(self, "q_attn") or not self.is_cross_attention:
|
||||||
|
from bigdl.llm.utils.common import invalidInputError
|
||||||
|
invalidInputError(
|
||||||
|
False,
|
||||||
|
"If class is used as cross attention," +
|
||||||
|
"the weights `q_attn` have to be defined. " +
|
||||||
|
"Please make sure to instantiate class with " +
|
||||||
|
"`GPTBigCodeAttention(..., is_cross_attention=True)`."
|
||||||
|
)
|
||||||
|
query = self.q_attn(hidden_states)
|
||||||
|
key_value = self.c_attn(encoder_hidden_states)
|
||||||
|
attention_mask = encoder_attention_mask
|
||||||
|
elif self.multi_query:
|
||||||
|
query, key_value = self.c_attn(hidden_states).split(
|
||||||
|
(self.embed_dim, 2 * self.kv_dim), dim=2)
|
||||||
|
else:
|
||||||
|
# Note: We split as (self.num_heads, 3, self.head_dim)
|
||||||
|
# instead of (3, self.num_heads, self.head_dim),
|
||||||
|
# i.e., the memory layout is not the same as GPT2.
|
||||||
|
# This makes the concatenation with past_key_value more efficient.
|
||||||
|
query, key_value = (
|
||||||
|
self.c_attn(hidden_states)
|
||||||
|
.view(*hidden_states.shape[:2], self.num_heads, 3 * self.head_dim)
|
||||||
|
.transpose(1, 2)
|
||||||
|
.split((self.head_dim, 2 * self.head_dim), dim=3)
|
||||||
|
)
|
||||||
|
|
||||||
|
if layer_past is not None:
|
||||||
|
if layer_past.shape[-2] == key_value.shape[-2]:
|
||||||
|
key_value = torch.cat((layer_past, key_value), dim=-2)
|
||||||
|
else:
|
||||||
|
fill_zeros = torch.zeros(layer_past.shape[0],
|
||||||
|
layer_past.shape[1],
|
||||||
|
key_value.shape[2] - layer_past.shape[2],
|
||||||
|
dtype=layer_past.dtype,
|
||||||
|
device=layer_past.device)
|
||||||
|
layer_past = torch.cat([layer_past, fill_zeros], dim=-1)
|
||||||
|
key_value = torch.cat((layer_past, key_value), dim=-2)
|
||||||
|
|
||||||
|
present = key_value if use_cache else None
|
||||||
|
|
||||||
|
key, value = key_value.split((self.head_dim, self.head_dim), dim=-1)
|
||||||
|
|
||||||
|
attn_output, attn_weights = self._attn(query,
|
||||||
|
key.transpose(-1, -2),
|
||||||
|
value,
|
||||||
|
attention_mask,
|
||||||
|
head_mask)
|
||||||
|
|
||||||
|
if not self.multi_query:
|
||||||
|
attn_output = attn_output.transpose(1, 2).reshape(hidden_states.shape)
|
||||||
|
attn_output = self.c_proj(attn_output)
|
||||||
|
attn_output = self.resid_dropout(attn_output)
|
||||||
|
|
||||||
|
outputs = (attn_output, present)
|
||||||
|
if output_attentions:
|
||||||
|
if self.multi_query:
|
||||||
|
attn_weights = attn_weights.transpose(1, 2)
|
||||||
|
outputs += (attn_weights,)
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
|
||||||
|
|
@ -168,8 +168,10 @@ def _prepare_past_key_values_storage_cpu(self, past_key_values,
|
||||||
if not _enable_ipex:
|
if not _enable_ipex:
|
||||||
len0 = past_key_values[0][0].size(0)
|
len0 = past_key_values[0][0].size(0)
|
||||||
len1 = past_key_values[0][0].size(1)
|
len1 = past_key_values[0][0].size(1)
|
||||||
len2 = past_key_values[0][0].size(2)
|
# gpt_bigcode has only 2-dimension kv
|
||||||
len3 = past_key_values[0][0].size(3)
|
if len(past_key_values[0][0].shape) == 4:
|
||||||
|
len2 = past_key_values[0][0].size(2)
|
||||||
|
len3 = past_key_values[0][0].size(3)
|
||||||
for i in range(len(past_key_values)):
|
for i in range(len(past_key_values)):
|
||||||
if self.config.model_type == "qwen":
|
if self.config.model_type == "qwen":
|
||||||
k0 = torch.ones(len0, len2, len1 + max_new_tokens, len3,
|
k0 = torch.ones(len0, len2, len1 + max_new_tokens, len3,
|
||||||
|
|
@ -195,6 +197,12 @@ def _prepare_past_key_values_storage_cpu(self, past_key_values,
|
||||||
torch.float32)
|
torch.float32)
|
||||||
past_key_values_storage[i][1][:len0, :, :, :] = past_key_values[i][1].to(
|
past_key_values_storage[i][1][:len0, :, :, :] = past_key_values[i][1].to(
|
||||||
torch.float32)
|
torch.float32)
|
||||||
|
elif self.config.model_type == "gpt_bigcode":
|
||||||
|
kv = torch.ones(len0 + max_new_tokens, len1,
|
||||||
|
dtype=torch.float32)
|
||||||
|
past_key_values_storage.append(kv[None, :, :])
|
||||||
|
past_key_values_storage[i][0][:len0, :] = past_key_values[i][0].to(
|
||||||
|
torch.float32)
|
||||||
else:
|
else:
|
||||||
k0 = torch.ones(len0, len1, len2 + max_new_tokens, len3,
|
k0 = torch.ones(len0, len1, len2 + max_new_tokens, len3,
|
||||||
dtype=torch.float32)
|
dtype=torch.float32)
|
||||||
|
|
@ -266,6 +274,10 @@ def _prepare_draft_past_key_values_cpu(self, past_key_values,
|
||||||
k0 = past_key_values_storage[i][0][:len0, :, :, :]
|
k0 = past_key_values_storage[i][0][:len0, :, :, :]
|
||||||
v0 = past_key_values_storage[i][1][:len0, :, :, :]
|
v0 = past_key_values_storage[i][1][:len0, :, :, :]
|
||||||
tmp_past_key_values.append((k0, v0))
|
tmp_past_key_values.append((k0, v0))
|
||||||
|
elif self.config.model_type == "gpt_bigcode":
|
||||||
|
len0 = past_key_values[0][0].size(0)
|
||||||
|
kv = past_key_values_storage[i][0][:len0, :]
|
||||||
|
tmp_past_key_values.append(kv[None, :, :])
|
||||||
else:
|
else:
|
||||||
len2 = past_key_values[0][0].size(2)
|
len2 = past_key_values[0][0].size(2)
|
||||||
k0 = past_key_values_storage[i][0][:, :, :len2, :]
|
k0 = past_key_values_storage[i][0][:, :, :len2, :]
|
||||||
|
|
@ -292,6 +304,12 @@ def _update_past_key_values_storage_cpu(self, past_key_values, past_key_values_s
|
||||||
past_key_values[i][0][size:size1, :, :, :].to(torch.float32)
|
past_key_values[i][0][size:size1, :, :, :].to(torch.float32)
|
||||||
past_key_values_storage[i][1][size:size1, :, :, :] = \
|
past_key_values_storage[i][1][size:size1, :, :, :] = \
|
||||||
past_key_values[i][1][size:size1, :, :, :].to(torch.float32)
|
past_key_values[i][1][size:size1, :, :, :].to(torch.float32)
|
||||||
|
elif self.config.model_type == "gpt_bigcode":
|
||||||
|
size = original_draft_past_key_values[i][0].size(0)
|
||||||
|
size1 = past_key_values[i][0].size(0)
|
||||||
|
if size < size1:
|
||||||
|
past_key_values_storage[i][0][size:size1, :] = \
|
||||||
|
past_key_values[i][0][size:size1, :].to(torch.float32)
|
||||||
else:
|
else:
|
||||||
size = original_draft_past_key_values[i][0].size(2)
|
size = original_draft_past_key_values[i][0].size(2)
|
||||||
size1 = past_key_values[i][0].size(2)
|
size1 = past_key_values[i][0].size(2)
|
||||||
|
|
@ -801,6 +819,11 @@ def speculative_generate(self,
|
||||||
v[:, :, :-(max_of_max_matched - max_matched), :])
|
v[:, :, :-(max_of_max_matched - max_matched), :])
|
||||||
for k, v in past_key_values
|
for k, v in past_key_values
|
||||||
]
|
]
|
||||||
|
elif self.config.model_type == "gpt_bigcode":
|
||||||
|
past_key_values = [
|
||||||
|
kv[:, :-(max_of_max_matched - max_matched)]
|
||||||
|
for kv in past_key_values
|
||||||
|
]
|
||||||
else:
|
else:
|
||||||
past_key_values = [
|
past_key_values = [
|
||||||
(k[:, :, :-(max_of_max_matched - max_matched)],
|
(k[:, :, :-(max_of_max_matched - max_matched)],
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue