LLM: fix abnormal output of fp16 deepspeed autotp (#10558)
This commit is contained in:
		
							parent
							
								
									e619142a16
								
							
						
					
					
						commit
						92dfed77be
					
				
					 1 changed files with 6 additions and 2 deletions
				
			
		| 
						 | 
					@ -702,12 +702,16 @@ class FP16Linear(nn.Linear):
 | 
				
			||||||
                if self.weight_type == 2:
 | 
					                if self.weight_type == 2:
 | 
				
			||||||
                    self.weight = self.weight.transpose(0, 1).contiguous()
 | 
					                    self.weight = self.weight.transpose(0, 1).contiguous()
 | 
				
			||||||
                    self.weight_type = 1
 | 
					                    self.weight_type = 1
 | 
				
			||||||
                return F.linear(x, self.weight, self.bias)
 | 
					                result = F.linear(x, self.weight, self.bias)
 | 
				
			||||||
            else:
 | 
					            else:
 | 
				
			||||||
                if self.weight_type == 1:
 | 
					                if self.weight_type == 1:
 | 
				
			||||||
                    self.weight = self.weight.transpose(0, 1).contiguous()
 | 
					                    self.weight = self.weight.transpose(0, 1).contiguous()
 | 
				
			||||||
                    self.weight_type = 2
 | 
					                    self.weight_type = 2
 | 
				
			||||||
                return torch.ops.torch_ipex.matmul_bias_out(x, self.weight, self.bias)
 | 
					                result = torch.ops.torch_ipex.matmul_bias_out(x, self.weight, self.bias)
 | 
				
			||||||
 | 
					            if self.mp_group is not None:
 | 
				
			||||||
 | 
					                from deepspeed import comm as dist
 | 
				
			||||||
 | 
					                dist.inference_all_reduce(result, group=self.mp_group)
 | 
				
			||||||
 | 
					            return result
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            if self.in_len == 4096 and self.weight_type != 3 or \
 | 
					            if self.in_len == 4096 and self.weight_type != 3 or \
 | 
				
			||||||
                    self.in_len == 11008 and self.weight_type != 1:
 | 
					                    self.in_len == 11008 and self.weight_type != 1:
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue