ipex-llm/python/llm/src/ipex_llm/transformers/speculative.py
Wang, Jian4 47cabe8fcc
LLM: Fix no return_last_logit running bigdl_ipex chatglm3 (#10678)
* fix no return_last_logits

* update only for chatglm
2024-04-07 15:27:58 +08:00

1031 lines
52 KiB
Python

#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Some parts of this file is adapted from
# https://github.com/dilab-zju/self-speculative-decoding/blob/main/decoding.py and
# https://github.com/huggingface/transformers/blob/main/src/transformers/generation
# /utils.py
#
import torch
import time
import os
import copy
import logging
import transformers
from packaging import version
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
from transformers import top_k_top_p_filtering, GenerationConfig, \
LogitsProcessorList, StoppingCriteriaList
from ipex_llm.utils.common import invalidInputError
from transformers.modeling_outputs import CausalLMOutputWithPast
# patch GenerationMixin.generate
from transformers import GenerationMixin
original_generate = GenerationMixin.generate
query_group_size = 16
logger = logging.getLogger("ipex_llm.speculative")
@torch.no_grad()
def generate(
self,
inputs: Optional[torch.Tensor] = None,
generation_config: Optional[GenerationConfig] = None,
logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]]=None,
synced_gpus: Optional[bool] = None,
assistant_model: Optional["PreTrainedModel"] = None,
streamer: Optional["BaseStreamer"] = None,
**kwargs,
):
if hasattr(self, "draft_model"):
from ipex_llm.transformers.convert import get_enable_ipex
_enable_ipex = get_enable_ipex()
if _enable_ipex and inputs.size(1) < 256:
logger.warning(
"IPEX_CPU optimized models have issues for speculative decoding with short prompts"
"(length < 256). Using normal generate() method instead."
)
for var in ['max_step_draft', 'th_stop_draft', 'hf_adjust',
'auto_th_stop_draft', 'auto_parameters', 'min_step_draft',
'th_batch_num']:
value = kwargs.pop(var, None)
del self.draft_model
return original_generate(self,
inputs=inputs,
generation_config=generation_config,
logits_processor=logits_processor,
stopping_criteria=stopping_criteria,
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
synced_gpus=synced_gpus,
assistant_model=assistant_model,
streamer=streamer,
**kwargs)
# do speculative decoding
# TODO: maybe add other way to double check
new_speculative_kwargs = {}
for var in ['max_new_tokens', 'max_step_draft', 'th_stop_draft', 'do_sample',
'top_k', 'top_p', 'temperature', 'hf_adjust',
'auto_th_stop_draft', 'auto_parameters', 'repetition_penalty',
'attention_mask', 'min_step_draft']:
value = kwargs.pop(var, None)
if value is not None:
new_speculative_kwargs[var] = value
return self.speculative_generate(inputs=inputs,
draft_model=self.draft_model,
**new_speculative_kwargs)
else:
return original_generate(self,
inputs=inputs,
generation_config=generation_config,
logits_processor=logits_processor,
stopping_criteria=stopping_criteria,
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
synced_gpus=synced_gpus,
assistant_model=assistant_model,
streamer=streamer,
**kwargs)
GenerationMixin.generate = generate
def greedy(logits, return_probs: bool=False):
if return_probs:
all_probs = logits.softmax(-1)
probs, output_ids = torch.max(all_probs, dim=-1)
return output_ids, probs
else:
output_ids = torch.argmax(logits, dim=-1)
return output_ids
def deepmind_sample(logits, return_probs: bool=False, top_k: int=50,
top_p: float=0.7, temperature: float=0.7):
prob_list = logits_to_probs(logits, top_k=top_k, top_p=top_p, temperature=temperature)
output_ids = multinomial_sample_one_no_sync(prob_list)
if return_probs:
all_probs = logits.softmax(-1)
probs = torch.gather(all_probs, -1, output_ids.unsqueeze(-1)).squeeze(-1)
return output_ids, prob_list, probs
else:
return output_ids, prob_list
def logits_to_probs(logits, top_k: int=50, top_p: float=0.7, temperature: float=0.7):
invalidInputError(top_k != 1 and top_p != 0.0 and temperature != 0.0,
"top_k != 1 and top_p != 0.0 and temperature != 0.0 if do_sample=True")
_logits = top_k_top_p_filtering(logits.view(-1, logits.size(-1)) / temperature,
top_k=top_k, top_p=top_p)
prob_list = _logits.softmax(-1)
return prob_list
def multinomial_sample_one_no_sync(probs_sort):
q = torch.empty_like(probs_sort).exponential_(1)
return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int64)
def clear_benchmarks(self):
self.first_token_time = 0
self.generate_time = []
self.draft_time = []
self.verify_time = []
self.draft_num = []
self.accept_num = []
self.n_drafted = 0
self.n_matched = 0
def _prepare_past_key_values_storage_cpu(self, past_key_values,
max_new_tokens, _enable_ipex=False):
past_key_values_storage = []
# init ipex_past_key_values
if _enable_ipex:
ipex_past_key_values = []
cur_len = past_key_values[0][0].size(1)
if self.config.model_type == "chatglm":
len0 = past_key_values[0][1].size(0) # seq max length
len1 = past_key_values[0][1].size(1)
len2 = past_key_values[0][1].size(2)
len3 = past_key_values[0][1].size(3)
for pkv in past_key_values:
key = pkv[1]
value = pkv[2]
key = key.permute(1, 2, 0, 3).unsqueeze(-3)
key = key.expand(-1, -1, query_group_size, -1, -1)
key = key.contiguous().view(len1, len2 * query_group_size,
len0, len3).permute(2, 0, 1, 3)
value = value.permute(1, 2, 0, 3).unsqueeze(-3)
value = value.expand(-1, -1, query_group_size, -1, -1)
value = value.contiguous().view(len1, len2 * query_group_size,
len0, len3).permute(2, 0, 1, 3)
list = [key[:cur_len, :, :, :], value[:cur_len, :, :, :]]
ipex_past_key_values.append(list)
elif self.config.model_type == "qwen":
ipex_past_key_values = [
[pkv[1].permute(1, 0, 2, 3)[:, :cur_len, :, :],
pkv[2].permute(1, 0, 2, 3)[:, :cur_len, :, :]]
for pkv in past_key_values
]
else:
ipex_past_key_values = [
[pkv[1].permute(1, 2, 0, 3)[:, :, :cur_len, :],
pkv[2].permute(1, 2, 0, 3)[:, :, :cur_len, :]]
for pkv in past_key_values
]
if not _enable_ipex:
len0 = past_key_values[0][0].size(0)
len1 = past_key_values[0][0].size(1)
# gpt_bigcode has only 2-dimension kv
if len(past_key_values[0][0].shape) == 4:
len2 = past_key_values[0][0].size(2)
len3 = past_key_values[0][0].size(3)
for i in range(len(past_key_values)):
if self.config.model_type == "qwen":
k0 = torch.ones(len0, len2, len1 + max_new_tokens, len3,
dtype=torch.float32)
v0 = torch.ones(len0, len2, len1 + max_new_tokens, len3,
dtype=torch.float32)
k0 = k0.transpose(1, 2)
v0 = v0.transpose(1, 2)
past_key_values_storage.append((k0, v0))
past_key_values_storage[i][0][:, :len1, :, :] = past_key_values[i][0].to(
torch.float32)
past_key_values_storage[i][1][:, :len1, :, :] = past_key_values[i][1].to(
torch.float32)
elif self.config.model_type == "chatglm":
k0 = torch.ones(len1, len2, len0 + max_new_tokens, len3,
dtype=torch.float32)
v0 = torch.ones(len1, len2, len0 + max_new_tokens, len3,
dtype=torch.float32)
k0 = k0.permute(2, 0, 1, 3)
v0 = v0.permute(2, 0, 1, 3)
past_key_values_storage.append((k0, v0))
past_key_values_storage[i][0][:len0, :, :, :] = past_key_values[i][0].to(
torch.float32)
past_key_values_storage[i][1][:len0, :, :, :] = past_key_values[i][1].to(
torch.float32)
elif self.config.model_type == "gpt_bigcode":
kv = torch.ones(len0 + max_new_tokens, len1,
dtype=torch.float32)
past_key_values_storage.append(kv[None, :, :])
past_key_values_storage[i][0][:len0, :] = past_key_values[i][0].to(
torch.float32)
else:
k0 = torch.ones(len0, len1, len2 + max_new_tokens, len3,
dtype=torch.float32)
v0 = torch.ones(len0, len1, len2 + max_new_tokens, len3,
dtype=torch.float32)
past_key_values_storage.append((k0, v0))
past_key_values_storage[i][0][:, :, :len2, :] = past_key_values[i][0].to(
torch.float32)
past_key_values_storage[i][1][:, :, :len2, :] = past_key_values[i][1].to(
torch.float32)
else:
len0 = past_key_values[0][1].size(1)
len1 = past_key_values[0][1].size(2)
len2 = past_key_values[0][0].size(2) # seq length
len3 = past_key_values[0][1].size(3)
for i in range(len(past_key_values)):
if self.config.model_type == "chatglm":
k0 = torch.ones(len0, len1 * query_group_size, len2 + max_new_tokens, len3,
dtype=torch.float32)
v0 = torch.ones(len0, len1 * query_group_size, len2 + max_new_tokens, len3,
dtype=torch.float32)
k0 = k0.permute(2, 0, 1, 3)
v0 = v0.permute(2, 0, 1, 3)
past_key_values_storage.append((k0, v0))
past_key_values_storage[i][0][:len2, :, :, :] = ipex_past_key_values[i][0].to(
torch.float32)
past_key_values_storage[i][1][:len2, :, :, :] = ipex_past_key_values[i][1].to(
torch.float32)
elif self.config.model_type == "qwen":
k0 = torch.ones(len0, len1, len2 + max_new_tokens, len3,
dtype=torch.float32)
v0 = torch.ones(len0, len1, len2 + max_new_tokens, len3,
dtype=torch.float32)
k0 = k0.permute(0, 2, 1, 3)
v0 = v0.permute(0, 2, 1, 3)
past_key_values_storage.append((k0, v0))
past_key_values_storage[i][0][:, :len2, :, :] = ipex_past_key_values[i][0].to(
torch.float32)
past_key_values_storage[i][1][:, :len2, :, :] = ipex_past_key_values[i][1].to(
torch.float32)
else:
k0 = torch.ones(len0, len1, len2 + max_new_tokens, len3,
dtype=torch.float32)
v0 = torch.ones(len0, len1, len2 + max_new_tokens, len3,
dtype=torch.float32)
past_key_values_storage.append((k0, v0))
past_key_values_storage[i][0][:, :, :len2, :] = ipex_past_key_values[i][0].to(
torch.float32)
past_key_values_storage[i][1][:, :, :len2, :] = ipex_past_key_values[i][1].to(
torch.float32)
return past_key_values_storage
def _prepare_draft_past_key_values_cpu(self, past_key_values,
past_key_values_storage, _enable_ipex):
tmp_past_key_values = []
for i in range(len(past_key_values)):
if self.config.model_type == "qwen":
len1 = past_key_values[0][0].size(1)
k0 = past_key_values_storage[i][0][:, :len1, :, :]
v0 = past_key_values_storage[i][1][:, :len1, :, :]
tmp_past_key_values.append((k0, v0))
elif self.config.model_type == "chatglm":
if not _enable_ipex:
len0 = past_key_values[0][0].size(0)
else:
len0 = past_key_values[0][0].size(1)
k0 = past_key_values_storage[i][0][:len0, :, :, :]
v0 = past_key_values_storage[i][1][:len0, :, :, :]
tmp_past_key_values.append((k0, v0))
elif self.config.model_type == "gpt_bigcode":
len0 = past_key_values[0][0].size(0)
kv = past_key_values_storage[i][0][:len0, :]
tmp_past_key_values.append(kv[None, :, :])
else:
len2 = past_key_values[0][0].size(2)
k0 = past_key_values_storage[i][0][:, :, :len2, :]
v0 = past_key_values_storage[i][1][:, :, :len2, :]
tmp_past_key_values.append((k0, v0))
return tmp_past_key_values
def _update_past_key_values_storage_cpu(self, past_key_values, past_key_values_storage,
original_draft_past_key_values, _enable_ipex=False):
for i in range(len(past_key_values)):
if not _enable_ipex:
if self.config.model_type == "qwen":
size = original_draft_past_key_values[i][0].size(1)
size1 = past_key_values[i][0].size(1)
past_key_values_storage[i][0][:, size:size1, :, :] = \
past_key_values[i][0][:, size:size1, :, :].to(torch.float32)
past_key_values_storage[i][1][:, size:size1, :, :] = \
past_key_values[i][1][:, size:size1, :, :].to(torch.float32)
elif self.config.model_type == "chatglm":
size = original_draft_past_key_values[i][0].size(0)
size1 = past_key_values[i][0].size(0)
past_key_values_storage[i][0][size:size1, :, :, :] = \
past_key_values[i][0][size:size1, :, :, :].to(torch.float32)
past_key_values_storage[i][1][size:size1, :, :, :] = \
past_key_values[i][1][size:size1, :, :, :].to(torch.float32)
elif self.config.model_type == "gpt_bigcode":
size = original_draft_past_key_values[i][0].size(0)
size1 = past_key_values[i][0].size(0)
if size < size1:
past_key_values_storage[i][0][size:size1, :] = \
past_key_values[i][0][size:size1, :].to(torch.float32)
else:
size = original_draft_past_key_values[i][0].size(2)
size1 = past_key_values[i][0].size(2)
past_key_values_storage[i][0][:, :, size:size1, :] = \
past_key_values[i][0][:, :, size:size1, :].to(torch.float32)
past_key_values_storage[i][1][:, :, size:size1, :] = \
past_key_values[i][1][:, :, size:size1, :].to(torch.float32)
else:
size = original_draft_past_key_values[i][0].size(2)
size1 = past_key_values[i][0].size(1)
if self.config.model_type == "chatglm":
size = original_draft_past_key_values[0][0].size(0)
size1 = past_key_values[0][0].size(1)
len0 = past_key_values[0][1].size(0) # seq max_length
len1 = past_key_values[0][1].size(1)
len2 = past_key_values[0][1].size(2)
len3 = past_key_values[0][1].size(3)
key0 = torch.ones(size1-size, len1, len2, len3,
dtype=torch.float32)
value0 = torch.ones(size1-size, len1, len2, len3,
dtype=torch.float32)
key0 = past_key_values[i][1][size:size1, :, :, :]
value0 = past_key_values[i][2][size:size1, :, :, :]
key = key0.permute(1, 2, 0, 3).unsqueeze(-3)
key = key.expand(-1, -1, query_group_size, -1, -1)
key = key.contiguous().view(len1, len2 * query_group_size, size1-size, len3)
key = key.permute(2, 0, 1, 3)
value = value0.permute(1, 2, 0, 3).unsqueeze(-3)
value = value.expand(-1, -1, query_group_size, -1, -1)
value = value.contiguous().view(len1, len2 * query_group_size, size1-size, len3)
value = value.permute(2, 0, 1, 3)
past_key_values_storage[i][0][size:size1, :, :, :] = \
key.to(torch.float32)
past_key_values_storage[i][1][size:size1, :, :, :] = \
value.to(torch.float32)
elif self.config.model_type == "qwen":
size = original_draft_past_key_values[0][0].size(1)
delta_past_key = \
past_key_values[i][1][size:size1, :, :, :].permute(1, 0, 2, 3)
delta_past_value = \
past_key_values[i][2][size:size1, :, :, :].permute(1, 0, 2, 3)
past_key_values_storage[i][0][:, size:size1, :, :] = \
delta_past_key.to(torch.float32)
past_key_values_storage[i][1][:, size:size1, :, :] = \
delta_past_value.to(torch.float32)
else:
delta_past_key = \
past_key_values[i][1][size:size1, :, :, :].permute(1, 2, 0, 3)
delta_past_value = \
past_key_values[i][2][size:size1, :, :, :].permute(1, 2, 0, 3)
past_key_values_storage[i][0][:, :, size:size1, :] = \
delta_past_key.to(torch.float32)
past_key_values_storage[i][1][:, :, size:size1, :] = \
delta_past_value.to(torch.float32)
def _check_and_extend_kv_cache(past_key_values, max_step_draft, kv_alloc_block_len=256,
model_type="llama"):
from ipex_llm.transformers.models.utils import is_enough_kv_cache_room_4_31, \
extend_kv_cache
enough_kv_room = True
if model_type not in ["chatglm", "qwen", "baichuan", "llama", "mistral",
"gptj", "opt"]:
return past_key_values, False
cache_k = past_key_values[0][0]
if model_type == "chatglm":
cache_k = cache_k.permute(1, 2, 0, 3)
elif model_type == "qwen":
cache_k = cache_k.transpose(1, 2)
enough_kv_room = is_enough_kv_cache_room_4_31(past_key_value=(cache_k, None),
seq_len=max_step_draft)
bsz, num_heads, current_seq_len, head_dim = cache_k.shape
device = past_key_values[0][0].device
if not enough_kv_room:
past_key_values = list(past_key_values)
for i in range(len(past_key_values)):
cache_k = past_key_values[i][0]
cache_v = past_key_values[i][1]
if model_type == "chatglm":
cache_k = cache_k.permute(1, 2, 0, 3)
cache_v = cache_v.permute(1, 2, 0, 3)
elif model_type == "qwen":
cache_k = cache_k.transpose(1, 2)
cache_v = cache_v.transpose(1, 2)
new_cache_k, new_cache_v = extend_kv_cache(
bsz,
num_heads, # Support GQA
head_dim,
cache_k.size(2),
current_seq_len + max_step_draft + kv_alloc_block_len,
dtype=cache_v.dtype,
device=device)
new_cache_k[:] = cache_k
new_cache_v[:] = cache_v
if model_type == "chatglm":
past_key_values[i] = (new_cache_k.permute(2, 0, 1, 3),
new_cache_v.permute(2, 0, 1, 3))
elif model_type == "qwen":
past_key_values[i] = (new_cache_k.transpose(1, 2), new_cache_v.transpose(1, 2))
else:
past_key_values[i] = (new_cache_k, new_cache_v)
return past_key_values, not enough_kv_room
@torch.no_grad()
def speculative_generate(self,
inputs: Optional[torch.Tensor] = None,
draft_model=None,
max_new_tokens=10,
max_step_draft=8,
th_stop_draft=0.8,
auto_th_stop_draft=True,
auto_parameters=[1, 0.5, 0.9, 1e-2, 0.9],
hf_adjust=False,
min_step_draft=3,
generation_config: Optional[GenerationConfig] = None,
attention_mask=None,
**sampling_kwargs):
invalidInputError(draft_model is not None,
"Draft model should be provided.")
# min_step_draft >= 1. Since the max_step_draft may adjust,
# min_step_draft can > max_step_draft
min_step_draft = min_step_draft if min_step_draft >= 1 else 1
if generation_config is None:
generation_config = self.generation_config
generation_config = copy.deepcopy(generation_config)
# All unused kwargs must be model kwargs
model_kwargs = generation_config.update(**sampling_kwargs)
generation_config.validate()
self._validate_model_kwargs(model_kwargs.copy())
if generation_config.pad_token_id is None and generation_config.eos_token_id is not None:
if model_kwargs.get("attention_mask", None) is None:
logger.warning(
"The attention mask and the pad token id were not set. As a consequence, "
"you may observe unexpected behavior. Please pass your input's "
"`attention_mask` to obtain reliable results."
)
eos_token_id = generation_config.eos_token_id
if isinstance(eos_token_id, list):
eos_token_id = eos_token_id[0]
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:"
f"{eos_token_id} for open-end generation.")
generation_config.pad_token_id = eos_token_id
# 2. Set generation parameters if not already defined
logits_processor = LogitsProcessorList()
stopping_criteria = StoppingCriteriaList()
# 3. Define model inputs
# inputs_tensor has to be defined
# model_input_name is defined if model-specific keyword input is passed
# otherwise model_input_name is None
# all model-specific keyword inputs are removed from `model_kwargs`
inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(
inputs, generation_config.bos_token_id, model_kwargs
)
batch_size = inputs_tensor.shape[0]
# 4. Define other model kwargs
# Removed not used
# decoder-only models should use left-padding for generation
if not self.config.is_encoder_decoder:
# If `input_ids` was given, check if the last id in any sequence is `pad_token_id`
# Note: If using, `inputs_embeds` this check does not work,
# because we want to be more hands-off.
if (
generation_config.pad_token_id is not None
and len(inputs_tensor.shape) == 2
and torch.sum(inputs_tensor[:, -1] == generation_config.pad_token_id) > 0
):
logger.warning(
"A decoder-only architecture is being used, but right-padding "
"was detected! For correct generation results, please set "
"`padding_side='left'` when initializing the tokenizer."
)
else:
invalidInputError(False, "encoder-decoder models are not supported now.")
# 5. Prepare `input_ids` which will be used for auto-regressive generation
input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids")
# if streamer is not None:
# streamer.put(input_ids.cpu())
input_ids_length = input_ids.shape[-1]
# Here we use sample generation mode
# 8. prepare distribution pre_processing samplers
logits_processor = self._get_logits_processor(
generation_config=generation_config,
input_ids_seq_length=input_ids_length,
encoder_input_ids=inputs_tensor,
prefix_allowed_tokens_fn=None,
logits_processor=logits_processor,
)
# 12. expand input_ids with `num_return_sequences` additional sequences per batch
input_ids, model_kwargs = self._expand_inputs_for_generation(
input_ids=input_ids,
expand_size=generation_config.num_return_sequences,
is_encoder_decoder=self.config.is_encoder_decoder,
**model_kwargs,
)
step = 0
step_draft = 0
step_verify = 0
draft_gen_length = max_step_draft + 6 if hf_adjust else max_step_draft + 1
current_input_ids = input_ids
generate_ids = torch.empty([input_ids.size(0), max_new_tokens+max_step_draft],
dtype=torch.long, device=self.device)
draft_generate_ids = torch.empty([input_ids.size(0), draft_gen_length],
dtype=torch.long, device=self.device)
past_key_values = None
past_key_values_storage = []
from ipex_llm.transformers.convert import get_enable_ipex
_enable_ipex = get_enable_ipex()
if _enable_ipex:
if not ((self.config.model_type == 'baichuan') or
('llama' in self.config.model_type) or
("mistral" in self.config.model_type) or
("qwen" in self.config.model_type) or
("chatglm" in self.config.model_type)):
invalidInputError(False, "BigDL Speculative Decoding with IPEX-LLM only supports \
Llama, Baichuan2, Mistral, ChatGLM and Qwen models currently.")
if "chatglm" in self.config.model_type:
global query_group_size
query_group_size = draft_model.config.num_attention_heads // \
draft_model.config.multi_query_group_num
tmp_matchness = 0
e2e_tic = 0.0
self.clear_benchmarks()
if self.device.type == 'xpu':
torch.xpu.empty_cache()
# Example:
# Target model forward for the first token
# Step 1. target_model(prompt) -> a
# Generate k drafts, k = 3
# Step 2. draft_model(a) -> b, c, d
# Verify k drafts -> k + 1 results (f is always accepted)
# Step 3. target_model (a, b, c, d) -> b, c, e, f
# Compare drafts with results
# Step 4. (b, c, e) match (b, c, d) -> b, c
# Final, f will be the next input, just like a
# Step 5. Final-> b, c, f
while True:
if step >= max_new_tokens:
break
if step == 0:
# first token use full model
tic = time.time()
output = self(input_ids=current_input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
return_dict=True,
use_cache=True)
if _enable_ipex:
output = CausalLMOutputWithPast(
logits=output[0],
past_key_values=output[1],
)
logits = output['logits']
logits = logits[:, -1:]
logits[:, -1, :] = logits_processor(current_input_ids, logits[:, -1, :])
if generation_config.do_sample:
output_ids, prob_list = deepmind_sample(logits,
top_k=generation_config.top_k,
top_p=generation_config.top_p,
temperature=generation_config.temperature)
else:
output_ids = greedy(logits)
generate_ids[:, step] = output_ids
current_input_ids = output_ids
past_key_values = output['past_key_values']
step += 1
if self.device.type == 'xpu':
torch.xpu.synchronize()
toc = time.time()
self.first_token_time = toc - tic
e2e_tic = time.time()
else:
draft_current_input_ids = current_input_ids
# Target model KV cache to draft model
if self.device.type == 'cpu':
# init past_key_values_storage and assign initial fp32 value
if _enable_ipex:
draft_past_key_values = past_key_values
else:
if step == 1:
past_key_values_storage = \
_prepare_past_key_values_storage_cpu(self, past_key_values,
max_new_tokens, _enable_ipex)
# each iter cut off cur_len kv_cache from past_key_values1
draft_past_key_values = \
_prepare_draft_past_key_values_cpu(self, past_key_values,
past_key_values_storage, _enable_ipex)
original_draft_past_key_values = draft_past_key_values
else:
past_key_values, extend_kv = _check_and_extend_kv_cache(past_key_values,
max_step_draft,
max_new_tokens - step + 40,
self.config.model_type)
draft_past_key_values = past_key_values
draft_generate_ids[:, 0] = current_input_ids
draft_prob_list = []
tic = time.time()
random_probs = None
if generation_config.do_sample:
random_probs = torch.rand(max_step_draft, device=self.device, dtype=self.dtype)
# Draft model auto-regressively generate k tokens
# Early stop when prob less then th_stop_draft
for step_draft in range(max_step_draft):
if attention_mask is None:
draft_attention_mask = None
else:
appended_len = step_draft + step
ones_to_append = torch.ones(attention_mask.size(0), appended_len)
draft_attention_mask = torch.cat((attention_mask, ones_to_append), dim=1)
forward_args = {
"input_ids": draft_current_input_ids,
"past_key_values": draft_past_key_values,
"attention_mask": draft_attention_mask,
"return_dict": True,
"use_cache": True,
}
if self.config.model_type == "chatglm":
if _enable_ipex:
past_key_value_len = past_key_values[0][0].shape[1]
else:
past_key_value_len = past_key_values[0][0].shape[0]
position_ids = torch.Tensor([[past_key_value_len + step_draft]]).long()
forward_args["position_ids"] = position_ids
elif self.config.model_type == "gptj":
past_length = draft_past_key_values[0][0].size(2)
position_ids = torch.Tensor([[past_length]]).long().to(self.device)
forward_args["position_ids"] = position_ids
if _enable_ipex:
if any(keyword in self.config.model_type
for keyword in ["llama", "chatglm", "mistral"]):
past_key_value_len = draft_past_key_values[0][0].shape[2]
position_ids = torch.Tensor([[past_key_value_len + step_draft]]).long()
position_ids = position_ids[:, :-draft_current_input_ids.size(0)]
if self.config.model_type == "chatglm":
draft_output = draft_model.trace_graph(
input_ids=draft_current_input_ids,
attention_mask=draft_attention_mask,
position_ids=position_ids,
return_last_logit=torch.tensor(False),
past_key_values=draft_past_key_values,
)
else:
draft_output = draft_model.trace_graph(
input_ids=draft_current_input_ids,
attention_mask=draft_attention_mask,
position_ids=position_ids,
past_key_values=draft_past_key_values,
)
elif self.config.model_type == "baichuan":
if self.config.hidden_size == 4096:
past_key_value_len = draft_past_key_values[0][0].shape[2]
seq_len = draft_current_input_ids.shape[1]
seq_len_with_past = seq_len + past_key_value_len
position_ids = torch.arange(past_key_value_len,
seq_len_with_past,
dtype=torch.long,
device=draft_current_input_ids.device)
position_ids = position_ids.unsqueeze(0).view(-1, seq_len)
draft_output = draft_model.trace_graph(
input_ids=draft_current_input_ids,
attention_mask=draft_attention_mask,
position_ids=position_ids,
past_key_values=draft_past_key_values,
)
elif self.config.hidden_size == 5120:
draft_output = draft_model.trace_graph(
input_ids=draft_current_input_ids,
attention_mask=draft_attention_mask,
past_key_values=draft_past_key_values,
)
elif "qwen" in self.config.model_type:
draft_output = draft_model.trace_graph(
input_ids=draft_current_input_ids,
attention_mask=draft_attention_mask,
past_key_values=draft_past_key_values,
)
else:
invalidInputError(False, "BigDL Speculative Decoding with IPEX-LLM only supports \
Llama, Baichuan2, Mistral, ChatGLM and Qwen models currently.")
draft_output = CausalLMOutputWithPast(
logits=draft_output[0],
past_key_values=draft_output[1],
)
else:
draft_output = draft_model(**forward_args)
temp_input_ids = torch.cat((input_ids, generate_ids[:, :step],
draft_generate_ids[:, 1:step_draft+1]), dim=-1)
logits = draft_output['logits']
logits[:, -1, :] = logits_processor(temp_input_ids,
draft_output['logits'][:, -1, :])
if generation_config.do_sample:
draft_output_ids, draft_probs, draft_output_probs = deepmind_sample(
logits,
return_probs=True,
top_k=generation_config.top_k,
top_p=generation_config.top_p,
temperature=generation_config.temperature)
draft_prob_list.append(draft_probs)
else:
draft_output_ids, draft_output_probs = greedy(
logits,
return_probs=True)
draft_generate_ids[:, step_draft+1] = draft_output_ids
draft_current_input_ids = draft_output_ids
draft_past_key_values = draft_output['past_key_values']
# check if draft prob is less then th_stop_draft
# Draft number + step >= max output token number
th_random = 1 if random_probs is None else random_probs[step_draft]
if (draft_output_probs.item() < th_stop_draft and th_random > 0.3 and
step_draft + 1 >= min_step_draft) or \
step + step_draft + 2 >= max_new_tokens:
break
if self.device.type == 'xpu':
torch.xpu.synchronize()
toc = time.time()
self.draft_time.append(toc - tic)
drafted_n_tokens = step_draft + 1
# raft input + raft completion
drafted_input_ids = draft_generate_ids[:, :drafted_n_tokens+1]
self.draft_num.append(drafted_n_tokens)
tic = time.time()
# Target model verify drafts
# input.size is k + 1, 1 previous token + k drafts
# verified output.size is k + 1, k token + 1 final
# Final token is always accepted
if attention_mask is None:
cur_attention_mask = None
else:
appended_len = drafted_input_ids.size(1) + step - 1
ones_to_append = torch.ones(attention_mask.size(0), appended_len)
cur_attention_mask = torch.cat((attention_mask, ones_to_append), dim=1)
if _enable_ipex and hasattr(self, "trace_graph"):
if self.config.model_type == "baichuan":
if self.config.hidden_size == 4096:
past_key_value_len = past_key_values[0][0].shape[2]
seq_len = drafted_input_ids.shape[1]
seq_len_with_past = seq_len + past_key_value_len
position_ids = torch.arange(past_key_value_len,
seq_len_with_past,
dtype=torch.long,
device=drafted_input_ids.device)
position_ids = position_ids.unsqueeze(0).view(-1, seq_len)
output = self.trace_graph(input_ids=drafted_input_ids,
attention_mask=cur_attention_mask,
past_key_values=past_key_values,
position_ids=position_ids,
)
elif self.config.hidden_size == 5120:
output = self.trace_graph(input_ids=drafted_input_ids,
attention_mask=cur_attention_mask,
past_key_values=past_key_values,
)
elif "llama" in self.config.model_type:
past_key_value_len = past_key_values[0][0].shape[2]
position_ids = torch.arange(drafted_input_ids.shape[1], dtype=torch.long,
device=drafted_input_ids.device).unsqueeze(0)
position_ids = position_ids.repeat(1, 1) + past_key_value_len
output = self.trace_graph(input_ids=drafted_input_ids,
attention_mask=cur_attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
)
elif "chatglm" in self.config.model_type:
past_key_value_len = past_key_values[0][0].shape[2]
position_ids = torch.arange(drafted_input_ids.shape[1], dtype=torch.long,
device=drafted_input_ids.device).unsqueeze(0)
position_ids = position_ids.repeat(1, 1) + past_key_value_len
output = self.trace_graph(input_ids=drafted_input_ids,
attention_mask=cur_attention_mask,
position_ids=position_ids,
return_last_logit=torch.tensor(False),
past_key_values=past_key_values,)
elif "qwen" in self.config.model_type:
output = self.trace_graph(input_ids=drafted_input_ids,
attention_mask=cur_attention_mask,
past_key_values=past_key_values)
elif "mistral" in self.config.model_type:
past_key_value_len = past_key_values[0][0].shape[2]
seq_len = drafted_input_ids.shape[1]
position_ids = torch.arange(past_key_value_len,
seq_len + past_key_value_len,
dtype=torch.long,
device=drafted_input_ids.device)
position_ids = position_ids.unsqueeze(0).view(-1, seq_len)
output = self.trace_graph(input_ids=drafted_input_ids,
attention_mask=cur_attention_mask,
past_key_values=past_key_values,
position_ids=position_ids,
)
logits = output[0]
past_key_values = output[1]
else:
forward_args = {
"input_ids": drafted_input_ids,
"past_key_values": past_key_values,
"attention_mask": cur_attention_mask,
"return_dict": True,
"use_cache": True,
}
if self.config.model_type == "chatglm":
past_key_value_len = past_key_values[0][0].shape[0]
position_ids = torch.arange(drafted_input_ids.shape[1], dtype=torch.long,
device=drafted_input_ids.device)
position_ids = position_ids.unsqueeze(0).repeat(1, 1) + past_key_value_len
forward_args["position_ids"] = position_ids
elif self.config.model_type == "gptj":
past_length = past_key_values[0][0].size(2)
input_len = drafted_input_ids.shape[1]
position_ids = torch.arange(past_length, input_len + past_length,
dtype=torch.long, device=drafted_input_ids.device)
position_ids = position_ids.unsqueeze(0).view(-1, input_len)
forward_args["position_ids"] = position_ids
output = self(**forward_args)
if isinstance(output, dict):
logits = output['logits']
past_key_values = output['past_key_values']
temp_input_ids = torch.cat((input_ids, generate_ids[:, :step],
draft_generate_ids[:, 1:step_draft + 2]), dim=-1)
for i in range(logits.size(1)):
logits[:, i, :] = logits_processor(temp_input_ids[:, :input_ids.size(1)+step+i],
logits[:, i, :])
if generation_config.do_sample:
target_probs = logits_to_probs(logits,
top_k=generation_config.top_k,
top_p=generation_config.top_p,
temperature=generation_config.temperature)
else:
output_ids = greedy(logits)
if self.device.type == 'xpu':
torch.xpu.synchronize()
if extend_kv:
torch.xpu.empty_cache()
toc = time.time()
self.verify_time.append(toc - tic)
self.generate_time.append(self.draft_time[-1] + self.verify_time[-1])
if past_key_values is None:
past_key_values = output['past_key_values']
if generation_config.do_sample:
draft_tokens = drafted_input_ids[:, 1:].squeeze(0)
draft_probs = torch.stack(draft_prob_list).squeeze((1, 2))
# q: target prob, p: draft prob
# q >= p: always accept draft token
# q < p: q/p prob to accept draft token
p = draft_probs[torch.arange(0, drafted_n_tokens), draft_tokens]
q = target_probs[torch.arange(0, drafted_n_tokens), draft_tokens]
accept_draft_prob = torch.minimum(torch.ones(()), q[:drafted_n_tokens] / p)
rejected_locations = (random_probs[:drafted_n_tokens] > accept_draft_prob).nonzero()
if rejected_locations.shape[0] == 0: # All draft tokens have been accepted
max_matched = drafted_n_tokens + 1
last_token = multinomial_sample_one_no_sync(target_probs[-1])
output_ids = torch.cat([draft_tokens, last_token])
else:
max_matched = rejected_locations[0].item()
p = draft_probs[max_matched]
q = target_probs[max_matched]
resample_prob = q - p
resample_prob = torch.where(resample_prob > 0, resample_prob, 0.0)
resample_prob = resample_prob / resample_prob.sum()
next_token = multinomial_sample_one_no_sync(resample_prob)
output_ids = torch.cat([draft_tokens[:max_matched], next_token])
max_matched += 1
output_ids = output_ids.unsqueeze(0)
else:
# Compare drafts with target verified outputs
# Drafts start from [1, k]
# Verified output start from [0, k - 1]
# including the one generated by the base model
max_matched = ((output_ids[:, :-1] != drafted_input_ids[:, 1:]).cumsum(-1) == 0)
max_matched = max_matched.sum(-1).item() + 1
max_of_max_matched = output_ids.size(1)
# Accept number is max_matched, min is 1
self.accept_num.append(max_matched)
# Clean up target model KV cache
if max_of_max_matched != max_matched:
output_ids = output_ids[:, :max_matched]
if _enable_ipex:
cur_len = past_key_values[0][0].size(1)
delta = max_of_max_matched - max_matched
tmp = torch.empty(1, (cur_len - delta), (cur_len - delta), 1,
dtype=torch.long,
).contiguous()
past_key_values = [[tmp, key_cache, value_cache, beam_idx]
for _, key_cache, value_cache, beam_idx in past_key_values]
else:
if self.config.model_type in ["qwen"]:
past_key_values = [
(k[:, :-(max_of_max_matched - max_matched), :],
v[:, :-(max_of_max_matched - max_matched), :])
for k, v in past_key_values
]
elif self.config.model_type == "chatglm":
# for chatglm, cache shape is [sl, bs, nh, hn]
past_key_values = [
(k[:-(max_of_max_matched - max_matched), :, :, :],
v[:-(max_of_max_matched - max_matched), :, :, :])
for k, v in past_key_values
]
elif self.config.model_type in ["baichuan", "gptj"]:
past_key_values = [
(k[:, :, :-(max_of_max_matched - max_matched), :],
v[:, :, :-(max_of_max_matched - max_matched), :])
for k, v in past_key_values
]
elif self.config.model_type == "gpt_bigcode":
past_key_values = [
kv[:, :-(max_of_max_matched - max_matched)]
for kv in past_key_values
]
else:
past_key_values = [
(k[:, :, :-(max_of_max_matched - max_matched)],
v[:, :, :-(max_of_max_matched - max_matched)])
for k, v in past_key_values
]
# Each iter assign new_matched kv_cache to past_key_values1
if self.device.type == 'cpu' and (not _enable_ipex):
_update_past_key_values_storage_cpu(self, past_key_values, past_key_values_storage,
original_draft_past_key_values,
_enable_ipex)
generate_ids[:, step:step+output_ids.size(1)] = output_ids
current_input_ids = output_ids[:, -1:]
step += output_ids.size(1)
# remove one generated by the base model
self.n_matched += max_matched - 1
self.n_drafted += drafted_n_tokens
step_verify += 1
if auto_th_stop_draft and step_verify % auto_parameters[0] == 0:
tmp_matchness = auto_parameters[1]*(tmp_matchness) + \
(1-auto_parameters[1])*((max_matched - 1)/drafted_n_tokens)
if tmp_matchness < auto_parameters[2]:
new_th_stop_draft = th_stop_draft+auto_parameters[3]
else:
if drafted_n_tokens == max_step_draft:
new_th_stop_draft = th_stop_draft
else:
new_th_stop_draft = th_stop_draft - auto_parameters[3]
th_stop_draft = auto_parameters[4] * th_stop_draft + \
(1-auto_parameters[4]) * new_th_stop_draft
if hf_adjust:
if (max_matched - 1) == max_step_draft:
max_step_draft = min(draft_gen_length - 1, max_step_draft + 1)
else:
max_step_draft = max(1, max_step_draft - 1)
# Stop on eos and remove content after eos
output_ids_list = output_ids[0].tolist()
if generation_config.eos_token_id in output_ids_list:
idx = output_ids_list.index(generation_config.eos_token_id)
step -= (len(output_ids_list) - idx - 1)
break
step = min(step, max_new_tokens)
e2e_toc = time.time()
self.n_token_generated = step
self.e2e_time_without_first = e2e_toc - e2e_tic
generate_ids = torch.cat([input_ids, generate_ids[:, :step]], dim=-1)
return generate_ids