[LLM] Use IPEX Optimization for Self Speculative Decoding (#9997)

Use IPEX Optimization for Self Speculative Decoding
This commit is contained in:
Xiangyu Tian 2024-01-30 09:11:06 +08:00 committed by GitHub
parent ccf8f613fb
commit f57d0fda8b

View file

@ -57,7 +57,8 @@ def generate(
new_speculative_kwargs = {} new_speculative_kwargs = {}
for var in ['max_new_tokens', 'max_step_draft', 'th_stop_draft', 'do_sample', for var in ['max_new_tokens', 'max_step_draft', 'th_stop_draft', 'do_sample',
'top_k', 'top_p', 'temperature', 'hf_adjust', 'top_k', 'top_p', 'temperature', 'hf_adjust',
'auto_th_stop_draft', 'auto_parameters', 'repetition_penalty']: 'auto_th_stop_draft', 'auto_parameters', 'repetition_penalty',
'attention_mask']:
value = kwargs.pop(var, None) value = kwargs.pop(var, None)
if value is not None: if value is not None:
new_speculative_kwargs[var] = value new_speculative_kwargs[var] = value
@ -115,6 +116,133 @@ def clear_benchmarks(self):
self.n_matched = 0 self.n_matched = 0
def _prepare_past_key_values_storage_cpu(self, past_key_values,
max_new_tokens, _enable_ipex=False):
past_key_values_storage = []
if _enable_ipex:
ipex_past_key_values = []
cur_len = past_key_values[0][0].size(1)
ipex_past_key_values = [
[pkv[1].permute(1, 2, 0, 3)[:, :, :cur_len, :],
pkv[2].permute(1, 2, 0, 3)[:, :, :cur_len, :]]
for pkv in past_key_values
]
for i in range(len(past_key_values)):
if not _enable_ipex:
len0 = past_key_values[i][0].size(0)
len1 = past_key_values[i][0].size(1)
len2 = past_key_values[i][0].size(2)
len3 = past_key_values[i][0].size(3)
else:
len0 = past_key_values[i][1].size(1)
len1 = past_key_values[i][1].size(2)
len2 = past_key_values[i][0].size(2) # seq length
len3 = past_key_values[i][1].size(3)
if self.config.model_type == "qwen":
k0 = torch.ones(len0, len2, len1 + max_new_tokens, len3,
dtype=torch.float32)
v0 = torch.ones(len0, len2, len1 + max_new_tokens, len3,
dtype=torch.float32)
k0 = k0.transpose(1, 2)
v0 = v0.transpose(1, 2)
past_key_values_storage.append((k0, v0))
past_key_values_storage[i][0][:, :len1, :, :] = past_key_values[i][0].to(
torch.float32)
past_key_values_storage[i][1][:, :len1, :, :] = past_key_values[i][1].to(
torch.float32)
elif self.config.model_type == "chatglm":
k0 = torch.ones(len1, len2, len0 + max_new_tokens, len3,
dtype=torch.float32)
v0 = torch.ones(len1, len2, len0 + max_new_tokens, len3,
dtype=torch.float32)
k0 = k0.permute(2, 0, 1, 3)
v0 = v0.permute(2, 0, 1, 3)
past_key_values_storage.append((k0, v0))
past_key_values_storage[i][0][:len0, :, :, :] = past_key_values[i][0].to(
torch.float32)
past_key_values_storage[i][1][:len0, :, :, :] = past_key_values[i][1].to(
torch.float32)
else:
k0 = torch.ones(len0, len1, len2 + max_new_tokens, len3,
dtype=torch.float32)
v0 = torch.ones(len0, len1, len2 + max_new_tokens, len3,
dtype=torch.float32)
past_key_values_storage.append((k0, v0))
if not _enable_ipex:
past_key_values_storage[i][0][:, :, :len2, :] = past_key_values[i][0].to(
torch.float32)
past_key_values_storage[i][1][:, :, :len2, :] = past_key_values[i][1].to(
torch.float32)
else:
past_key_values_storage[i][0][:, :, :len2, :] = ipex_past_key_values[i][0].to(
torch.float32)
past_key_values_storage[i][1][:, :, :len2, :] = ipex_past_key_values[i][1].to(
torch.float32)
return past_key_values_storage
def _prepare_draft_past_key_values_cpu(self, past_key_values, past_key_values_storage):
tmp_past_key_values = []
for i in range(len(past_key_values)):
if self.config.model_type == "qwen":
len1 = past_key_values[0][0].size(1)
k0 = past_key_values_storage[i][0][:, :len1, :, :]
v0 = past_key_values_storage[i][1][:, :len1, :, :]
tmp_past_key_values.append((k0, v0))
elif self.config.model_type == "chatglm":
len0 = past_key_values[0][0].size(0)
k0 = past_key_values_storage[i][0][:len0, :, :, :]
v0 = past_key_values_storage[i][1][:len0, :, :, :]
tmp_past_key_values.append((k0, v0))
else:
len2 = past_key_values[0][0].size(2)
k0 = past_key_values_storage[i][0][:, :, :len2, :]
v0 = past_key_values_storage[i][1][:, :, :len2, :]
tmp_past_key_values.append((k0, v0))
return tmp_past_key_values
def _update_past_key_values_storage_cpu(self, past_key_values, past_key_values_storage,
original_draft_past_key_values, _enable_ipex=False):
for i in range(len(past_key_values)):
if not _enable_ipex:
if self.config.model_type == "qwen":
size = original_draft_past_key_values[i][0].size(1)
size1 = past_key_values[i][0].size(1)
past_key_values_storage[i][0][:, size:size1, :, :] = \
past_key_values[i][0][:, size:size1, :, :].to(torch.float32)
past_key_values_storage[i][1][:, size:size1, :, :] = \
past_key_values[i][1][:, size:size1, :, :].to(torch.float32)
elif self.config.model_type == "chatglm":
size = original_draft_past_key_values[i][0].size(0)
size1 = past_key_values[i][0].size(0)
past_key_values_storage[i][0][size:size1, :, :, :] = \
past_key_values[i][0][size:size1, :, :, :].to(torch.float32)
past_key_values_storage[i][1][size:size1, :, :, :] = \
past_key_values[i][1][size:size1, :, :, :].to(torch.float32)
else:
size = original_draft_past_key_values[i][0].size(2)
size1 = past_key_values[i][0].size(2)
past_key_values_storage[i][0][:, :, size:size1, :] = \
past_key_values[i][0][:, :, size:size1, :].to(torch.float32)
past_key_values_storage[i][1][:, :, size:size1, :] = \
past_key_values[i][1][:, :, size:size1, :].to(torch.float32)
else:
size = original_draft_past_key_values[i][0].size(2)
size1 = past_key_values[i][0].size(1)
delta_past_key = \
past_key_values[i][1][size:size1, :, :, :].permute(1, 2, 0, 3)
delta_past_value = \
past_key_values[i][2][size:size1, :, :, :].permute(1, 2, 0, 3)
past_key_values_storage[i][0][:, :, size:size1, :] = \
delta_past_key.to(torch.float32)
past_key_values_storage[i][1][:, :, size:size1, :] = \
delta_past_value.to(torch.float32)
@torch.no_grad() @torch.no_grad()
def speculative_generate(self, def speculative_generate(self,
inputs: Optional[torch.Tensor] = None, inputs: Optional[torch.Tensor] = None,
@ -126,6 +254,7 @@ def speculative_generate(self,
auto_parameters=[1, 0.5, 0.9, 1e-2, 0.9], auto_parameters=[1, 0.5, 0.9, 1e-2, 0.9],
hf_adjust=False, hf_adjust=False,
generation_config: Optional[GenerationConfig] = None, generation_config: Optional[GenerationConfig] = None,
attention_mask=None,
**sampling_kwargs): **sampling_kwargs):
invalidInputError(draft_model is not None, invalidInputError(draft_model is not None,
"Draft model should be provided.") "Draft model should be provided.")
@ -225,7 +354,15 @@ def speculative_generate(self,
draft_generate_ids = torch.empty([input_ids.size(0), draft_gen_length], draft_generate_ids = torch.empty([input_ids.size(0), draft_gen_length],
dtype=torch.long, device=self.device) dtype=torch.long, device=self.device)
past_key_values = None past_key_values = None
past_key_values1 = [] past_key_values_storage = []
_enable_ipex = os.getenv("BIGDL_OPT_IPEX")
_enable_ipex = (_enable_ipex is not None) and (_enable_ipex.lower() == "true")
if _enable_ipex:
if not ((self.config.model_type == 'baichuan' and self.config.hidden_size == 5120) or
('llama' in self.config.model_type)):
invalidInputError(False, "BigDL Speculative Decoding with IPEX BF16 only supports \
Llama and Baichuan2-13b models currently.")
tmp_matchness = 0 tmp_matchness = 0
e2e_tic = 0.0 e2e_tic = 0.0
@ -252,6 +389,7 @@ def speculative_generate(self,
tic = time.time() tic = time.time()
output = self(input_ids=current_input_ids, output = self(input_ids=current_input_ids,
past_key_values=past_key_values, past_key_values=past_key_values,
attention_mask=attention_mask,
return_dict=True, return_dict=True,
use_cache=True) use_cache=True)
logits = output['logits'] logits = output['logits']
@ -273,68 +411,17 @@ def speculative_generate(self,
draft_current_input_ids = current_input_ids draft_current_input_ids = current_input_ids
# Target model KV cache to draft model # Target model KV cache to draft model
# init draft_self_past_key_values:past_key_values1 and assign initial fp32 value
if self.device.type == 'cpu' and step == 1:
for i in range(len(past_key_values)):
len0 = past_key_values[i][0].size(0)
len1 = past_key_values[i][0].size(1)
len2 = past_key_values[i][0].size(2)
len3 = past_key_values[i][0].size(3)
if self.config.model_type == "qwen":
k0 = torch.ones(len0, len2, len1 + max_new_tokens, len3,
dtype=torch.float32)
v0 = torch.ones(len0, len2, len1 + max_new_tokens, len3,
dtype=torch.float32)
k0 = k0.transpose(1, 2)
v0 = v0.transpose(1, 2)
past_key_values1.append((k0, v0))
past_key_values1[i][0][:, :len1, :, :] = past_key_values[i][0].to(
torch.float32)
past_key_values1[i][1][:, :len1, :, :] = past_key_values[i][1].to(
torch.float32)
elif self.config.model_type == "chatglm":
k0 = torch.ones(len1, len2, len0 + max_new_tokens, len3,
dtype=torch.float32)
v0 = torch.ones(len1, len2, len0 + max_new_tokens, len3,
dtype=torch.float32)
k0 = k0.permute(2, 0, 1, 3)
v0 = v0.permute(2, 0, 1, 3)
past_key_values1.append((k0, v0))
past_key_values1[i][0][:len0, :, :, :] = past_key_values[i][0].to(
torch.float32)
past_key_values1[i][1][:len0, :, :, :] = past_key_values[i][1].to(
torch.float32)
else:
k0 = torch.ones(len0, len1, len2 + max_new_tokens, len3,
dtype=torch.float32)
v0 = torch.ones(len0, len1, len2 + max_new_tokens, len3,
dtype=torch.float32)
past_key_values1.append((k0, v0))
past_key_values1[i][0][:, :, :len2, :] = past_key_values[i][0].to(
torch.float32)
past_key_values1[i][1][:, :, :len2, :] = past_key_values[i][1].to(
torch.float32)
# each iter cut off cur_len kv_cache from past_key_values1
if self.device.type == 'cpu': if self.device.type == 'cpu':
tmp_past_key_values = [] # init past_key_values_storage and assign initial fp32 value
for i in range(len(past_key_values)): if step == 1:
if self.config.model_type == "qwen": past_key_values_storage = \
len1 = past_key_values[0][0].size(1) _prepare_past_key_values_storage_cpu(self, past_key_values,
k0 = past_key_values1[i][0][:, :len1, :, :] max_new_tokens, _enable_ipex)
v0 = past_key_values1[i][1][:, :len1, :, :] # each iter cut off cur_len kv_cache from past_key_values1
tmp_past_key_values.append((k0, v0)) draft_past_key_values = \
elif self.config.model_type == "chatglm": _prepare_draft_past_key_values_cpu(self, past_key_values,
len0 = past_key_values[0][0].size(0) past_key_values_storage)
k0 = past_key_values1[i][0][:len0, :, :, :] original_draft_past_key_values = draft_past_key_values
v0 = past_key_values1[i][1][:len0, :, :, :]
tmp_past_key_values.append((k0, v0))
else:
len2 = past_key_values[0][0].size(2)
k0 = past_key_values1[i][0][:, :, :len2, :]
v0 = past_key_values1[i][1][:, :, :len2, :]
tmp_past_key_values.append((k0, v0))
draft_past_key_values = tmp_past_key_values
else: else:
draft_past_key_values = past_key_values draft_past_key_values = past_key_values
draft_generate_ids[:, 0] = current_input_ids draft_generate_ids[:, 0] = current_input_ids
@ -342,17 +429,25 @@ def speculative_generate(self,
# Draft model auto-regressively generate k tokens # Draft model auto-regressively generate k tokens
# Early stop when prob less then th_stop_draft # Early stop when prob less then th_stop_draft
for step_draft in range(max_step_draft): for step_draft in range(max_step_draft):
if attention_mask is None:
draft_attention_mask = None
else:
appended_len = step_draft + step
ones_to_append = torch.ones(attention_mask.size(0), appended_len)
draft_attention_mask = torch.cat((attention_mask, ones_to_append), dim=1)
if self.config.model_type == "chatglm": if self.config.model_type == "chatglm":
past_key_value_len = past_key_values[0][0].shape[0] past_key_value_len = past_key_values[0][0].shape[0]
position_ids = torch.Tensor([[past_key_value_len + step_draft]]).long() position_ids = torch.Tensor([[past_key_value_len + step_draft]]).long()
draft_output = draft_model(input_ids=draft_current_input_ids, draft_output = draft_model(input_ids=draft_current_input_ids,
past_key_values=draft_past_key_values, past_key_values=draft_past_key_values,
attention_mask=draft_attention_mask,
return_dict=True, return_dict=True,
use_cache=True, use_cache=True,
position_ids=position_ids) position_ids=position_ids)
else: else:
draft_output = draft_model(input_ids=draft_current_input_ids, draft_output = draft_model(input_ids=draft_current_input_ids,
past_key_values=draft_past_key_values, past_key_values=draft_past_key_values,
attention_mask=draft_attention_mask,
return_dict=True, return_dict=True,
use_cache=True) use_cache=True)
temp_input_ids = torch.cat((input_ids, generate_ids[:, :step], temp_input_ids = torch.cat((input_ids, generate_ids[:, :step],
@ -388,6 +483,31 @@ def speculative_generate(self,
# input.size is k + 1, 1 previous token + k drafts # input.size is k + 1, 1 previous token + k drafts
# verified output.size is k + 1, k token + 1 final # verified output.size is k + 1, k token + 1 final
# Final token is always accepted # Final token is always accepted
if attention_mask is None:
cur_attention_mask = None
else:
appended_len = drafted_input_ids.size(1) + step - 1
ones_to_append = torch.ones(attention_mask.size(0), appended_len)
cur_attention_mask = torch.cat((attention_mask, ones_to_append), dim=1)
if _enable_ipex and hasattr(self, "trace_graph"):
if self.config.model_type == "baichuan":
output = self.trace_graph(input_ids=drafted_input_ids,
attention_mask=cur_attention_mask,
past_key_values=past_key_values,
)
elif "llama" in self.config.model_type:
past_key_value_len = past_key_values[0][0].shape[2]
position_ids = torch.arange(drafted_input_ids.shape[1], dtype=torch.long,
device=drafted_input_ids.device).unsqueeze(0)
position_ids = position_ids.repeat(1, 1) + past_key_value_len
output = self.trace_graph(input_ids=drafted_input_ids,
attention_mask=cur_attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
)
logits = output[0]
past_key_values = output[1]
else:
if self.config.model_type == "chatglm": if self.config.model_type == "chatglm":
past_key_value_len = past_key_values[0][0].shape[0] past_key_value_len = past_key_values[0][0].shape[0]
position_ids = torch.arange(drafted_input_ids.shape[1], dtype=torch.long, position_ids = torch.arange(drafted_input_ids.shape[1], dtype=torch.long,
@ -395,20 +515,24 @@ def speculative_generate(self,
position_ids = position_ids.unsqueeze(0).repeat(1, 1) + past_key_value_len position_ids = position_ids.unsqueeze(0).repeat(1, 1) + past_key_value_len
output = self(input_ids=drafted_input_ids, output = self(input_ids=drafted_input_ids,
past_key_values=past_key_values, past_key_values=past_key_values,
attention_mask=cur_attention_mask,
return_dict=True, return_dict=True,
use_cache=True, use_cache=True,
position_ids=position_ids) position_ids=position_ids)
else: else:
output = self(input_ids=drafted_input_ids, output = self(input_ids=drafted_input_ids,
past_key_values=past_key_values, past_key_values=past_key_values,
attention_mask=cur_attention_mask,
return_dict=True, return_dict=True,
use_cache=True) use_cache=True)
if isinstance(output, dict):
logits = output['logits'] logits = output['logits']
past_key_values = output['past_key_values']
temp_input_ids = torch.cat((input_ids, generate_ids[:, :step], temp_input_ids = torch.cat((input_ids, generate_ids[:, :step],
draft_generate_ids[:, 1:step_draft + 2]), dim=-1) draft_generate_ids[:, 1:step_draft + 2]), dim=-1)
for i in range(logits.size(1)): for i in range(logits.size(1)):
logits[:, i, :] = logits_processor(temp_input_ids[:, :input_ids.size(1)+step+i], logits[:, i, :] = logits_processor(temp_input_ids[:, :input_ids.size(1)+step+i],
output['logits'][:, i, :]) logits[:, i, :])
output_ids = sample(logits, do_sample=generation_config.do_sample, output_ids = sample(logits, do_sample=generation_config.do_sample,
top_k=generation_config.top_k, top_p=generation_config.top_p, top_k=generation_config.top_k, top_p=generation_config.top_p,
temperature=generation_config.temperature) temperature=generation_config.temperature)
@ -418,7 +542,6 @@ def speculative_generate(self,
self.verify_time.append(toc - tic) self.verify_time.append(toc - tic)
self.generate_time.append(self.draft_time[-1] + self.verify_time[-1]) self.generate_time.append(self.draft_time[-1] + self.verify_time[-1])
past_key_values = output['past_key_values']
# Compare drafts with target verified outputs # Compare drafts with target verified outputs
# Drafts start from [1, k] # Drafts start from [1, k]
# Verified output start from [0, k - 1] # Verified output start from [0, k - 1]
@ -432,6 +555,15 @@ def speculative_generate(self,
if max_of_max_matched != max_matched: if max_of_max_matched != max_matched:
output_ids = output_ids[:, :max_matched] output_ids = output_ids[:, :max_matched]
# For Qwen # For Qwen
if _enable_ipex:
cur_len = past_key_values[0][0].size(1)
delta = max_of_max_matched - max_matched
tmp = torch.empty(1, (cur_len - delta), (cur_len - delta), 1,
dtype=torch.long,
).contiguous()
past_key_values = [[tmp, key_cache, value_cache, beam_idx]
for _, key_cache, value_cache, beam_idx in past_key_values]
else:
if self.config.model_type == "qwen": if self.config.model_type == "qwen":
past_key_values = [ past_key_values = [
(k[:, :-(max_of_max_matched - max_matched), :], (k[:, :-(max_of_max_matched - max_matched), :],
@ -454,33 +586,15 @@ def speculative_generate(self,
else: else:
past_key_values = [ past_key_values = [
(k[:, :, :-(max_of_max_matched - max_matched)], (k[:, :, :-(max_of_max_matched - max_matched)],
v[:, :, :-(max_of_max_matched - max_matched)]) for k, v in past_key_values v[:, :, :-(max_of_max_matched - max_matched)])
for k, v in past_key_values
] ]
# Each iter assign new_matched kv_cache to past_key_values1 # Each iter assign new_matched kv_cache to past_key_values1
if self.device.type == 'cpu': if self.device.type == 'cpu':
for i in range(len(past_key_values)): _update_past_key_values_storage_cpu(self, past_key_values, past_key_values_storage,
if self.config.model_type == "qwen": original_draft_past_key_values,
size = tmp_past_key_values[i][0].size(1) _enable_ipex)
size1 = past_key_values[i][0].size(1)
past_key_values1[i][0][:, size:size1, :, :] = \
past_key_values[i][0][:, size:size1, :, :].to(torch.float32)
past_key_values1[i][1][:, size:size1, :, :] = \
past_key_values[i][1][:, size:size1, :, :].to(torch.float32)
elif self.config.model_type == "chatglm":
size = tmp_past_key_values[i][0].size(0)
size1 = past_key_values[i][0].size(0)
past_key_values1[i][0][size:size1, :, :, :] = \
past_key_values[i][0][size:size1, :, :, :].to(torch.float32)
past_key_values1[i][1][size:size1, :, :, :] = \
past_key_values[i][1][size:size1, :, :, :].to(torch.float32)
else:
size = tmp_past_key_values[i][0].size(2)
size1 = past_key_values[i][0].size(2)
past_key_values1[i][0][:, :, size:size1, :] = \
past_key_values[i][0][:, :, size:size1, :].to(torch.float32)
past_key_values1[i][1][:, :, size:size1, :] = \
past_key_values[i][1][:, :, size:size1, :].to(torch.float32)
generate_ids[:, step:step+output_ids.size(1)] = output_ids generate_ids[:, step:step+output_ids.size(1)] = output_ids
current_input_ids = output_ids[:, -1:] current_input_ids = output_ids[:, -1:]