fix fp16 linear (#12250)
This commit is contained in:
		
							parent
							
								
									e8cf7f32f5
								
							
						
					
					
						commit
						88dc120a4c
					
				
					 1 changed files with 2 additions and 1 deletions
				
			
		| 
						 | 
				
			
			@ -886,7 +886,8 @@ class FP16Linear(nn.Linear):
 | 
			
		|||
                    self.weight = torch.nn.Parameter(self.weight.transpose(0, 1).contiguous(),
 | 
			
		||||
                                                     requires_grad=False)
 | 
			
		||||
                    self.weight_type = 2
 | 
			
		||||
                result = torch.ops.torch_ipex.matmul_bias_out(x, self.weight, self.bias)
 | 
			
		||||
                result = torch.ops.torch_ipex.matmul_bias_out(x.contiguous(),
 | 
			
		||||
                                                              self.weight, self.bias)
 | 
			
		||||
            if self.mp_group is not None:
 | 
			
		||||
                if get_use_vllm():
 | 
			
		||||
                    result = self.mp_group.all_reduce(result)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue