Qwen layernorm as input (#12309)

* qwen layernorm as input

* add group size
This commit is contained in:
Kai Huang 2024-11-04 09:51:15 +08:00 committed by GitHub
parent 94ce447794
commit c8679ad592
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 13 additions and 7 deletions

View file

@ -47,6 +47,7 @@ if __name__ == "__main__":
parser.add_argument("--n-predict", type=int, default=32, help="Max tokens to predict")
parser.add_argument("--max-context-len", type=int, default=1024)
parser.add_argument("--max-prompt-len", type=int, default=960)
parser.add_argument("--quantization_group_size", type=int, default=0)
parser.add_argument('--load_in_low_bit', type=str, default="sym_int4",
help='Load in low bit to use')
parser.add_argument("--disable-transpose-value-cache", action="store_true", default=False)
@ -62,6 +63,7 @@ if __name__ == "__main__":
load_in_low_bit=args.load_in_low_bit,
max_context_len=args.max_context_len,
max_prompt_len=args.max_prompt_len,
quantization_group_size=args.quantization_group_size,
torch_dtype=torch.float16,
attn_implementation="eager",
transpose_value_cache=not args.disable_transpose_value_cache,

View file

@ -149,8 +149,8 @@ def convert_qwen_layer(model, layer_idx, n_splits_linear, n_splits_down_proj,
single_decoder = LowBitQwenMultiDecoderlayer(
[1, 1, num_heads * head_dim],
input_layernorm_weights=[layer_norm_0],
post_attn_layernorm_weights=[layer_norm_1],
input_layernorm_weights=None,
post_attn_layernorm_weights=None,
q_biases=None,
k_biases=None,
v_biases=None,
@ -174,17 +174,21 @@ def convert_qwen_layer(model, layer_idx, n_splits_linear, n_splits_down_proj,
temp_dir)
# 0, 1, 2 are input_embed/attention_mask/position_id
q_bias_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_3.bin")
k_bias_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_4.bin")
v_bias_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_5.bin")
input_lm_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_3.bin")
post_lm_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_4.bin")
layer_norm_0.data.numpy().tofile(input_lm_bin_file)
layer_norm_1.data.numpy().tofile(post_lm_bin_file)
q_bias_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_5.bin")
k_bias_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_6.bin")
v_bias_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_7.bin")
q_bias.data.numpy().tofile(q_bias_bin_file)
k_bias.data.numpy().tofile(k_bias_bin_file)
v_bias.data.numpy().tofile(v_bias_bin_file)
# 6, 7 are past k/v
for idx, (weight, scale) in enumerate(weights):
bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{8+idx*2}.bin")
bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{10+idx*2}.bin")
weight.numpy().tofile(bin_file)
bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{8+idx*2+1}.bin")
bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{10+idx*2+1}.bin")
scale.numpy().tofile(bin_file)
del single_decoder