Fix AttributeError of qwen2-1.5B (#11990)

This commit is contained in:
binbin Deng 2024-09-02 17:55:10 +08:00 committed by GitHub
parent c48817bd43
commit a40ea7038d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -568,16 +568,28 @@ def run_decode(
attn_layer = curr_layer.self_attn
mlp_layer = curr_layer.mlp
weights = [
(attn_layer.q_proj.weight, attn_layer.q_proj.scale),
(attn_layer.k_proj.weight, attn_layer.k_proj.scale),
(attn_layer.v_proj.weight, attn_layer.v_proj.scale),
(attn_layer.o_proj.weight, attn_layer.o_proj.scale),
(mlp_layer.gate_proj.weight, mlp_layer.gate_proj.scale),
(mlp_layer.up_proj.weight, mlp_layer.up_proj.scale),
(mlp_layer.down_proj_0.weight, mlp_layer.down_proj_0.scale),
(mlp_layer.down_proj_1.weight, mlp_layer.down_proj_1.scale)
]
if model.config.intermediate_size == 8960:
# for qwen2-1.5b
weights = [
(attn_layer.q_proj.weight, attn_layer.q_proj.scale),
(attn_layer.k_proj.weight, attn_layer.k_proj.scale),
(attn_layer.v_proj.weight, attn_layer.v_proj.scale),
(attn_layer.o_proj.weight, attn_layer.o_proj.scale),
(mlp_layer.gate_proj.weight, mlp_layer.gate_proj.scale),
(mlp_layer.up_proj.weight, mlp_layer.up_proj.scale),
(mlp_layer.down_proj.weight, mlp_layer.down_proj.scale),
]
elif model.config.intermediate_size == 18944:
# for qwen2-7b
weights = [
(attn_layer.q_proj.weight, attn_layer.q_proj.scale),
(attn_layer.k_proj.weight, attn_layer.k_proj.scale),
(attn_layer.v_proj.weight, attn_layer.v_proj.scale),
(attn_layer.o_proj.weight, attn_layer.o_proj.scale),
(mlp_layer.gate_proj.weight, mlp_layer.gate_proj.scale),
(mlp_layer.down_proj_0.weight, mlp_layer.down_proj_0.scale),
(mlp_layer.down_proj_1.weight, mlp_layer.down_proj_1.scale)
]
cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16)