LLM: Support speculative decoding in bigdl-llm (#9951)
* first commit * fix error, add llama example * hidden print * update api usage * change to api v3 * update * meet code review * meet code review, fix style * add reference, fix style * fix style * fix first token time
This commit is contained in:
parent
6341c498b3
commit
3e601f9a5d
3 changed files with 470 additions and 0 deletions
|
|
@ -0,0 +1,104 @@
|
|||
#
|
||||
# Copyright 2016 The BigDL Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import torch
|
||||
from bigdl.llm.transformers import AutoModel, AutoModelForCausalLM
|
||||
from transformers import LlamaTokenizer, AutoTokenizer
|
||||
import argparse
|
||||
import time
|
||||
import numpy as np
|
||||
|
||||
|
||||
torch.nn.Linear.reset_parameters = lambda x: None
|
||||
seed=42
|
||||
torch.manual_seed(seed)
|
||||
np.random.seed(seed)
|
||||
|
||||
# you could tune the prompt based on your own model,
|
||||
# here the prompt tuning refers to https://huggingface.co/georgesung/llama2_7b_chat_uncensored#prompt-style
|
||||
LLAMA2_PROMPT_FORMAT = """### HUMAN:
|
||||
[inst]{prompt}[/inst]
|
||||
|
||||
### RESPONSE:
|
||||
"""
|
||||
|
||||
long_input = """In the year 2048, the world was a very different place from what it had been just two decades before. The pace of technological progress had quickened to an almost unimaginable degree, and the changes that had swept through society as a result were nothing short of revolutionary.
|
||||
In many ways, the year 2048 represented the culmination of a long and tumultuous journey that humanity had been on since the dawn of civilization. The great leaps forward in science and technology that had occurred over the course of the previous century had laid the groundwork for a future that was beyond anything anyone could have imagined.
|
||||
One of the most striking aspects of life in 2048 was the degree to which technology had become an integral part of nearly every aspect of daily existence. From the moment people woke up in the morning until they went to bed at night, they were surrounded by devices and systems that were powered by advanced artificial intelligence and machine learning algorithms.
|
||||
In fact, it was hard to find anything in people's lives that wasn't touched by technology in some way. Every aspect of society had been transformed, from the way people communicated with one another to the way they worked, played, and even socialized. And as the years went on, it seemed as though there was no limit to what technology could achieve.
|
||||
Despite all of these advances, however, not everyone was happy with the state of the world in 2048. Some people saw the increasing reliance on technology as a sign that humanity was losing touch with its own humanity, and they worried about the implications of this for the future.
|
||||
Others were more pragmatic, recognizing that while technology had brought many benefits, it also posed new challenges and risks that needed to be addressed. As a result, there was a growing movement of people who were working to ensure that the advances of technology were used in ways that were safe, ethical, and beneficial for everyone.
|
||||
One person who was at the forefront of this movement was a young woman named Maya. Maya was a brilliant and ambitious researcher who had dedicated her life to understanding the implications of emerging technologies like artificial intelligence and biotechnology. She was deeply concerned about the potential risks and unintended consequences of these technologies, and she worked tirelessly to raise awareness about the need for responsible innovation.
|
||||
Maya's work had earned her a reputation as one of the most influential voices in the field of technology and ethics, and she was widely respected for her deep understanding of the issues and her ability to communicate complex ideas in ways that were accessible and engaging. She was also known for her passionate and inspiring speeches, which often left her audiences with a sense of purpose and determination to make the world a better place through their own efforts.
|
||||
One day, Maya received an invitation to speak at a major conference on technology and ethics, which was being held in a large convention center in the heart of the city. The conference was expected to attract thousands of people from all over the world, and there was a great deal of excitement and anticipation about what Maya would say.
|
||||
As she prepared for her speech, Maya knew that she had a big responsibility on her shoulders. She felt a deep sense of obligation to use her platform to inspire others to take action and make a difference in the world, and she was determined to do everything in her power to live up to this responsibility.
|
||||
When the day of the conference arrived, Maya was filled with a mixture of excitement and nerves. She spent hours rehearsing her speech and fine-tuning her ideas, making sure that she had everything just right. Finally, after what felt like an eternity, it was time for her to take the stage.
|
||||
As she stepped up to the podium, Maya could feel the energy of the crowd surging around her. She took a deep breath and began to speak, her voice strong and clear as she outlined the challenges and opportunities facing society in the age of technology. She spoke passionately about the need for responsible innovation and the importance of considering the ethical implications of our actions, and she inspired many people in the audience to take up this cause and make a difference in their own lives.
|
||||
Overall, Maya's speech was a resounding success, and she received countless messages of gratitude and appreciation from those who had heard her speak. She knew that there was still much work to be done, but she felt hopeful about the future and the role that technology could play in creating a better world for all.
|
||||
As Maya left the stage and made her way back to her seat, she couldn't help but feel a sense of pride and accomplishment at what she had just accomplished. She knew that her words had the power to inspire others and make a real difference in the world, and she was grateful for the opportunity to have played a part in this important work.
|
||||
For Maya, the future was full of promise and possibility, and she was determined to continue doing everything in her power to help create a brighter, more ethical world for everyone.
|
||||
As she """
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='Predict Tokens using `generate()` API for Llama2 model')
|
||||
parser.add_argument('--repo-id-or-model-path', type=str, default="meta-llama/Llama-2-7b-chat-hf",
|
||||
help='The huggingface repo id for the Llama2 (e.g. `meta-llama/Llama-2-7b-chat-hf` and `meta-llama/Llama-2-13b-chat-hf`) to be downloaded'
|
||||
', or the path to the huggingface checkpoint folder')
|
||||
parser.add_argument('--prompt', type=str, default=long_input,
|
||||
help='Prompt to infer')
|
||||
parser.add_argument('--n-predict', type=int, default=128,
|
||||
help='Max tokens to predict')
|
||||
|
||||
args = parser.parse_args()
|
||||
model_path = args.repo_id_or_model_path
|
||||
# Load model in optimized fp16 here.
|
||||
# Set `speculative=True`` to enable speculative decoding,
|
||||
# it only works when load_in_low_bit="fp16" on Intel GPU or load_in_low_bit="bf16" on latest Intel Xeon CPU
|
||||
model = AutoModelForCausalLM.from_pretrained(model_path,
|
||||
optimize_model=True,
|
||||
torch_dtype=torch.float16,
|
||||
load_in_low_bit="fp16",
|
||||
speculative=True,
|
||||
trust_remote_code=True,
|
||||
use_cache=True)
|
||||
model = model.to('xpu')
|
||||
|
||||
tokenizer = LlamaTokenizer.from_pretrained(model_path)
|
||||
|
||||
with torch.inference_mode():
|
||||
prompt = LLAMA2_PROMPT_FORMAT.format(prompt=args.prompt)
|
||||
input_ids = tokenizer(prompt, return_tensors='pt').input_ids.to(model.device)
|
||||
|
||||
# warmup
|
||||
output = model.generate(input_ids,
|
||||
max_new_tokens=args.n_predict,
|
||||
do_sample=False)
|
||||
output_str = tokenizer.decode(output[0])
|
||||
|
||||
# speculative decoding
|
||||
st = time.perf_counter()
|
||||
output = model.generate(input_ids,
|
||||
max_new_tokens=args.n_predict,
|
||||
do_sample=False)
|
||||
output_str = tokenizer.decode(output[0], skip_special_tokens=True)
|
||||
torch.xpu.synchronize()
|
||||
end = time.perf_counter()
|
||||
|
||||
print(output_str)
|
||||
print(f"Tokens generated {model.n_token_generated}")
|
||||
print(f"E2E Generation time {(end - st):.4f}s")
|
||||
print(f"First token latency {model.first_token_time:.4f}s")
|
||||
|
|
@ -49,6 +49,7 @@ import torch
|
|||
import warnings
|
||||
import copy
|
||||
from .utils import logger
|
||||
from .speculative import speculative_generate, clear_benchmarks
|
||||
|
||||
|
||||
def save_low_bit(self, *args, **kwargs):
|
||||
|
|
@ -115,6 +116,8 @@ class _BaseAutoModelClass:
|
|||
Default to be ``True``.
|
||||
:param modules_to_not_convert: list of str value, modules (nn.Module) that are skipped when
|
||||
conducting model optimizations. Default to be ``None``.
|
||||
:param speculative: boolean value, Whether to use speculative decoding.
|
||||
Default to be ``False``.
|
||||
:param cpu_embedding: Whether to replace the Embedding layer, may need to set it
|
||||
to ``True`` when running BigDL-LLM on GPU on Windows. Default to be ``False``.
|
||||
:param lightweight_bmm: Whether to replace the torch.bmm ops, may need to set it
|
||||
|
|
@ -136,6 +139,7 @@ class _BaseAutoModelClass:
|
|||
load_in_low_bit = kwargs.pop("load_in_low_bit", None)
|
||||
optimize_model = kwargs.pop("optimize_model", True)
|
||||
user_quantization_config = kwargs.pop("quantization_config", None)
|
||||
speculative = kwargs.pop("speculative", False)
|
||||
|
||||
if user_quantization_config is not None and \
|
||||
"BitsAndBytesConfig" in str(user_quantization_config.__class__):
|
||||
|
|
@ -241,6 +245,15 @@ class _BaseAutoModelClass:
|
|||
kwargs["pretraining_tp"] = 1
|
||||
q_k = load_in_low_bit if load_in_low_bit else "sym_int4"
|
||||
model = cls.load_convert(q_k, optimize_model, *args, **kwargs)
|
||||
|
||||
if speculative:
|
||||
# load a sym_int4 model as draft model
|
||||
draft_model = cls.load_convert('sym_int4', optimize_model, *args, **kwargs)
|
||||
model.draft_model = draft_model
|
||||
import types
|
||||
# add speculative_generate to pretrained model dynamically
|
||||
model.clear_benchmarks = types.MethodType(clear_benchmarks, model)
|
||||
model.speculative_generate = types.MethodType(speculative_generate, model)
|
||||
else:
|
||||
# load default
|
||||
model = cls.HF_Model.from_pretrained(*args, **kwargs)
|
||||
|
|
|
|||
353
python/llm/src/bigdl/llm/transformers/speculative.py
Normal file
353
python/llm/src/bigdl/llm/transformers/speculative.py
Normal file
|
|
@ -0,0 +1,353 @@
|
|||
#
|
||||
# Copyright 2016 The BigDL Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# Some parts of this file is adapted from
|
||||
# https://github.com/dilab-zju/self-speculative-decoding/blob/main/decoding.py
|
||||
#
|
||||
|
||||
import torch
|
||||
import time
|
||||
import os
|
||||
import copy
|
||||
import logging
|
||||
import warnings
|
||||
import inspect
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
from transformers import top_k_top_p_filtering, GenerationConfig, \
|
||||
LogitsProcessorList, StoppingCriteriaList
|
||||
from bigdl.llm.utils.common import invalidInputError
|
||||
|
||||
# patch GenerationMixin.generate
|
||||
from transformers import GenerationMixin
|
||||
original_generate = GenerationMixin.generate
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def generate(
|
||||
self,
|
||||
inputs: Optional[torch.Tensor] = None,
|
||||
generation_config: Optional[GenerationConfig] = None,
|
||||
logits_processor: Optional[LogitsProcessorList] = None,
|
||||
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
||||
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]]=None,
|
||||
synced_gpus: Optional[bool] = None,
|
||||
assistant_model: Optional["PreTrainedModel"] = None,
|
||||
streamer: Optional["BaseStreamer"] = None,
|
||||
**kwargs,
|
||||
):
|
||||
if hasattr(self, "draft_model"):
|
||||
# do speculative decoding
|
||||
# TODO: maybe add other way to double check
|
||||
new_speculative_kwargs = {}
|
||||
for var in ['max_new_tokens', 'max_step_draft', 'th_stop_draft', 'do_sample',
|
||||
'top_k', 'top_p', 'temperature', 'hf_adjust',
|
||||
'auto_th_stop_draft', 'auto_parameters']:
|
||||
value = kwargs.pop(var, None)
|
||||
if value is not None:
|
||||
new_speculative_kwargs[var] = value
|
||||
return self.speculative_generate(input_ids=inputs,
|
||||
draft_model=self.draft_model,
|
||||
**new_speculative_kwargs)
|
||||
else:
|
||||
return original_generate(self,
|
||||
inputs=inputs,
|
||||
generation_config=generation_config,
|
||||
logits_processor=logits_processor,
|
||||
stopping_criteria=stopping_criteria,
|
||||
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
|
||||
synced_gpus=synced_gpus,
|
||||
assistant_model=assistant_model,
|
||||
streamer=streamer,
|
||||
**kwargs)
|
||||
|
||||
GenerationMixin.generate = generate
|
||||
|
||||
|
||||
def sample(logits, return_probs: bool=False, do_sample: bool=False, top_k: int=50,
|
||||
top_p: float=0.7, temperature: float=0.7):
|
||||
|
||||
if return_probs:
|
||||
all_probs = logits.softmax(-1)
|
||||
if do_sample and top_k != 1 and top_p != 0.0 and temperature != 0.0:
|
||||
_logits = top_k_top_p_filtering(logits.view(-1, logits.size(-1)) / temperature,
|
||||
top_k=top_k, top_p=top_p)
|
||||
output_ids = torch.multinomial(_logits.softmax(-1),
|
||||
num_samples=1).view(logits.shape[:-1])
|
||||
probs = torch.gather(all_probs, -1, output_ids.unsqueeze(-1)).squeeze(-1)
|
||||
else:
|
||||
probs, output_ids = torch.max(all_probs, dim=-1)
|
||||
return output_ids, probs
|
||||
else:
|
||||
if do_sample and top_k != 1 and top_p != 0.0 and temperature != 0.0:
|
||||
_logits = top_k_top_p_filtering(logits.view(-1, logits.size(-1)) / temperature,
|
||||
top_k=top_k, top_p=top_p)
|
||||
output_ids = torch.multinomial(_logits.softmax(-1),
|
||||
num_samples=1).view(logits.shape[:-1])
|
||||
else:
|
||||
output_ids = torch.argmax(logits, dim=-1)
|
||||
return output_ids
|
||||
|
||||
|
||||
def clear_benchmarks(self):
|
||||
self.first_token_time = 0
|
||||
self.generate_time = []
|
||||
self.draft_time = []
|
||||
self.verify_time = []
|
||||
self.draft_num = []
|
||||
self.accept_num = []
|
||||
self.n_drafted = 0
|
||||
self.n_matched = 0
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def speculative_generate(self,
|
||||
input_ids: Optional[torch.Tensor] = None,
|
||||
draft_model=None,
|
||||
max_new_tokens=10,
|
||||
max_step_draft=8,
|
||||
th_stop_draft=0.8,
|
||||
auto_th_stop_draft=True,
|
||||
auto_parameters=[1, 0.5, 0.9, 1e-2, 0.9],
|
||||
do_sample=False,
|
||||
top_k=0,
|
||||
top_p=0.85,
|
||||
temperature=0.2,
|
||||
hf_adjust=False):
|
||||
invalidInputError(draft_model is not None,
|
||||
"Draft model should be provided.")
|
||||
step = 0
|
||||
step_draft = 0
|
||||
step_verify = 0
|
||||
|
||||
draft_gen_length = max_step_draft + 6 if hf_adjust else max_step_draft + 1
|
||||
current_input_ids = input_ids
|
||||
generate_ids = torch.empty([input_ids.size(0), max_new_tokens+max_step_draft],
|
||||
dtype=torch.long, device=self.device)
|
||||
draft_generate_ids = torch.empty([input_ids.size(0), draft_gen_length],
|
||||
dtype=torch.long, device=self.device)
|
||||
past_key_values = None
|
||||
|
||||
tmp_matchness = 0
|
||||
e2e_tic = 0.0
|
||||
|
||||
self.clear_benchmarks()
|
||||
|
||||
if self.config.model_type == "qwen":
|
||||
from transformers.generation.logits_process import RepetitionPenaltyLogitsProcessor
|
||||
logit_processor = RepetitionPenaltyLogitsProcessor(
|
||||
penalty=self.generation_config.repetition_penalty)
|
||||
# Example:
|
||||
# Target model forward for the first token
|
||||
# Step 1. target_model(prompt) -> a
|
||||
# Generate k drafts, k = 3
|
||||
# Step 2. draft_model(a) -> b, c, d
|
||||
# Verify k drafts -> k + 1 results (f is always accepted)
|
||||
# Step 3. target_model (a, b, c, d) -> b, c, e, f
|
||||
# Compare drafts with results
|
||||
# Step 4. (b, c, e) match (b, c, d) -> b, c
|
||||
# Final, f will be the next input, just like a
|
||||
# Step 5. Final-> b, c, f
|
||||
while True:
|
||||
if step >= max_new_tokens:
|
||||
break
|
||||
|
||||
if step == 0:
|
||||
# first token use full model
|
||||
tic = time.time()
|
||||
output = self(input_ids=current_input_ids,
|
||||
past_key_values=past_key_values,
|
||||
return_dict=True,
|
||||
use_cache=True)
|
||||
logits = output['logits']
|
||||
logits = logits[:, -1:]
|
||||
if self.config.model_type == "qwen":
|
||||
temp_input_ids = torch.cat((input_ids, generate_ids[:, :step]), dim=-1)
|
||||
logits[:, -1, :] = logit_processor(temp_input_ids, logits[:, -1, :])
|
||||
output_ids = sample(logits, do_sample=do_sample, top_k=top_k,
|
||||
top_p=top_p, temperature=temperature)
|
||||
generate_ids[:, step] = output_ids
|
||||
current_input_ids = output_ids
|
||||
past_key_values = output['past_key_values']
|
||||
step += 1
|
||||
if self.device.type == 'xpu':
|
||||
torch.xpu.synchronize()
|
||||
toc = time.time()
|
||||
self.first_token_time = toc - tic
|
||||
e2e_tic = time.time()
|
||||
else:
|
||||
draft_current_input_ids = current_input_ids
|
||||
# Target model KV cache to draft model
|
||||
draft_past_key_values = past_key_values
|
||||
draft_generate_ids[:, 0] = current_input_ids
|
||||
tic = time.time()
|
||||
# Draft model auto-regressively generate k tokens
|
||||
# Early stop when prob less then th_stop_draft
|
||||
for step_draft in range(max_step_draft):
|
||||
if self.config.model_type == "chatglm":
|
||||
past_key_value_len = past_key_values[0][0].shape[0]
|
||||
position_ids = torch.Tensor([[past_key_value_len + step_draft]]).long()
|
||||
draft_output = draft_model(input_ids=draft_current_input_ids,
|
||||
past_key_values=draft_past_key_values,
|
||||
return_dict=True,
|
||||
use_cache=True,
|
||||
position_ids=position_ids)
|
||||
else:
|
||||
draft_output = draft_model(input_ids=draft_current_input_ids,
|
||||
past_key_values=draft_past_key_values,
|
||||
return_dict=True,
|
||||
use_cache=True)
|
||||
if self.config.model_type == "qwen":
|
||||
temp_input_ids = torch.cat((input_ids, generate_ids[:, :step],
|
||||
draft_generate_ids[:, 1:step_draft+1]), dim=-1)
|
||||
draft_output['logits'][:, -1, :] = logit_processor(
|
||||
temp_input_ids,
|
||||
draft_output['logits'][:, -1, :])
|
||||
draft_output_ids, draft_output_probs = sample(
|
||||
draft_output['logits'], return_probs=True, do_sample=do_sample,
|
||||
top_k=top_k, top_p=top_p, temperature=temperature)
|
||||
draft_generate_ids[:, step_draft+1] = draft_output_ids
|
||||
draft_current_input_ids = draft_output_ids
|
||||
draft_past_key_values = draft_output['past_key_values']
|
||||
# check if draft prob is less then th_stop_draft
|
||||
# Draft number + step >= max output token number
|
||||
if draft_output_probs.item() < th_stop_draft or \
|
||||
step + step_draft + 2 >= max_new_tokens:
|
||||
break
|
||||
if self.device.type == 'xpu':
|
||||
torch.xpu.synchronize()
|
||||
toc = time.time()
|
||||
self.draft_time.append(toc - tic)
|
||||
drafted_n_tokens = step_draft + 1
|
||||
# raft input + raft completion
|
||||
drafted_input_ids = draft_generate_ids[:, :drafted_n_tokens+1]
|
||||
self.draft_num.append(drafted_n_tokens)
|
||||
tic = time.time()
|
||||
# Target model verify drafts
|
||||
# input.size is k + 1, 1 previous token + k drafts
|
||||
# verified output.size is k + 1, k token + 1 final
|
||||
# Final token is always accepted
|
||||
if self.config.model_type == "chatglm":
|
||||
past_key_value_len = past_key_values[0][0].shape[0]
|
||||
position_ids = torch.arange(drafted_input_ids.shape[1], dtype=torch.long,
|
||||
device=drafted_input_ids.device)
|
||||
position_ids = position_ids.unsqueeze(0).repeat(1, 1) + past_key_value_len
|
||||
output = self(input_ids=drafted_input_ids,
|
||||
past_key_values=past_key_values,
|
||||
return_dict=True,
|
||||
use_cache=True,
|
||||
position_ids=position_ids)
|
||||
else:
|
||||
output = self(input_ids=drafted_input_ids,
|
||||
past_key_values=past_key_values,
|
||||
return_dict=True,
|
||||
use_cache=True)
|
||||
logits = output['logits']
|
||||
if self.config.model_type == "qwen":
|
||||
temp_input_ids = torch.cat((input_ids, generate_ids[:, :step],
|
||||
draft_generate_ids[:, 1:step_draft + 2]), dim=-1)
|
||||
for i in range(logits.size(1)):
|
||||
logits[:, i, :] = logit_processor(temp_input_ids[:, :input_ids.size(1)+step+i],
|
||||
output['logits'][:, i, :])
|
||||
output_ids = sample(logits, do_sample=do_sample, top_k=top_k,
|
||||
top_p=top_p, temperature=temperature)
|
||||
if self.device.type == 'xpu':
|
||||
torch.xpu.synchronize()
|
||||
toc = time.time()
|
||||
self.verify_time.append(toc - tic)
|
||||
self.generate_time.append(self.draft_time[-1] + self.verify_time[-1])
|
||||
|
||||
past_key_values = output['past_key_values']
|
||||
# Compare drafts with target verified outputs
|
||||
# Drafts start from [1, k]
|
||||
# Verified output start from [0, k - 1]
|
||||
# including the one generated by the base model
|
||||
max_matched = ((output_ids[:, :-1] != drafted_input_ids[:, 1:]).cumsum(-1) == 0)
|
||||
max_matched = max_matched.sum(-1).item() + 1
|
||||
max_of_max_matched = output_ids.size(1)
|
||||
# Accept number is max_matched, min is 1
|
||||
self.accept_num.append(max_matched)
|
||||
# Clean up target model KV cache
|
||||
if max_of_max_matched != max_matched:
|
||||
output_ids = output_ids[:, :max_matched]
|
||||
# For Qwen
|
||||
if self.config.model_type == "qwen":
|
||||
past_key_values = [
|
||||
(k[:, :-(max_of_max_matched - max_matched), :],
|
||||
v[:, :-(max_of_max_matched - max_matched), :])
|
||||
for k, v in past_key_values
|
||||
]
|
||||
elif self.config.model_type == "chatglm":
|
||||
# for chatglm, cache shape is [sl, bs, nh, hn]
|
||||
past_key_values = [
|
||||
(k[:-(max_of_max_matched - max_matched), :, :, :],
|
||||
v[:-(max_of_max_matched - max_matched), :, :, :])
|
||||
for k, v in past_key_values
|
||||
]
|
||||
elif self.config.model_type == "baichuan":
|
||||
past_key_values = [
|
||||
(k[:, :, :-(max_of_max_matched - max_matched), :],
|
||||
v[:, :, :-(max_of_max_matched - max_matched), :])
|
||||
for k, v in past_key_values
|
||||
]
|
||||
else:
|
||||
past_key_values = [
|
||||
(k[:, :, :-(max_of_max_matched - max_matched)],
|
||||
v[:, :, :-(max_of_max_matched - max_matched)]) for k, v in past_key_values
|
||||
]
|
||||
|
||||
generate_ids[:, step:step+output_ids.size(1)] = output_ids
|
||||
current_input_ids = output_ids[:, -1:]
|
||||
|
||||
step += output_ids.size(1)
|
||||
|
||||
# remove one generated by the base model
|
||||
self.n_matched += max_matched - 1
|
||||
self.n_drafted += drafted_n_tokens
|
||||
step_verify += 1
|
||||
|
||||
if auto_th_stop_draft and step_verify % auto_parameters[0] == 0:
|
||||
tmp_matchness = auto_parameters[1]*(tmp_matchness) + \
|
||||
(1-auto_parameters[1])*((max_matched - 1)/drafted_n_tokens)
|
||||
if tmp_matchness < auto_parameters[2]:
|
||||
new_th_stop_draft = th_stop_draft+auto_parameters[3]
|
||||
else:
|
||||
if drafted_n_tokens == max_step_draft:
|
||||
new_th_stop_draft = th_stop_draft
|
||||
else:
|
||||
new_th_stop_draft = th_stop_draft - auto_parameters[3]
|
||||
th_stop_draft = auto_parameters[4] * th_stop_draft + \
|
||||
(1-auto_parameters[4]) * new_th_stop_draft
|
||||
|
||||
if hf_adjust:
|
||||
if (max_matched - 1) == max_step_draft:
|
||||
max_step_draft = min(draft_gen_length - 1, max_step_draft + 1)
|
||||
else:
|
||||
max_step_draft = max(1, max_step_draft - 1)
|
||||
|
||||
# Stop on eos and remove content after eos
|
||||
output_ids_list = output_ids[0].tolist()
|
||||
if self.config.eos_token_id in output_ids_list:
|
||||
idx = output_ids_list.index(self.config.eos_token_id)
|
||||
step -= (len(output_ids_list) - idx - 1)
|
||||
break
|
||||
|
||||
step = min(step, max_new_tokens)
|
||||
e2e_toc = time.time()
|
||||
self.n_token_generated = step
|
||||
self.e2e_time_without_first = e2e_toc - e2e_tic
|
||||
|
||||
generate_ids = torch.cat([input_ids, generate_ids[:, :step]], dim=-1)
|
||||
|
||||
return generate_ids
|
||||
Loading…
Reference in a new issue