[NPU] Support non-const parameter for decoder layers when keep_ir=True (#12789)

* support layernorm=False for decoder layers

* renbame to meet review

* fix style

* rename to const_parameter

* fix rebase error

* fix rebase error
This commit is contained in:
Ruonan Wang 2025-02-08 09:58:42 +08:00 committed by GitHub
parent 8aea5319bb
commit e90a9ad196
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 65 additions and 47 deletions

View file

@ -58,6 +58,16 @@ if __name__ == "__main__":
model_path = args.repo_id_or_model_path model_path = args.repo_id_or_model_path
save_dir = args.save_directory save_dir = args.save_directory
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
trans_version = transformers.__version__
if version.parse(trans_version) >= version.parse("4.45.0"):
tokenizer_json = os.path.join(model_path, "tokenizer.json")
dst_path = os.path.join(save_dir, "tokenizer.json")
shutil.copy(tokenizer_json, dst_path)
else:
tokenizer.save_pretrained(save_dir)
t0 = time.perf_counter() t0 = time.perf_counter()
model = AutoModelForCausalLM.from_pretrained(model_path, model = AutoModelForCausalLM.from_pretrained(model_path,
optimize_model=True, optimize_model=True,
@ -73,15 +83,6 @@ if __name__ == "__main__":
compile_blob=not args.disable_compile_blob) compile_blob=not args.disable_compile_blob)
t1 = time.perf_counter() t1 = time.perf_counter()
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
trans_version = transformers.__version__
if version.parse(trans_version) >= version.parse("4.45.0"):
tokenizer_json = os.path.join(model_path, "tokenizer.json")
dst_path = os.path.join(save_dir, "tokenizer.json")
shutil.copy(tokenizer_json, dst_path)
else:
tokenizer.save_pretrained(save_dir)
print("-" * 80) print("-" * 80)
print(f"Convert model cost {t1 - t0}s.") print(f"Convert model cost {t1 - t0}s.")

View file

@ -201,7 +201,7 @@ def convert_llm(model: torch.nn.Module,
keep_ir: bool=False, keep_ir: bool=False,
compile_blob: bool=True): compile_blob: bool=True):
# whether to set layernorm weight as const # whether to set layernorm weight as const
layernorm_const = os.environ.get("IPEX_LLM_NPU_LAYERNORM_CONST", "1") == "1" const_parameter = os.environ.get("IPEX_LLM_NPU_CONST_PARAMETER", "1") == "1"
if group_size == 0: if group_size == 0:
n_splits_linear = 1 n_splits_linear = 1
if qtype in ["sym_int8_rtn", "asym_int4_rtn"]: if qtype in ["sym_int8_rtn", "asym_int4_rtn"]:
@ -240,7 +240,7 @@ def convert_llm(model: torch.nn.Module,
for layer_idx in range(0, layer_num): for layer_idx in range(0, layer_num):
param_list.append((model, layer_idx, n_splits_linear, n_splits_down_proj, param_list.append((model, layer_idx, n_splits_linear, n_splits_down_proj,
temp_dir, weight_dir, transpose_value_cache, kv_len, group_size, temp_dir, weight_dir, transpose_value_cache, kv_len, group_size,
layernorm_const)) const_parameter))
with Pool() as pool: with Pool() as pool:
result = pool.starmap(convert_llama_layer, param_list) result = pool.starmap(convert_llama_layer, param_list)
@ -267,7 +267,7 @@ def convert_llm(model: torch.nn.Module,
res = InitLLMPipeline(model_type, kv_len, model.num_head, model.head_dim, layer_num, res = InitLLMPipeline(model_type, kv_len, model.num_head, model.head_dim, layer_num,
model.vocab_size, weight_dir, "model", model.vocab_size, weight_dir, "model",
first_blob_path, last_blob_path, first_blob_path, last_blob_path,
os.path.join(temp_dir, "decoder_layer"), layernorm_const) os.path.join(temp_dir, "decoder_layer"), const_parameter)
except: except:
invalidInputError(False, invalidInputError(False,
"False to InitLLMPipeline.") "False to InitLLMPipeline.")
@ -284,7 +284,7 @@ def convert_llm(model: torch.nn.Module,
for layer_idx in range(0, layer_num): for layer_idx in range(0, layer_num):
param_list.append((model, layer_idx, n_splits_linear, n_splits_down_proj, param_list.append((model, layer_idx, n_splits_linear, n_splits_down_proj,
temp_dir, weight_dir, transpose_value_cache, kv_len, group_size, temp_dir, weight_dir, transpose_value_cache, kv_len, group_size,
layernorm_const)) const_parameter))
with Pool() as pool: with Pool() as pool:
result = pool.starmap(convert_baichuan_layer, param_list) result = pool.starmap(convert_baichuan_layer, param_list)
@ -308,7 +308,7 @@ def convert_llm(model: torch.nn.Module,
res = InitLLMPipeline("baichuan", kv_len, model.num_head, model.head_dim, layer_num, res = InitLLMPipeline("baichuan", kv_len, model.num_head, model.head_dim, layer_num,
model.vocab_size, weight_dir, "model", model.vocab_size, weight_dir, "model",
first_blob_path, last_blob_path, first_blob_path, last_blob_path,
os.path.join(temp_dir, "decoder_layer"), layernorm_const) os.path.join(temp_dir, "decoder_layer"), const_parameter)
except: except:
invalidInputError(False, invalidInputError(False,
"False to InitLLMPipeline.") "False to InitLLMPipeline.")
@ -325,7 +325,7 @@ def convert_llm(model: torch.nn.Module,
for layer_idx in range(0, layer_num): for layer_idx in range(0, layer_num):
param_list.append((model, layer_idx, n_splits_linear, n_splits_down_proj, param_list.append((model, layer_idx, n_splits_linear, n_splits_down_proj,
temp_dir, weight_dir, transpose_value_cache, kv_len, group_size, temp_dir, weight_dir, transpose_value_cache, kv_len, group_size,
layernorm_const)) const_parameter))
with Pool() as pool: with Pool() as pool:
result = pool.starmap(convert_minicpm_layer, param_list) result = pool.starmap(convert_minicpm_layer, param_list)
@ -348,12 +348,12 @@ def convert_llm(model: torch.nn.Module,
res = InitLLMPipeline("minicpm", kv_len, model.num_head, model.head_dim, layer_num, res = InitLLMPipeline("minicpm", kv_len, model.num_head, model.head_dim, layer_num,
model.vocab_size, weight_dir, "model", model.vocab_size, weight_dir, "model",
first_blob_path, last_blob_path, first_blob_path, last_blob_path,
os.path.join(temp_dir, "decoder_layer"), layernorm_const) os.path.join(temp_dir, "decoder_layer"), const_parameter)
except: except:
invalidInputError(False, invalidInputError(False,
"False to InitLLMPipeline.") "False to InitLLMPipeline.")
elif model.config.model_type == "qwen2": elif model.config.model_type == "qwen2":
layernorm_const = os.environ.get("IPEX_LLM_NPU_LAYERNORM_CONST", "0") == "1" const_parameter = os.environ.get("IPEX_LLM_NPU_CONST_PARAMETER", "0") == "1"
with tempfile.TemporaryDirectory() as temp_dir: with tempfile.TemporaryDirectory() as temp_dir:
if save_directory is not None: if save_directory is not None:
temp_dir = save_directory temp_dir = save_directory
@ -371,7 +371,7 @@ def convert_llm(model: torch.nn.Module,
for layer_idx in range(0, layer_num): for layer_idx in range(0, layer_num):
param_list.append((model, layer_idx, n_splits_linear, n_splits_down_proj, param_list.append((model, layer_idx, n_splits_linear, n_splits_down_proj,
temp_dir, weight_dir, transpose_value_cache, kv_len, group_size, temp_dir, weight_dir, transpose_value_cache, kv_len, group_size,
layernorm_const)) const_parameter))
with Pool() as pool: with Pool() as pool:
result = pool.starmap(convert_qwen_layer, param_list) result = pool.starmap(convert_qwen_layer, param_list)
@ -396,7 +396,7 @@ def convert_llm(model: torch.nn.Module,
"head_dim": model.head_dim, "head_dim": model.head_dim,
"transpose_value_cache": transpose_value_cache, "transpose_value_cache": transpose_value_cache,
"max_prompt_len": max_prompt_len, "max_prompt_len": max_prompt_len,
"layernorm_const": layernorm_const, "const_parameter": const_parameter,
"group_size": group_size} "group_size": group_size}
model.config.update(update_dict) model.config.update(update_dict)
model.config.save_pretrained(save_directory) model.config.save_pretrained(save_directory)
@ -405,7 +405,7 @@ def convert_llm(model: torch.nn.Module,
res = InitLLMPipeline("qwen", kv_len, model.num_head, model.head_dim, layer_num, res = InitLLMPipeline("qwen", kv_len, model.num_head, model.head_dim, layer_num,
model.vocab_size, weight_dir, "model", model.vocab_size, weight_dir, "model",
first_blob_path, last_blob_path, first_blob_path, last_blob_path,
os.path.join(temp_dir, "decoder_layer"), layernorm_const) os.path.join(temp_dir, "decoder_layer"), const_parameter)
except: except:
invalidInputError(False, invalidInputError(False,
"False to InitLLMPipeline.") "False to InitLLMPipeline.")
@ -441,7 +441,9 @@ def convert_llm_for_deploy(model: torch.nn.Module,
weight_dir = os.path.join(save_directory, "model_weights") weight_dir = os.path.join(save_directory, "model_weights")
if not os.path.exists(weight_dir): if not os.path.exists(weight_dir):
os.mkdir(weight_dir) os.mkdir(weight_dir)
layernorm_const = os.environ.get("IPEX_LLM_NPU_LAYERNORM_CONST", "1") == "1" const_parameter = os.environ.get("IPEX_LLM_NPU_CONST_PARAMETER", "1") == "1"
if keep_ir:
const_parameter = False
lm_head_low_bit = getattr(model.config, "bigdl_transformers_low_bit", "sym_int4_rtn") lm_head_low_bit = getattr(model.config, "bigdl_transformers_low_bit", "sym_int4_rtn")
if hasattr(model, "lm_head") and not isinstance(model.lm_head, SlicedLMHead): if hasattr(model, "lm_head") and not isinstance(model.lm_head, SlicedLMHead):
@ -472,7 +474,7 @@ def convert_llm_for_deploy(model: torch.nn.Module,
"head_dim": model.model.layers[0].self_attn.head_dim, "head_dim": model.model.layers[0].self_attn.head_dim,
"transpose_value_cache": transpose_value_cache, "transpose_value_cache": transpose_value_cache,
"max_prompt_len": max_prompt_len, "max_prompt_len": max_prompt_len,
"layernorm_const": layernorm_const, "const_parameter": const_parameter,
"group_size": group_size, "group_size": group_size,
"fused_layers": fused_layers, "fused_layers": fused_layers,
"qkv_bias": True, "qkv_bias": True,
@ -490,12 +492,12 @@ def convert_llm_for_deploy(model: torch.nn.Module,
# save fused_layers blobs of fused decoder layers # save fused_layers blobs of fused decoder layers
convert_fused_qwen_layer(model, fused_layers, n_splits_linear, n_splits_down_proj, convert_fused_qwen_layer(model, fused_layers, n_splits_linear, n_splits_down_proj,
save_directory, weight_dir, transpose_value_cache, kv_len, save_directory, weight_dir, transpose_value_cache, kv_len,
group_size, layernorm_const, "decode", group_size, const_parameter, "decode",
keep_ir=keep_ir, compile_blob=compile_blob) keep_ir=keep_ir, compile_blob=compile_blob)
# save blob of single prefill layer # save blob of single prefill layer
convert_qwen_layer(model, 0, n_splits_linear, n_splits_down_proj, convert_qwen_layer(model, 0, n_splits_linear, n_splits_down_proj,
save_directory, weight_dir, transpose_value_cache, max_prompt_len, save_directory, weight_dir, transpose_value_cache, max_prompt_len,
group_size, layernorm_const, "prefill", group_size, const_parameter, "prefill",
keep_ir=keep_ir, compile_blob=compile_blob) keep_ir=keep_ir, compile_blob=compile_blob)
# save blob of lmhead and bin of embedding # save blob of lmhead and bin of embedding
convert_lm_head_and_embedding(model, save_directory, weight_dir, convert_model=True, convert_lm_head_and_embedding(model, save_directory, weight_dir, convert_model=True,
@ -535,7 +537,7 @@ def convert_llm_for_deploy(model: torch.nn.Module,
"head_dim": model.model.layers[0].self_attn.head_dim, "head_dim": model.model.layers[0].self_attn.head_dim,
"transpose_value_cache": transpose_value_cache, "transpose_value_cache": transpose_value_cache,
"max_prompt_len": max_prompt_len, "max_prompt_len": max_prompt_len,
"layernorm_const": layernorm_const, "const_parameter": const_parameter,
"group_size": group_size, "group_size": group_size,
"fused_layers": fused_layers, "fused_layers": fused_layers,
"qkv_bias": False, "qkv_bias": False,
@ -559,12 +561,12 @@ def convert_llm_for_deploy(model: torch.nn.Module,
# save fused_layers blobs of fused decoder layers # save fused_layers blobs of fused decoder layers
convert_fused_llama_layer(model, fused_layers, n_splits_linear, n_splits_down_proj, convert_fused_llama_layer(model, fused_layers, n_splits_linear, n_splits_down_proj,
save_directory, weight_dir, transpose_value_cache, kv_len, save_directory, weight_dir, transpose_value_cache, kv_len,
group_size, layernorm_const, "decode", group_size, const_parameter, "decode",
keep_ir=keep_ir, compile_blob=compile_blob) keep_ir=keep_ir, compile_blob=compile_blob)
# save blob of single prefill layer # save blob of single prefill layer
convert_llama_layer(model, 0, n_splits_linear, n_splits_down_proj, convert_llama_layer(model, 0, n_splits_linear, n_splits_down_proj,
save_directory, weight_dir, transpose_value_cache, max_prompt_len, save_directory, weight_dir, transpose_value_cache, max_prompt_len,
group_size, layernorm_const, "prefill", group_size, const_parameter, "prefill",
keep_ir=keep_ir, compile_blob=compile_blob) keep_ir=keep_ir, compile_blob=compile_blob)
elif model.config.model_type == "minicpm": elif model.config.model_type == "minicpm":
if group_size == 0: if group_size == 0:
@ -576,7 +578,7 @@ def convert_llm_for_deploy(model: torch.nn.Module,
"head_dim": model.model.layers[0].self_attn.head_dim, "head_dim": model.model.layers[0].self_attn.head_dim,
"transpose_value_cache": transpose_value_cache, "transpose_value_cache": transpose_value_cache,
"max_prompt_len": max_prompt_len, "max_prompt_len": max_prompt_len,
"layernorm_const": layernorm_const, "const_parameter": const_parameter,
"group_size": group_size, "group_size": group_size,
"fused_layers": fused_layers, "fused_layers": fused_layers,
"qkv_bias": False, "qkv_bias": False,
@ -594,12 +596,12 @@ def convert_llm_for_deploy(model: torch.nn.Module,
# save fused_layers blobs of fused decoder layers # save fused_layers blobs of fused decoder layers
convert_fused_minicpm_layer(model, fused_layers, n_splits_linear, n_splits_down_proj, convert_fused_minicpm_layer(model, fused_layers, n_splits_linear, n_splits_down_proj,
save_directory, weight_dir, transpose_value_cache, kv_len, save_directory, weight_dir, transpose_value_cache, kv_len,
group_size, layernorm_const, "decode", group_size, const_parameter, "decode",
keep_ir=keep_ir, compile_blob=compile_blob) keep_ir=keep_ir, compile_blob=compile_blob)
# save blob of single prefill layer # save blob of single prefill layer
convert_minicpm_layer(model, 0, n_splits_linear, n_splits_down_proj, convert_minicpm_layer(model, 0, n_splits_linear, n_splits_down_proj,
save_directory, weight_dir, transpose_value_cache, max_prompt_len, save_directory, weight_dir, transpose_value_cache, max_prompt_len,
group_size, layernorm_const, "prefill", group_size, const_parameter, "prefill",
keep_ir=keep_ir, compile_blob=compile_blob) keep_ir=keep_ir, compile_blob=compile_blob)
# save blob of lmhead and bin of embedding and embedding_post # save blob of lmhead and bin of embedding and embedding_post
convert_lm_head_and_embedding(model, n_splits_linear, convert_lm_head_and_embedding(model, n_splits_linear,

View file

@ -107,7 +107,7 @@ def convert_lm_head_and_embedding(model, n_splits_linear, temp_dir, weight_dir,
def convert_llama_layer(model, layer_idx, n_splits_linear, n_splits_down_proj, def convert_llama_layer(model, layer_idx, n_splits_linear, n_splits_down_proj,
temp_dir, weight_dir, transpose_value_cache, kv_len, group_size, temp_dir, weight_dir, transpose_value_cache, kv_len, group_size,
layernorm_const, mode="decode", const_parameter, mode="decode",
keep_ir=False, compile_blob=True): keep_ir=False, compile_blob=True):
num_heads = model.model.layers[0].self_attn.num_heads num_heads = model.model.layers[0].self_attn.num_heads
num_key_value_heads = model.model.layers[0].self_attn.num_key_value_heads num_key_value_heads = model.model.layers[0].self_attn.num_key_value_heads
@ -145,14 +145,14 @@ def convert_llama_layer(model, layer_idx, n_splits_linear, n_splits_down_proj,
else: else:
input_len = kv_len input_len = kv_len
decoder_name = "decoder_layer_prefill" decoder_name = "decoder_layer_prefill"
layernorm_const = False const_parameter = False
keep_position_ids = False keep_position_ids = False
npu_dpu_groups = 6 npu_dpu_groups = 6
single_decoder = LowBitLlamaMultiDecoderlayer( single_decoder = LowBitLlamaMultiDecoderlayer(
[1, input_len, num_heads * head_dim], [1, input_len, num_heads * head_dim],
input_layernorm_weights=[layer_norm_0] if layernorm_const else None, input_layernorm_weights=[layer_norm_0] if const_parameter else None,
post_attn_layernorm_weights=[layer_norm_1] if layernorm_const else None, post_attn_layernorm_weights=[layer_norm_1] if const_parameter else None,
cached_cos=cached_cos, cached_cos=cached_cos,
cached_sin=cached_sin, cached_sin=cached_sin,
num_heads=num_heads, num_heads=num_heads,
@ -182,7 +182,7 @@ def convert_llama_layer(model, layer_idx, n_splits_linear, n_splits_down_proj,
if mode == "decode": if mode == "decode":
if hasattr(curr_layer.self_attn.rotary_emb, "cos_cached"): if hasattr(curr_layer.self_attn.rotary_emb, "cos_cached"):
# llama-2-7B & llama-3-8B # llama-2-7B & llama-3-8B
if layernorm_const: if const_parameter:
st_idx = 5 st_idx = 5
else: else:
input_lm_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_3.bin") input_lm_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_3.bin")
@ -192,7 +192,7 @@ def convert_llama_layer(model, layer_idx, n_splits_linear, n_splits_down_proj,
st_idx = 7 st_idx = 7
else: else:
# llama-3.2-3B & llama-3.2-1B # llama-3.2-3B & llama-3.2-1B
if layernorm_const: if const_parameter:
st_idx = 6 st_idx = 6
else: else:
input_lm_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_4.bin") input_lm_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_4.bin")
@ -223,7 +223,7 @@ def convert_llama_layer(model, layer_idx, n_splits_linear, n_splits_down_proj,
def convert_fused_llama_layer(model, fused_layers, n_splits_linear, n_splits_down_proj, def convert_fused_llama_layer(model, fused_layers, n_splits_linear, n_splits_down_proj,
save_dir, weight_dir, transpose_value_cache, kv_len, group_size, save_dir, weight_dir, transpose_value_cache, kv_len, group_size,
layernorm_const, mode="decode", const_parameter, mode="decode",
keep_ir=False, compile_blob=True): keep_ir=False, compile_blob=True):
num_heads = model.model.layers[0].self_attn.num_heads num_heads = model.model.layers[0].self_attn.num_heads
num_key_value_heads = model.model.layers[0].self_attn.num_key_value_heads num_key_value_heads = model.model.layers[0].self_attn.num_key_value_heads
@ -294,6 +294,10 @@ def convert_fused_llama_layer(model, fused_layers, n_splits_linear, n_splits_dow
else: # FP16 Linear else: # FP16 Linear
np_dtype = np.float16 np_dtype = np.float16
if not const_parameter:
input_layer_norm_weights = None
post_attn_layernorm_weights = None
fused_decoder = LowBitLlamaMultiDecoderlayer( fused_decoder = LowBitLlamaMultiDecoderlayer(
[1, 1, num_heads * head_dim], [1, 1, num_heads * head_dim],
input_layernorm_weights=input_layer_norm_weights, input_layernorm_weights=input_layer_norm_weights,

View file

@ -301,7 +301,7 @@ def convert_lm_head_and_embedding(model, n_splits_linear, temp_dir, weight_dir,
def convert_minicpm_layer(model, layer_idx, n_splits_linear, n_splits_down_proj, def convert_minicpm_layer(model, layer_idx, n_splits_linear, n_splits_down_proj,
temp_dir, weight_dir, transpose_value_cache, kv_len, group_size, temp_dir, weight_dir, transpose_value_cache, kv_len, group_size,
layernorm_const, mode="decode", const_parameter, mode="decode",
keep_ir=False, compile_blob=True): keep_ir=False, compile_blob=True):
num_heads = model.model.layers[0].self_attn.num_heads num_heads = model.model.layers[0].self_attn.num_heads
num_key_value_heads = model.model.layers[0].self_attn.num_key_value_heads num_key_value_heads = model.model.layers[0].self_attn.num_key_value_heads
@ -333,12 +333,12 @@ def convert_minicpm_layer(model, layer_idx, n_splits_linear, n_splits_down_proj,
else: else:
input_len = kv_len input_len = kv_len
decoder_name = "decoder_layer_prefill" decoder_name = "decoder_layer_prefill"
layernorm_const = False const_parameter = False
single_decoder = LowBitMinicpmMultiDecoderlayer( single_decoder = LowBitMinicpmMultiDecoderlayer(
[1, input_len, num_heads * head_dim], [1, input_len, num_heads * head_dim],
input_layernorm_weights=[layer_norm_0] if layernorm_const else None, input_layernorm_weights=[layer_norm_0] if const_parameter else None,
post_attn_layernorm_weights=[layer_norm_1] if layernorm_const else None, post_attn_layernorm_weights=[layer_norm_1] if const_parameter else None,
cached_cos=cached_cos, cached_cos=cached_cos,
cached_sin=cached_sin, cached_sin=cached_sin,
num_heads=num_heads, num_heads=num_heads,
@ -364,7 +364,7 @@ def convert_minicpm_layer(model, layer_idx, n_splits_linear, n_splits_down_proj,
os.remove(os.path.join(temp_dir, decoder_name + ".bin")) os.remove(os.path.join(temp_dir, decoder_name + ".bin"))
if mode == "decode": if mode == "decode":
if layernorm_const: if const_parameter:
st_idx = 5 st_idx = 5
else: else:
input_lm_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_3.bin") input_lm_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_3.bin")
@ -394,7 +394,7 @@ def convert_minicpm_layer(model, layer_idx, n_splits_linear, n_splits_down_proj,
def convert_fused_minicpm_layer(model, fused_layers, n_splits_linear, n_splits_down_proj, def convert_fused_minicpm_layer(model, fused_layers, n_splits_linear, n_splits_down_proj,
save_dir, weight_dir, transpose_value_cache, kv_len, group_size, save_dir, weight_dir, transpose_value_cache, kv_len, group_size,
layernorm_const, mode="decode", const_parameter, mode="decode",
keep_ir=False, compile_blob=True): keep_ir=False, compile_blob=True):
num_heads = model.model.layers[0].self_attn.num_heads num_heads = model.model.layers[0].self_attn.num_heads
num_key_value_heads = model.model.layers[0].self_attn.num_key_value_heads num_key_value_heads = model.model.layers[0].self_attn.num_key_value_heads
@ -461,6 +461,10 @@ def convert_fused_minicpm_layer(model, fused_layers, n_splits_linear, n_splits_d
else: # FP16 Linear else: # FP16 Linear
np_dtype = np.float16 np_dtype = np.float16
if not const_parameter:
input_layer_norm_weights = None
post_attn_layernorm_weights = None
fused_decoder = LowBitMinicpmMultiDecoderlayer( fused_decoder = LowBitMinicpmMultiDecoderlayer(
[1, 1, num_heads * head_dim], [1, 1, num_heads * head_dim],
input_layernorm_weights=input_layer_norm_weights, input_layernorm_weights=input_layer_norm_weights,

View file

@ -117,7 +117,7 @@ def convert_lm_head_and_embedding(model, temp_dir, weight_dir,
def convert_qwen_layer(model, layer_idx, n_splits_linear, n_splits_down_proj, def convert_qwen_layer(model, layer_idx, n_splits_linear, n_splits_down_proj,
temp_dir, weight_dir, transpose_value_cache, kv_len, group_size, temp_dir, weight_dir, transpose_value_cache, kv_len, group_size,
layernorm_const, mode="decode", const_parameter, mode="decode",
keep_ir=False, compile_blob=True): keep_ir=False, compile_blob=True):
num_heads = model.model.layers[0].self_attn.num_heads num_heads = model.model.layers[0].self_attn.num_heads
num_key_value_heads = model.model.layers[0].self_attn.num_key_value_heads num_key_value_heads = model.model.layers[0].self_attn.num_key_value_heads
@ -193,7 +193,7 @@ def convert_qwen_layer(model, layer_idx, n_splits_linear, n_splits_down_proj,
# 0, 1, 2 are input_embed/attention_mask/position_id # 0, 1, 2 are input_embed/attention_mask/position_id
if mode == "decode": if mode == "decode":
if hasattr(curr_layer.self_attn.rotary_emb, "cos_cached"): if hasattr(curr_layer.self_attn.rotary_emb, "cos_cached"):
if layernorm_const: if const_parameter:
st_idx = 3 st_idx = 3
else: else:
input_lm_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_3.bin") input_lm_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_3.bin")
@ -203,7 +203,7 @@ def convert_qwen_layer(model, layer_idx, n_splits_linear, n_splits_down_proj,
st_idx = 5 st_idx = 5
else: else:
# transformers >= 4.45.0 # transformers >= 4.45.0
if layernorm_const: if const_parameter:
st_idx = 4 st_idx = 4
else: else:
input_lm_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_4.bin") input_lm_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_4.bin")
@ -241,7 +241,7 @@ def convert_qwen_layer(model, layer_idx, n_splits_linear, n_splits_down_proj,
def convert_fused_qwen_layer(model, fused_layers, n_splits_linear, n_splits_down_proj, def convert_fused_qwen_layer(model, fused_layers, n_splits_linear, n_splits_down_proj,
save_dir, weight_dir, transpose_value_cache, kv_len, group_size, save_dir, weight_dir, transpose_value_cache, kv_len, group_size,
layernorm_const, mode="decode", const_parameter, mode="decode",
keep_ir=False, compile_blob=True): keep_ir=False, compile_blob=True):
num_heads = model.model.layers[0].self_attn.num_heads num_heads = model.model.layers[0].self_attn.num_heads
num_key_value_heads = model.model.layers[0].self_attn.num_key_value_heads num_key_value_heads = model.model.layers[0].self_attn.num_key_value_heads
@ -325,6 +325,13 @@ def convert_fused_qwen_layer(model, fused_layers, n_splits_linear, n_splits_down
else: # FP16 Linear else: # FP16 Linear
np_dtype = np.float16 np_dtype = np.float16
if not const_parameter:
input_layer_norm_weights = None
post_attn_layernorm_weights = None
q_biases = None
k_biases = None
v_biases = None
fused_decoder = LowBitQwenMultiDecoderlayer( fused_decoder = LowBitQwenMultiDecoderlayer(
[1, 1, num_heads * head_dim], [1, 1, num_heads * head_dim],
input_layernorm_weights=input_layer_norm_weights, input_layernorm_weights=input_layer_norm_weights,