LLM: Enable BigDL IPEX optimization for int4 (#10319)
Enable BigDL IPEX optimization for int4
This commit is contained in:
		
							parent
							
								
									5d7e044dbc
								
							
						
					
					
						commit
						0ded0b4b13
					
				
					 5 changed files with 276 additions and 36 deletions
				
			
		| 
						 | 
					@ -18,6 +18,8 @@ test_api:
 | 
				
			||||||
  - "optimize_model"
 | 
					  - "optimize_model"
 | 
				
			||||||
  - "pytorch_autocast_bf16"
 | 
					  - "pytorch_autocast_bf16"
 | 
				
			||||||
  # - "transformer_autocast_bf16"
 | 
					  # - "transformer_autocast_bf16"
 | 
				
			||||||
 | 
					  # - "bigdl_ipex_bf16"
 | 
				
			||||||
 | 
					  # - "bigdl_ipex_int4"
 | 
				
			||||||
  # - "ipex_fp16_gpu" # on Intel GPU
 | 
					  # - "ipex_fp16_gpu" # on Intel GPU
 | 
				
			||||||
  # - "bigdl_fp16_gpu" # on Intel GPU
 | 
					  # - "bigdl_fp16_gpu" # on Intel GPU
 | 
				
			||||||
  # - "transformer_int4_gpu"  # on Intel GPU
 | 
					  # - "transformer_int4_gpu"  # on Intel GPU
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -92,6 +92,10 @@ def run_model(repo_id, test_api, in_out_pairs, local_model_hub=None, warm_up=1,
 | 
				
			||||||
        result = run_transformer_int4_loadlowbit_gpu_win(repo_id, local_model_hub, in_out_pairs, warm_up, num_trials, num_beams, low_bit, cpu_embedding, batch_size, streaming)
 | 
					        result = run_transformer_int4_loadlowbit_gpu_win(repo_id, local_model_hub, in_out_pairs, warm_up, num_trials, num_beams, low_bit, cpu_embedding, batch_size, streaming)
 | 
				
			||||||
    elif test_api == 'transformer_autocast_bf16':
 | 
					    elif test_api == 'transformer_autocast_bf16':
 | 
				
			||||||
        result = run_transformer_autocast_bf16(repo_id, local_model_hub, in_out_pairs, warm_up, num_trials, num_beams, batch_size)
 | 
					        result = run_transformer_autocast_bf16(repo_id, local_model_hub, in_out_pairs, warm_up, num_trials, num_beams, batch_size)
 | 
				
			||||||
 | 
					    elif test_api == 'bigdl_ipex_bf16':
 | 
				
			||||||
 | 
					        result = run_bigdl_ipex_bf16(repo_id, local_model_hub, in_out_pairs, warm_up, num_trials, num_beams, batch_size)
 | 
				
			||||||
 | 
					    elif test_api == 'bigdl_ipex_int4':
 | 
				
			||||||
 | 
					        result = run_bigdl_ipex_int4(repo_id, local_model_hub, in_out_pairs, warm_up, num_trials, num_beams, batch_size)
 | 
				
			||||||
    elif test_api == 'deepspeed_optimize_model_gpu':
 | 
					    elif test_api == 'deepspeed_optimize_model_gpu':
 | 
				
			||||||
        result = run_deepspeed_optimize_model_gpu(repo_id, local_model_hub, in_out_pairs, warm_up, num_trials, num_beams, low_bit, batch_size)
 | 
					        result = run_deepspeed_optimize_model_gpu(repo_id, local_model_hub, in_out_pairs, warm_up, num_trials, num_beams, low_bit, batch_size)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1079,6 +1083,148 @@ def run_transformer_autocast_bf16( repo_id,
 | 
				
			||||||
                                          actual_in_len, actual_out_len, load_time])
 | 
					                                          actual_in_len, actual_out_len, load_time])
 | 
				
			||||||
    return result
 | 
					    return result
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def run_bigdl_ipex_bf16(repo_id,
 | 
				
			||||||
 | 
					                    local_model_hub,
 | 
				
			||||||
 | 
					                    in_out_pairs,
 | 
				
			||||||
 | 
					                    warm_up,
 | 
				
			||||||
 | 
					                    num_trials,
 | 
				
			||||||
 | 
					                    num_beams,
 | 
				
			||||||
 | 
					                    batch_size):
 | 
				
			||||||
 | 
					    from bigdl.llm.transformers import AutoModel, AutoModelForCausalLM
 | 
				
			||||||
 | 
					    from transformers import AutoTokenizer, LlamaTokenizer
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    os.environ["BIGDL_OPT_IPEX"] = "true"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    model_path = get_model_path(repo_id, local_model_hub)
 | 
				
			||||||
 | 
					    # Load model in bf16,
 | 
				
			||||||
 | 
					    # which convert the relevant layers in the model into BF16 format
 | 
				
			||||||
 | 
					    st = time.perf_counter()
 | 
				
			||||||
 | 
					    if repo_id in CHATGLM_IDS:
 | 
				
			||||||
 | 
					        model = AutoModel.from_pretrained(model_path, load_in_low_bit='bf16', trust_remote_code=True, torch_dtype=torch.bfloat16,
 | 
				
			||||||
 | 
					                                          use_cache=True, torchscript=True)
 | 
				
			||||||
 | 
					        tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
 | 
				
			||||||
 | 
					    elif repo_id in LLAMA_IDS:
 | 
				
			||||||
 | 
					        model = AutoModelForCausalLM.from_pretrained(model_path, load_in_low_bit='bf16', trust_remote_code=True, torch_dtype=torch.bfloat16,
 | 
				
			||||||
 | 
					                                                     use_cache=True, torchscript=True)
 | 
				
			||||||
 | 
					        tokenizer = LlamaTokenizer.from_pretrained(model_path, trust_remote_code=True)
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        model = AutoModelForCausalLM.from_pretrained(model_path, load_in_low_bit='bf16', trust_remote_code=True, torch_dtype=torch.bfloat16,
 | 
				
			||||||
 | 
					                                                     use_cache=True, torchscript=True)
 | 
				
			||||||
 | 
					        tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
 | 
				
			||||||
 | 
					    if not hasattr(model.config, "token_latency"):
 | 
				
			||||||
 | 
					        model.config.token_latency = True
 | 
				
			||||||
 | 
					    end = time.perf_counter()
 | 
				
			||||||
 | 
					    load_time = end - st
 | 
				
			||||||
 | 
					    print(">> loading of model costs {}s".format(load_time))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    result = {}
 | 
				
			||||||
 | 
					    with torch.inference_mode(), torch.autocast("cpu"):
 | 
				
			||||||
 | 
					        for in_out in in_out_pairs:
 | 
				
			||||||
 | 
					            in_out_len = in_out.split("-")
 | 
				
			||||||
 | 
					            in_len = int(in_out_len[0])
 | 
				
			||||||
 | 
					            out_len = int(in_out_len[1])
 | 
				
			||||||
 | 
					            # As different tokenizer has different encodings,
 | 
				
			||||||
 | 
					            # in_len.txt maybe shorter than we need,
 | 
				
			||||||
 | 
					            # use much longer context to make sure input length
 | 
				
			||||||
 | 
					            test_length = min(in_len*2, 8192)
 | 
				
			||||||
 | 
					            while test_length not in [32, 256, 1024, 2048, 8192]:
 | 
				
			||||||
 | 
					                test_length = test_length * 2
 | 
				
			||||||
 | 
					            input_str = open(f"prompt/{test_length}.txt", 'r').read()
 | 
				
			||||||
 | 
					            # As different tokenizer has different encodings,
 | 
				
			||||||
 | 
					            # slice the input_ids to ensure the prompt length is required length.
 | 
				
			||||||
 | 
					            input_ids = tokenizer.encode(input_str, return_tensors="pt")
 | 
				
			||||||
 | 
					            input_ids = input_ids[:, :in_len]
 | 
				
			||||||
 | 
					            true_str = tokenizer.batch_decode(input_ids)[0]
 | 
				
			||||||
 | 
					            input_list = [true_str] * batch_size
 | 
				
			||||||
 | 
					            input_ids = tokenizer(input_list, return_tensors="pt").input_ids
 | 
				
			||||||
 | 
					            actual_in_len = input_ids.shape[1]
 | 
				
			||||||
 | 
					            result[in_out] = []
 | 
				
			||||||
 | 
					            for i in range(num_trials + warm_up):
 | 
				
			||||||
 | 
					                st = time.perf_counter()
 | 
				
			||||||
 | 
					                output_ids, total_list = model.generate(input_ids, do_sample=False, max_new_tokens=out_len,
 | 
				
			||||||
 | 
					                                            num_beams=num_beams)
 | 
				
			||||||
 | 
					                end = time.perf_counter()
 | 
				
			||||||
 | 
					                print("model generate cost: " + str(end - st))
 | 
				
			||||||
 | 
					                output = tokenizer.batch_decode(output_ids)
 | 
				
			||||||
 | 
					                print(output[0])
 | 
				
			||||||
 | 
					                actual_out_len = output_ids.shape[1] - actual_in_len
 | 
				
			||||||
 | 
					                if i >= warm_up:
 | 
				
			||||||
 | 
					                    result[in_out].append([total_list[0], np.mean(total_list[1:]), 0,
 | 
				
			||||||
 | 
					                                          actual_in_len, actual_out_len, load_time])
 | 
				
			||||||
 | 
					    return result
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def run_bigdl_ipex_int4(repo_id,
 | 
				
			||||||
 | 
					                    local_model_hub,
 | 
				
			||||||
 | 
					                    in_out_pairs,
 | 
				
			||||||
 | 
					                    warm_up,
 | 
				
			||||||
 | 
					                    num_trials,
 | 
				
			||||||
 | 
					                    num_beams,
 | 
				
			||||||
 | 
					                    batch_size):
 | 
				
			||||||
 | 
					    from bigdl.llm.transformers import AutoModel, AutoModelForCausalLM
 | 
				
			||||||
 | 
					    from transformers import AutoTokenizer, LlamaTokenizer
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    os.environ["BIGDL_OPT_IPEX"] = "true"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    model_path = get_model_path(repo_id, local_model_hub)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    st = time.perf_counter()
 | 
				
			||||||
 | 
					    if repo_id in CHATGLM_IDS:
 | 
				
			||||||
 | 
					        model = AutoModel.from_pretrained(model_path, load_in_low_bit='sym_int4', trust_remote_code=True, torch_dtype='auto',
 | 
				
			||||||
 | 
					                                          use_cache=True, torchscript=True)
 | 
				
			||||||
 | 
					        tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
 | 
				
			||||||
 | 
					    elif repo_id in LLAMA_IDS:
 | 
				
			||||||
 | 
					        model = AutoModelForCausalLM.from_pretrained(model_path, load_in_low_bit='sym_int4', trust_remote_code=True, torch_dtype='auto',
 | 
				
			||||||
 | 
					                                                     use_cache=True, torchscript=True)
 | 
				
			||||||
 | 
					        tokenizer = LlamaTokenizer.from_pretrained(model_path, trust_remote_code=True)
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        model = AutoModelForCausalLM.from_pretrained(model_path, load_in_low_bit='sym_int4', trust_remote_code=True, torch_dtype='auto',
 | 
				
			||||||
 | 
					                                                     use_cache=True, torchscript=True)
 | 
				
			||||||
 | 
					        tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
 | 
				
			||||||
 | 
					    if not hasattr(model.config, "token_latency"):
 | 
				
			||||||
 | 
					        model.config.token_latency = True
 | 
				
			||||||
 | 
					    end = time.perf_counter()
 | 
				
			||||||
 | 
					    load_time = end - st
 | 
				
			||||||
 | 
					    print(">> loading of model costs {}s".format(load_time))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    result = {}
 | 
				
			||||||
 | 
					    with torch.inference_mode(), torch.autocast("cpu"):
 | 
				
			||||||
 | 
					        for in_out in in_out_pairs:
 | 
				
			||||||
 | 
					            in_out_len = in_out.split("-")
 | 
				
			||||||
 | 
					            in_len = int(in_out_len[0])
 | 
				
			||||||
 | 
					            out_len = int(in_out_len[1])
 | 
				
			||||||
 | 
					            # As different tokenizer has different encodings,
 | 
				
			||||||
 | 
					            # in_len.txt maybe shorter than we need,
 | 
				
			||||||
 | 
					            # use much longer context to make sure input length
 | 
				
			||||||
 | 
					            test_length = min(in_len*2, 8192)
 | 
				
			||||||
 | 
					            while test_length not in [32, 256, 1024, 2048, 8192]:
 | 
				
			||||||
 | 
					                test_length = test_length * 2
 | 
				
			||||||
 | 
					            input_str = open(f"prompt/{test_length}.txt", 'r').read()
 | 
				
			||||||
 | 
					            # As different tokenizer has different encodings,
 | 
				
			||||||
 | 
					            # slice the input_ids to ensure the prompt length is required length.
 | 
				
			||||||
 | 
					            input_ids = tokenizer.encode(input_str, return_tensors="pt")
 | 
				
			||||||
 | 
					            input_ids = input_ids[:, :in_len]
 | 
				
			||||||
 | 
					            true_str = tokenizer.batch_decode(input_ids)[0]
 | 
				
			||||||
 | 
					            input_list = [true_str] * batch_size
 | 
				
			||||||
 | 
					            input_ids = tokenizer(input_list, return_tensors="pt").input_ids
 | 
				
			||||||
 | 
					            actual_in_len = input_ids.shape[1]
 | 
				
			||||||
 | 
					            result[in_out] = []
 | 
				
			||||||
 | 
					            for i in range(num_trials + warm_up):
 | 
				
			||||||
 | 
					                st = time.perf_counter()
 | 
				
			||||||
 | 
					                output_ids, total_list = model.generate(input_ids, do_sample=False, max_new_tokens=out_len,
 | 
				
			||||||
 | 
					                                            num_beams=num_beams)
 | 
				
			||||||
 | 
					                end = time.perf_counter()
 | 
				
			||||||
 | 
					                print("model generate cost: " + str(end - st))
 | 
				
			||||||
 | 
					                output = tokenizer.batch_decode(output_ids)
 | 
				
			||||||
 | 
					                print(output[0])
 | 
				
			||||||
 | 
					                actual_out_len = output_ids.shape[1] - actual_in_len
 | 
				
			||||||
 | 
					                if i >= warm_up:
 | 
				
			||||||
 | 
					                    result[in_out].append([total_list[0], np.mean(total_list[1:]), 0,
 | 
				
			||||||
 | 
					                                          actual_in_len, actual_out_len, load_time])
 | 
				
			||||||
 | 
					    return result
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def run_deepspeed_optimize_model_gpu(repo_id,
 | 
					def run_deepspeed_optimize_model_gpu(repo_id,
 | 
				
			||||||
                                     local_model_hub,
 | 
					                                     local_model_hub,
 | 
				
			||||||
                                     in_out_pairs,
 | 
					                                     in_out_pairs,
 | 
				
			||||||
| 
						 | 
					@ -1192,6 +1338,7 @@ def run_deepspeed_optimize_model_gpu(repo_id,
 | 
				
			||||||
    torch.xpu.empty_cache()
 | 
					    torch.xpu.empty_cache()
 | 
				
			||||||
    return result
 | 
					    return result
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
if __name__ == '__main__':
 | 
					if __name__ == '__main__':
 | 
				
			||||||
    from omegaconf import OmegaConf
 | 
					    from omegaconf import OmegaConf
 | 
				
			||||||
    conf = OmegaConf.load(f'{current_dir}/config.yaml')
 | 
					    conf = OmegaConf.load(f'{current_dir}/config.yaml')
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -611,13 +611,12 @@ def ggml_convert_low_bit(model, qtype, optimize_model=True,
 | 
				
			||||||
    modules_to_not_convert = [] if modules_to_not_convert is None else modules_to_not_convert
 | 
					    modules_to_not_convert = [] if modules_to_not_convert is None else modules_to_not_convert
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # using ipex optimizer before changing to bigdl linear
 | 
					    # using ipex optimizer before changing to bigdl linear
 | 
				
			||||||
    _enable_ipex = os.getenv("BIGDL_OPT_IPEX")
 | 
					    _enable_ipex = get_enable_ipex()
 | 
				
			||||||
    _enable_ipex = (_enable_ipex is not None) and (_enable_ipex.lower() == "true")
 | 
					
 | 
				
			||||||
    _enable_ipex = _enable_ipex and (qtype == ggml_tensor_qtype["bf16"])
 | 
					    if device == "cpu":
 | 
				
			||||||
    if (device == "cpu") and (qtype == ggml_tensor_qtype["bf16"]):
 | 
					 | 
				
			||||||
        logger.info(f"BIGDL_OPT_IPEX: {_enable_ipex}")
 | 
					        logger.info(f"BIGDL_OPT_IPEX: {_enable_ipex}")
 | 
				
			||||||
    if _enable_ipex:
 | 
					    if _enable_ipex:
 | 
				
			||||||
        model = _optimize_ipex(model)
 | 
					        model = _optimize_ipex(model, qtype)
 | 
				
			||||||
        return model
 | 
					        return model
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if optimize_model:
 | 
					    if optimize_model:
 | 
				
			||||||
| 
						 | 
					@ -686,12 +685,19 @@ def replace_func(m, target_m, func_name, new_func):
 | 
				
			||||||
        replace_func(sub_m, target_m, func_name, new_func)
 | 
					        replace_func(sub_m, target_m, func_name, new_func)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def _optimize_ipex(model):
 | 
					def get_enable_ipex():
 | 
				
			||||||
 | 
					    _enable_ipex = os.getenv("BIGDL_OPT_IPEX")
 | 
				
			||||||
 | 
					    _enable_ipex = (_enable_ipex is not None) and (_enable_ipex.lower() == "true")
 | 
				
			||||||
 | 
					    return _enable_ipex
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def _optimize_ipex(model, qtype=ggml_tensor_qtype["bf16"]):
 | 
				
			||||||
 | 
					    import intel_extension_for_pytorch as ipex
 | 
				
			||||||
    from intel_extension_for_pytorch.transformers.optimize import model_convert_reference
 | 
					    from intel_extension_for_pytorch.transformers.optimize import model_convert_reference
 | 
				
			||||||
    from transformers.modeling_attn_mask_utils import AttentionMaskConverter
 | 
					    from transformers.modeling_attn_mask_utils import AttentionMaskConverter
 | 
				
			||||||
    from bigdl.llm.transformers.convert_ipex import (
 | 
					    from bigdl.llm.transformers.convert_ipex import (
 | 
				
			||||||
        _ipex_optimize_model, _ipex_jit, _make_causal_mask,
 | 
					        _ipex_optimize_model, _ipex_jit, _make_causal_mask,
 | 
				
			||||||
        _llama_model_forward_4_35, convert_function, GLM_get_masks
 | 
					        _llama_model_forward_4_35, convert_function, GLM_get_masks,
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    model = model_convert_reference(model)
 | 
					    model = model_convert_reference(model)
 | 
				
			||||||
| 
						 | 
					@ -718,7 +724,7 @@ def _optimize_ipex(model):
 | 
				
			||||||
        # baichuan2
 | 
					        # baichuan2
 | 
				
			||||||
        rms_classes.append(type(model.model.layers[0].input_layernorm))
 | 
					        rms_classes.append(type(model.model.layers[0].input_layernorm))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    _ipex_optimize_model(model, rms_classes)
 | 
					    model = _ipex_optimize_model(model, rms_classes, qtype)
 | 
				
			||||||
    return _ipex_jit(model)
 | 
					    return _ipex_jit(model)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -44,10 +44,14 @@ from intel_extension_for_pytorch.transformers.optimize import (
 | 
				
			||||||
from intel_extension_for_pytorch.cpu._auto_kernel_selection import (
 | 
					from intel_extension_for_pytorch.cpu._auto_kernel_selection import (
 | 
				
			||||||
    _enable_tpp,
 | 
					    _enable_tpp,
 | 
				
			||||||
    _using_tpp,
 | 
					    _using_tpp,
 | 
				
			||||||
 | 
					    _disable_tpp
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					from bigdl.llm.ggml.quantize import ggml_tensor_qtype
 | 
				
			||||||
 | 
					from bigdl.llm.transformers.convert import get_enable_ipex
 | 
				
			||||||
 | 
					import os
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def _ipex_optimize_rmsnorm(_model, supported_classes):
 | 
					def _ipex_optimize_rmsnorm(_model, supported_classes, is_tpp=False, is_woq=False):
 | 
				
			||||||
    from intel_extension_for_pytorch.transformers.models.cpu.fusions.mha_fusion import _IPEXRMSNorm
 | 
					    from intel_extension_for_pytorch.transformers.models.cpu.fusions.mha_fusion import _IPEXRMSNorm
 | 
				
			||||||
    for supported_class in supported_classes:
 | 
					    for supported_class in supported_classes:
 | 
				
			||||||
        lowering_class_cpu(
 | 
					        lowering_class_cpu(
 | 
				
			||||||
| 
						 | 
					@ -55,12 +59,12 @@ def _ipex_optimize_rmsnorm(_model, supported_classes):
 | 
				
			||||||
            supported_class,
 | 
					            supported_class,
 | 
				
			||||||
            _IPEXRMSNorm,
 | 
					            _IPEXRMSNorm,
 | 
				
			||||||
            _model.config,
 | 
					            _model.config,
 | 
				
			||||||
            tpp=False,
 | 
					            tpp=is_tpp,
 | 
				
			||||||
            woq=False,
 | 
					            woq=is_woq,
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def _ipex_optimize_decoder(model):
 | 
					def _ipex_optimize_decoder(model, is_tpp=False, is_woq=False):
 | 
				
			||||||
    from intel_extension_for_pytorch.transformers.models.reference.modules.decoder import (
 | 
					    from intel_extension_for_pytorch.transformers.models.reference.modules.decoder import (
 | 
				
			||||||
        _IPEXDecoderLayerRef
 | 
					        _IPEXDecoderLayerRef
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
| 
						 | 
					@ -73,12 +77,12 @@ def _ipex_optimize_decoder(model):
 | 
				
			||||||
            supported_mlp_class,
 | 
					            supported_mlp_class,
 | 
				
			||||||
            _IPEXDecoderLayerCPU,
 | 
					            _IPEXDecoderLayerCPU,
 | 
				
			||||||
            model.config,
 | 
					            model.config,
 | 
				
			||||||
            tpp=True if _using_tpp() else False,
 | 
					            tpp=is_tpp,
 | 
				
			||||||
            woq=False,
 | 
					            woq=is_woq,
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def _ipex_optimize_attention(model):
 | 
					def _ipex_optimize_attention(model, is_tpp=False, is_woq=False):
 | 
				
			||||||
    from intel_extension_for_pytorch.transformers.models.reference.modules.attentions import (
 | 
					    from intel_extension_for_pytorch.transformers.models.reference.modules.attentions import (
 | 
				
			||||||
        _IPEXAttentionRef
 | 
					        _IPEXAttentionRef
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
| 
						 | 
					@ -91,18 +95,47 @@ def _ipex_optimize_attention(model):
 | 
				
			||||||
            supported_mha_class,
 | 
					            supported_mha_class,
 | 
				
			||||||
            _IPEXAttentionCPU,
 | 
					            _IPEXAttentionCPU,
 | 
				
			||||||
            model.config,
 | 
					            model.config,
 | 
				
			||||||
            tpp=True if _using_tpp() else False,
 | 
					            tpp=is_tpp,
 | 
				
			||||||
            woq=False,
 | 
					            woq=is_woq,
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def _ipex_optimize_model(model, rms_classes):
 | 
					def _ipex_optimize_model(model, rms_classes, qtype):
 | 
				
			||||||
    _enable_tpp()
 | 
					 | 
				
			||||||
    import intel_extension_for_pytorch as ipex
 | 
					    import intel_extension_for_pytorch as ipex
 | 
				
			||||||
    ipex.optimize(model.eval(), dtype=torch.bfloat16, inplace=True).eval()
 | 
					    from intel_extension_for_pytorch.transformers.models.reference.models import output_hook
 | 
				
			||||||
    _ipex_optimize_rmsnorm(model, rms_classes)
 | 
					    from intel_extension_for_pytorch.transformers.optimize import ipex_quantization_flow
 | 
				
			||||||
    _ipex_optimize_attention(model)
 | 
					
 | 
				
			||||||
    _ipex_optimize_decoder(model)
 | 
					    is_woq = False
 | 
				
			||||||
 | 
					    is_quantization = False
 | 
				
			||||||
 | 
					    _disable_tpp()
 | 
				
			||||||
 | 
					    if qtype == ggml_tensor_qtype["bf16"]:
 | 
				
			||||||
 | 
					        _enable_tpp()
 | 
				
			||||||
 | 
					        model = ipex.optimize(model.eval(), dtype=torch.bfloat16, inplace=True).eval()
 | 
				
			||||||
 | 
					    elif qtype == ggml_tensor_qtype["sym_int4"]:
 | 
				
			||||||
 | 
					        is_quantization = True
 | 
				
			||||||
 | 
					        is_woq = True
 | 
				
			||||||
 | 
					        act_quant_mode_dict = {
 | 
				
			||||||
 | 
					            "PER_TENSOR": ipex.quantization.WoqActQuantMode.PER_TENSOR,
 | 
				
			||||||
 | 
					            "PER_IC_BLOCK": ipex.quantization.WoqActQuantMode.PER_IC_BLOCK,
 | 
				
			||||||
 | 
					            "PER_BATCH": ipex.quantization.WoqActQuantMode.PER_BATCH,
 | 
				
			||||||
 | 
					            "PER_BATCH_IC_BLOCK": ipex.quantization.WoqActQuantMode.PER_BATCH_IC_BLOCK,
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					        qconfig = ipex.quantization.get_weight_only_quant_qconfig_mapping(
 | 
				
			||||||
 | 
					            weight_dtype=torch.quint4x2,  # INT4
 | 
				
			||||||
 | 
					            lowp_mode=ipex.quantization.WoqLowpMode.INT8,
 | 
				
			||||||
 | 
					            act_quant_mode=act_quant_mode_dict["PER_IC_BLOCK"],
 | 
				
			||||||
 | 
					            group_size=-1,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        model = ipex_quantization_flow(model, torch.bfloat16, None, qconfig, None)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    is_tpp = _using_tpp()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    _ipex_optimize_rmsnorm(model, rms_classes, is_tpp=is_tpp, is_woq=is_woq)
 | 
				
			||||||
 | 
					    _ipex_optimize_attention(model, is_tpp=is_tpp, is_woq=is_woq)
 | 
				
			||||||
 | 
					    _ipex_optimize_decoder(model, is_tpp=is_tpp, is_woq=is_woq)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    model.register_forward_hook(output_hook, with_kwargs=True)
 | 
				
			||||||
 | 
					    return model
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def _ipex_jit(model):
 | 
					def _ipex_jit(model):
 | 
				
			||||||
| 
						 | 
					@ -152,9 +185,7 @@ def GLM_get_masks(self, input_ids, past_key_values, padding_mask=None):
 | 
				
			||||||
        else:  # discrete kv cache
 | 
					        else:  # discrete kv cache
 | 
				
			||||||
            past_length = past_key_values[0][0].shape[-2]
 | 
					            past_length = past_key_values[0][0].shape[-2]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    import os
 | 
					    _enable_ipex = get_enable_ipex()
 | 
				
			||||||
    _enable_ipex = os.getenv("BIGDL_OPT_IPEX")
 | 
					 | 
				
			||||||
    _enable_ipex = (_enable_ipex is not None) and (_enable_ipex.lower() == "true")
 | 
					 | 
				
			||||||
    # always call for jit
 | 
					    # always call for jit
 | 
				
			||||||
    if past_length or _enable_ipex:
 | 
					    if past_length or _enable_ipex:
 | 
				
			||||||
        full_attention_mask = torch.cat(
 | 
					        full_attention_mask = torch.cat(
 | 
				
			||||||
| 
						 | 
					@ -191,9 +222,7 @@ def _make_causal_mask(
 | 
				
			||||||
    mask_cond = torch.arange(mask.size(-1), device=device)
 | 
					    mask_cond = torch.arange(mask.size(-1), device=device)
 | 
				
			||||||
    mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
 | 
					    mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    import os
 | 
					    _enable_ipex = get_enable_ipex()
 | 
				
			||||||
    _enable_ipex = os.getenv("BIGDL_OPT_IPEX")
 | 
					 | 
				
			||||||
    _enable_ipex = (_enable_ipex is not None) and (_enable_ipex.lower() == "true")
 | 
					 | 
				
			||||||
    if _enable_ipex or past_key_values_length > 0:
 | 
					    if _enable_ipex or past_key_values_length > 0:
 | 
				
			||||||
        mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)  # noqa
 | 
					        mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)  # noqa
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -30,6 +30,7 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Un
 | 
				
			||||||
from transformers import top_k_top_p_filtering, GenerationConfig, \
 | 
					from transformers import top_k_top_p_filtering, GenerationConfig, \
 | 
				
			||||||
    LogitsProcessorList, StoppingCriteriaList
 | 
					    LogitsProcessorList, StoppingCriteriaList
 | 
				
			||||||
from bigdl.llm.utils.common import invalidInputError
 | 
					from bigdl.llm.utils.common import invalidInputError
 | 
				
			||||||
 | 
					from transformers.modeling_outputs import CausalLMOutputWithPast
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# patch GenerationMixin.generate
 | 
					# patch GenerationMixin.generate
 | 
				
			||||||
from transformers import GenerationMixin
 | 
					from transformers import GenerationMixin
 | 
				
			||||||
| 
						 | 
					@ -533,15 +534,16 @@ def speculative_generate(self,
 | 
				
			||||||
    past_key_values = None
 | 
					    past_key_values = None
 | 
				
			||||||
    past_key_values_storage = []
 | 
					    past_key_values_storage = []
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    _enable_ipex = os.getenv("BIGDL_OPT_IPEX")
 | 
					    from bigdl.llm.transformers.convert import get_enable_ipex
 | 
				
			||||||
    _enable_ipex = (_enable_ipex is not None) and (_enable_ipex.lower() == "true")
 | 
					    _enable_ipex = get_enable_ipex()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if _enable_ipex:
 | 
					    if _enable_ipex:
 | 
				
			||||||
        if not ((self.config.model_type == 'baichuan') or
 | 
					        if not ((self.config.model_type == 'baichuan') or
 | 
				
			||||||
                ('llama' in self.config.model_type) or
 | 
					                ('llama' in self.config.model_type) or
 | 
				
			||||||
                ("mistral" in self.config.model_type) or
 | 
					                ("mistral" in self.config.model_type) or
 | 
				
			||||||
                ("qwen" in self.config.model_type) or
 | 
					                ("qwen" in self.config.model_type) or
 | 
				
			||||||
                ("chatglm" in self.config.model_type)):
 | 
					                ("chatglm" in self.config.model_type)):
 | 
				
			||||||
            invalidInputError(False, "BigDL Speculative Decoding with IPEX BF16 only supports \
 | 
					            invalidInputError(False, "BigDL Speculative Decoding with IPEX only supports \
 | 
				
			||||||
                              Llama, Baichuan2, Mistral, ChatGLM and Qwen models currently.")
 | 
					                              Llama, Baichuan2, Mistral, ChatGLM and Qwen models currently.")
 | 
				
			||||||
        if "chatglm" in self.config.model_type:
 | 
					        if "chatglm" in self.config.model_type:
 | 
				
			||||||
            global query_group_size
 | 
					            global query_group_size
 | 
				
			||||||
| 
						 | 
					@ -579,6 +581,11 @@ def speculative_generate(self,
 | 
				
			||||||
                          attention_mask=attention_mask,
 | 
					                          attention_mask=attention_mask,
 | 
				
			||||||
                          return_dict=True,
 | 
					                          return_dict=True,
 | 
				
			||||||
                          use_cache=True)
 | 
					                          use_cache=True)
 | 
				
			||||||
 | 
					            if _enable_ipex:
 | 
				
			||||||
 | 
					                output = CausalLMOutputWithPast(
 | 
				
			||||||
 | 
					                    logits=output[0],
 | 
				
			||||||
 | 
					                    past_key_values=output[1],
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
            logits = output['logits']
 | 
					            logits = output['logits']
 | 
				
			||||||
            logits = logits[:, -1:]
 | 
					            logits = logits[:, -1:]
 | 
				
			||||||
            logits[:, -1, :] = logits_processor(current_input_ids, logits[:, -1, :])
 | 
					            logits[:, -1, :] = logits_processor(current_input_ids, logits[:, -1, :])
 | 
				
			||||||
| 
						 | 
					@ -602,7 +609,7 @@ def speculative_generate(self,
 | 
				
			||||||
            draft_current_input_ids = current_input_ids
 | 
					            draft_current_input_ids = current_input_ids
 | 
				
			||||||
            # Target model KV cache to draft model
 | 
					            # Target model KV cache to draft model
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            if self.device.type == 'cpu':
 | 
					            if self.device.type == 'cpu' and (not _enable_ipex):
 | 
				
			||||||
                # init past_key_values_storage and assign initial fp32 value
 | 
					                # init past_key_values_storage and assign initial fp32 value
 | 
				
			||||||
                if step == 1:
 | 
					                if step == 1:
 | 
				
			||||||
                    past_key_values_storage = \
 | 
					                    past_key_values_storage = \
 | 
				
			||||||
| 
						 | 
					@ -652,7 +659,57 @@ def speculative_generate(self,
 | 
				
			||||||
                    past_length = draft_past_key_values[0][0].size(2)
 | 
					                    past_length = draft_past_key_values[0][0].size(2)
 | 
				
			||||||
                    position_ids = torch.Tensor([[past_length]]).long().to(self.device)
 | 
					                    position_ids = torch.Tensor([[past_length]]).long().to(self.device)
 | 
				
			||||||
                    forward_args["position_ids"] = position_ids
 | 
					                    forward_args["position_ids"] = position_ids
 | 
				
			||||||
                draft_output = draft_model(**forward_args)
 | 
					
 | 
				
			||||||
 | 
					                if _enable_ipex:
 | 
				
			||||||
 | 
					                    if any(keyword in self.config.model_type
 | 
				
			||||||
 | 
					                            for keyword in ["llama", "chatglm", "mistral"]):
 | 
				
			||||||
 | 
					                        past_key_value_len = draft_past_key_values[0][0].shape[2]
 | 
				
			||||||
 | 
					                        position_ids = torch.Tensor([[past_key_value_len + step_draft]]).long()
 | 
				
			||||||
 | 
					                        position_ids = position_ids[:, :-draft_current_input_ids.size(0)]
 | 
				
			||||||
 | 
					                        draft_output = draft_model.trace_graph(
 | 
				
			||||||
 | 
					                            input_ids=draft_current_input_ids,
 | 
				
			||||||
 | 
					                            attention_mask=draft_attention_mask,
 | 
				
			||||||
 | 
					                            position_ids=position_ids,
 | 
				
			||||||
 | 
					                            past_key_values=draft_past_key_values,
 | 
				
			||||||
 | 
					                        )
 | 
				
			||||||
 | 
					                    elif self.config.model_type == "baichuan":
 | 
				
			||||||
 | 
					                        if self.config.hidden_size == 4096:
 | 
				
			||||||
 | 
					                            past_key_value_len = draft_past_key_values[0][0].shape[2]
 | 
				
			||||||
 | 
					                            seq_len = draft_current_input_ids.shape[1]
 | 
				
			||||||
 | 
					                            seq_len_with_past = seq_len + past_key_value_len
 | 
				
			||||||
 | 
					                            position_ids = torch.arange(past_key_value_len,
 | 
				
			||||||
 | 
					                                                        seq_len_with_past,
 | 
				
			||||||
 | 
					                                                        dtype=torch.long,
 | 
				
			||||||
 | 
					                                                        device=draft_current_input_ids.device)
 | 
				
			||||||
 | 
					                            position_ids = position_ids.unsqueeze(0).view(-1, seq_len)
 | 
				
			||||||
 | 
					                            draft_output = draft_model.trace_graph(
 | 
				
			||||||
 | 
					                                input_ids=draft_current_input_ids,
 | 
				
			||||||
 | 
					                                attention_mask=draft_attention_mask,
 | 
				
			||||||
 | 
					                                position_ids=position_ids,
 | 
				
			||||||
 | 
					                                past_key_values=draft_past_key_values,
 | 
				
			||||||
 | 
					                            )
 | 
				
			||||||
 | 
					                        elif self.config.hidden_size == 5120:
 | 
				
			||||||
 | 
					                            draft_output = draft_model.trace_graph(
 | 
				
			||||||
 | 
					                                input_ids=draft_current_input_ids,
 | 
				
			||||||
 | 
					                                attention_mask=draft_attention_mask,
 | 
				
			||||||
 | 
					                                past_key_values=draft_past_key_values,
 | 
				
			||||||
 | 
					                            )
 | 
				
			||||||
 | 
					                    elif "qwen" in self.config.model_type:
 | 
				
			||||||
 | 
					                        draft_output = draft_model.trace_graph(
 | 
				
			||||||
 | 
					                            input_ids=draft_current_input_ids,
 | 
				
			||||||
 | 
					                            attention_mask=draft_attention_mask,
 | 
				
			||||||
 | 
					                            past_key_values=draft_past_key_values,
 | 
				
			||||||
 | 
					                        )
 | 
				
			||||||
 | 
					                    else:
 | 
				
			||||||
 | 
					                        invalidInputError(False, "BigDL Speculative Decoding with IPEX only supports \
 | 
				
			||||||
 | 
					                              Llama, Baichuan2, Mistral, ChatGLM and Qwen models currently.")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                    draft_output = CausalLMOutputWithPast(
 | 
				
			||||||
 | 
					                        logits=draft_output[0],
 | 
				
			||||||
 | 
					                        past_key_values=draft_output[1],
 | 
				
			||||||
 | 
					                    )
 | 
				
			||||||
 | 
					                else:
 | 
				
			||||||
 | 
					                    draft_output = draft_model(**forward_args)
 | 
				
			||||||
                temp_input_ids = torch.cat((input_ids, generate_ids[:, :step],
 | 
					                temp_input_ids = torch.cat((input_ids, generate_ids[:, :step],
 | 
				
			||||||
                                            draft_generate_ids[:, 1:step_draft+1]), dim=-1)
 | 
					                                            draft_generate_ids[:, 1:step_draft+1]), dim=-1)
 | 
				
			||||||
                logits = draft_output['logits']
 | 
					                logits = draft_output['logits']
 | 
				
			||||||
| 
						 | 
					@ -848,7 +905,6 @@ def speculative_generate(self,
 | 
				
			||||||
            # Clean up target model KV cache
 | 
					            # Clean up target model KV cache
 | 
				
			||||||
            if max_of_max_matched != max_matched:
 | 
					            if max_of_max_matched != max_matched:
 | 
				
			||||||
                output_ids = output_ids[:, :max_matched]
 | 
					                output_ids = output_ids[:, :max_matched]
 | 
				
			||||||
                # For Qwen
 | 
					 | 
				
			||||||
                if _enable_ipex:
 | 
					                if _enable_ipex:
 | 
				
			||||||
                    cur_len = past_key_values[0][0].size(1)
 | 
					                    cur_len = past_key_values[0][0].size(1)
 | 
				
			||||||
                    delta = max_of_max_matched - max_matched
 | 
					                    delta = max_of_max_matched - max_matched
 | 
				
			||||||
| 
						 | 
					@ -890,7 +946,7 @@ def speculative_generate(self,
 | 
				
			||||||
                        ]
 | 
					                        ]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            # Each iter assign new_matched kv_cache to past_key_values1
 | 
					            # Each iter assign new_matched kv_cache to past_key_values1
 | 
				
			||||||
            if self.device.type == 'cpu':
 | 
					            if self.device.type == 'cpu' and (not _enable_ipex):
 | 
				
			||||||
                _update_past_key_values_storage_cpu(self, past_key_values, past_key_values_storage,
 | 
					                _update_past_key_values_storage_cpu(self, past_key_values, past_key_values_storage,
 | 
				
			||||||
                                                    original_draft_past_key_values,
 | 
					                                                    original_draft_past_key_values,
 | 
				
			||||||
                                                    _enable_ipex)
 | 
					                                                    _enable_ipex)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue