LLM: enable chatglm3-6b target_model ipex (#10085)
* init * always make casual_mask * not return last tensor * update * optimize_model = False * enable optimized=False * enable optimized_model=true * speed_up ipex target_model * remove if True * use group_size * update python style * update * update
This commit is contained in:
parent
177273c1a4
commit
f2417e083c
3 changed files with 204 additions and 71 deletions
|
|
@ -518,6 +518,16 @@ def ggml_convert_low_bit(model, qtype, optimize_model=True,
|
|||
f"format......")
|
||||
modules_to_not_convert = [] if modules_to_not_convert is None else modules_to_not_convert
|
||||
|
||||
# using ipex optimizer before changing to bigdl linear
|
||||
_enable_ipex = os.getenv("BIGDL_OPT_IPEX")
|
||||
_enable_ipex = (_enable_ipex is not None) and (_enable_ipex.lower() == "true")
|
||||
_enable_ipex = _enable_ipex and (qtype == ggml_tensor_qtype["bf16"])
|
||||
if (device == "cpu") and (qtype == ggml_tensor_qtype["bf16"]):
|
||||
logger.info(f"BIGDL_OPT_IPEX: {_enable_ipex}")
|
||||
if _enable_ipex:
|
||||
model = _optimize_ipex(model)
|
||||
return model
|
||||
|
||||
if optimize_model:
|
||||
model = _optimize_pre(model)
|
||||
|
||||
|
|
@ -543,14 +553,6 @@ def ggml_convert_low_bit(model, qtype, optimize_model=True,
|
|||
# Do nothing here for weights are empty.
|
||||
pass
|
||||
|
||||
_enable_ipex = os.getenv("BIGDL_OPT_IPEX")
|
||||
_enable_ipex = (_enable_ipex is not None) and (_enable_ipex.lower() == "true")
|
||||
_enable_ipex = _enable_ipex and (qtype == ggml_tensor_qtype["bf16"])
|
||||
if (device == "cpu") and (qtype == ggml_tensor_qtype["bf16"]):
|
||||
logger.info(f"BIGDL_OPT_IPEX: {_enable_ipex}")
|
||||
if _enable_ipex:
|
||||
model = _optimize_ipex(model)
|
||||
return model
|
||||
if optimize_model:
|
||||
model = _optimize_post(model, lightweight_bmm)
|
||||
return model
|
||||
|
|
@ -590,13 +592,17 @@ def _optimize_ipex(model):
|
|||
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
|
||||
from bigdl.llm.transformers.convert_ipex import (
|
||||
_ipex_optimize_attention, _ipex_optimize_decoder, _ipex_jit, _make_causal_mask,
|
||||
_ipex_optimize_rmsnorm, _llama_model_forward_4_35
|
||||
_ipex_optimize_rmsnorm, _llama_model_forward_4_35, convert_function, GLM_get_masks
|
||||
)
|
||||
|
||||
AttentionMaskConverter._make_causal_mask = _make_causal_mask
|
||||
convert_forward(model, transformers.models.llama.modeling_llama.LlamaModel,
|
||||
_llama_model_forward_4_35)
|
||||
model = model_convert_reference(model)
|
||||
if model.config.architectures is not None \
|
||||
and model.config.architectures[0] in ["ChatGLMModel", "ChatGLMForConditionalGeneration"]:
|
||||
convert_function(model.transformer, "get_masks", GLM_get_masks)
|
||||
model = ipex.optimize(model.eval(), dtype=torch.bfloat16, inplace=True).eval()
|
||||
_ipex_optimize_rmsnorm(model)
|
||||
_ipex_optimize_attention(model)
|
||||
_ipex_optimize_decoder(model)
|
||||
|
|
|
|||
|
|
@ -142,6 +142,8 @@ def _ipex_jit(model):
|
|||
sample_inputs = (
|
||||
get_dummy_input(model, return_dict=True)
|
||||
)
|
||||
if "return_last_logit" in sample_inputs:
|
||||
del sample_inputs["return_last_logit"]
|
||||
with torch.no_grad(), torch.cpu.amp.autocast(
|
||||
enabled=True
|
||||
):
|
||||
|
|
@ -159,6 +161,47 @@ def _ipex_jit(model):
|
|||
return model.eval()
|
||||
|
||||
|
||||
def convert_function(m, func_name, new_function):
|
||||
bound_method = new_function.__get__(m, m.__class__)
|
||||
setattr(m, func_name, bound_method)
|
||||
|
||||
|
||||
def GLM_get_masks(self, input_ids, past_key_values, padding_mask=None):
|
||||
batch_size, seq_length = input_ids.shape
|
||||
full_attention_mask = torch.ones(
|
||||
batch_size, seq_length, seq_length, device=input_ids.device
|
||||
)
|
||||
full_attention_mask.tril_()
|
||||
past_length = 0
|
||||
if past_key_values:
|
||||
if len(past_key_values[0]) != 4: # not discrete kv cache
|
||||
past_length = past_key_values[0][0].shape[0]
|
||||
else: # discrete kv cache
|
||||
past_length = past_key_values[0][0].shape[-2]
|
||||
|
||||
import os
|
||||
_enable_ipex = os.getenv("BIGDL_OPT_IPEX")
|
||||
_enable_ipex = (_enable_ipex is not None) and (_enable_ipex.lower() == "true")
|
||||
# always call for jit
|
||||
if past_length or _enable_ipex:
|
||||
full_attention_mask = torch.cat(
|
||||
(
|
||||
torch.ones(
|
||||
batch_size, seq_length, past_length, device=input_ids.device
|
||||
),
|
||||
full_attention_mask,
|
||||
),
|
||||
dim=-1,
|
||||
)
|
||||
if padding_mask is not None:
|
||||
full_attention_mask = full_attention_mask * padding_mask.unsqueeze(1)
|
||||
# if not past_length and padding_mask is not None:
|
||||
# full_attention_mask -= padding_mask.unsqueeze(-1) - 1
|
||||
full_attention_mask = (full_attention_mask < 0.5).bool()
|
||||
full_attention_mask.unsqueeze_(1)
|
||||
return full_attention_mask
|
||||
|
||||
|
||||
@staticmethod
|
||||
def _make_causal_mask(
|
||||
input_ids_shape: torch.Size,
|
||||
|
|
|
|||
|
|
@ -34,7 +34,7 @@ from bigdl.llm.utils.common import invalidInputError
|
|||
# patch GenerationMixin.generate
|
||||
from transformers import GenerationMixin
|
||||
original_generate = GenerationMixin.generate
|
||||
|
||||
query_group_size = 16
|
||||
logger = logging.getLogger("bigdl.llm.speculative")
|
||||
|
||||
|
||||
|
|
@ -131,62 +131,98 @@ def clear_benchmarks(self):
|
|||
def _prepare_past_key_values_storage_cpu(self, past_key_values,
|
||||
max_new_tokens, _enable_ipex=False):
|
||||
past_key_values_storage = []
|
||||
# init ipex_past_key_values
|
||||
if _enable_ipex:
|
||||
ipex_past_key_values = []
|
||||
cur_len = past_key_values[0][0].size(1)
|
||||
ipex_past_key_values = [
|
||||
[pkv[1].permute(1, 2, 0, 3)[:, :, :cur_len, :],
|
||||
pkv[2].permute(1, 2, 0, 3)[:, :, :cur_len, :]]
|
||||
for pkv in past_key_values
|
||||
]
|
||||
|
||||
for i in range(len(past_key_values)):
|
||||
if not _enable_ipex:
|
||||
len0 = past_key_values[i][0].size(0)
|
||||
len1 = past_key_values[i][0].size(1)
|
||||
len2 = past_key_values[i][0].size(2)
|
||||
len3 = past_key_values[i][0].size(3)
|
||||
if self.config.model_type == "chatglm":
|
||||
len0 = past_key_values[0][1].size(0) # seq max length
|
||||
len1 = past_key_values[0][1].size(1)
|
||||
len2 = past_key_values[0][1].size(2)
|
||||
len3 = past_key_values[0][1].size(3)
|
||||
for pkv in past_key_values:
|
||||
key = pkv[1]
|
||||
value = pkv[2]
|
||||
key = key.permute(1, 2, 0, 3).unsqueeze(-3)
|
||||
key = key.expand(-1, -1, query_group_size, -1, -1)
|
||||
key = key.contiguous().view(len1, len2 * query_group_size,
|
||||
len0, len3).permute(2, 0, 1, 3)
|
||||
value = value.permute(1, 2, 0, 3).unsqueeze(-3)
|
||||
value = value.expand(-1, -1, query_group_size, -1, -1)
|
||||
value = value.contiguous().view(len1, len2 * query_group_size,
|
||||
len0, len3).permute(2, 0, 1, 3)
|
||||
list = [key[:cur_len, :, :, :], value[:cur_len, :, :, :]]
|
||||
ipex_past_key_values.append(list)
|
||||
else:
|
||||
len0 = past_key_values[i][1].size(1)
|
||||
len1 = past_key_values[i][1].size(2)
|
||||
len2 = past_key_values[i][0].size(2) # seq length
|
||||
len3 = past_key_values[i][1].size(3)
|
||||
if self.config.model_type == "qwen":
|
||||
k0 = torch.ones(len0, len2, len1 + max_new_tokens, len3,
|
||||
dtype=torch.float32)
|
||||
v0 = torch.ones(len0, len2, len1 + max_new_tokens, len3,
|
||||
dtype=torch.float32)
|
||||
k0 = k0.transpose(1, 2)
|
||||
v0 = v0.transpose(1, 2)
|
||||
past_key_values_storage.append((k0, v0))
|
||||
past_key_values_storage[i][0][:, :len1, :, :] = past_key_values[i][0].to(
|
||||
torch.float32)
|
||||
past_key_values_storage[i][1][:, :len1, :, :] = past_key_values[i][1].to(
|
||||
torch.float32)
|
||||
elif self.config.model_type == "chatglm":
|
||||
k0 = torch.ones(len1, len2, len0 + max_new_tokens, len3,
|
||||
dtype=torch.float32)
|
||||
v0 = torch.ones(len1, len2, len0 + max_new_tokens, len3,
|
||||
dtype=torch.float32)
|
||||
k0 = k0.permute(2, 0, 1, 3)
|
||||
v0 = v0.permute(2, 0, 1, 3)
|
||||
past_key_values_storage.append((k0, v0))
|
||||
past_key_values_storage[i][0][:len0, :, :, :] = past_key_values[i][0].to(
|
||||
torch.float32)
|
||||
past_key_values_storage[i][1][:len0, :, :, :] = past_key_values[i][1].to(
|
||||
torch.float32)
|
||||
else:
|
||||
k0 = torch.ones(len0, len1, len2 + max_new_tokens, len3,
|
||||
dtype=torch.float32)
|
||||
v0 = torch.ones(len0, len1, len2 + max_new_tokens, len3,
|
||||
dtype=torch.float32)
|
||||
past_key_values_storage.append((k0, v0))
|
||||
if not _enable_ipex:
|
||||
ipex_past_key_values = [
|
||||
[pkv[1].permute(1, 2, 0, 3)[:, :, :cur_len, :],
|
||||
pkv[2].permute(1, 2, 0, 3)[:, :, :cur_len, :]]
|
||||
for pkv in past_key_values
|
||||
]
|
||||
if not _enable_ipex:
|
||||
len0 = past_key_values[0][0].size(0)
|
||||
len1 = past_key_values[0][0].size(1)
|
||||
len2 = past_key_values[0][0].size(2)
|
||||
len3 = past_key_values[0][0].size(3)
|
||||
for i in range(len(past_key_values)):
|
||||
if self.config.model_type == "qwen":
|
||||
k0 = torch.ones(len0, len2, len1 + max_new_tokens, len3,
|
||||
dtype=torch.float32)
|
||||
v0 = torch.ones(len0, len2, len1 + max_new_tokens, len3,
|
||||
dtype=torch.float32)
|
||||
k0 = k0.transpose(1, 2)
|
||||
v0 = v0.transpose(1, 2)
|
||||
past_key_values_storage.append((k0, v0))
|
||||
past_key_values_storage[i][0][:, :len1, :, :] = past_key_values[i][0].to(
|
||||
torch.float32)
|
||||
past_key_values_storage[i][1][:, :len1, :, :] = past_key_values[i][1].to(
|
||||
torch.float32)
|
||||
elif self.config.model_type == "chatglm":
|
||||
k0 = torch.ones(len1, len2, len0 + max_new_tokens, len3,
|
||||
dtype=torch.float32)
|
||||
v0 = torch.ones(len1, len2, len0 + max_new_tokens, len3,
|
||||
dtype=torch.float32)
|
||||
k0 = k0.permute(2, 0, 1, 3)
|
||||
v0 = v0.permute(2, 0, 1, 3)
|
||||
past_key_values_storage.append((k0, v0))
|
||||
past_key_values_storage[i][0][:len0, :, :, :] = past_key_values[i][0].to(
|
||||
torch.float32)
|
||||
past_key_values_storage[i][1][:len0, :, :, :] = past_key_values[i][1].to(
|
||||
torch.float32)
|
||||
else:
|
||||
k0 = torch.ones(len0, len1, len2 + max_new_tokens, len3,
|
||||
dtype=torch.float32)
|
||||
v0 = torch.ones(len0, len1, len2 + max_new_tokens, len3,
|
||||
dtype=torch.float32)
|
||||
past_key_values_storage.append((k0, v0))
|
||||
past_key_values_storage[i][0][:, :, :len2, :] = past_key_values[i][0].to(
|
||||
torch.float32)
|
||||
past_key_values_storage[i][1][:, :, :len2, :] = past_key_values[i][1].to(
|
||||
torch.float32)
|
||||
else:
|
||||
len0 = past_key_values[0][1].size(1)
|
||||
len1 = past_key_values[0][1].size(2)
|
||||
len2 = past_key_values[0][0].size(2) # seq length
|
||||
len3 = past_key_values[0][1].size(3)
|
||||
for i in range(len(past_key_values)):
|
||||
if self.config.model_type == "chatglm":
|
||||
k0 = torch.ones(len0, len1 * query_group_size, len2 + max_new_tokens, len3,
|
||||
dtype=torch.float32)
|
||||
v0 = torch.ones(len0, len1 * query_group_size, len2 + max_new_tokens, len3,
|
||||
dtype=torch.float32)
|
||||
k0 = k0.permute(2, 0, 1, 3)
|
||||
v0 = v0.permute(2, 0, 1, 3)
|
||||
past_key_values_storage.append((k0, v0))
|
||||
past_key_values_storage[i][0][:len2, :, :, :] = ipex_past_key_values[i][0].to(
|
||||
torch.float32)
|
||||
past_key_values_storage[i][1][:len2, :, :, :] = ipex_past_key_values[i][1].to(
|
||||
torch.float32)
|
||||
else:
|
||||
k0 = torch.ones(len0, len1, len2 + max_new_tokens, len3,
|
||||
dtype=torch.float32)
|
||||
v0 = torch.ones(len0, len1, len2 + max_new_tokens, len3,
|
||||
dtype=torch.float32)
|
||||
past_key_values_storage.append((k0, v0))
|
||||
past_key_values_storage[i][0][:, :, :len2, :] = ipex_past_key_values[i][0].to(
|
||||
torch.float32)
|
||||
past_key_values_storage[i][1][:, :, :len2, :] = ipex_past_key_values[i][1].to(
|
||||
|
|
@ -195,7 +231,8 @@ def _prepare_past_key_values_storage_cpu(self, past_key_values,
|
|||
return past_key_values_storage
|
||||
|
||||
|
||||
def _prepare_draft_past_key_values_cpu(self, past_key_values, past_key_values_storage):
|
||||
def _prepare_draft_past_key_values_cpu(self, past_key_values,
|
||||
past_key_values_storage, _enable_ipex):
|
||||
tmp_past_key_values = []
|
||||
for i in range(len(past_key_values)):
|
||||
if self.config.model_type == "qwen":
|
||||
|
|
@ -204,7 +241,10 @@ def _prepare_draft_past_key_values_cpu(self, past_key_values, past_key_values_st
|
|||
v0 = past_key_values_storage[i][1][:, :len1, :, :]
|
||||
tmp_past_key_values.append((k0, v0))
|
||||
elif self.config.model_type == "chatglm":
|
||||
len0 = past_key_values[0][0].size(0)
|
||||
if not _enable_ipex:
|
||||
len0 = past_key_values[0][0].size(0)
|
||||
else:
|
||||
len0 = past_key_values[0][0].size(1)
|
||||
k0 = past_key_values_storage[i][0][:len0, :, :, :]
|
||||
v0 = past_key_values_storage[i][1][:len0, :, :, :]
|
||||
tmp_past_key_values.append((k0, v0))
|
||||
|
|
@ -244,15 +284,41 @@ def _update_past_key_values_storage_cpu(self, past_key_values, past_key_values_s
|
|||
else:
|
||||
size = original_draft_past_key_values[i][0].size(2)
|
||||
size1 = past_key_values[i][0].size(1)
|
||||
delta_past_key = \
|
||||
past_key_values[i][1][size:size1, :, :, :].permute(1, 2, 0, 3)
|
||||
delta_past_value = \
|
||||
past_key_values[i][2][size:size1, :, :, :].permute(1, 2, 0, 3)
|
||||
if self.config.model_type == "chatglm":
|
||||
size = original_draft_past_key_values[0][0].size(0)
|
||||
size1 = past_key_values[0][0].size(1)
|
||||
len0 = past_key_values[0][1].size(0) # seq max_length
|
||||
len1 = past_key_values[0][1].size(1)
|
||||
len2 = past_key_values[0][1].size(2)
|
||||
len3 = past_key_values[0][1].size(3)
|
||||
key0 = torch.ones(size1-size, len1, len2, len3,
|
||||
dtype=torch.float32)
|
||||
value0 = torch.ones(size1-size, len1, len2, len3,
|
||||
dtype=torch.float32)
|
||||
key0 = past_key_values[i][1][size:size1, :, :, :]
|
||||
value0 = past_key_values[i][2][size:size1, :, :, :]
|
||||
key = key0.permute(1, 2, 0, 3).unsqueeze(-3)
|
||||
key = key.expand(-1, -1, query_group_size, -1, -1)
|
||||
key = key.contiguous().view(len1, len2 * query_group_size, size1-size, len3)
|
||||
key = key.permute(2, 0, 1, 3)
|
||||
value = value0.permute(1, 2, 0, 3).unsqueeze(-3)
|
||||
value = value.expand(-1, -1, query_group_size, -1, -1)
|
||||
value = value.contiguous().view(len1, len2 * query_group_size, size1-size, len3)
|
||||
value = value.permute(2, 0, 1, 3)
|
||||
past_key_values_storage[i][0][size:size1, :, :, :] = \
|
||||
key.to(torch.float32)
|
||||
past_key_values_storage[i][1][size:size1, :, :, :] = \
|
||||
value.to(torch.float32)
|
||||
else:
|
||||
delta_past_key = \
|
||||
past_key_values[i][1][size:size1, :, :, :].permute(1, 2, 0, 3)
|
||||
delta_past_value = \
|
||||
past_key_values[i][2][size:size1, :, :, :].permute(1, 2, 0, 3)
|
||||
|
||||
past_key_values_storage[i][0][:, :, size:size1, :] = \
|
||||
delta_past_key.to(torch.float32)
|
||||
past_key_values_storage[i][1][:, :, size:size1, :] = \
|
||||
delta_past_value.to(torch.float32)
|
||||
past_key_values_storage[i][0][:, :, size:size1, :] = \
|
||||
delta_past_key.to(torch.float32)
|
||||
past_key_values_storage[i][1][:, :, size:size1, :] = \
|
||||
delta_past_value.to(torch.float32)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
|
|
@ -372,10 +438,14 @@ def speculative_generate(self,
|
|||
_enable_ipex = (_enable_ipex is not None) and (_enable_ipex.lower() == "true")
|
||||
if _enable_ipex:
|
||||
if not ((self.config.model_type == 'baichuan' and self.config.hidden_size == 5120) or
|
||||
('llama' in self.config.model_type) or
|
||||
('llama' in self.config.model_type) or ("chatglm" in self.config.model_type) or
|
||||
("mistral" in self.config.model_type)):
|
||||
invalidInputError(False, "BigDL Speculative Decoding with IPEX BF16 only supports \
|
||||
Llama, Baichuan2-13b and Mistral models currently.")
|
||||
if "chatglm" in self.config.model_type:
|
||||
global query_group_size
|
||||
query_group_size = draft_model.config.num_attention_heads // \
|
||||
draft_model.config.multi_query_group_num
|
||||
|
||||
tmp_matchness = 0
|
||||
e2e_tic = 0.0
|
||||
|
|
@ -437,7 +507,7 @@ def speculative_generate(self,
|
|||
# each iter cut off cur_len kv_cache from past_key_values1
|
||||
draft_past_key_values = \
|
||||
_prepare_draft_past_key_values_cpu(self, past_key_values,
|
||||
past_key_values_storage)
|
||||
past_key_values_storage, _enable_ipex)
|
||||
original_draft_past_key_values = draft_past_key_values
|
||||
else:
|
||||
draft_past_key_values = past_key_values
|
||||
|
|
@ -464,7 +534,10 @@ def speculative_generate(self,
|
|||
"use_cache": True,
|
||||
}
|
||||
if self.config.model_type == "chatglm":
|
||||
past_key_value_len = past_key_values[0][0].shape[0]
|
||||
if _enable_ipex:
|
||||
past_key_value_len = past_key_values[0][0].shape[1]
|
||||
else:
|
||||
past_key_value_len = past_key_values[0][0].shape[0]
|
||||
position_ids = torch.Tensor([[past_key_value_len + step_draft]]).long()
|
||||
forward_args["position_ids"] = position_ids
|
||||
elif self.config.model_type == "gptj":
|
||||
|
|
@ -533,6 +606,16 @@ def speculative_generate(self,
|
|||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
)
|
||||
elif "chatglm" in self.config.model_type:
|
||||
past_key_value_len = past_key_values[0][0].shape[2]
|
||||
position_ids = torch.arange(drafted_input_ids.shape[1], dtype=torch.long,
|
||||
device=drafted_input_ids.device).unsqueeze(0)
|
||||
position_ids = position_ids.repeat(1, 1) + past_key_value_len
|
||||
output = self.trace_graph(input_ids=drafted_input_ids,
|
||||
attention_mask=cur_attention_mask,
|
||||
position_ids=position_ids,
|
||||
# return_last_logit=torch.tensor(False),
|
||||
past_key_values=past_key_values,)
|
||||
elif "mistral" in self.config.model_type:
|
||||
past_key_value_len = past_key_values[0][0].shape[2]
|
||||
seq_len = drafted_input_ids.shape[1]
|
||||
|
|
@ -591,7 +674,8 @@ def speculative_generate(self,
|
|||
self.verify_time.append(toc - tic)
|
||||
self.generate_time.append(self.draft_time[-1] + self.verify_time[-1])
|
||||
|
||||
past_key_values = output['past_key_values']
|
||||
if past_key_values is None:
|
||||
past_key_values = output['past_key_values']
|
||||
|
||||
if generation_config.do_sample:
|
||||
draft_tokens = drafted_input_ids[:, 1:].squeeze(0)
|
||||
|
|
|
|||
Loading…
Reference in a new issue