LLM: enable chatglm3-6b target_model ipex (#10085)

* init

* always make casual_mask

* not return last tensor

* update

* optimize_model = False

* enable optimized=False

* enable optimized_model=true

* speed_up ipex target_model

* remove if True

* use group_size

* update python style

* update

* update
This commit is contained in:
Wang, Jian4 2024-02-19 13:38:32 +08:00 committed by GitHub
parent 177273c1a4
commit f2417e083c
3 changed files with 204 additions and 71 deletions

View file

@ -518,6 +518,16 @@ def ggml_convert_low_bit(model, qtype, optimize_model=True,
f"format......") f"format......")
modules_to_not_convert = [] if modules_to_not_convert is None else modules_to_not_convert modules_to_not_convert = [] if modules_to_not_convert is None else modules_to_not_convert
# using ipex optimizer before changing to bigdl linear
_enable_ipex = os.getenv("BIGDL_OPT_IPEX")
_enable_ipex = (_enable_ipex is not None) and (_enable_ipex.lower() == "true")
_enable_ipex = _enable_ipex and (qtype == ggml_tensor_qtype["bf16"])
if (device == "cpu") and (qtype == ggml_tensor_qtype["bf16"]):
logger.info(f"BIGDL_OPT_IPEX: {_enable_ipex}")
if _enable_ipex:
model = _optimize_ipex(model)
return model
if optimize_model: if optimize_model:
model = _optimize_pre(model) model = _optimize_pre(model)
@ -543,14 +553,6 @@ def ggml_convert_low_bit(model, qtype, optimize_model=True,
# Do nothing here for weights are empty. # Do nothing here for weights are empty.
pass pass
_enable_ipex = os.getenv("BIGDL_OPT_IPEX")
_enable_ipex = (_enable_ipex is not None) and (_enable_ipex.lower() == "true")
_enable_ipex = _enable_ipex and (qtype == ggml_tensor_qtype["bf16"])
if (device == "cpu") and (qtype == ggml_tensor_qtype["bf16"]):
logger.info(f"BIGDL_OPT_IPEX: {_enable_ipex}")
if _enable_ipex:
model = _optimize_ipex(model)
return model
if optimize_model: if optimize_model:
model = _optimize_post(model, lightweight_bmm) model = _optimize_post(model, lightweight_bmm)
return model return model
@ -590,13 +592,17 @@ def _optimize_ipex(model):
from transformers.modeling_attn_mask_utils import AttentionMaskConverter from transformers.modeling_attn_mask_utils import AttentionMaskConverter
from bigdl.llm.transformers.convert_ipex import ( from bigdl.llm.transformers.convert_ipex import (
_ipex_optimize_attention, _ipex_optimize_decoder, _ipex_jit, _make_causal_mask, _ipex_optimize_attention, _ipex_optimize_decoder, _ipex_jit, _make_causal_mask,
_ipex_optimize_rmsnorm, _llama_model_forward_4_35 _ipex_optimize_rmsnorm, _llama_model_forward_4_35, convert_function, GLM_get_masks
) )
AttentionMaskConverter._make_causal_mask = _make_causal_mask AttentionMaskConverter._make_causal_mask = _make_causal_mask
convert_forward(model, transformers.models.llama.modeling_llama.LlamaModel, convert_forward(model, transformers.models.llama.modeling_llama.LlamaModel,
_llama_model_forward_4_35) _llama_model_forward_4_35)
model = model_convert_reference(model) model = model_convert_reference(model)
if model.config.architectures is not None \
and model.config.architectures[0] in ["ChatGLMModel", "ChatGLMForConditionalGeneration"]:
convert_function(model.transformer, "get_masks", GLM_get_masks)
model = ipex.optimize(model.eval(), dtype=torch.bfloat16, inplace=True).eval()
_ipex_optimize_rmsnorm(model) _ipex_optimize_rmsnorm(model)
_ipex_optimize_attention(model) _ipex_optimize_attention(model)
_ipex_optimize_decoder(model) _ipex_optimize_decoder(model)

View file

@ -142,6 +142,8 @@ def _ipex_jit(model):
sample_inputs = ( sample_inputs = (
get_dummy_input(model, return_dict=True) get_dummy_input(model, return_dict=True)
) )
if "return_last_logit" in sample_inputs:
del sample_inputs["return_last_logit"]
with torch.no_grad(), torch.cpu.amp.autocast( with torch.no_grad(), torch.cpu.amp.autocast(
enabled=True enabled=True
): ):
@ -159,6 +161,47 @@ def _ipex_jit(model):
return model.eval() return model.eval()
def convert_function(m, func_name, new_function):
bound_method = new_function.__get__(m, m.__class__)
setattr(m, func_name, bound_method)
def GLM_get_masks(self, input_ids, past_key_values, padding_mask=None):
batch_size, seq_length = input_ids.shape
full_attention_mask = torch.ones(
batch_size, seq_length, seq_length, device=input_ids.device
)
full_attention_mask.tril_()
past_length = 0
if past_key_values:
if len(past_key_values[0]) != 4: # not discrete kv cache
past_length = past_key_values[0][0].shape[0]
else: # discrete kv cache
past_length = past_key_values[0][0].shape[-2]
import os
_enable_ipex = os.getenv("BIGDL_OPT_IPEX")
_enable_ipex = (_enable_ipex is not None) and (_enable_ipex.lower() == "true")
# always call for jit
if past_length or _enable_ipex:
full_attention_mask = torch.cat(
(
torch.ones(
batch_size, seq_length, past_length, device=input_ids.device
),
full_attention_mask,
),
dim=-1,
)
if padding_mask is not None:
full_attention_mask = full_attention_mask * padding_mask.unsqueeze(1)
# if not past_length and padding_mask is not None:
# full_attention_mask -= padding_mask.unsqueeze(-1) - 1
full_attention_mask = (full_attention_mask < 0.5).bool()
full_attention_mask.unsqueeze_(1)
return full_attention_mask
@staticmethod @staticmethod
def _make_causal_mask( def _make_causal_mask(
input_ids_shape: torch.Size, input_ids_shape: torch.Size,

View file

@ -34,7 +34,7 @@ from bigdl.llm.utils.common import invalidInputError
# patch GenerationMixin.generate # patch GenerationMixin.generate
from transformers import GenerationMixin from transformers import GenerationMixin
original_generate = GenerationMixin.generate original_generate = GenerationMixin.generate
query_group_size = 16
logger = logging.getLogger("bigdl.llm.speculative") logger = logging.getLogger("bigdl.llm.speculative")
@ -131,62 +131,98 @@ def clear_benchmarks(self):
def _prepare_past_key_values_storage_cpu(self, past_key_values, def _prepare_past_key_values_storage_cpu(self, past_key_values,
max_new_tokens, _enable_ipex=False): max_new_tokens, _enable_ipex=False):
past_key_values_storage = [] past_key_values_storage = []
# init ipex_past_key_values
if _enable_ipex: if _enable_ipex:
ipex_past_key_values = [] ipex_past_key_values = []
cur_len = past_key_values[0][0].size(1) cur_len = past_key_values[0][0].size(1)
ipex_past_key_values = [ if self.config.model_type == "chatglm":
[pkv[1].permute(1, 2, 0, 3)[:, :, :cur_len, :], len0 = past_key_values[0][1].size(0) # seq max length
pkv[2].permute(1, 2, 0, 3)[:, :, :cur_len, :]] len1 = past_key_values[0][1].size(1)
for pkv in past_key_values len2 = past_key_values[0][1].size(2)
] len3 = past_key_values[0][1].size(3)
for pkv in past_key_values:
for i in range(len(past_key_values)): key = pkv[1]
if not _enable_ipex: value = pkv[2]
len0 = past_key_values[i][0].size(0) key = key.permute(1, 2, 0, 3).unsqueeze(-3)
len1 = past_key_values[i][0].size(1) key = key.expand(-1, -1, query_group_size, -1, -1)
len2 = past_key_values[i][0].size(2) key = key.contiguous().view(len1, len2 * query_group_size,
len3 = past_key_values[i][0].size(3) len0, len3).permute(2, 0, 1, 3)
value = value.permute(1, 2, 0, 3).unsqueeze(-3)
value = value.expand(-1, -1, query_group_size, -1, -1)
value = value.contiguous().view(len1, len2 * query_group_size,
len0, len3).permute(2, 0, 1, 3)
list = [key[:cur_len, :, :, :], value[:cur_len, :, :, :]]
ipex_past_key_values.append(list)
else: else:
len0 = past_key_values[i][1].size(1) ipex_past_key_values = [
len1 = past_key_values[i][1].size(2) [pkv[1].permute(1, 2, 0, 3)[:, :, :cur_len, :],
len2 = past_key_values[i][0].size(2) # seq length pkv[2].permute(1, 2, 0, 3)[:, :, :cur_len, :]]
len3 = past_key_values[i][1].size(3) for pkv in past_key_values
if self.config.model_type == "qwen": ]
k0 = torch.ones(len0, len2, len1 + max_new_tokens, len3, if not _enable_ipex:
dtype=torch.float32) len0 = past_key_values[0][0].size(0)
v0 = torch.ones(len0, len2, len1 + max_new_tokens, len3, len1 = past_key_values[0][0].size(1)
dtype=torch.float32) len2 = past_key_values[0][0].size(2)
k0 = k0.transpose(1, 2) len3 = past_key_values[0][0].size(3)
v0 = v0.transpose(1, 2) for i in range(len(past_key_values)):
past_key_values_storage.append((k0, v0)) if self.config.model_type == "qwen":
past_key_values_storage[i][0][:, :len1, :, :] = past_key_values[i][0].to( k0 = torch.ones(len0, len2, len1 + max_new_tokens, len3,
torch.float32) dtype=torch.float32)
past_key_values_storage[i][1][:, :len1, :, :] = past_key_values[i][1].to( v0 = torch.ones(len0, len2, len1 + max_new_tokens, len3,
torch.float32) dtype=torch.float32)
elif self.config.model_type == "chatglm": k0 = k0.transpose(1, 2)
k0 = torch.ones(len1, len2, len0 + max_new_tokens, len3, v0 = v0.transpose(1, 2)
dtype=torch.float32) past_key_values_storage.append((k0, v0))
v0 = torch.ones(len1, len2, len0 + max_new_tokens, len3, past_key_values_storage[i][0][:, :len1, :, :] = past_key_values[i][0].to(
dtype=torch.float32) torch.float32)
k0 = k0.permute(2, 0, 1, 3) past_key_values_storage[i][1][:, :len1, :, :] = past_key_values[i][1].to(
v0 = v0.permute(2, 0, 1, 3) torch.float32)
past_key_values_storage.append((k0, v0)) elif self.config.model_type == "chatglm":
past_key_values_storage[i][0][:len0, :, :, :] = past_key_values[i][0].to( k0 = torch.ones(len1, len2, len0 + max_new_tokens, len3,
torch.float32) dtype=torch.float32)
past_key_values_storage[i][1][:len0, :, :, :] = past_key_values[i][1].to( v0 = torch.ones(len1, len2, len0 + max_new_tokens, len3,
torch.float32) dtype=torch.float32)
else: k0 = k0.permute(2, 0, 1, 3)
k0 = torch.ones(len0, len1, len2 + max_new_tokens, len3, v0 = v0.permute(2, 0, 1, 3)
dtype=torch.float32) past_key_values_storage.append((k0, v0))
v0 = torch.ones(len0, len1, len2 + max_new_tokens, len3, past_key_values_storage[i][0][:len0, :, :, :] = past_key_values[i][0].to(
dtype=torch.float32) torch.float32)
past_key_values_storage.append((k0, v0)) past_key_values_storage[i][1][:len0, :, :, :] = past_key_values[i][1].to(
if not _enable_ipex: torch.float32)
else:
k0 = torch.ones(len0, len1, len2 + max_new_tokens, len3,
dtype=torch.float32)
v0 = torch.ones(len0, len1, len2 + max_new_tokens, len3,
dtype=torch.float32)
past_key_values_storage.append((k0, v0))
past_key_values_storage[i][0][:, :, :len2, :] = past_key_values[i][0].to( past_key_values_storage[i][0][:, :, :len2, :] = past_key_values[i][0].to(
torch.float32) torch.float32)
past_key_values_storage[i][1][:, :, :len2, :] = past_key_values[i][1].to( past_key_values_storage[i][1][:, :, :len2, :] = past_key_values[i][1].to(
torch.float32) torch.float32)
else:
len0 = past_key_values[0][1].size(1)
len1 = past_key_values[0][1].size(2)
len2 = past_key_values[0][0].size(2) # seq length
len3 = past_key_values[0][1].size(3)
for i in range(len(past_key_values)):
if self.config.model_type == "chatglm":
k0 = torch.ones(len0, len1 * query_group_size, len2 + max_new_tokens, len3,
dtype=torch.float32)
v0 = torch.ones(len0, len1 * query_group_size, len2 + max_new_tokens, len3,
dtype=torch.float32)
k0 = k0.permute(2, 0, 1, 3)
v0 = v0.permute(2, 0, 1, 3)
past_key_values_storage.append((k0, v0))
past_key_values_storage[i][0][:len2, :, :, :] = ipex_past_key_values[i][0].to(
torch.float32)
past_key_values_storage[i][1][:len2, :, :, :] = ipex_past_key_values[i][1].to(
torch.float32)
else: else:
k0 = torch.ones(len0, len1, len2 + max_new_tokens, len3,
dtype=torch.float32)
v0 = torch.ones(len0, len1, len2 + max_new_tokens, len3,
dtype=torch.float32)
past_key_values_storage.append((k0, v0))
past_key_values_storage[i][0][:, :, :len2, :] = ipex_past_key_values[i][0].to( past_key_values_storage[i][0][:, :, :len2, :] = ipex_past_key_values[i][0].to(
torch.float32) torch.float32)
past_key_values_storage[i][1][:, :, :len2, :] = ipex_past_key_values[i][1].to( past_key_values_storage[i][1][:, :, :len2, :] = ipex_past_key_values[i][1].to(
@ -195,7 +231,8 @@ def _prepare_past_key_values_storage_cpu(self, past_key_values,
return past_key_values_storage return past_key_values_storage
def _prepare_draft_past_key_values_cpu(self, past_key_values, past_key_values_storage): def _prepare_draft_past_key_values_cpu(self, past_key_values,
past_key_values_storage, _enable_ipex):
tmp_past_key_values = [] tmp_past_key_values = []
for i in range(len(past_key_values)): for i in range(len(past_key_values)):
if self.config.model_type == "qwen": if self.config.model_type == "qwen":
@ -204,7 +241,10 @@ def _prepare_draft_past_key_values_cpu(self, past_key_values, past_key_values_st
v0 = past_key_values_storage[i][1][:, :len1, :, :] v0 = past_key_values_storage[i][1][:, :len1, :, :]
tmp_past_key_values.append((k0, v0)) tmp_past_key_values.append((k0, v0))
elif self.config.model_type == "chatglm": elif self.config.model_type == "chatglm":
len0 = past_key_values[0][0].size(0) if not _enable_ipex:
len0 = past_key_values[0][0].size(0)
else:
len0 = past_key_values[0][0].size(1)
k0 = past_key_values_storage[i][0][:len0, :, :, :] k0 = past_key_values_storage[i][0][:len0, :, :, :]
v0 = past_key_values_storage[i][1][:len0, :, :, :] v0 = past_key_values_storage[i][1][:len0, :, :, :]
tmp_past_key_values.append((k0, v0)) tmp_past_key_values.append((k0, v0))
@ -244,15 +284,41 @@ def _update_past_key_values_storage_cpu(self, past_key_values, past_key_values_s
else: else:
size = original_draft_past_key_values[i][0].size(2) size = original_draft_past_key_values[i][0].size(2)
size1 = past_key_values[i][0].size(1) size1 = past_key_values[i][0].size(1)
delta_past_key = \ if self.config.model_type == "chatglm":
past_key_values[i][1][size:size1, :, :, :].permute(1, 2, 0, 3) size = original_draft_past_key_values[0][0].size(0)
delta_past_value = \ size1 = past_key_values[0][0].size(1)
past_key_values[i][2][size:size1, :, :, :].permute(1, 2, 0, 3) len0 = past_key_values[0][1].size(0) # seq max_length
len1 = past_key_values[0][1].size(1)
len2 = past_key_values[0][1].size(2)
len3 = past_key_values[0][1].size(3)
key0 = torch.ones(size1-size, len1, len2, len3,
dtype=torch.float32)
value0 = torch.ones(size1-size, len1, len2, len3,
dtype=torch.float32)
key0 = past_key_values[i][1][size:size1, :, :, :]
value0 = past_key_values[i][2][size:size1, :, :, :]
key = key0.permute(1, 2, 0, 3).unsqueeze(-3)
key = key.expand(-1, -1, query_group_size, -1, -1)
key = key.contiguous().view(len1, len2 * query_group_size, size1-size, len3)
key = key.permute(2, 0, 1, 3)
value = value0.permute(1, 2, 0, 3).unsqueeze(-3)
value = value.expand(-1, -1, query_group_size, -1, -1)
value = value.contiguous().view(len1, len2 * query_group_size, size1-size, len3)
value = value.permute(2, 0, 1, 3)
past_key_values_storage[i][0][size:size1, :, :, :] = \
key.to(torch.float32)
past_key_values_storage[i][1][size:size1, :, :, :] = \
value.to(torch.float32)
else:
delta_past_key = \
past_key_values[i][1][size:size1, :, :, :].permute(1, 2, 0, 3)
delta_past_value = \
past_key_values[i][2][size:size1, :, :, :].permute(1, 2, 0, 3)
past_key_values_storage[i][0][:, :, size:size1, :] = \ past_key_values_storage[i][0][:, :, size:size1, :] = \
delta_past_key.to(torch.float32) delta_past_key.to(torch.float32)
past_key_values_storage[i][1][:, :, size:size1, :] = \ past_key_values_storage[i][1][:, :, size:size1, :] = \
delta_past_value.to(torch.float32) delta_past_value.to(torch.float32)
@torch.no_grad() @torch.no_grad()
@ -372,10 +438,14 @@ def speculative_generate(self,
_enable_ipex = (_enable_ipex is not None) and (_enable_ipex.lower() == "true") _enable_ipex = (_enable_ipex is not None) and (_enable_ipex.lower() == "true")
if _enable_ipex: if _enable_ipex:
if not ((self.config.model_type == 'baichuan' and self.config.hidden_size == 5120) or if not ((self.config.model_type == 'baichuan' and self.config.hidden_size == 5120) or
('llama' in self.config.model_type) or ('llama' in self.config.model_type) or ("chatglm" in self.config.model_type) or
("mistral" in self.config.model_type)): ("mistral" in self.config.model_type)):
invalidInputError(False, "BigDL Speculative Decoding with IPEX BF16 only supports \ invalidInputError(False, "BigDL Speculative Decoding with IPEX BF16 only supports \
Llama, Baichuan2-13b and Mistral models currently.") Llama, Baichuan2-13b and Mistral models currently.")
if "chatglm" in self.config.model_type:
global query_group_size
query_group_size = draft_model.config.num_attention_heads // \
draft_model.config.multi_query_group_num
tmp_matchness = 0 tmp_matchness = 0
e2e_tic = 0.0 e2e_tic = 0.0
@ -437,7 +507,7 @@ def speculative_generate(self,
# each iter cut off cur_len kv_cache from past_key_values1 # each iter cut off cur_len kv_cache from past_key_values1
draft_past_key_values = \ draft_past_key_values = \
_prepare_draft_past_key_values_cpu(self, past_key_values, _prepare_draft_past_key_values_cpu(self, past_key_values,
past_key_values_storage) past_key_values_storage, _enable_ipex)
original_draft_past_key_values = draft_past_key_values original_draft_past_key_values = draft_past_key_values
else: else:
draft_past_key_values = past_key_values draft_past_key_values = past_key_values
@ -464,7 +534,10 @@ def speculative_generate(self,
"use_cache": True, "use_cache": True,
} }
if self.config.model_type == "chatglm": if self.config.model_type == "chatglm":
past_key_value_len = past_key_values[0][0].shape[0] if _enable_ipex:
past_key_value_len = past_key_values[0][0].shape[1]
else:
past_key_value_len = past_key_values[0][0].shape[0]
position_ids = torch.Tensor([[past_key_value_len + step_draft]]).long() position_ids = torch.Tensor([[past_key_value_len + step_draft]]).long()
forward_args["position_ids"] = position_ids forward_args["position_ids"] = position_ids
elif self.config.model_type == "gptj": elif self.config.model_type == "gptj":
@ -533,6 +606,16 @@ def speculative_generate(self,
position_ids=position_ids, position_ids=position_ids,
past_key_values=past_key_values, past_key_values=past_key_values,
) )
elif "chatglm" in self.config.model_type:
past_key_value_len = past_key_values[0][0].shape[2]
position_ids = torch.arange(drafted_input_ids.shape[1], dtype=torch.long,
device=drafted_input_ids.device).unsqueeze(0)
position_ids = position_ids.repeat(1, 1) + past_key_value_len
output = self.trace_graph(input_ids=drafted_input_ids,
attention_mask=cur_attention_mask,
position_ids=position_ids,
# return_last_logit=torch.tensor(False),
past_key_values=past_key_values,)
elif "mistral" in self.config.model_type: elif "mistral" in self.config.model_type:
past_key_value_len = past_key_values[0][0].shape[2] past_key_value_len = past_key_values[0][0].shape[2]
seq_len = drafted_input_ids.shape[1] seq_len = drafted_input_ids.shape[1]
@ -591,7 +674,8 @@ def speculative_generate(self,
self.verify_time.append(toc - tic) self.verify_time.append(toc - tic)
self.generate_time.append(self.draft_time[-1] + self.verify_time[-1]) self.generate_time.append(self.draft_time[-1] + self.verify_time[-1])
past_key_values = output['past_key_values'] if past_key_values is None:
past_key_values = output['past_key_values']
if generation_config.do_sample: if generation_config.do_sample:
draft_tokens = drafted_input_ids[:, 1:].squeeze(0) draft_tokens = drafted_input_ids[:, 1:].squeeze(0)