diff --git a/python/llm/src/ipex_llm/transformers/npu_model.py b/python/llm/src/ipex_llm/transformers/npu_model.py index 850f9009..2c96b06e 100644 --- a/python/llm/src/ipex_llm/transformers/npu_model.py +++ b/python/llm/src/ipex_llm/transformers/npu_model.py @@ -134,6 +134,8 @@ class _BaseAutoModelClass: mixed_precision = kwargs.pop('mixed_precision', False) quantization_group_size = kwargs.pop("quantization_group_size", 0) mock_device = kwargs.pop('device', None) # For mock on CPU + compile_full_model = kwargs.pop('compile_full_model', False) + save_directory = kwargs.pop('save_directory', None) invalidInputError( quantization_group_size in [0, 32, 64, 128], @@ -199,7 +201,9 @@ class _BaseAutoModelClass: "max_prompt_len": max_prompt_len, "inter_pp": inter_pp, "intra_pp": intra_pp, - "transpose_value_cache": transpose_value_cache + "transpose_value_cache": transpose_value_cache, + "compile_full_model": compile_full_model, + "save_directory": save_directory, } model = cls.optimize_npu_model(*args, **optimize_kwargs) else: @@ -237,6 +241,8 @@ class _BaseAutoModelClass: inter_pp = kwargs.pop("inter_pp", None) intra_pp = kwargs.pop("intra_pp", None) transpose_value_cache = kwargs.pop("transpose_value_cache", True) + compile_full_model = kwargs.pop('compile_full_model', False) + save_directory = kwargs.pop('save_directory', None) if hasattr(model, "llm"): llm = model.llm @@ -273,7 +279,9 @@ class _BaseAutoModelClass: kv_len=max_context_len, max_prompt_len=max_prompt_len, transpose_value_cache=transpose_value_cache, - group_size=quantization_group_size) + group_size=quantization_group_size, + compile_full_model=compile_full_model, + save_directory=save_directory) model.save_low_bit = types.MethodType(save_low_bit, model) return model diff --git a/python/llm/src/ipex_llm/transformers/npu_pipeline_model/common.py b/python/llm/src/ipex_llm/transformers/npu_pipeline_model/common.py index 3a9f81e2..a5aa5791 100644 --- a/python/llm/src/ipex_llm/transformers/npu_pipeline_model/common.py +++ b/python/llm/src/ipex_llm/transformers/npu_pipeline_model/common.py @@ -23,16 +23,19 @@ from intel_npu_acceleration_library.backend.factory import NNFactory import numpy as np -def update_names_of_IR_and_export_blob(model, model_name, dir): +def update_names_of_IR_and_export_blob(model, model_name, dir, compile_blob=True, keep_ir=True): xml_path = os.path.join(dir, model_name + ".xml") + bin_path = os.path.join(dir, model_name + ".bin") model.save(xml_path) new_ir_path = os.path.join(dir, model_name + "_new.xml") + new_bin_path = os.path.join(dir, model_name + "_new.bin") blob_path = os.path.join(dir, model_name + ".blob") core = Core() core.set_property("NPU", {"NPU_COMPILATION_MODE_PARAMS": "compute-layers-with-higher-precision=Sqrt,Power,ReduceMean,Add"}) core.set_property("NPU", {"PERFORMANCE_HINT": "LATENCY"}) + model = core.read_model(xml_path) inputs = model.inputs for idx, input in enumerate(inputs): @@ -46,14 +49,17 @@ def update_names_of_IR_and_export_blob(model, model_name, dir): if new_ir_path is not None: serialize(model, new_ir_path) - if blob_path is not None: + if compile_blob: compiledModel = core.compile_model(model, device_name="NPU") model_stream = compiledModel.export_model() with open(blob_path, 'wb') as f: f.write(model_stream) os.remove(xml_path) - os.remove(new_ir_path) + + if not keep_ir: + os.remove(new_ir_path) + os.remove(new_bin_path) return blob_path @@ -123,6 +129,7 @@ class LLMEmbedding(NNFactory): embedding_weight, padding_idx, dtype, # fp16 + input_length: int = 1, device: str = "NPU", ): super().__init__(False, device) @@ -133,7 +140,7 @@ class LLMEmbedding(NNFactory): # define input weight = self.constant(embedding_weight) - input = self.parameter((1, 1), dtype=np.int32) + input = self.parameter((1, input_length), dtype=np.int32) if padding_idx == -1: padding_idx += vocab_size diff --git a/python/llm/src/ipex_llm/transformers/npu_pipeline_model/convert_pipeline.py b/python/llm/src/ipex_llm/transformers/npu_pipeline_model/convert_pipeline.py index 47f6732a..343d79ee 100644 --- a/python/llm/src/ipex_llm/transformers/npu_pipeline_model/convert_pipeline.py +++ b/python/llm/src/ipex_llm/transformers/npu_pipeline_model/convert_pipeline.py @@ -192,7 +192,9 @@ def convert_llm(model: torch.nn.Module, kv_len: int, max_prompt_len: int, transpose_value_cache: bool, - group_size: int): + group_size: int, + compile_full_model: bool=False, + save_directory: str=None): # whether to set layernorm weight as const layernorm_const = os.environ.get("IPEX_LLM_LAYERNORM_CONST", "1") == "1" if group_size == 0: @@ -329,12 +331,16 @@ def convert_llm(model: torch.nn.Module, elif model.config.model_type == "qwen2": layernorm_const = os.environ.get("IPEX_LLM_LAYERNORM_CONST", "0") == "1" with tempfile.TemporaryDirectory() as temp_dir: + if save_directory is not None: + temp_dir = save_directory + os.mkdir(temp_dir) weight_dir = os.path.join(temp_dir, "model_weights") os.mkdir(weight_dir) layer_num = len(model.model.layers) from .qwen import convert_qwen_layer, convert_lm_head_and_embedding first_blob_path, last_blob_path = convert_lm_head_and_embedding(model, n_splits_linear, - temp_dir, weight_dir) + temp_dir, weight_dir, + compile_full_model) param_list = [] for layer_idx in range(0, layer_num): @@ -344,6 +350,11 @@ def convert_llm(model: torch.nn.Module, with Pool() as pool: result = pool.starmap(convert_qwen_layer, param_list) + if compile_full_model: + convert_qwen_layer(model, 0, n_splits_linear, n_splits_down_proj, + temp_dir, weight_dir, transpose_value_cache, max_prompt_len, + group_size, layernorm_const, "prefill") + # Prefill Runner from ipex_llm.transformers.npu_models.convert_mp import convert_qwen convert_qwen(model, @@ -360,6 +371,16 @@ def convert_llm(model: torch.nn.Module, model.transpose_value_cache = transpose_value_cache model.vocab_size = model.config.vocab_size + if save_directory is not None: + update_dict = {"kv_len": kv_len, "num_head": model.num_head, + "head_dim": model.head_dim, + "transpose_value_cache": transpose_value_cache, + "max_prompt_len": max_prompt_len, + "layernorm_const": layernorm_const, + "group_size": group_size} + model.config.update(update_dict) + model.config.save_pretrained(save_directory) + try: res = InitLLMPipeline("qwen", kv_len, model.num_head, model.head_dim, layer_num, model.vocab_size, weight_dir, "model", diff --git a/python/llm/src/ipex_llm/transformers/npu_pipeline_model/qwen.py b/python/llm/src/ipex_llm/transformers/npu_pipeline_model/qwen.py index 645ad830..8d3966b4 100644 --- a/python/llm/src/ipex_llm/transformers/npu_pipeline_model/qwen.py +++ b/python/llm/src/ipex_llm/transformers/npu_pipeline_model/qwen.py @@ -22,7 +22,8 @@ from .common import update_names_of_IR_and_export_blob, LLMEmbedding, LowBitLLML from ipex_llm.transformers.npu_models.lm_head import SlicedLMHead -def convert_lm_head_and_embedding(model, n_splits_linear, temp_dir, weight_dir): +def convert_lm_head_and_embedding(model, n_splits_linear, temp_dir, weight_dir, + compile_full_model=False): num_heads = model.model.layers[0].self_attn.num_heads head_dim = model.model.layers[0].self_attn.head_dim rms_norm_eps = model.config.rms_norm_eps @@ -57,7 +58,9 @@ def convert_lm_head_and_embedding(model, n_splits_linear, temp_dir, weight_dir): vocab_size=vocab_size, n_splits=n_splits_linear ) - last_blob_path = update_names_of_IR_and_export_blob(new_lm_head, "lm_head", temp_dir) + + last_blob_path = update_names_of_IR_and_export_blob(new_lm_head, f"lm_head", + temp_dir, True, True) # save weights bins files if not isinstance(lm_head, SlicedLMHead): @@ -78,15 +81,19 @@ def convert_lm_head_and_embedding(model, n_splits_linear, temp_dir, weight_dir): embedding_weight=embedding_layer.weight.to(torch.float16).detach().numpy(), padding_idx=model.config.pad_token_id, dtype=np.float16, + input_length=1, ) - first_blob_path = update_names_of_IR_and_export_blob(new_embedding, "embedding", - temp_dir) + first_blob_path = update_names_of_IR_and_export_blob(new_embedding, f"embedding", + temp_dir, True, keep_ir=True) + if compile_full_model: + bin_file = os.path.join(weight_dir, f"model_embedding_input_0.bin") + embedding_layer.weight.to(torch.float16).detach().numpy().tofile(bin_file) return first_blob_path, last_blob_path def convert_qwen_layer(model, layer_idx, n_splits_linear, n_splits_down_proj, temp_dir, weight_dir, transpose_value_cache, kv_len, group_size, - layernorm_const): + layernorm_const, mode="decode"): num_heads = model.model.layers[0].self_attn.num_heads num_key_value_heads = model.model.layers[0].self_attn.num_key_value_heads head_dim = model.model.layers[0].self_attn.head_dim @@ -123,10 +130,20 @@ def convert_qwen_layer(model, layer_idx, n_splits_linear, n_splits_down_proj, else: # FP16 Linear np_dtype = np.float16 + if mode == "decode": + input_len = 1 + decoder_name = f"decoder_layer_{layer_idx}" + compile = True + keep_ir = True + else: + input_len = kv_len + decoder_name = "decoder_layer_prefill" + compile = False + keep_ir = True single_decoder = LowBitQwenMultiDecoderlayer( - [1, 1, num_heads * head_dim], - input_layernorm_weights=[layer_norm_0] if layernorm_const else None, - post_attn_layernorm_weights=[layer_norm_1] if layernorm_const else None, + [1, input_len, num_heads * head_dim], + input_layernorm_weights=None, + post_attn_layernorm_weights=None, q_biases=None, k_biases=None, v_biases=None, @@ -138,7 +155,7 @@ def convert_qwen_layer(model, layer_idx, n_splits_linear, n_splits_down_proj, max_seq_len=kv_len, rms_norm_eps=rms_norm_eps, intermediate_size=intermediate_size, - mode="decode", + mode=mode, transpose_value=transpose_value_cache, dtype=np_dtype, n_splits_linear=n_splits_linear, @@ -146,29 +163,30 @@ def convert_qwen_layer(model, layer_idx, n_splits_linear, n_splits_down_proj, group_size=group_size ) rest_blob_path = update_names_of_IR_and_export_blob(single_decoder, - f"decoder_layer_{layer_idx}", - temp_dir) + decoder_name, + temp_dir, compile, keep_ir) # 0, 1, 2 are input_embed/attention_mask/position_id - if layernorm_const: - st_idx = 3 - else: - input_lm_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_3.bin") - post_lm_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_4.bin") - layer_norm_0.data.numpy().tofile(input_lm_bin_file) - layer_norm_1.data.numpy().tofile(post_lm_bin_file) - st_idx = 5 - q_bias_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{st_idx}.bin") - k_bias_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{st_idx+1}.bin") - v_bias_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{st_idx+2}.bin") - q_bias.data.numpy().tofile(q_bias_bin_file) - k_bias.data.numpy().tofile(k_bias_bin_file) - v_bias.data.numpy().tofile(v_bias_bin_file) - # 6, 7 are past k/v - for idx, (weight, scale) in enumerate(weights): - bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{st_idx+5+idx*2}.bin") - weight.numpy().tofile(bin_file) - bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{st_idx+5+idx*2+1}.bin") - scale.numpy().tofile(bin_file) + if mode == "decode": + if layernorm_const: + st_idx = 3 + else: + input_lm_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_3.bin") + post_lm_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_4.bin") + layer_norm_0.data.numpy().tofile(input_lm_bin_file) + layer_norm_1.data.numpy().tofile(post_lm_bin_file) + st_idx = 5 + q_bias_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{st_idx}.bin") + k_bias_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{st_idx+1}.bin") + v_bias_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{st_idx+2}.bin") + q_bias.data.numpy().tofile(q_bias_bin_file) + k_bias.data.numpy().tofile(k_bias_bin_file) + v_bias.data.numpy().tofile(v_bias_bin_file) + # 6, 7 are past k/v + for idx, (weight, scale) in enumerate(weights): + bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{st_idx+5+idx*2}.bin") + weight.numpy().tofile(bin_file) + bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{st_idx+5+idx*2+1}.bin") + scale.numpy().tofile(bin_file) del single_decoder