From 0ded0b4b131673e05dd50d397dc912d3e10cb914 Mon Sep 17 00:00:00 2001 From: Xiangyu Tian <109123695+xiangyuT@users.noreply.github.com> Date: Tue, 12 Mar 2024 17:08:50 +0800 Subject: [PATCH] LLM: Enable BigDL IPEX optimization for int4 (#10319) Enable BigDL IPEX optimization for int4 --- .../llm/dev/benchmark/all-in-one/config.yaml | 2 + python/llm/dev/benchmark/all-in-one/run.py | 147 ++++++++++++++++++ .../llm/src/bigdl/llm/transformers/convert.py | 22 ++- .../bigdl/llm/transformers/convert_ipex.py | 71 ++++++--- .../src/bigdl/llm/transformers/speculative.py | 70 ++++++++- 5 files changed, 276 insertions(+), 36 deletions(-) diff --git a/python/llm/dev/benchmark/all-in-one/config.yaml b/python/llm/dev/benchmark/all-in-one/config.yaml index 00cdc62a..53d9c90f 100644 --- a/python/llm/dev/benchmark/all-in-one/config.yaml +++ b/python/llm/dev/benchmark/all-in-one/config.yaml @@ -18,6 +18,8 @@ test_api: - "optimize_model" - "pytorch_autocast_bf16" # - "transformer_autocast_bf16" + # - "bigdl_ipex_bf16" + # - "bigdl_ipex_int4" # - "ipex_fp16_gpu" # on Intel GPU # - "bigdl_fp16_gpu" # on Intel GPU # - "transformer_int4_gpu" # on Intel GPU diff --git a/python/llm/dev/benchmark/all-in-one/run.py b/python/llm/dev/benchmark/all-in-one/run.py index 72539202..2acbc251 100644 --- a/python/llm/dev/benchmark/all-in-one/run.py +++ b/python/llm/dev/benchmark/all-in-one/run.py @@ -92,6 +92,10 @@ def run_model(repo_id, test_api, in_out_pairs, local_model_hub=None, warm_up=1, result = run_transformer_int4_loadlowbit_gpu_win(repo_id, local_model_hub, in_out_pairs, warm_up, num_trials, num_beams, low_bit, cpu_embedding, batch_size, streaming) elif test_api == 'transformer_autocast_bf16': result = run_transformer_autocast_bf16(repo_id, local_model_hub, in_out_pairs, warm_up, num_trials, num_beams, batch_size) + elif test_api == 'bigdl_ipex_bf16': + result = run_bigdl_ipex_bf16(repo_id, local_model_hub, in_out_pairs, warm_up, num_trials, num_beams, batch_size) + elif test_api == 'bigdl_ipex_int4': + result = run_bigdl_ipex_int4(repo_id, local_model_hub, in_out_pairs, warm_up, num_trials, num_beams, batch_size) elif test_api == 'deepspeed_optimize_model_gpu': result = run_deepspeed_optimize_model_gpu(repo_id, local_model_hub, in_out_pairs, warm_up, num_trials, num_beams, low_bit, batch_size) @@ -1079,6 +1083,148 @@ def run_transformer_autocast_bf16( repo_id, actual_in_len, actual_out_len, load_time]) return result + +def run_bigdl_ipex_bf16(repo_id, + local_model_hub, + in_out_pairs, + warm_up, + num_trials, + num_beams, + batch_size): + from bigdl.llm.transformers import AutoModel, AutoModelForCausalLM + from transformers import AutoTokenizer, LlamaTokenizer + + os.environ["BIGDL_OPT_IPEX"] = "true" + + model_path = get_model_path(repo_id, local_model_hub) + # Load model in bf16, + # which convert the relevant layers in the model into BF16 format + st = time.perf_counter() + if repo_id in CHATGLM_IDS: + model = AutoModel.from_pretrained(model_path, load_in_low_bit='bf16', trust_remote_code=True, torch_dtype=torch.bfloat16, + use_cache=True, torchscript=True) + tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + elif repo_id in LLAMA_IDS: + model = AutoModelForCausalLM.from_pretrained(model_path, load_in_low_bit='bf16', trust_remote_code=True, torch_dtype=torch.bfloat16, + use_cache=True, torchscript=True) + tokenizer = LlamaTokenizer.from_pretrained(model_path, trust_remote_code=True) + else: + model = AutoModelForCausalLM.from_pretrained(model_path, load_in_low_bit='bf16', trust_remote_code=True, torch_dtype=torch.bfloat16, + use_cache=True, torchscript=True) + tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + if not hasattr(model.config, "token_latency"): + model.config.token_latency = True + end = time.perf_counter() + load_time = end - st + print(">> loading of model costs {}s".format(load_time)) + + result = {} + with torch.inference_mode(), torch.autocast("cpu"): + for in_out in in_out_pairs: + in_out_len = in_out.split("-") + in_len = int(in_out_len[0]) + out_len = int(in_out_len[1]) + # As different tokenizer has different encodings, + # in_len.txt maybe shorter than we need, + # use much longer context to make sure input length + test_length = min(in_len*2, 8192) + while test_length not in [32, 256, 1024, 2048, 8192]: + test_length = test_length * 2 + input_str = open(f"prompt/{test_length}.txt", 'r').read() + # As different tokenizer has different encodings, + # slice the input_ids to ensure the prompt length is required length. + input_ids = tokenizer.encode(input_str, return_tensors="pt") + input_ids = input_ids[:, :in_len] + true_str = tokenizer.batch_decode(input_ids)[0] + input_list = [true_str] * batch_size + input_ids = tokenizer(input_list, return_tensors="pt").input_ids + actual_in_len = input_ids.shape[1] + result[in_out] = [] + for i in range(num_trials + warm_up): + st = time.perf_counter() + output_ids, total_list = model.generate(input_ids, do_sample=False, max_new_tokens=out_len, + num_beams=num_beams) + end = time.perf_counter() + print("model generate cost: " + str(end - st)) + output = tokenizer.batch_decode(output_ids) + print(output[0]) + actual_out_len = output_ids.shape[1] - actual_in_len + if i >= warm_up: + result[in_out].append([total_list[0], np.mean(total_list[1:]), 0, + actual_in_len, actual_out_len, load_time]) + return result + + +def run_bigdl_ipex_int4(repo_id, + local_model_hub, + in_out_pairs, + warm_up, + num_trials, + num_beams, + batch_size): + from bigdl.llm.transformers import AutoModel, AutoModelForCausalLM + from transformers import AutoTokenizer, LlamaTokenizer + + os.environ["BIGDL_OPT_IPEX"] = "true" + + model_path = get_model_path(repo_id, local_model_hub) + + st = time.perf_counter() + if repo_id in CHATGLM_IDS: + model = AutoModel.from_pretrained(model_path, load_in_low_bit='sym_int4', trust_remote_code=True, torch_dtype='auto', + use_cache=True, torchscript=True) + tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + elif repo_id in LLAMA_IDS: + model = AutoModelForCausalLM.from_pretrained(model_path, load_in_low_bit='sym_int4', trust_remote_code=True, torch_dtype='auto', + use_cache=True, torchscript=True) + tokenizer = LlamaTokenizer.from_pretrained(model_path, trust_remote_code=True) + else: + model = AutoModelForCausalLM.from_pretrained(model_path, load_in_low_bit='sym_int4', trust_remote_code=True, torch_dtype='auto', + use_cache=True, torchscript=True) + tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + if not hasattr(model.config, "token_latency"): + model.config.token_latency = True + end = time.perf_counter() + load_time = end - st + print(">> loading of model costs {}s".format(load_time)) + + result = {} + with torch.inference_mode(), torch.autocast("cpu"): + for in_out in in_out_pairs: + in_out_len = in_out.split("-") + in_len = int(in_out_len[0]) + out_len = int(in_out_len[1]) + # As different tokenizer has different encodings, + # in_len.txt maybe shorter than we need, + # use much longer context to make sure input length + test_length = min(in_len*2, 8192) + while test_length not in [32, 256, 1024, 2048, 8192]: + test_length = test_length * 2 + input_str = open(f"prompt/{test_length}.txt", 'r').read() + # As different tokenizer has different encodings, + # slice the input_ids to ensure the prompt length is required length. + input_ids = tokenizer.encode(input_str, return_tensors="pt") + input_ids = input_ids[:, :in_len] + true_str = tokenizer.batch_decode(input_ids)[0] + input_list = [true_str] * batch_size + input_ids = tokenizer(input_list, return_tensors="pt").input_ids + actual_in_len = input_ids.shape[1] + result[in_out] = [] + for i in range(num_trials + warm_up): + st = time.perf_counter() + output_ids, total_list = model.generate(input_ids, do_sample=False, max_new_tokens=out_len, + num_beams=num_beams) + end = time.perf_counter() + print("model generate cost: " + str(end - st)) + output = tokenizer.batch_decode(output_ids) + print(output[0]) + actual_out_len = output_ids.shape[1] - actual_in_len + if i >= warm_up: + result[in_out].append([total_list[0], np.mean(total_list[1:]), 0, + actual_in_len, actual_out_len, load_time]) + return result + + def run_deepspeed_optimize_model_gpu(repo_id, local_model_hub, in_out_pairs, @@ -1192,6 +1338,7 @@ def run_deepspeed_optimize_model_gpu(repo_id, torch.xpu.empty_cache() return result + if __name__ == '__main__': from omegaconf import OmegaConf conf = OmegaConf.load(f'{current_dir}/config.yaml') diff --git a/python/llm/src/bigdl/llm/transformers/convert.py b/python/llm/src/bigdl/llm/transformers/convert.py index a64ea4a9..6470e6bb 100644 --- a/python/llm/src/bigdl/llm/transformers/convert.py +++ b/python/llm/src/bigdl/llm/transformers/convert.py @@ -611,13 +611,12 @@ def ggml_convert_low_bit(model, qtype, optimize_model=True, modules_to_not_convert = [] if modules_to_not_convert is None else modules_to_not_convert # using ipex optimizer before changing to bigdl linear - _enable_ipex = os.getenv("BIGDL_OPT_IPEX") - _enable_ipex = (_enable_ipex is not None) and (_enable_ipex.lower() == "true") - _enable_ipex = _enable_ipex and (qtype == ggml_tensor_qtype["bf16"]) - if (device == "cpu") and (qtype == ggml_tensor_qtype["bf16"]): + _enable_ipex = get_enable_ipex() + + if device == "cpu": logger.info(f"BIGDL_OPT_IPEX: {_enable_ipex}") if _enable_ipex: - model = _optimize_ipex(model) + model = _optimize_ipex(model, qtype) return model if optimize_model: @@ -686,12 +685,19 @@ def replace_func(m, target_m, func_name, new_func): replace_func(sub_m, target_m, func_name, new_func) -def _optimize_ipex(model): +def get_enable_ipex(): + _enable_ipex = os.getenv("BIGDL_OPT_IPEX") + _enable_ipex = (_enable_ipex is not None) and (_enable_ipex.lower() == "true") + return _enable_ipex + + +def _optimize_ipex(model, qtype=ggml_tensor_qtype["bf16"]): + import intel_extension_for_pytorch as ipex from intel_extension_for_pytorch.transformers.optimize import model_convert_reference from transformers.modeling_attn_mask_utils import AttentionMaskConverter from bigdl.llm.transformers.convert_ipex import ( _ipex_optimize_model, _ipex_jit, _make_causal_mask, - _llama_model_forward_4_35, convert_function, GLM_get_masks + _llama_model_forward_4_35, convert_function, GLM_get_masks, ) model = model_convert_reference(model) @@ -718,7 +724,7 @@ def _optimize_ipex(model): # baichuan2 rms_classes.append(type(model.model.layers[0].input_layernorm)) - _ipex_optimize_model(model, rms_classes) + model = _ipex_optimize_model(model, rms_classes, qtype) return _ipex_jit(model) diff --git a/python/llm/src/bigdl/llm/transformers/convert_ipex.py b/python/llm/src/bigdl/llm/transformers/convert_ipex.py index 7b2ef0c4..5e5421d1 100644 --- a/python/llm/src/bigdl/llm/transformers/convert_ipex.py +++ b/python/llm/src/bigdl/llm/transformers/convert_ipex.py @@ -44,10 +44,14 @@ from intel_extension_for_pytorch.transformers.optimize import ( from intel_extension_for_pytorch.cpu._auto_kernel_selection import ( _enable_tpp, _using_tpp, + _disable_tpp ) +from bigdl.llm.ggml.quantize import ggml_tensor_qtype +from bigdl.llm.transformers.convert import get_enable_ipex +import os -def _ipex_optimize_rmsnorm(_model, supported_classes): +def _ipex_optimize_rmsnorm(_model, supported_classes, is_tpp=False, is_woq=False): from intel_extension_for_pytorch.transformers.models.cpu.fusions.mha_fusion import _IPEXRMSNorm for supported_class in supported_classes: lowering_class_cpu( @@ -55,12 +59,12 @@ def _ipex_optimize_rmsnorm(_model, supported_classes): supported_class, _IPEXRMSNorm, _model.config, - tpp=False, - woq=False, + tpp=is_tpp, + woq=is_woq, ) -def _ipex_optimize_decoder(model): +def _ipex_optimize_decoder(model, is_tpp=False, is_woq=False): from intel_extension_for_pytorch.transformers.models.reference.modules.decoder import ( _IPEXDecoderLayerRef ) @@ -73,12 +77,12 @@ def _ipex_optimize_decoder(model): supported_mlp_class, _IPEXDecoderLayerCPU, model.config, - tpp=True if _using_tpp() else False, - woq=False, + tpp=is_tpp, + woq=is_woq, ) -def _ipex_optimize_attention(model): +def _ipex_optimize_attention(model, is_tpp=False, is_woq=False): from intel_extension_for_pytorch.transformers.models.reference.modules.attentions import ( _IPEXAttentionRef ) @@ -91,18 +95,47 @@ def _ipex_optimize_attention(model): supported_mha_class, _IPEXAttentionCPU, model.config, - tpp=True if _using_tpp() else False, - woq=False, + tpp=is_tpp, + woq=is_woq, ) -def _ipex_optimize_model(model, rms_classes): - _enable_tpp() +def _ipex_optimize_model(model, rms_classes, qtype): import intel_extension_for_pytorch as ipex - ipex.optimize(model.eval(), dtype=torch.bfloat16, inplace=True).eval() - _ipex_optimize_rmsnorm(model, rms_classes) - _ipex_optimize_attention(model) - _ipex_optimize_decoder(model) + from intel_extension_for_pytorch.transformers.models.reference.models import output_hook + from intel_extension_for_pytorch.transformers.optimize import ipex_quantization_flow + + is_woq = False + is_quantization = False + _disable_tpp() + if qtype == ggml_tensor_qtype["bf16"]: + _enable_tpp() + model = ipex.optimize(model.eval(), dtype=torch.bfloat16, inplace=True).eval() + elif qtype == ggml_tensor_qtype["sym_int4"]: + is_quantization = True + is_woq = True + act_quant_mode_dict = { + "PER_TENSOR": ipex.quantization.WoqActQuantMode.PER_TENSOR, + "PER_IC_BLOCK": ipex.quantization.WoqActQuantMode.PER_IC_BLOCK, + "PER_BATCH": ipex.quantization.WoqActQuantMode.PER_BATCH, + "PER_BATCH_IC_BLOCK": ipex.quantization.WoqActQuantMode.PER_BATCH_IC_BLOCK, + } + qconfig = ipex.quantization.get_weight_only_quant_qconfig_mapping( + weight_dtype=torch.quint4x2, # INT4 + lowp_mode=ipex.quantization.WoqLowpMode.INT8, + act_quant_mode=act_quant_mode_dict["PER_IC_BLOCK"], + group_size=-1, + ) + model = ipex_quantization_flow(model, torch.bfloat16, None, qconfig, None) + + is_tpp = _using_tpp() + + _ipex_optimize_rmsnorm(model, rms_classes, is_tpp=is_tpp, is_woq=is_woq) + _ipex_optimize_attention(model, is_tpp=is_tpp, is_woq=is_woq) + _ipex_optimize_decoder(model, is_tpp=is_tpp, is_woq=is_woq) + + model.register_forward_hook(output_hook, with_kwargs=True) + return model def _ipex_jit(model): @@ -152,9 +185,7 @@ def GLM_get_masks(self, input_ids, past_key_values, padding_mask=None): else: # discrete kv cache past_length = past_key_values[0][0].shape[-2] - import os - _enable_ipex = os.getenv("BIGDL_OPT_IPEX") - _enable_ipex = (_enable_ipex is not None) and (_enable_ipex.lower() == "true") + _enable_ipex = get_enable_ipex() # always call for jit if past_length or _enable_ipex: full_attention_mask = torch.cat( @@ -191,9 +222,7 @@ def _make_causal_mask( mask_cond = torch.arange(mask.size(-1), device=device) mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) - import os - _enable_ipex = os.getenv("BIGDL_OPT_IPEX") - _enable_ipex = (_enable_ipex is not None) and (_enable_ipex.lower() == "true") + _enable_ipex = get_enable_ipex() if _enable_ipex or past_key_values_length > 0: mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) # noqa diff --git a/python/llm/src/bigdl/llm/transformers/speculative.py b/python/llm/src/bigdl/llm/transformers/speculative.py index 1e3f0029..d37833bb 100644 --- a/python/llm/src/bigdl/llm/transformers/speculative.py +++ b/python/llm/src/bigdl/llm/transformers/speculative.py @@ -30,6 +30,7 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Un from transformers import top_k_top_p_filtering, GenerationConfig, \ LogitsProcessorList, StoppingCriteriaList from bigdl.llm.utils.common import invalidInputError +from transformers.modeling_outputs import CausalLMOutputWithPast # patch GenerationMixin.generate from transformers import GenerationMixin @@ -533,15 +534,16 @@ def speculative_generate(self, past_key_values = None past_key_values_storage = [] - _enable_ipex = os.getenv("BIGDL_OPT_IPEX") - _enable_ipex = (_enable_ipex is not None) and (_enable_ipex.lower() == "true") + from bigdl.llm.transformers.convert import get_enable_ipex + _enable_ipex = get_enable_ipex() + if _enable_ipex: if not ((self.config.model_type == 'baichuan') or ('llama' in self.config.model_type) or ("mistral" in self.config.model_type) or ("qwen" in self.config.model_type) or ("chatglm" in self.config.model_type)): - invalidInputError(False, "BigDL Speculative Decoding with IPEX BF16 only supports \ + invalidInputError(False, "BigDL Speculative Decoding with IPEX only supports \ Llama, Baichuan2, Mistral, ChatGLM and Qwen models currently.") if "chatglm" in self.config.model_type: global query_group_size @@ -579,6 +581,11 @@ def speculative_generate(self, attention_mask=attention_mask, return_dict=True, use_cache=True) + if _enable_ipex: + output = CausalLMOutputWithPast( + logits=output[0], + past_key_values=output[1], + ) logits = output['logits'] logits = logits[:, -1:] logits[:, -1, :] = logits_processor(current_input_ids, logits[:, -1, :]) @@ -602,7 +609,7 @@ def speculative_generate(self, draft_current_input_ids = current_input_ids # Target model KV cache to draft model - if self.device.type == 'cpu': + if self.device.type == 'cpu' and (not _enable_ipex): # init past_key_values_storage and assign initial fp32 value if step == 1: past_key_values_storage = \ @@ -652,7 +659,57 @@ def speculative_generate(self, past_length = draft_past_key_values[0][0].size(2) position_ids = torch.Tensor([[past_length]]).long().to(self.device) forward_args["position_ids"] = position_ids - draft_output = draft_model(**forward_args) + + if _enable_ipex: + if any(keyword in self.config.model_type + for keyword in ["llama", "chatglm", "mistral"]): + past_key_value_len = draft_past_key_values[0][0].shape[2] + position_ids = torch.Tensor([[past_key_value_len + step_draft]]).long() + position_ids = position_ids[:, :-draft_current_input_ids.size(0)] + draft_output = draft_model.trace_graph( + input_ids=draft_current_input_ids, + attention_mask=draft_attention_mask, + position_ids=position_ids, + past_key_values=draft_past_key_values, + ) + elif self.config.model_type == "baichuan": + if self.config.hidden_size == 4096: + past_key_value_len = draft_past_key_values[0][0].shape[2] + seq_len = draft_current_input_ids.shape[1] + seq_len_with_past = seq_len + past_key_value_len + position_ids = torch.arange(past_key_value_len, + seq_len_with_past, + dtype=torch.long, + device=draft_current_input_ids.device) + position_ids = position_ids.unsqueeze(0).view(-1, seq_len) + draft_output = draft_model.trace_graph( + input_ids=draft_current_input_ids, + attention_mask=draft_attention_mask, + position_ids=position_ids, + past_key_values=draft_past_key_values, + ) + elif self.config.hidden_size == 5120: + draft_output = draft_model.trace_graph( + input_ids=draft_current_input_ids, + attention_mask=draft_attention_mask, + past_key_values=draft_past_key_values, + ) + elif "qwen" in self.config.model_type: + draft_output = draft_model.trace_graph( + input_ids=draft_current_input_ids, + attention_mask=draft_attention_mask, + past_key_values=draft_past_key_values, + ) + else: + invalidInputError(False, "BigDL Speculative Decoding with IPEX only supports \ + Llama, Baichuan2, Mistral, ChatGLM and Qwen models currently.") + + draft_output = CausalLMOutputWithPast( + logits=draft_output[0], + past_key_values=draft_output[1], + ) + else: + draft_output = draft_model(**forward_args) temp_input_ids = torch.cat((input_ids, generate_ids[:, :step], draft_generate_ids[:, 1:step_draft+1]), dim=-1) logits = draft_output['logits'] @@ -848,7 +905,6 @@ def speculative_generate(self, # Clean up target model KV cache if max_of_max_matched != max_matched: output_ids = output_ids[:, :max_matched] - # For Qwen if _enable_ipex: cur_len = past_key_values[0][0].size(1) delta = max_of_max_matched - max_matched @@ -890,7 +946,7 @@ def speculative_generate(self, ] # Each iter assign new_matched kv_cache to past_key_values1 - if self.device.type == 'cpu': + if self.device.type == 'cpu' and (not _enable_ipex): _update_past_key_values_storage_cpu(self, past_key_values, past_key_values_storage, original_draft_past_key_values, _enable_ipex)