LLM: fix qwen AutoTP (#10766)
This commit is contained in:
parent
3e2662c87e
commit
0a62933d36
1 changed files with 2 additions and 1 deletions
|
|
@ -617,7 +617,8 @@ def _optimize_pre(model):
|
||||||
if "QWenAttention" in module.__class__.__name__:
|
if "QWenAttention" in module.__class__.__name__:
|
||||||
c_attn_weight = module.c_attn.weight.data
|
c_attn_weight = module.c_attn.weight.data
|
||||||
c_attn_bias = module.c_attn.bias.data
|
c_attn_bias = module.c_attn.bias.data
|
||||||
projection_size = module.projection_size
|
# Compatible with AutoTP case
|
||||||
|
projection_size = c_attn_weight.shape[0] // 3
|
||||||
hid_size = module.hidden_size
|
hid_size = module.hidden_size
|
||||||
with init_empty_weights():
|
with init_empty_weights():
|
||||||
q_proj = torch.nn.Linear(hid_size, projection_size)
|
q_proj = torch.nn.Linear(hid_size, projection_size)
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue