[LLM] Use IPEX Optimization for Self Speculative Decoding (#9997)
Use IPEX Optimization for Self Speculative Decoding
This commit is contained in:
parent
ccf8f613fb
commit
f57d0fda8b
1 changed files with 239 additions and 125 deletions
|
|
@ -57,7 +57,8 @@ def generate(
|
||||||
new_speculative_kwargs = {}
|
new_speculative_kwargs = {}
|
||||||
for var in ['max_new_tokens', 'max_step_draft', 'th_stop_draft', 'do_sample',
|
for var in ['max_new_tokens', 'max_step_draft', 'th_stop_draft', 'do_sample',
|
||||||
'top_k', 'top_p', 'temperature', 'hf_adjust',
|
'top_k', 'top_p', 'temperature', 'hf_adjust',
|
||||||
'auto_th_stop_draft', 'auto_parameters', 'repetition_penalty']:
|
'auto_th_stop_draft', 'auto_parameters', 'repetition_penalty',
|
||||||
|
'attention_mask']:
|
||||||
value = kwargs.pop(var, None)
|
value = kwargs.pop(var, None)
|
||||||
if value is not None:
|
if value is not None:
|
||||||
new_speculative_kwargs[var] = value
|
new_speculative_kwargs[var] = value
|
||||||
|
|
@ -115,6 +116,133 @@ def clear_benchmarks(self):
|
||||||
self.n_matched = 0
|
self.n_matched = 0
|
||||||
|
|
||||||
|
|
||||||
|
def _prepare_past_key_values_storage_cpu(self, past_key_values,
|
||||||
|
max_new_tokens, _enable_ipex=False):
|
||||||
|
past_key_values_storage = []
|
||||||
|
if _enable_ipex:
|
||||||
|
ipex_past_key_values = []
|
||||||
|
cur_len = past_key_values[0][0].size(1)
|
||||||
|
ipex_past_key_values = [
|
||||||
|
[pkv[1].permute(1, 2, 0, 3)[:, :, :cur_len, :],
|
||||||
|
pkv[2].permute(1, 2, 0, 3)[:, :, :cur_len, :]]
|
||||||
|
for pkv in past_key_values
|
||||||
|
]
|
||||||
|
|
||||||
|
for i in range(len(past_key_values)):
|
||||||
|
if not _enable_ipex:
|
||||||
|
len0 = past_key_values[i][0].size(0)
|
||||||
|
len1 = past_key_values[i][0].size(1)
|
||||||
|
len2 = past_key_values[i][0].size(2)
|
||||||
|
len3 = past_key_values[i][0].size(3)
|
||||||
|
else:
|
||||||
|
len0 = past_key_values[i][1].size(1)
|
||||||
|
len1 = past_key_values[i][1].size(2)
|
||||||
|
len2 = past_key_values[i][0].size(2) # seq length
|
||||||
|
len3 = past_key_values[i][1].size(3)
|
||||||
|
if self.config.model_type == "qwen":
|
||||||
|
k0 = torch.ones(len0, len2, len1 + max_new_tokens, len3,
|
||||||
|
dtype=torch.float32)
|
||||||
|
v0 = torch.ones(len0, len2, len1 + max_new_tokens, len3,
|
||||||
|
dtype=torch.float32)
|
||||||
|
k0 = k0.transpose(1, 2)
|
||||||
|
v0 = v0.transpose(1, 2)
|
||||||
|
past_key_values_storage.append((k0, v0))
|
||||||
|
past_key_values_storage[i][0][:, :len1, :, :] = past_key_values[i][0].to(
|
||||||
|
torch.float32)
|
||||||
|
past_key_values_storage[i][1][:, :len1, :, :] = past_key_values[i][1].to(
|
||||||
|
torch.float32)
|
||||||
|
elif self.config.model_type == "chatglm":
|
||||||
|
k0 = torch.ones(len1, len2, len0 + max_new_tokens, len3,
|
||||||
|
dtype=torch.float32)
|
||||||
|
v0 = torch.ones(len1, len2, len0 + max_new_tokens, len3,
|
||||||
|
dtype=torch.float32)
|
||||||
|
k0 = k0.permute(2, 0, 1, 3)
|
||||||
|
v0 = v0.permute(2, 0, 1, 3)
|
||||||
|
past_key_values_storage.append((k0, v0))
|
||||||
|
past_key_values_storage[i][0][:len0, :, :, :] = past_key_values[i][0].to(
|
||||||
|
torch.float32)
|
||||||
|
past_key_values_storage[i][1][:len0, :, :, :] = past_key_values[i][1].to(
|
||||||
|
torch.float32)
|
||||||
|
else:
|
||||||
|
k0 = torch.ones(len0, len1, len2 + max_new_tokens, len3,
|
||||||
|
dtype=torch.float32)
|
||||||
|
v0 = torch.ones(len0, len1, len2 + max_new_tokens, len3,
|
||||||
|
dtype=torch.float32)
|
||||||
|
past_key_values_storage.append((k0, v0))
|
||||||
|
if not _enable_ipex:
|
||||||
|
past_key_values_storage[i][0][:, :, :len2, :] = past_key_values[i][0].to(
|
||||||
|
torch.float32)
|
||||||
|
past_key_values_storage[i][1][:, :, :len2, :] = past_key_values[i][1].to(
|
||||||
|
torch.float32)
|
||||||
|
else:
|
||||||
|
past_key_values_storage[i][0][:, :, :len2, :] = ipex_past_key_values[i][0].to(
|
||||||
|
torch.float32)
|
||||||
|
past_key_values_storage[i][1][:, :, :len2, :] = ipex_past_key_values[i][1].to(
|
||||||
|
torch.float32)
|
||||||
|
|
||||||
|
return past_key_values_storage
|
||||||
|
|
||||||
|
|
||||||
|
def _prepare_draft_past_key_values_cpu(self, past_key_values, past_key_values_storage):
|
||||||
|
tmp_past_key_values = []
|
||||||
|
for i in range(len(past_key_values)):
|
||||||
|
if self.config.model_type == "qwen":
|
||||||
|
len1 = past_key_values[0][0].size(1)
|
||||||
|
k0 = past_key_values_storage[i][0][:, :len1, :, :]
|
||||||
|
v0 = past_key_values_storage[i][1][:, :len1, :, :]
|
||||||
|
tmp_past_key_values.append((k0, v0))
|
||||||
|
elif self.config.model_type == "chatglm":
|
||||||
|
len0 = past_key_values[0][0].size(0)
|
||||||
|
k0 = past_key_values_storage[i][0][:len0, :, :, :]
|
||||||
|
v0 = past_key_values_storage[i][1][:len0, :, :, :]
|
||||||
|
tmp_past_key_values.append((k0, v0))
|
||||||
|
else:
|
||||||
|
len2 = past_key_values[0][0].size(2)
|
||||||
|
k0 = past_key_values_storage[i][0][:, :, :len2, :]
|
||||||
|
v0 = past_key_values_storage[i][1][:, :, :len2, :]
|
||||||
|
tmp_past_key_values.append((k0, v0))
|
||||||
|
return tmp_past_key_values
|
||||||
|
|
||||||
|
|
||||||
|
def _update_past_key_values_storage_cpu(self, past_key_values, past_key_values_storage,
|
||||||
|
original_draft_past_key_values, _enable_ipex=False):
|
||||||
|
for i in range(len(past_key_values)):
|
||||||
|
if not _enable_ipex:
|
||||||
|
if self.config.model_type == "qwen":
|
||||||
|
size = original_draft_past_key_values[i][0].size(1)
|
||||||
|
size1 = past_key_values[i][0].size(1)
|
||||||
|
past_key_values_storage[i][0][:, size:size1, :, :] = \
|
||||||
|
past_key_values[i][0][:, size:size1, :, :].to(torch.float32)
|
||||||
|
past_key_values_storage[i][1][:, size:size1, :, :] = \
|
||||||
|
past_key_values[i][1][:, size:size1, :, :].to(torch.float32)
|
||||||
|
elif self.config.model_type == "chatglm":
|
||||||
|
size = original_draft_past_key_values[i][0].size(0)
|
||||||
|
size1 = past_key_values[i][0].size(0)
|
||||||
|
past_key_values_storage[i][0][size:size1, :, :, :] = \
|
||||||
|
past_key_values[i][0][size:size1, :, :, :].to(torch.float32)
|
||||||
|
past_key_values_storage[i][1][size:size1, :, :, :] = \
|
||||||
|
past_key_values[i][1][size:size1, :, :, :].to(torch.float32)
|
||||||
|
else:
|
||||||
|
size = original_draft_past_key_values[i][0].size(2)
|
||||||
|
size1 = past_key_values[i][0].size(2)
|
||||||
|
past_key_values_storage[i][0][:, :, size:size1, :] = \
|
||||||
|
past_key_values[i][0][:, :, size:size1, :].to(torch.float32)
|
||||||
|
past_key_values_storage[i][1][:, :, size:size1, :] = \
|
||||||
|
past_key_values[i][1][:, :, size:size1, :].to(torch.float32)
|
||||||
|
else:
|
||||||
|
size = original_draft_past_key_values[i][0].size(2)
|
||||||
|
size1 = past_key_values[i][0].size(1)
|
||||||
|
delta_past_key = \
|
||||||
|
past_key_values[i][1][size:size1, :, :, :].permute(1, 2, 0, 3)
|
||||||
|
delta_past_value = \
|
||||||
|
past_key_values[i][2][size:size1, :, :, :].permute(1, 2, 0, 3)
|
||||||
|
|
||||||
|
past_key_values_storage[i][0][:, :, size:size1, :] = \
|
||||||
|
delta_past_key.to(torch.float32)
|
||||||
|
past_key_values_storage[i][1][:, :, size:size1, :] = \
|
||||||
|
delta_past_value.to(torch.float32)
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def speculative_generate(self,
|
def speculative_generate(self,
|
||||||
inputs: Optional[torch.Tensor] = None,
|
inputs: Optional[torch.Tensor] = None,
|
||||||
|
|
@ -126,6 +254,7 @@ def speculative_generate(self,
|
||||||
auto_parameters=[1, 0.5, 0.9, 1e-2, 0.9],
|
auto_parameters=[1, 0.5, 0.9, 1e-2, 0.9],
|
||||||
hf_adjust=False,
|
hf_adjust=False,
|
||||||
generation_config: Optional[GenerationConfig] = None,
|
generation_config: Optional[GenerationConfig] = None,
|
||||||
|
attention_mask=None,
|
||||||
**sampling_kwargs):
|
**sampling_kwargs):
|
||||||
invalidInputError(draft_model is not None,
|
invalidInputError(draft_model is not None,
|
||||||
"Draft model should be provided.")
|
"Draft model should be provided.")
|
||||||
|
|
@ -225,7 +354,15 @@ def speculative_generate(self,
|
||||||
draft_generate_ids = torch.empty([input_ids.size(0), draft_gen_length],
|
draft_generate_ids = torch.empty([input_ids.size(0), draft_gen_length],
|
||||||
dtype=torch.long, device=self.device)
|
dtype=torch.long, device=self.device)
|
||||||
past_key_values = None
|
past_key_values = None
|
||||||
past_key_values1 = []
|
past_key_values_storage = []
|
||||||
|
|
||||||
|
_enable_ipex = os.getenv("BIGDL_OPT_IPEX")
|
||||||
|
_enable_ipex = (_enable_ipex is not None) and (_enable_ipex.lower() == "true")
|
||||||
|
if _enable_ipex:
|
||||||
|
if not ((self.config.model_type == 'baichuan' and self.config.hidden_size == 5120) or
|
||||||
|
('llama' in self.config.model_type)):
|
||||||
|
invalidInputError(False, "BigDL Speculative Decoding with IPEX BF16 only supports \
|
||||||
|
Llama and Baichuan2-13b models currently.")
|
||||||
|
|
||||||
tmp_matchness = 0
|
tmp_matchness = 0
|
||||||
e2e_tic = 0.0
|
e2e_tic = 0.0
|
||||||
|
|
@ -252,6 +389,7 @@ def speculative_generate(self,
|
||||||
tic = time.time()
|
tic = time.time()
|
||||||
output = self(input_ids=current_input_ids,
|
output = self(input_ids=current_input_ids,
|
||||||
past_key_values=past_key_values,
|
past_key_values=past_key_values,
|
||||||
|
attention_mask=attention_mask,
|
||||||
return_dict=True,
|
return_dict=True,
|
||||||
use_cache=True)
|
use_cache=True)
|
||||||
logits = output['logits']
|
logits = output['logits']
|
||||||
|
|
@ -273,68 +411,17 @@ def speculative_generate(self,
|
||||||
draft_current_input_ids = current_input_ids
|
draft_current_input_ids = current_input_ids
|
||||||
# Target model KV cache to draft model
|
# Target model KV cache to draft model
|
||||||
|
|
||||||
# init draft_self_past_key_values:past_key_values1 and assign initial fp32 value
|
|
||||||
if self.device.type == 'cpu' and step == 1:
|
|
||||||
for i in range(len(past_key_values)):
|
|
||||||
len0 = past_key_values[i][0].size(0)
|
|
||||||
len1 = past_key_values[i][0].size(1)
|
|
||||||
len2 = past_key_values[i][0].size(2)
|
|
||||||
len3 = past_key_values[i][0].size(3)
|
|
||||||
if self.config.model_type == "qwen":
|
|
||||||
k0 = torch.ones(len0, len2, len1 + max_new_tokens, len3,
|
|
||||||
dtype=torch.float32)
|
|
||||||
v0 = torch.ones(len0, len2, len1 + max_new_tokens, len3,
|
|
||||||
dtype=torch.float32)
|
|
||||||
k0 = k0.transpose(1, 2)
|
|
||||||
v0 = v0.transpose(1, 2)
|
|
||||||
past_key_values1.append((k0, v0))
|
|
||||||
past_key_values1[i][0][:, :len1, :, :] = past_key_values[i][0].to(
|
|
||||||
torch.float32)
|
|
||||||
past_key_values1[i][1][:, :len1, :, :] = past_key_values[i][1].to(
|
|
||||||
torch.float32)
|
|
||||||
elif self.config.model_type == "chatglm":
|
|
||||||
k0 = torch.ones(len1, len2, len0 + max_new_tokens, len3,
|
|
||||||
dtype=torch.float32)
|
|
||||||
v0 = torch.ones(len1, len2, len0 + max_new_tokens, len3,
|
|
||||||
dtype=torch.float32)
|
|
||||||
k0 = k0.permute(2, 0, 1, 3)
|
|
||||||
v0 = v0.permute(2, 0, 1, 3)
|
|
||||||
past_key_values1.append((k0, v0))
|
|
||||||
past_key_values1[i][0][:len0, :, :, :] = past_key_values[i][0].to(
|
|
||||||
torch.float32)
|
|
||||||
past_key_values1[i][1][:len0, :, :, :] = past_key_values[i][1].to(
|
|
||||||
torch.float32)
|
|
||||||
else:
|
|
||||||
k0 = torch.ones(len0, len1, len2 + max_new_tokens, len3,
|
|
||||||
dtype=torch.float32)
|
|
||||||
v0 = torch.ones(len0, len1, len2 + max_new_tokens, len3,
|
|
||||||
dtype=torch.float32)
|
|
||||||
past_key_values1.append((k0, v0))
|
|
||||||
past_key_values1[i][0][:, :, :len2, :] = past_key_values[i][0].to(
|
|
||||||
torch.float32)
|
|
||||||
past_key_values1[i][1][:, :, :len2, :] = past_key_values[i][1].to(
|
|
||||||
torch.float32)
|
|
||||||
|
|
||||||
# each iter cut off cur_len kv_cache from past_key_values1
|
|
||||||
if self.device.type == 'cpu':
|
if self.device.type == 'cpu':
|
||||||
tmp_past_key_values = []
|
# init past_key_values_storage and assign initial fp32 value
|
||||||
for i in range(len(past_key_values)):
|
if step == 1:
|
||||||
if self.config.model_type == "qwen":
|
past_key_values_storage = \
|
||||||
len1 = past_key_values[0][0].size(1)
|
_prepare_past_key_values_storage_cpu(self, past_key_values,
|
||||||
k0 = past_key_values1[i][0][:, :len1, :, :]
|
max_new_tokens, _enable_ipex)
|
||||||
v0 = past_key_values1[i][1][:, :len1, :, :]
|
# each iter cut off cur_len kv_cache from past_key_values1
|
||||||
tmp_past_key_values.append((k0, v0))
|
draft_past_key_values = \
|
||||||
elif self.config.model_type == "chatglm":
|
_prepare_draft_past_key_values_cpu(self, past_key_values,
|
||||||
len0 = past_key_values[0][0].size(0)
|
past_key_values_storage)
|
||||||
k0 = past_key_values1[i][0][:len0, :, :, :]
|
original_draft_past_key_values = draft_past_key_values
|
||||||
v0 = past_key_values1[i][1][:len0, :, :, :]
|
|
||||||
tmp_past_key_values.append((k0, v0))
|
|
||||||
else:
|
|
||||||
len2 = past_key_values[0][0].size(2)
|
|
||||||
k0 = past_key_values1[i][0][:, :, :len2, :]
|
|
||||||
v0 = past_key_values1[i][1][:, :, :len2, :]
|
|
||||||
tmp_past_key_values.append((k0, v0))
|
|
||||||
draft_past_key_values = tmp_past_key_values
|
|
||||||
else:
|
else:
|
||||||
draft_past_key_values = past_key_values
|
draft_past_key_values = past_key_values
|
||||||
draft_generate_ids[:, 0] = current_input_ids
|
draft_generate_ids[:, 0] = current_input_ids
|
||||||
|
|
@ -342,17 +429,25 @@ def speculative_generate(self,
|
||||||
# Draft model auto-regressively generate k tokens
|
# Draft model auto-regressively generate k tokens
|
||||||
# Early stop when prob less then th_stop_draft
|
# Early stop when prob less then th_stop_draft
|
||||||
for step_draft in range(max_step_draft):
|
for step_draft in range(max_step_draft):
|
||||||
|
if attention_mask is None:
|
||||||
|
draft_attention_mask = None
|
||||||
|
else:
|
||||||
|
appended_len = step_draft + step
|
||||||
|
ones_to_append = torch.ones(attention_mask.size(0), appended_len)
|
||||||
|
draft_attention_mask = torch.cat((attention_mask, ones_to_append), dim=1)
|
||||||
if self.config.model_type == "chatglm":
|
if self.config.model_type == "chatglm":
|
||||||
past_key_value_len = past_key_values[0][0].shape[0]
|
past_key_value_len = past_key_values[0][0].shape[0]
|
||||||
position_ids = torch.Tensor([[past_key_value_len + step_draft]]).long()
|
position_ids = torch.Tensor([[past_key_value_len + step_draft]]).long()
|
||||||
draft_output = draft_model(input_ids=draft_current_input_ids,
|
draft_output = draft_model(input_ids=draft_current_input_ids,
|
||||||
past_key_values=draft_past_key_values,
|
past_key_values=draft_past_key_values,
|
||||||
|
attention_mask=draft_attention_mask,
|
||||||
return_dict=True,
|
return_dict=True,
|
||||||
use_cache=True,
|
use_cache=True,
|
||||||
position_ids=position_ids)
|
position_ids=position_ids)
|
||||||
else:
|
else:
|
||||||
draft_output = draft_model(input_ids=draft_current_input_ids,
|
draft_output = draft_model(input_ids=draft_current_input_ids,
|
||||||
past_key_values=draft_past_key_values,
|
past_key_values=draft_past_key_values,
|
||||||
|
attention_mask=draft_attention_mask,
|
||||||
return_dict=True,
|
return_dict=True,
|
||||||
use_cache=True)
|
use_cache=True)
|
||||||
temp_input_ids = torch.cat((input_ids, generate_ids[:, :step],
|
temp_input_ids = torch.cat((input_ids, generate_ids[:, :step],
|
||||||
|
|
@ -388,6 +483,31 @@ def speculative_generate(self,
|
||||||
# input.size is k + 1, 1 previous token + k drafts
|
# input.size is k + 1, 1 previous token + k drafts
|
||||||
# verified output.size is k + 1, k token + 1 final
|
# verified output.size is k + 1, k token + 1 final
|
||||||
# Final token is always accepted
|
# Final token is always accepted
|
||||||
|
if attention_mask is None:
|
||||||
|
cur_attention_mask = None
|
||||||
|
else:
|
||||||
|
appended_len = drafted_input_ids.size(1) + step - 1
|
||||||
|
ones_to_append = torch.ones(attention_mask.size(0), appended_len)
|
||||||
|
cur_attention_mask = torch.cat((attention_mask, ones_to_append), dim=1)
|
||||||
|
if _enable_ipex and hasattr(self, "trace_graph"):
|
||||||
|
if self.config.model_type == "baichuan":
|
||||||
|
output = self.trace_graph(input_ids=drafted_input_ids,
|
||||||
|
attention_mask=cur_attention_mask,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
)
|
||||||
|
elif "llama" in self.config.model_type:
|
||||||
|
past_key_value_len = past_key_values[0][0].shape[2]
|
||||||
|
position_ids = torch.arange(drafted_input_ids.shape[1], dtype=torch.long,
|
||||||
|
device=drafted_input_ids.device).unsqueeze(0)
|
||||||
|
position_ids = position_ids.repeat(1, 1) + past_key_value_len
|
||||||
|
output = self.trace_graph(input_ids=drafted_input_ids,
|
||||||
|
attention_mask=cur_attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
)
|
||||||
|
logits = output[0]
|
||||||
|
past_key_values = output[1]
|
||||||
|
else:
|
||||||
if self.config.model_type == "chatglm":
|
if self.config.model_type == "chatglm":
|
||||||
past_key_value_len = past_key_values[0][0].shape[0]
|
past_key_value_len = past_key_values[0][0].shape[0]
|
||||||
position_ids = torch.arange(drafted_input_ids.shape[1], dtype=torch.long,
|
position_ids = torch.arange(drafted_input_ids.shape[1], dtype=torch.long,
|
||||||
|
|
@ -395,20 +515,24 @@ def speculative_generate(self,
|
||||||
position_ids = position_ids.unsqueeze(0).repeat(1, 1) + past_key_value_len
|
position_ids = position_ids.unsqueeze(0).repeat(1, 1) + past_key_value_len
|
||||||
output = self(input_ids=drafted_input_ids,
|
output = self(input_ids=drafted_input_ids,
|
||||||
past_key_values=past_key_values,
|
past_key_values=past_key_values,
|
||||||
|
attention_mask=cur_attention_mask,
|
||||||
return_dict=True,
|
return_dict=True,
|
||||||
use_cache=True,
|
use_cache=True,
|
||||||
position_ids=position_ids)
|
position_ids=position_ids)
|
||||||
else:
|
else:
|
||||||
output = self(input_ids=drafted_input_ids,
|
output = self(input_ids=drafted_input_ids,
|
||||||
past_key_values=past_key_values,
|
past_key_values=past_key_values,
|
||||||
|
attention_mask=cur_attention_mask,
|
||||||
return_dict=True,
|
return_dict=True,
|
||||||
use_cache=True)
|
use_cache=True)
|
||||||
|
if isinstance(output, dict):
|
||||||
logits = output['logits']
|
logits = output['logits']
|
||||||
|
past_key_values = output['past_key_values']
|
||||||
temp_input_ids = torch.cat((input_ids, generate_ids[:, :step],
|
temp_input_ids = torch.cat((input_ids, generate_ids[:, :step],
|
||||||
draft_generate_ids[:, 1:step_draft + 2]), dim=-1)
|
draft_generate_ids[:, 1:step_draft + 2]), dim=-1)
|
||||||
for i in range(logits.size(1)):
|
for i in range(logits.size(1)):
|
||||||
logits[:, i, :] = logits_processor(temp_input_ids[:, :input_ids.size(1)+step+i],
|
logits[:, i, :] = logits_processor(temp_input_ids[:, :input_ids.size(1)+step+i],
|
||||||
output['logits'][:, i, :])
|
logits[:, i, :])
|
||||||
output_ids = sample(logits, do_sample=generation_config.do_sample,
|
output_ids = sample(logits, do_sample=generation_config.do_sample,
|
||||||
top_k=generation_config.top_k, top_p=generation_config.top_p,
|
top_k=generation_config.top_k, top_p=generation_config.top_p,
|
||||||
temperature=generation_config.temperature)
|
temperature=generation_config.temperature)
|
||||||
|
|
@ -418,7 +542,6 @@ def speculative_generate(self,
|
||||||
self.verify_time.append(toc - tic)
|
self.verify_time.append(toc - tic)
|
||||||
self.generate_time.append(self.draft_time[-1] + self.verify_time[-1])
|
self.generate_time.append(self.draft_time[-1] + self.verify_time[-1])
|
||||||
|
|
||||||
past_key_values = output['past_key_values']
|
|
||||||
# Compare drafts with target verified outputs
|
# Compare drafts with target verified outputs
|
||||||
# Drafts start from [1, k]
|
# Drafts start from [1, k]
|
||||||
# Verified output start from [0, k - 1]
|
# Verified output start from [0, k - 1]
|
||||||
|
|
@ -432,6 +555,15 @@ def speculative_generate(self,
|
||||||
if max_of_max_matched != max_matched:
|
if max_of_max_matched != max_matched:
|
||||||
output_ids = output_ids[:, :max_matched]
|
output_ids = output_ids[:, :max_matched]
|
||||||
# For Qwen
|
# For Qwen
|
||||||
|
if _enable_ipex:
|
||||||
|
cur_len = past_key_values[0][0].size(1)
|
||||||
|
delta = max_of_max_matched - max_matched
|
||||||
|
tmp = torch.empty(1, (cur_len - delta), (cur_len - delta), 1,
|
||||||
|
dtype=torch.long,
|
||||||
|
).contiguous()
|
||||||
|
past_key_values = [[tmp, key_cache, value_cache, beam_idx]
|
||||||
|
for _, key_cache, value_cache, beam_idx in past_key_values]
|
||||||
|
else:
|
||||||
if self.config.model_type == "qwen":
|
if self.config.model_type == "qwen":
|
||||||
past_key_values = [
|
past_key_values = [
|
||||||
(k[:, :-(max_of_max_matched - max_matched), :],
|
(k[:, :-(max_of_max_matched - max_matched), :],
|
||||||
|
|
@ -454,33 +586,15 @@ def speculative_generate(self,
|
||||||
else:
|
else:
|
||||||
past_key_values = [
|
past_key_values = [
|
||||||
(k[:, :, :-(max_of_max_matched - max_matched)],
|
(k[:, :, :-(max_of_max_matched - max_matched)],
|
||||||
v[:, :, :-(max_of_max_matched - max_matched)]) for k, v in past_key_values
|
v[:, :, :-(max_of_max_matched - max_matched)])
|
||||||
|
for k, v in past_key_values
|
||||||
]
|
]
|
||||||
|
|
||||||
# Each iter assign new_matched kv_cache to past_key_values1
|
# Each iter assign new_matched kv_cache to past_key_values1
|
||||||
if self.device.type == 'cpu':
|
if self.device.type == 'cpu':
|
||||||
for i in range(len(past_key_values)):
|
_update_past_key_values_storage_cpu(self, past_key_values, past_key_values_storage,
|
||||||
if self.config.model_type == "qwen":
|
original_draft_past_key_values,
|
||||||
size = tmp_past_key_values[i][0].size(1)
|
_enable_ipex)
|
||||||
size1 = past_key_values[i][0].size(1)
|
|
||||||
past_key_values1[i][0][:, size:size1, :, :] = \
|
|
||||||
past_key_values[i][0][:, size:size1, :, :].to(torch.float32)
|
|
||||||
past_key_values1[i][1][:, size:size1, :, :] = \
|
|
||||||
past_key_values[i][1][:, size:size1, :, :].to(torch.float32)
|
|
||||||
elif self.config.model_type == "chatglm":
|
|
||||||
size = tmp_past_key_values[i][0].size(0)
|
|
||||||
size1 = past_key_values[i][0].size(0)
|
|
||||||
past_key_values1[i][0][size:size1, :, :, :] = \
|
|
||||||
past_key_values[i][0][size:size1, :, :, :].to(torch.float32)
|
|
||||||
past_key_values1[i][1][size:size1, :, :, :] = \
|
|
||||||
past_key_values[i][1][size:size1, :, :, :].to(torch.float32)
|
|
||||||
else:
|
|
||||||
size = tmp_past_key_values[i][0].size(2)
|
|
||||||
size1 = past_key_values[i][0].size(2)
|
|
||||||
past_key_values1[i][0][:, :, size:size1, :] = \
|
|
||||||
past_key_values[i][0][:, :, size:size1, :].to(torch.float32)
|
|
||||||
past_key_values1[i][1][:, :, size:size1, :] = \
|
|
||||||
past_key_values[i][1][:, :, size:size1, :].to(torch.float32)
|
|
||||||
|
|
||||||
generate_ids[:, step:step+output_ids.size(1)] = output_ids
|
generate_ids[:, step:step+output_ids.size(1)] = output_ids
|
||||||
current_input_ids = output_ids[:, -1:]
|
current_input_ids = output_ids[:, -1:]
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue