Fix shape error when run qwen1.5-14b using deepspeed autotp (#11420)
This commit is contained in:
		
							parent
							
								
									3b23de684a
								
							
						
					
					
						commit
						aacc1fd8c0
					
				
					 1 changed files with 8 additions and 5 deletions
				
			
		| 
						 | 
					@ -361,8 +361,8 @@ def merge_qkv(module: torch.nn.Module):
 | 
				
			||||||
def padding_mlp(module: torch.nn.Module):
 | 
					def padding_mlp(module: torch.nn.Module):
 | 
				
			||||||
    # for qwen 1.5 14B
 | 
					    # for qwen 1.5 14B
 | 
				
			||||||
    if isinstance(module, Qwen2MLP):
 | 
					    if isinstance(module, Qwen2MLP):
 | 
				
			||||||
        hidden_size = module.hidden_size
 | 
					        hidden_size = module.gate_proj.weight.shape[1]
 | 
				
			||||||
        intermediate_size = module.intermediate_size
 | 
					        intermediate_size = module.gate_proj.weight.shape[0]
 | 
				
			||||||
        padding_intermediate_size = (intermediate_size + 256 - 1) // 256 * 256
 | 
					        padding_intermediate_size = (intermediate_size + 256 - 1) // 256 * 256
 | 
				
			||||||
        if intermediate_size % 256 == 0:
 | 
					        if intermediate_size % 256 == 0:
 | 
				
			||||||
            return
 | 
					            return
 | 
				
			||||||
| 
						 | 
					@ -371,21 +371,24 @@ def padding_mlp(module: torch.nn.Module):
 | 
				
			||||||
        new_gate_weight = torch.zeros([padding_intermediate_size, hidden_size],
 | 
					        new_gate_weight = torch.zeros([padding_intermediate_size, hidden_size],
 | 
				
			||||||
                                      dtype=gate_weight.dtype, device=gate_weight.device)
 | 
					                                      dtype=gate_weight.dtype, device=gate_weight.device)
 | 
				
			||||||
        new_gate_weight[:intermediate_size, :] = gate_weight
 | 
					        new_gate_weight[:intermediate_size, :] = gate_weight
 | 
				
			||||||
        module.gate_proj.out_features = padding_intermediate_size
 | 
					        if hasattr(module.gate_proj, 'out_features'):
 | 
				
			||||||
 | 
					            module.gate_proj.out_features = padding_intermediate_size
 | 
				
			||||||
        module.gate_proj.weight = torch.nn.Parameter(new_gate_weight, requires_grad=False)
 | 
					        module.gate_proj.weight = torch.nn.Parameter(new_gate_weight, requires_grad=False)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        up_weight = module.up_proj.weight.data
 | 
					        up_weight = module.up_proj.weight.data
 | 
				
			||||||
        new_up_weight = torch.zeros([padding_intermediate_size, hidden_size],
 | 
					        new_up_weight = torch.zeros([padding_intermediate_size, hidden_size],
 | 
				
			||||||
                                    dtype=up_weight.dtype, device=up_weight.device)
 | 
					                                    dtype=up_weight.dtype, device=up_weight.device)
 | 
				
			||||||
        new_up_weight[:intermediate_size, :] = up_weight
 | 
					        new_up_weight[:intermediate_size, :] = up_weight
 | 
				
			||||||
        module.up_proj.out_features = padding_intermediate_size
 | 
					        if hasattr(module.gate_proj, 'out_features'):
 | 
				
			||||||
 | 
					            module.up_proj.out_features = padding_intermediate_size
 | 
				
			||||||
        module.up_proj.weight = torch.nn.Parameter(new_up_weight, requires_grad=False)
 | 
					        module.up_proj.weight = torch.nn.Parameter(new_up_weight, requires_grad=False)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        down_weight = module.down_proj.weight.data
 | 
					        down_weight = module.down_proj.weight.data
 | 
				
			||||||
        new_down_weight = torch.zeros([hidden_size, padding_intermediate_size],
 | 
					        new_down_weight = torch.zeros([hidden_size, padding_intermediate_size],
 | 
				
			||||||
                                      dtype=down_weight.dtype, device=down_weight.device)
 | 
					                                      dtype=down_weight.dtype, device=down_weight.device)
 | 
				
			||||||
        new_down_weight[:, :intermediate_size] = down_weight
 | 
					        new_down_weight[:, :intermediate_size] = down_weight
 | 
				
			||||||
        module.down_proj.in_features = padding_intermediate_size
 | 
					        if hasattr(module.gate_proj, 'out_features'):
 | 
				
			||||||
 | 
					            module.down_proj.in_features = padding_intermediate_size
 | 
				
			||||||
        module.down_proj.weight = torch.nn.Parameter(new_down_weight, requires_grad=False)
 | 
					        module.down_proj.weight = torch.nn.Parameter(new_down_weight, requires_grad=False)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue