diff --git a/python/llm/src/bigdl/llm/transformers/convert.py b/python/llm/src/bigdl/llm/transformers/convert.py index f299cb78..bc886e3c 100644 --- a/python/llm/src/bigdl/llm/transformers/convert.py +++ b/python/llm/src/bigdl/llm/transformers/convert.py @@ -518,6 +518,16 @@ def ggml_convert_low_bit(model, qtype, optimize_model=True, f"format......") modules_to_not_convert = [] if modules_to_not_convert is None else modules_to_not_convert + # using ipex optimizer before changing to bigdl linear + _enable_ipex = os.getenv("BIGDL_OPT_IPEX") + _enable_ipex = (_enable_ipex is not None) and (_enable_ipex.lower() == "true") + _enable_ipex = _enable_ipex and (qtype == ggml_tensor_qtype["bf16"]) + if (device == "cpu") and (qtype == ggml_tensor_qtype["bf16"]): + logger.info(f"BIGDL_OPT_IPEX: {_enable_ipex}") + if _enable_ipex: + model = _optimize_ipex(model) + return model + if optimize_model: model = _optimize_pre(model) @@ -543,14 +553,6 @@ def ggml_convert_low_bit(model, qtype, optimize_model=True, # Do nothing here for weights are empty. pass - _enable_ipex = os.getenv("BIGDL_OPT_IPEX") - _enable_ipex = (_enable_ipex is not None) and (_enable_ipex.lower() == "true") - _enable_ipex = _enable_ipex and (qtype == ggml_tensor_qtype["bf16"]) - if (device == "cpu") and (qtype == ggml_tensor_qtype["bf16"]): - logger.info(f"BIGDL_OPT_IPEX: {_enable_ipex}") - if _enable_ipex: - model = _optimize_ipex(model) - return model if optimize_model: model = _optimize_post(model, lightweight_bmm) return model @@ -590,13 +592,17 @@ def _optimize_ipex(model): from transformers.modeling_attn_mask_utils import AttentionMaskConverter from bigdl.llm.transformers.convert_ipex import ( _ipex_optimize_attention, _ipex_optimize_decoder, _ipex_jit, _make_causal_mask, - _ipex_optimize_rmsnorm, _llama_model_forward_4_35 + _ipex_optimize_rmsnorm, _llama_model_forward_4_35, convert_function, GLM_get_masks ) AttentionMaskConverter._make_causal_mask = _make_causal_mask convert_forward(model, transformers.models.llama.modeling_llama.LlamaModel, _llama_model_forward_4_35) model = model_convert_reference(model) + if model.config.architectures is not None \ + and model.config.architectures[0] in ["ChatGLMModel", "ChatGLMForConditionalGeneration"]: + convert_function(model.transformer, "get_masks", GLM_get_masks) + model = ipex.optimize(model.eval(), dtype=torch.bfloat16, inplace=True).eval() _ipex_optimize_rmsnorm(model) _ipex_optimize_attention(model) _ipex_optimize_decoder(model) diff --git a/python/llm/src/bigdl/llm/transformers/convert_ipex.py b/python/llm/src/bigdl/llm/transformers/convert_ipex.py index cd534fbe..5ece6835 100644 --- a/python/llm/src/bigdl/llm/transformers/convert_ipex.py +++ b/python/llm/src/bigdl/llm/transformers/convert_ipex.py @@ -142,6 +142,8 @@ def _ipex_jit(model): sample_inputs = ( get_dummy_input(model, return_dict=True) ) + if "return_last_logit" in sample_inputs: + del sample_inputs["return_last_logit"] with torch.no_grad(), torch.cpu.amp.autocast( enabled=True ): @@ -159,6 +161,47 @@ def _ipex_jit(model): return model.eval() +def convert_function(m, func_name, new_function): + bound_method = new_function.__get__(m, m.__class__) + setattr(m, func_name, bound_method) + + +def GLM_get_masks(self, input_ids, past_key_values, padding_mask=None): + batch_size, seq_length = input_ids.shape + full_attention_mask = torch.ones( + batch_size, seq_length, seq_length, device=input_ids.device + ) + full_attention_mask.tril_() + past_length = 0 + if past_key_values: + if len(past_key_values[0]) != 4: # not discrete kv cache + past_length = past_key_values[0][0].shape[0] + else: # discrete kv cache + past_length = past_key_values[0][0].shape[-2] + + import os + _enable_ipex = os.getenv("BIGDL_OPT_IPEX") + _enable_ipex = (_enable_ipex is not None) and (_enable_ipex.lower() == "true") + # always call for jit + if past_length or _enable_ipex: + full_attention_mask = torch.cat( + ( + torch.ones( + batch_size, seq_length, past_length, device=input_ids.device + ), + full_attention_mask, + ), + dim=-1, + ) + if padding_mask is not None: + full_attention_mask = full_attention_mask * padding_mask.unsqueeze(1) + # if not past_length and padding_mask is not None: + # full_attention_mask -= padding_mask.unsqueeze(-1) - 1 + full_attention_mask = (full_attention_mask < 0.5).bool() + full_attention_mask.unsqueeze_(1) + return full_attention_mask + + @staticmethod def _make_causal_mask( input_ids_shape: torch.Size, diff --git a/python/llm/src/bigdl/llm/transformers/speculative.py b/python/llm/src/bigdl/llm/transformers/speculative.py index 60501e0b..103a6237 100644 --- a/python/llm/src/bigdl/llm/transformers/speculative.py +++ b/python/llm/src/bigdl/llm/transformers/speculative.py @@ -34,7 +34,7 @@ from bigdl.llm.utils.common import invalidInputError # patch GenerationMixin.generate from transformers import GenerationMixin original_generate = GenerationMixin.generate - +query_group_size = 16 logger = logging.getLogger("bigdl.llm.speculative") @@ -131,62 +131,98 @@ def clear_benchmarks(self): def _prepare_past_key_values_storage_cpu(self, past_key_values, max_new_tokens, _enable_ipex=False): past_key_values_storage = [] + # init ipex_past_key_values if _enable_ipex: ipex_past_key_values = [] cur_len = past_key_values[0][0].size(1) - ipex_past_key_values = [ - [pkv[1].permute(1, 2, 0, 3)[:, :, :cur_len, :], - pkv[2].permute(1, 2, 0, 3)[:, :, :cur_len, :]] - for pkv in past_key_values - ] - - for i in range(len(past_key_values)): - if not _enable_ipex: - len0 = past_key_values[i][0].size(0) - len1 = past_key_values[i][0].size(1) - len2 = past_key_values[i][0].size(2) - len3 = past_key_values[i][0].size(3) + if self.config.model_type == "chatglm": + len0 = past_key_values[0][1].size(0) # seq max length + len1 = past_key_values[0][1].size(1) + len2 = past_key_values[0][1].size(2) + len3 = past_key_values[0][1].size(3) + for pkv in past_key_values: + key = pkv[1] + value = pkv[2] + key = key.permute(1, 2, 0, 3).unsqueeze(-3) + key = key.expand(-1, -1, query_group_size, -1, -1) + key = key.contiguous().view(len1, len2 * query_group_size, + len0, len3).permute(2, 0, 1, 3) + value = value.permute(1, 2, 0, 3).unsqueeze(-3) + value = value.expand(-1, -1, query_group_size, -1, -1) + value = value.contiguous().view(len1, len2 * query_group_size, + len0, len3).permute(2, 0, 1, 3) + list = [key[:cur_len, :, :, :], value[:cur_len, :, :, :]] + ipex_past_key_values.append(list) else: - len0 = past_key_values[i][1].size(1) - len1 = past_key_values[i][1].size(2) - len2 = past_key_values[i][0].size(2) # seq length - len3 = past_key_values[i][1].size(3) - if self.config.model_type == "qwen": - k0 = torch.ones(len0, len2, len1 + max_new_tokens, len3, - dtype=torch.float32) - v0 = torch.ones(len0, len2, len1 + max_new_tokens, len3, - dtype=torch.float32) - k0 = k0.transpose(1, 2) - v0 = v0.transpose(1, 2) - past_key_values_storage.append((k0, v0)) - past_key_values_storage[i][0][:, :len1, :, :] = past_key_values[i][0].to( - torch.float32) - past_key_values_storage[i][1][:, :len1, :, :] = past_key_values[i][1].to( - torch.float32) - elif self.config.model_type == "chatglm": - k0 = torch.ones(len1, len2, len0 + max_new_tokens, len3, - dtype=torch.float32) - v0 = torch.ones(len1, len2, len0 + max_new_tokens, len3, - dtype=torch.float32) - k0 = k0.permute(2, 0, 1, 3) - v0 = v0.permute(2, 0, 1, 3) - past_key_values_storage.append((k0, v0)) - past_key_values_storage[i][0][:len0, :, :, :] = past_key_values[i][0].to( - torch.float32) - past_key_values_storage[i][1][:len0, :, :, :] = past_key_values[i][1].to( - torch.float32) - else: - k0 = torch.ones(len0, len1, len2 + max_new_tokens, len3, - dtype=torch.float32) - v0 = torch.ones(len0, len1, len2 + max_new_tokens, len3, - dtype=torch.float32) - past_key_values_storage.append((k0, v0)) - if not _enable_ipex: + ipex_past_key_values = [ + [pkv[1].permute(1, 2, 0, 3)[:, :, :cur_len, :], + pkv[2].permute(1, 2, 0, 3)[:, :, :cur_len, :]] + for pkv in past_key_values + ] + if not _enable_ipex: + len0 = past_key_values[0][0].size(0) + len1 = past_key_values[0][0].size(1) + len2 = past_key_values[0][0].size(2) + len3 = past_key_values[0][0].size(3) + for i in range(len(past_key_values)): + if self.config.model_type == "qwen": + k0 = torch.ones(len0, len2, len1 + max_new_tokens, len3, + dtype=torch.float32) + v0 = torch.ones(len0, len2, len1 + max_new_tokens, len3, + dtype=torch.float32) + k0 = k0.transpose(1, 2) + v0 = v0.transpose(1, 2) + past_key_values_storage.append((k0, v0)) + past_key_values_storage[i][0][:, :len1, :, :] = past_key_values[i][0].to( + torch.float32) + past_key_values_storage[i][1][:, :len1, :, :] = past_key_values[i][1].to( + torch.float32) + elif self.config.model_type == "chatglm": + k0 = torch.ones(len1, len2, len0 + max_new_tokens, len3, + dtype=torch.float32) + v0 = torch.ones(len1, len2, len0 + max_new_tokens, len3, + dtype=torch.float32) + k0 = k0.permute(2, 0, 1, 3) + v0 = v0.permute(2, 0, 1, 3) + past_key_values_storage.append((k0, v0)) + past_key_values_storage[i][0][:len0, :, :, :] = past_key_values[i][0].to( + torch.float32) + past_key_values_storage[i][1][:len0, :, :, :] = past_key_values[i][1].to( + torch.float32) + else: + k0 = torch.ones(len0, len1, len2 + max_new_tokens, len3, + dtype=torch.float32) + v0 = torch.ones(len0, len1, len2 + max_new_tokens, len3, + dtype=torch.float32) + past_key_values_storage.append((k0, v0)) past_key_values_storage[i][0][:, :, :len2, :] = past_key_values[i][0].to( torch.float32) past_key_values_storage[i][1][:, :, :len2, :] = past_key_values[i][1].to( torch.float32) + else: + len0 = past_key_values[0][1].size(1) + len1 = past_key_values[0][1].size(2) + len2 = past_key_values[0][0].size(2) # seq length + len3 = past_key_values[0][1].size(3) + for i in range(len(past_key_values)): + if self.config.model_type == "chatglm": + k0 = torch.ones(len0, len1 * query_group_size, len2 + max_new_tokens, len3, + dtype=torch.float32) + v0 = torch.ones(len0, len1 * query_group_size, len2 + max_new_tokens, len3, + dtype=torch.float32) + k0 = k0.permute(2, 0, 1, 3) + v0 = v0.permute(2, 0, 1, 3) + past_key_values_storage.append((k0, v0)) + past_key_values_storage[i][0][:len2, :, :, :] = ipex_past_key_values[i][0].to( + torch.float32) + past_key_values_storage[i][1][:len2, :, :, :] = ipex_past_key_values[i][1].to( + torch.float32) else: + k0 = torch.ones(len0, len1, len2 + max_new_tokens, len3, + dtype=torch.float32) + v0 = torch.ones(len0, len1, len2 + max_new_tokens, len3, + dtype=torch.float32) + past_key_values_storage.append((k0, v0)) past_key_values_storage[i][0][:, :, :len2, :] = ipex_past_key_values[i][0].to( torch.float32) past_key_values_storage[i][1][:, :, :len2, :] = ipex_past_key_values[i][1].to( @@ -195,7 +231,8 @@ def _prepare_past_key_values_storage_cpu(self, past_key_values, return past_key_values_storage -def _prepare_draft_past_key_values_cpu(self, past_key_values, past_key_values_storage): +def _prepare_draft_past_key_values_cpu(self, past_key_values, + past_key_values_storage, _enable_ipex): tmp_past_key_values = [] for i in range(len(past_key_values)): if self.config.model_type == "qwen": @@ -204,7 +241,10 @@ def _prepare_draft_past_key_values_cpu(self, past_key_values, past_key_values_st v0 = past_key_values_storage[i][1][:, :len1, :, :] tmp_past_key_values.append((k0, v0)) elif self.config.model_type == "chatglm": - len0 = past_key_values[0][0].size(0) + if not _enable_ipex: + len0 = past_key_values[0][0].size(0) + else: + len0 = past_key_values[0][0].size(1) k0 = past_key_values_storage[i][0][:len0, :, :, :] v0 = past_key_values_storage[i][1][:len0, :, :, :] tmp_past_key_values.append((k0, v0)) @@ -244,15 +284,41 @@ def _update_past_key_values_storage_cpu(self, past_key_values, past_key_values_s else: size = original_draft_past_key_values[i][0].size(2) size1 = past_key_values[i][0].size(1) - delta_past_key = \ - past_key_values[i][1][size:size1, :, :, :].permute(1, 2, 0, 3) - delta_past_value = \ - past_key_values[i][2][size:size1, :, :, :].permute(1, 2, 0, 3) + if self.config.model_type == "chatglm": + size = original_draft_past_key_values[0][0].size(0) + size1 = past_key_values[0][0].size(1) + len0 = past_key_values[0][1].size(0) # seq max_length + len1 = past_key_values[0][1].size(1) + len2 = past_key_values[0][1].size(2) + len3 = past_key_values[0][1].size(3) + key0 = torch.ones(size1-size, len1, len2, len3, + dtype=torch.float32) + value0 = torch.ones(size1-size, len1, len2, len3, + dtype=torch.float32) + key0 = past_key_values[i][1][size:size1, :, :, :] + value0 = past_key_values[i][2][size:size1, :, :, :] + key = key0.permute(1, 2, 0, 3).unsqueeze(-3) + key = key.expand(-1, -1, query_group_size, -1, -1) + key = key.contiguous().view(len1, len2 * query_group_size, size1-size, len3) + key = key.permute(2, 0, 1, 3) + value = value0.permute(1, 2, 0, 3).unsqueeze(-3) + value = value.expand(-1, -1, query_group_size, -1, -1) + value = value.contiguous().view(len1, len2 * query_group_size, size1-size, len3) + value = value.permute(2, 0, 1, 3) + past_key_values_storage[i][0][size:size1, :, :, :] = \ + key.to(torch.float32) + past_key_values_storage[i][1][size:size1, :, :, :] = \ + value.to(torch.float32) + else: + delta_past_key = \ + past_key_values[i][1][size:size1, :, :, :].permute(1, 2, 0, 3) + delta_past_value = \ + past_key_values[i][2][size:size1, :, :, :].permute(1, 2, 0, 3) - past_key_values_storage[i][0][:, :, size:size1, :] = \ - delta_past_key.to(torch.float32) - past_key_values_storage[i][1][:, :, size:size1, :] = \ - delta_past_value.to(torch.float32) + past_key_values_storage[i][0][:, :, size:size1, :] = \ + delta_past_key.to(torch.float32) + past_key_values_storage[i][1][:, :, size:size1, :] = \ + delta_past_value.to(torch.float32) @torch.no_grad() @@ -372,10 +438,14 @@ def speculative_generate(self, _enable_ipex = (_enable_ipex is not None) and (_enable_ipex.lower() == "true") if _enable_ipex: if not ((self.config.model_type == 'baichuan' and self.config.hidden_size == 5120) or - ('llama' in self.config.model_type) or + ('llama' in self.config.model_type) or ("chatglm" in self.config.model_type) or ("mistral" in self.config.model_type)): invalidInputError(False, "BigDL Speculative Decoding with IPEX BF16 only supports \ Llama, Baichuan2-13b and Mistral models currently.") + if "chatglm" in self.config.model_type: + global query_group_size + query_group_size = draft_model.config.num_attention_heads // \ + draft_model.config.multi_query_group_num tmp_matchness = 0 e2e_tic = 0.0 @@ -437,7 +507,7 @@ def speculative_generate(self, # each iter cut off cur_len kv_cache from past_key_values1 draft_past_key_values = \ _prepare_draft_past_key_values_cpu(self, past_key_values, - past_key_values_storage) + past_key_values_storage, _enable_ipex) original_draft_past_key_values = draft_past_key_values else: draft_past_key_values = past_key_values @@ -464,7 +534,10 @@ def speculative_generate(self, "use_cache": True, } if self.config.model_type == "chatglm": - past_key_value_len = past_key_values[0][0].shape[0] + if _enable_ipex: + past_key_value_len = past_key_values[0][0].shape[1] + else: + past_key_value_len = past_key_values[0][0].shape[0] position_ids = torch.Tensor([[past_key_value_len + step_draft]]).long() forward_args["position_ids"] = position_ids elif self.config.model_type == "gptj": @@ -533,6 +606,16 @@ def speculative_generate(self, position_ids=position_ids, past_key_values=past_key_values, ) + elif "chatglm" in self.config.model_type: + past_key_value_len = past_key_values[0][0].shape[2] + position_ids = torch.arange(drafted_input_ids.shape[1], dtype=torch.long, + device=drafted_input_ids.device).unsqueeze(0) + position_ids = position_ids.repeat(1, 1) + past_key_value_len + output = self.trace_graph(input_ids=drafted_input_ids, + attention_mask=cur_attention_mask, + position_ids=position_ids, + # return_last_logit=torch.tensor(False), + past_key_values=past_key_values,) elif "mistral" in self.config.model_type: past_key_value_len = past_key_values[0][0].shape[2] seq_len = drafted_input_ids.shape[1] @@ -591,7 +674,8 @@ def speculative_generate(self, self.verify_time.append(toc - tic) self.generate_time.append(self.draft_time[-1] + self.verify_time[-1]) - past_key_values = output['past_key_values'] + if past_key_values is None: + past_key_values = output['past_key_values'] if generation_config.do_sample: draft_tokens = drafted_input_ids[:, 1:].squeeze(0)