[NPU] dump prefill IR for further C++ solution (#12402)

* save prefill ir

* fix

* shorten convert time

* fix

* fix

* fix

* fix

* fix style

* dump config.json

* meet review

* small fix
This commit is contained in:
Ruonan Wang 2024-11-19 23:20:05 -08:00 committed by GitHub
parent 1bfcbc0640
commit 54c62feb74
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 93 additions and 39 deletions

View file

@ -134,6 +134,8 @@ class _BaseAutoModelClass:
mixed_precision = kwargs.pop('mixed_precision', False)
quantization_group_size = kwargs.pop("quantization_group_size", 0)
mock_device = kwargs.pop('device', None) # For mock on CPU
compile_full_model = kwargs.pop('compile_full_model', False)
save_directory = kwargs.pop('save_directory', None)
invalidInputError(
quantization_group_size in [0, 32, 64, 128],
@ -199,7 +201,9 @@ class _BaseAutoModelClass:
"max_prompt_len": max_prompt_len,
"inter_pp": inter_pp,
"intra_pp": intra_pp,
"transpose_value_cache": transpose_value_cache
"transpose_value_cache": transpose_value_cache,
"compile_full_model": compile_full_model,
"save_directory": save_directory,
}
model = cls.optimize_npu_model(*args, **optimize_kwargs)
else:
@ -237,6 +241,8 @@ class _BaseAutoModelClass:
inter_pp = kwargs.pop("inter_pp", None)
intra_pp = kwargs.pop("intra_pp", None)
transpose_value_cache = kwargs.pop("transpose_value_cache", True)
compile_full_model = kwargs.pop('compile_full_model', False)
save_directory = kwargs.pop('save_directory', None)
if hasattr(model, "llm"):
llm = model.llm
@ -273,7 +279,9 @@ class _BaseAutoModelClass:
kv_len=max_context_len,
max_prompt_len=max_prompt_len,
transpose_value_cache=transpose_value_cache,
group_size=quantization_group_size)
group_size=quantization_group_size,
compile_full_model=compile_full_model,
save_directory=save_directory)
model.save_low_bit = types.MethodType(save_low_bit, model)
return model

View file

@ -23,16 +23,19 @@ from intel_npu_acceleration_library.backend.factory import NNFactory
import numpy as np
def update_names_of_IR_and_export_blob(model, model_name, dir):
def update_names_of_IR_and_export_blob(model, model_name, dir, compile_blob=True, keep_ir=True):
xml_path = os.path.join(dir, model_name + ".xml")
bin_path = os.path.join(dir, model_name + ".bin")
model.save(xml_path)
new_ir_path = os.path.join(dir, model_name + "_new.xml")
new_bin_path = os.path.join(dir, model_name + "_new.bin")
blob_path = os.path.join(dir, model_name + ".blob")
core = Core()
core.set_property("NPU", {"NPU_COMPILATION_MODE_PARAMS":
"compute-layers-with-higher-precision=Sqrt,Power,ReduceMean,Add"})
core.set_property("NPU", {"PERFORMANCE_HINT": "LATENCY"})
model = core.read_model(xml_path)
inputs = model.inputs
for idx, input in enumerate(inputs):
@ -46,14 +49,17 @@ def update_names_of_IR_and_export_blob(model, model_name, dir):
if new_ir_path is not None:
serialize(model, new_ir_path)
if blob_path is not None:
if compile_blob:
compiledModel = core.compile_model(model, device_name="NPU")
model_stream = compiledModel.export_model()
with open(blob_path, 'wb') as f:
f.write(model_stream)
os.remove(xml_path)
os.remove(new_ir_path)
if not keep_ir:
os.remove(new_ir_path)
os.remove(new_bin_path)
return blob_path
@ -123,6 +129,7 @@ class LLMEmbedding(NNFactory):
embedding_weight,
padding_idx,
dtype, # fp16
input_length: int = 1,
device: str = "NPU",
):
super().__init__(False, device)
@ -133,7 +140,7 @@ class LLMEmbedding(NNFactory):
# define input
weight = self.constant(embedding_weight)
input = self.parameter((1, 1), dtype=np.int32)
input = self.parameter((1, input_length), dtype=np.int32)
if padding_idx == -1:
padding_idx += vocab_size

