LLM: fix AttributeError of FP16Linear (#10740)
This commit is contained in:
		
							parent
							
								
									1256a2cc4e
								
							
						
					
					
						commit
						70ed9397f9
					
				
					 1 changed files with 4 additions and 2 deletions
				
			
		| 
						 | 
				
			
			@ -729,7 +729,7 @@ class LowBitLinear(nn.Linear):
 | 
			
		|||
 | 
			
		||||
class FP16Linear(nn.Linear):
 | 
			
		||||
    def __init__(self, input_features, output_features, bias=True,
 | 
			
		||||
                 mp_group=None, weight_type=1,
 | 
			
		||||
                 mp_group=None, weight_type=1, enable_xetla=False,
 | 
			
		||||
                 optimize_lm_head=False):
 | 
			
		||||
        super().__init__(input_features, output_features, bias)
 | 
			
		||||
        self.in_len = input_features
 | 
			
		||||
| 
						 | 
				
			
			@ -743,6 +743,7 @@ class FP16Linear(nn.Linear):
 | 
			
		|||
        # weigh_type = 3 means weight has been transposed by esimd method
 | 
			
		||||
        self.weight_type = 1
 | 
			
		||||
        self.optimize_lm_head = optimize_lm_head
 | 
			
		||||
        self.enable_xetla = enable_xetla
 | 
			
		||||
 | 
			
		||||
    def forward(self, x: torch.Tensor):
 | 
			
		||||
        # only work for GPU
 | 
			
		||||
| 
						 | 
				
			
			@ -849,7 +850,7 @@ class FP16Linear(nn.Linear):
 | 
			
		|||
 | 
			
		||||
class BF16Linear(nn.Linear):
 | 
			
		||||
    def __init__(self, input_features, output_features, bias=True,
 | 
			
		||||
                 mp_group=None, compute_dtype=None,
 | 
			
		||||
                 mp_group=None, compute_dtype=None, enable_xetla=False,
 | 
			
		||||
                 optimize_lm_head=False):
 | 
			
		||||
        super().__init__(input_features, output_features, bias)
 | 
			
		||||
        self.in_len = input_features
 | 
			
		||||
| 
						 | 
				
			
			@ -860,6 +861,7 @@ class BF16Linear(nn.Linear):
 | 
			
		|||
        self.mp_group = mp_group
 | 
			
		||||
        self.compute_dtype = compute_dtype
 | 
			
		||||
        self.optimize_lm_head = optimize_lm_head
 | 
			
		||||
        self.enable_xetla = enable_xetla
 | 
			
		||||
 | 
			
		||||
    def forward(self, x: torch.Tensor):
 | 
			
		||||
        if self.optimize_lm_head:
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue