[LLM] Use IPEX Optimization for Self Speculative Decoding (#9997)
Use IPEX Optimization for Self Speculative Decoding
This commit is contained in:
parent
ccf8f613fb
commit
f57d0fda8b
1 changed files with 239 additions and 125 deletions
|
|
@ -57,7 +57,8 @@ def generate(
|
|||
new_speculative_kwargs = {}
|
||||
for var in ['max_new_tokens', 'max_step_draft', 'th_stop_draft', 'do_sample',
|
||||
'top_k', 'top_p', 'temperature', 'hf_adjust',
|
||||
'auto_th_stop_draft', 'auto_parameters', 'repetition_penalty']:
|
||||
'auto_th_stop_draft', 'auto_parameters', 'repetition_penalty',
|
||||
'attention_mask']:
|
||||
value = kwargs.pop(var, None)
|
||||
if value is not None:
|
||||
new_speculative_kwargs[var] = value
|
||||
|
|
@ -115,6 +116,133 @@ def clear_benchmarks(self):
|
|||
self.n_matched = 0
|
||||
|
||||
|
||||
def _prepare_past_key_values_storage_cpu(self, past_key_values,
|
||||
max_new_tokens, _enable_ipex=False):
|
||||
past_key_values_storage = []
|
||||
if _enable_ipex:
|
||||
ipex_past_key_values = []
|
||||
cur_len = past_key_values[0][0].size(1)
|
||||
ipex_past_key_values = [
|
||||
[pkv[1].permute(1, 2, 0, 3)[:, :, :cur_len, :],
|
||||
pkv[2].permute(1, 2, 0, 3)[:, :, :cur_len, :]]
|
||||
for pkv in past_key_values
|
||||
]
|
||||
|
||||
for i in range(len(past_key_values)):
|
||||
if not _enable_ipex:
|
||||
len0 = past_key_values[i][0].size(0)
|
||||
len1 = past_key_values[i][0].size(1)
|
||||
len2 = past_key_values[i][0].size(2)
|
||||
len3 = past_key_values[i][0].size(3)
|
||||
else:
|
||||
len0 = past_key_values[i][1].size(1)
|
||||
len1 = past_key_values[i][1].size(2)
|
||||
len2 = past_key_values[i][0].size(2) # seq length
|
||||
len3 = past_key_values[i][1].size(3)
|
||||
if self.config.model_type == "qwen":
|
||||
k0 = torch.ones(len0, len2, len1 + max_new_tokens, len3,
|
||||
dtype=torch.float32)
|
||||
v0 = torch.ones(len0, len2, len1 + max_new_tokens, len3,
|
||||
dtype=torch.float32)
|
||||
k0 = k0.transpose(1, 2)
|
||||
v0 = v0.transpose(1, 2)
|
||||
past_key_values_storage.append((k0, v0))
|
||||
past_key_values_storage[i][0][:, :len1, :, :] = past_key_values[i][0].to(
|
||||
torch.float32)
|
||||
past_key_values_storage[i][1][:, :len1, :, :] = past_key_values[i][1].to(
|
||||
torch.float32)
|
||||
elif self.config.model_type == "chatglm":
|
||||
k0 = torch.ones(len1, len2, len0 + max_new_tokens, len3,
|
||||
dtype=torch.float32)
|
||||
v0 = torch.ones(len1, len2, len0 + max_new_tokens, len3,
|
||||
dtype=torch.float32)
|
||||
k0 = k0.permute(2, 0, 1, 3)
|
||||
v0 = v0.permute(2, 0, 1, 3)
|
||||
past_key_values_storage.append((k0, v0))
|
||||
past_key_values_storage[i][0][:len0, :, :, :] = past_key_values[i][0].to(
|
||||
torch.float32)
|
||||
past_key_values_storage[i][1][:len0, :, :, :] = past_key_values[i][1].to(
|
||||
torch.float32)
|
||||
else:
|
||||
k0 = torch.ones(len0, len1, len2 + max_new_tokens, len3,
|
||||
dtype=torch.float32)
|
||||
v0 = torch.ones(len0, len1, len2 + max_new_tokens, len3,
|
||||
dtype=torch.float32)
|
||||
past_key_values_storage.append((k0, v0))
|
||||
if not _enable_ipex:
|
||||
past_key_values_storage[i][0][:, :, :len2, :] = past_key_values[i][0].to(
|
||||
torch.float32)
|
||||
past_key_values_storage[i][1][:, :, :len2, :] = past_key_values[i][1].to(
|
||||
torch.float32)
|
||||
else:
|
||||
past_key_values_storage[i][0][:, :, :len2, :] = ipex_past_key_values[i][0].to(
|
||||
torch.float32)
|
||||
past_key_values_storage[i][1][:, :, :len2, :] = ipex_past_key_values[i][1].to(
|
||||
torch.float32)
|
||||
|
||||
return past_key_values_storage
|
||||
|
||||
|
||||
def _prepare_draft_past_key_values_cpu(self, past_key_values, past_key_values_storage):
|
||||
tmp_past_key_values = []
|
||||
for i in range(len(past_key_values)):
|
||||
if self.config.model_type == "qwen":
|
||||
len1 = past_key_values[0][0].size(1)
|
||||
k0 = past_key_values_storage[i][0][:, :len1, :, :]
|
||||
v0 = past_key_values_storage[i][1][:, :len1, :, :]
|
||||
tmp_past_key_values.append((k0, v0))
|
||||
elif self.config.model_type == "chatglm":
|
||||
len0 = past_key_values[0][0].size(0)
|
||||
k0 = past_key_values_storage[i][0][:len0, :, :, :]
|
||||
v0 = past_key_values_storage[i][1][:len0, :, :, :]
|
||||
tmp_past_key_values.append((k0, v0))
|
||||
else:
|
||||
len2 = past_key_values[0][0].size(2)
|
||||
k0 = past_key_values_storage[i][0][:, :, :len2, :]
|
||||
v0 = past_key_values_storage[i][1][:, :, :len2, :]
|
||||
tmp_past_key_values.append((k0, v0))
|
||||
return tmp_past_key_values
|
||||
|
||||
|
||||
def _update_past_key_values_storage_cpu(self, past_key_values, past_key_values_storage,
|
||||
original_draft_past_key_values, _enable_ipex=False):
|
||||
for i in range(len(past_key_values)):
|
||||
if not _enable_ipex:
|
||||
if self.config.model_type == "qwen":
|
||||
size = original_draft_past_key_values[i][0].size(1)
|
||||
size1 = past_key_values[i][0].size(1)
|
||||
past_key_values_storage[i][0][:, size:size1, :, :] = \
|
||||
past_key_values[i][0][:, size:size1, :, :].to(torch.float32)
|
||||
past_key_values_storage[i][1][:, size:size1, :, :] = \
|
||||
past_key_values[i][1][:, size:size1, :, :].to(torch.float32)
|
||||
elif self.config.model_type == "chatglm":
|
||||
size = original_draft_past_key_values[i][0].size(0)
|
||||
size1 = past_key_values[i][0].size(0)
|
||||
past_key_values_storage[i][0][size:size1, :, :, :] = \
|
||||
past_key_values[i][0][size:size1, :, :, :].to(torch.float32)
|
||||
past_key_values_storage[i][1][size:size1, :, :, :] = \
|
||||
past_key_values[i][1][size:size1, :, :, :].to(torch.float32)
|
||||
else:
|
||||
size = original_draft_past_key_values[i][0].size(2)
|
||||
size1 = past_key_values[i][0].size(2)
|
||||
past_key_values_storage[i][0][:, :, size:size1, :] = \
|
||||
past_key_values[i][0][:, :, size:size1, :].to(torch.float32)
|
||||
past_key_values_storage[i][1][:, :, size:size1, :] = \
|
||||
past_key_values[i][1][:, :, size:size1, :].to(torch.float32)
|
||||
else:
|
||||
size = original_draft_past_key_values[i][0].size(2)
|
||||
size1 = past_key_values[i][0].size(1)
|
||||
delta_past_key = \
|
||||
past_key_values[i][1][size:size1, :, :, :].permute(1, 2, 0, 3)
|
||||
delta_past_value = \
|
||||
past_key_values[i][2][size:size1, :, :, :].permute(1, 2, 0, 3)
|
||||
|
||||
past_key_values_storage[i][0][:, :, size:size1, :] = \
|
||||
delta_past_key.to(torch.float32)
|
||||
past_key_values_storage[i][1][:, :, size:size1, :] = \
|
||||
delta_past_value.to(torch.float32)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def speculative_generate(self,
|
||||
inputs: Optional[torch.Tensor] = None,
|
||||
|
|
@ -126,6 +254,7 @@ def speculative_generate(self,
|
|||
auto_parameters=[1, 0.5, 0.9, 1e-2, 0.9],
|
||||
hf_adjust=False,
|
||||
generation_config: Optional[GenerationConfig] = None,
|
||||
attention_mask=None,
|
||||
**sampling_kwargs):
|
||||
invalidInputError(draft_model is not None,
|
||||
"Draft model should be provided.")
|
||||
|
|
@ -225,7 +354,15 @@ def speculative_generate(self,
|
|||
draft_generate_ids = torch.empty([input_ids.size(0), draft_gen_length],
|
||||
dtype=torch.long, device=self.device)
|
||||
past_key_values = None
|
||||
past_key_values1 = []
|
||||
past_key_values_storage = []
|
||||
|
||||
_enable_ipex = os.getenv("BIGDL_OPT_IPEX")
|
||||
_enable_ipex = (_enable_ipex is not None) and (_enable_ipex.lower() == "true")
|
||||
if _enable_ipex:
|
||||
if not ((self.config.model_type == 'baichuan' and self.config.hidden_size == 5120) or
|
||||
('llama' in self.config.model_type)):
|
||||
invalidInputError(False, "BigDL Speculative Decoding with IPEX BF16 only supports \
|
||||
Llama and Baichuan2-13b models currently.")
|
||||
|
||||
tmp_matchness = 0
|
||||
e2e_tic = 0.0
|
||||
|
|
@ -252,6 +389,7 @@ def speculative_generate(self,
|
|||
tic = time.time()
|
||||
output = self(input_ids=current_input_ids,
|
||||
past_key_values=past_key_values,
|
||||
attention_mask=attention_mask,
|
||||
return_dict=True,
|
||||
use_cache=True)
|
||||
logits = output['logits']
|
||||
|
|
@ -273,68 +411,17 @@ def speculative_generate(self,
|
|||
draft_current_input_ids = current_input_ids
|
||||
# Target model KV cache to draft model
|
||||
|
||||
# init draft_self_past_key_values:past_key_values1 and assign initial fp32 value
|
||||
if self.device.type == 'cpu' and step == 1:
|
||||
for i in range(len(past_key_values)):
|
||||
len0 = past_key_values[i][0].size(0)
|
||||
len1 = past_key_values[i][0].size(1)
|
||||
len2 = past_key_values[i][0].size(2)
|
||||
len3 = past_key_values[i][0].size(3)
|
||||
if self.config.model_type == "qwen":
|
||||
k0 = torch.ones(len0, len2, len1 + max_new_tokens, len3,
|
||||
dtype=torch.float32)
|
||||
v0 = torch.ones(len0, len2, len1 + max_new_tokens, len3,
|
||||
dtype=torch.float32)
|
||||
k0 = k0.transpose(1, 2)
|
||||
v0 = v0.transpose(1, 2)
|
||||
past_key_values1.append((k0, v0))
|
||||
past_key_values1[i][0][:, :len1, :, :] = past_key_values[i][0].to(
|
||||
torch.float32)
|
||||
past_key_values1[i][1][:, :len1, :, :] = past_key_values[i][1].to(
|
||||
torch.float32)
|
||||
elif self.config.model_type == "chatglm":
|
||||
k0 = torch.ones(len1, len2, len0 + max_new_tokens, len3,
|
||||
dtype=torch.float32)
|
||||
v0 = torch.ones(len1, len2, len0 + max_new_tokens, len3,
|
||||
dtype=torch.float32)
|
||||
k0 = k0.permute(2, 0, 1, 3)
|
||||
v0 = v0.permute(2, 0, 1, 3)
|
||||
past_key_values1.append((k0, v0))
|
||||
past_key_values1[i][0][:len0, :, :, :] = past_key_values[i][0].to(
|
||||
torch.float32)
|
||||
past_key_values1[i][1][:len0, :, :, :] = past_key_values[i][1].to(
|
||||
torch.float32)
|
||||
else:
|
||||
k0 = torch.ones(len0, len1, len2 + max_new_tokens, len3,
|
||||
dtype=torch.float32)
|
||||
v0 = torch.ones(len0, len1, len2 + max_new_tokens, len3,
|
||||
dtype=torch.float32)
|
||||
past_key_values1.append((k0, v0))
|
||||
past_key_values1[i][0][:, :, :len2, :] = past_key_values[i][0].to(
|
||||
torch.float32)
|
||||
past_key_values1[i][1][:, :, :len2, :] = past_key_values[i][1].to(
|
||||
torch.float32)
|
||||
|
||||
# each iter cut off cur_len kv_cache from past_key_values1
|
||||
if self.device.type == 'cpu':
|
||||
tmp_past_key_values = []
|
||||
for i in range(len(past_key_values)):
|
||||
if self.config.model_type == "qwen":
|
||||
len1 = past_key_values[0][0].size(1)
|
||||
k0 = past_key_values1[i][0][:, :len1, :, :]
|
||||
v0 = past_key_values1[i][1][:, :len1, :, :]
|
||||
tmp_past_key_values.append((k0, v0))
|
||||
elif self.config.model_type == "chatglm":
|
||||
len0 = past_key_values[0][0].size(0)
|
||||
k0 = past_key_values1[i][0][:len0, :, :, :]
|
||||
v0 = past_key_values1[i][1][:len0, :, :, :]
|
||||
tmp_past_key_values.append((k0, v0))
|
||||
else:
|
||||
len2 = past_key_values[0][0].size(2)
|
||||
k0 = past_key_values1[i][0][:, :, :len2, :]
|
||||
v0 = past_key_values1[i][1][:, :, :len2, :]
|
||||
tmp_past_key_values.append((k0, v0))
|
||||
draft_past_key_values = tmp_past_key_values
|
||||
# init past_key_values_storage and assign initial fp32 value
|
||||
if step == 1:
|
||||
past_key_values_storage = \
|
||||
_prepare_past_key_values_storage_cpu(self, past_key_values,
|
||||
max_new_tokens, _enable_ipex)
|
||||
# each iter cut off cur_len kv_cache from past_key_values1
|
||||
draft_past_key_values = \
|
||||
_prepare_draft_past_key_values_cpu(self, past_key_values,
|
||||
past_key_values_storage)
|
||||
original_draft_past_key_values = draft_past_key_values
|
||||
else:
|
||||
draft_past_key_values = past_key_values
|
||||
draft_generate_ids[:, 0] = current_input_ids
|
||||
|
|
@ -342,17 +429,25 @@ def speculative_generate(self,
|
|||
# Draft model auto-regressively generate k tokens
|
||||
# Early stop when prob less then th_stop_draft
|
||||
for step_draft in range(max_step_draft):
|
||||
if attention_mask is None:
|
||||
draft_attention_mask = None
|
||||
else:
|
||||
appended_len = step_draft + step
|
||||
ones_to_append = torch.ones(attention_mask.size(0), appended_len)
|
||||
draft_attention_mask = torch.cat((attention_mask, ones_to_append), dim=1)
|
||||
if self.config.model_type == "chatglm":
|
||||
past_key_value_len = past_key_values[0][0].shape[0]
|
||||
position_ids = torch.Tensor([[past_key_value_len + step_draft]]).long()
|
||||
draft_output = draft_model(input_ids=draft_current_input_ids,
|
||||
past_key_values=draft_past_key_values,
|
||||
attention_mask=draft_attention_mask,
|
||||
return_dict=True,
|
||||
use_cache=True,
|
||||
position_ids=position_ids)
|
||||
else:
|
||||
draft_output = draft_model(input_ids=draft_current_input_ids,
|
||||
past_key_values=draft_past_key_values,
|
||||
attention_mask=draft_attention_mask,
|
||||
return_dict=True,
|
||||
use_cache=True)
|
||||
temp_input_ids = torch.cat((input_ids, generate_ids[:, :step],
|
||||
|
|
@ -388,6 +483,31 @@ def speculative_generate(self,
|
|||
# input.size is k + 1, 1 previous token + k drafts
|
||||
# verified output.size is k + 1, k token + 1 final
|
||||
# Final token is always accepted
|
||||
if attention_mask is None:
|
||||
cur_attention_mask = None
|
||||
else:
|
||||
appended_len = drafted_input_ids.size(1) + step - 1
|
||||
ones_to_append = torch.ones(attention_mask.size(0), appended_len)
|
||||
cur_attention_mask = torch.cat((attention_mask, ones_to_append), dim=1)
|
||||
if _enable_ipex and hasattr(self, "trace_graph"):
|
||||
if self.config.model_type == "baichuan":
|
||||
output = self.trace_graph(input_ids=drafted_input_ids,
|
||||
attention_mask=cur_attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
)
|
||||
elif "llama" in self.config.model_type:
|
||||
past_key_value_len = past_key_values[0][0].shape[2]
|
||||
position_ids = torch.arange(drafted_input_ids.shape[1], dtype=torch.long,
|
||||
device=drafted_input_ids.device).unsqueeze(0)
|
||||
position_ids = position_ids.repeat(1, 1) + past_key_value_len
|
||||
output = self.trace_graph(input_ids=drafted_input_ids,
|
||||
attention_mask=cur_attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
)
|
||||
logits = output[0]
|
||||
past_key_values = output[1]
|
||||
else:
|
||||
if self.config.model_type == "chatglm":
|
||||
past_key_value_len = past_key_values[0][0].shape[0]
|
||||
position_ids = torch.arange(drafted_input_ids.shape[1], dtype=torch.long,
|
||||
|
|
@ -395,20 +515,24 @@ def speculative_generate(self,
|
|||
position_ids = position_ids.unsqueeze(0).repeat(1, 1) + past_key_value_len
|
||||
output = self(input_ids=drafted_input_ids,
|
||||
past_key_values=past_key_values,
|
||||
attention_mask=cur_attention_mask,
|
||||
return_dict=True,
|
||||
use_cache=True,
|
||||
position_ids=position_ids)
|
||||
else:
|
||||
output = self(input_ids=drafted_input_ids,
|
||||
past_key_values=past_key_values,
|
||||
attention_mask=cur_attention_mask,
|
||||
return_dict=True,
|
||||
use_cache=True)
|
||||
if isinstance(output, dict):
|
||||
logits = output['logits']
|
||||
past_key_values = output['past_key_values']
|
||||
temp_input_ids = torch.cat((input_ids, generate_ids[:, :step],
|
||||
draft_generate_ids[:, 1:step_draft + 2]), dim=-1)
|
||||
for i in range(logits.size(1)):
|
||||
logits[:, i, :] = logits_processor(temp_input_ids[:, :input_ids.size(1)+step+i],
|
||||
output['logits'][:, i, :])
|
||||
logits[:, i, :])
|
||||
output_ids = sample(logits, do_sample=generation_config.do_sample,
|
||||
top_k=generation_config.top_k, top_p=generation_config.top_p,
|
||||
temperature=generation_config.temperature)
|
||||
|
|
@ -418,7 +542,6 @@ def speculative_generate(self,
|
|||
self.verify_time.append(toc - tic)
|
||||
self.generate_time.append(self.draft_time[-1] + self.verify_time[-1])
|
||||
|
||||
past_key_values = output['past_key_values']
|
||||
# Compare drafts with target verified outputs
|
||||
# Drafts start from [1, k]
|
||||
# Verified output start from [0, k - 1]
|
||||
|
|
@ -432,6 +555,15 @@ def speculative_generate(self,
|
|||
if max_of_max_matched != max_matched:
|
||||
output_ids = output_ids[:, :max_matched]
|
||||
# For Qwen
|
||||
if _enable_ipex:
|
||||
cur_len = past_key_values[0][0].size(1)
|
||||
delta = max_of_max_matched - max_matched
|
||||
tmp = torch.empty(1, (cur_len - delta), (cur_len - delta), 1,
|
||||
dtype=torch.long,
|
||||
).contiguous()
|
||||
past_key_values = [[tmp, key_cache, value_cache, beam_idx]
|
||||
for _, key_cache, value_cache, beam_idx in past_key_values]
|
||||
else:
|
||||
if self.config.model_type == "qwen":
|
||||
past_key_values = [
|
||||
(k[:, :-(max_of_max_matched - max_matched), :],
|
||||
|
|
@ -454,33 +586,15 @@ def speculative_generate(self,
|
|||
else:
|
||||
past_key_values = [
|
||||
(k[:, :, :-(max_of_max_matched - max_matched)],
|
||||
v[:, :, :-(max_of_max_matched - max_matched)]) for k, v in past_key_values
|
||||
v[:, :, :-(max_of_max_matched - max_matched)])
|
||||
for k, v in past_key_values
|
||||
]
|
||||
|
||||
# Each iter assign new_matched kv_cache to past_key_values1
|
||||
if self.device.type == 'cpu':
|
||||
for i in range(len(past_key_values)):
|
||||
if self.config.model_type == "qwen":
|
||||
size = tmp_past_key_values[i][0].size(1)
|
||||
size1 = past_key_values[i][0].size(1)
|
||||
past_key_values1[i][0][:, size:size1, :, :] = \
|
||||
past_key_values[i][0][:, size:size1, :, :].to(torch.float32)
|
||||
past_key_values1[i][1][:, size:size1, :, :] = \
|
||||
past_key_values[i][1][:, size:size1, :, :].to(torch.float32)
|
||||
elif self.config.model_type == "chatglm":
|
||||
size = tmp_past_key_values[i][0].size(0)
|
||||
size1 = past_key_values[i][0].size(0)
|
||||
past_key_values1[i][0][size:size1, :, :, :] = \
|
||||
past_key_values[i][0][size:size1, :, :, :].to(torch.float32)
|
||||
past_key_values1[i][1][size:size1, :, :, :] = \
|
||||
past_key_values[i][1][size:size1, :, :, :].to(torch.float32)
|
||||
else:
|
||||
size = tmp_past_key_values[i][0].size(2)
|
||||
size1 = past_key_values[i][0].size(2)
|
||||
past_key_values1[i][0][:, :, size:size1, :] = \
|
||||
past_key_values[i][0][:, :, size:size1, :].to(torch.float32)
|
||||
past_key_values1[i][1][:, :, size:size1, :] = \
|
||||
past_key_values[i][1][:, :, size:size1, :].to(torch.float32)
|
||||
_update_past_key_values_storage_cpu(self, past_key_values, past_key_values_storage,
|
||||
original_draft_past_key_values,
|
||||
_enable_ipex)
|
||||
|
||||
generate_ids[:, step:step+output_ids.size(1)] = output_ids
|
||||
current_input_ids = output_ids[:, -1:]
|
||||
|
|
|
|||
Loading…
Reference in a new issue