View file

@ -192,7 +192,9 @@ def convert_llm(model: torch.nn.Module,
kv_len: int,
max_prompt_len: int,
transpose_value_cache: bool,
group_size: int):
group_size: int,
compile_full_model: bool=False,
save_directory: str=None):
# whether to set layernorm weight as const
layernorm_const = os.environ.get("IPEX_LLM_LAYERNORM_CONST", "1") == "1"
if group_size == 0:
@ -329,12 +331,16 @@ def convert_llm(model: torch.nn.Module,
elif model.config.model_type == "qwen2":
layernorm_const = os.environ.get("IPEX_LLM_LAYERNORM_CONST", "0") == "1"
with tempfile.TemporaryDirectory() as temp_dir:
if save_directory is not None:
temp_dir = save_directory
os.mkdir(temp_dir)
weight_dir = os.path.join(temp_dir, "model_weights")
os.mkdir(weight_dir)
layer_num = len(model.model.layers)
from .qwen import convert_qwen_layer, convert_lm_head_and_embedding
first_blob_path, last_blob_path = convert_lm_head_and_embedding(model, n_splits_linear,
temp_dir, weight_dir)
temp_dir, weight_dir,
compile_full_model)
param_list = []
for layer_idx in range(0, layer_num):
@ -344,6 +350,11 @@ def convert_llm(model: torch.nn.Module,
with Pool() as pool:
result = pool.starmap(convert_qwen_layer, param_list)
if compile_full_model:
convert_qwen_layer(model, 0, n_splits_linear, n_splits_down_proj,
temp_dir, weight_dir, transpose_value_cache, max_prompt_len,
group_size, layernorm_const, "prefill")
# Prefill Runner
from ipex_llm.transformers.npu_models.convert_mp import convert_qwen
convert_qwen(model,
@ -360,6 +371,16 @@ def convert_llm(model: torch.nn.Module,
model.transpose_value_cache = transpose_value_cache
model.vocab_size = model.config.vocab_size
if save_directory is not None:
update_dict = {"kv_len": kv_len, "num_head": model.num_head,
"head_dim": model.head_dim,
"transpose_value_cache": transpose_value_cache,
"max_prompt_len": max_prompt_len,
"layernorm_const": layernorm_const,
"group_size": group_size}
model.config.update(update_dict)
model.config.save_pretrained(save_directory)
try:
res = InitLLMPipeline("qwen", kv_len, model.num_head, model.head_dim, layer_num,
model.vocab_size, weight_dir, "model",

View file

@ -22,7 +22,8 @@ from .common import update_names_of_IR_and_export_blob, LLMEmbedding, LowBitLLML
from ipex_llm.transformers.npu_models.lm_head import SlicedLMHead
def convert_lm_head_and_embedding(model, n_splits_linear, temp_dir, weight_dir):
def convert_lm_head_and_embedding(model, n_splits_linear, temp_dir, weight_dir,
compile_full_model=False):
num_heads = model.model.layers[0].self_attn.num_heads
head_dim = model.model.layers[0].self_attn.head_dim
rms_norm_eps = model.config.rms_norm_eps
@ -57,7 +58,9 @@ def convert_lm_head_and_embedding(model, n_splits_linear, temp_dir, weight_dir):
vocab_size=vocab_size,
n_splits=n_splits_linear
)
last_blob_path = update_names_of_IR_and_export_blob(new_lm_head, "lm_head", temp_dir)
last_blob_path = update_names_of_IR_and_export_blob(new_lm_head, f"lm_head",
temp_dir, True, True)
# save weights bins files
if not isinstance(lm_head, SlicedLMHead):
@ -78,15 +81,19 @@ def convert_lm_head_and_embedding(model, n_splits_linear, temp_dir, weight_dir):
embedding_weight=embedding_layer.weight.to(torch.float16).detach().numpy(),
padding_idx=model.config.pad_token_id,
dtype=np.float16,
input_length=1,
)
first_blob_path = update_names_of_IR_and_export_blob(new_embedding, "embedding",
temp_dir)
first_blob_path = update_names_of_IR_and_export_blob(new_embedding, f"embedding",
temp_dir, True, keep_ir=True)
if compile_full_model:
bin_file = os.path.join(weight_dir, f"model_embedding_input_0.bin")
embedding_layer.weight.to(torch.float16).detach().numpy().tofile(bin_file)
return first_blob_path, last_blob_path
def convert_qwen_layer(model, layer_idx, n_splits_linear, n_splits_down_proj,
temp_dir, weight_dir, transpose_value_cache, kv_len, group_size,
layernorm_const):
layernorm_const, mode="decode"):
num_heads = model.model.layers[0].self_attn.num_heads
num_key_value_heads = model.model.layers[0].self_attn.num_key_value_heads
head_dim = model.model.layers[0].self_attn.head_dim
@ -123,10 +130,20 @@ def convert_qwen_layer(model, layer_idx, n_splits_linear, n_splits_down_proj,
else: # FP16 Linear
np_dtype = np.float16
if mode == "decode":
input_len = 1
decoder_name = f"decoder_layer_{layer_idx}"
compile = True
keep_ir = True
else:
input_len = kv_len
decoder_name = "decoder_layer_prefill"
compile = False
keep_ir = True
single_decoder = LowBitQwenMultiDecoderlayer(
[1, 1, num_heads * head_dim],
input_layernorm_weights=[layer_norm_0] if layernorm_const else None,
post_attn_layernorm_weights=[layer_norm_1] if layernorm_const else None,
[1, input_len, num_heads * head_dim],
input_layernorm_weights=None,
post_attn_layernorm_weights=None,
q_biases=None,
k_biases=None,
v_biases=None,
@ -138,7 +155,7 @@ def convert_qwen_layer(model, layer_idx, n_splits_linear, n_splits_down_proj,
max_seq_len=kv_len,
rms_norm_eps=rms_norm_eps,
intermediate_size=intermediate_size,
mode="decode",
mode=mode,
transpose_value=transpose_value_cache,
dtype=np_dtype,
n_splits_linear=n_splits_linear,
@ -146,29 +163,30 @@ def convert_qwen_layer(model, layer_idx, n_splits_linear, n_splits_down_proj,
group_size=group_size
)
rest_blob_path = update_names_of_IR_and_export_blob(single_decoder,
f"decoder_layer_{layer_idx}",
temp_dir)
decoder_name,
temp_dir, compile, keep_ir)
# 0, 1, 2 are input_embed/attention_mask/position_id
if layernorm_const:
st_idx = 3
else:
input_lm_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_3.bin")
post_lm_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_4.bin")
layer_norm_0.data.numpy().tofile(input_lm_bin_file)
layer_norm_1.data.numpy().tofile(post_lm_bin_file)
st_idx = 5
q_bias_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{st_idx}.bin")
k_bias_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{st_idx+1}.bin")
v_bias_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{st_idx+2}.bin")
q_bias.data.numpy().tofile(q_bias_bin_file)
k_bias.data.numpy().tofile(k_bias_bin_file)
v_bias.data.numpy().tofile(v_bias_bin_file)
# 6, 7 are past k/v
for idx, (weight, scale) in enumerate(weights):
bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{st_idx+5+idx*2}.bin")
weight.numpy().tofile(bin_file)
bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{st_idx+5+idx*2+1}.bin")
scale.numpy().tofile(bin_file)
if mode == "decode":
if layernorm_const:
st_idx = 3
else:
input_lm_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_3.bin")
post_lm_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_4.bin")
layer_norm_0.data.numpy().tofile(input_lm_bin_file)
layer_norm_1.data.numpy().tofile(post_lm_bin_file)
st_idx = 5
q_bias_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{st_idx}.bin")
k_bias_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{st_idx+1}.bin")
v_bias_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{st_idx+2}.bin")
q_bias.data.numpy().tofile(q_bias_bin_file)
k_bias.data.numpy().tofile(k_bias_bin_file)
v_bias.data.numpy().tofile(v_bias_bin_file)
# 6, 7 are past k/v
for idx, (weight, scale) in enumerate(weights):
bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{st_idx+5+idx*2}.bin")
weight.numpy().tofile(bin_file)
bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{st_idx+5+idx*2+1}.bin")
scale.numpy().tofile(bin_file)
del single_decoder