LLM: update fp16 Linear on ARC/FLEX (#10023)
This commit is contained in:
		
							parent
							
								
									a5c9dfdf91
								
							
						
					
					
						commit
						ccf8f613fb
					
				
					 1 changed files with 15 additions and 12 deletions
				
			
		| 
						 | 
				
			
			@ -575,7 +575,8 @@ class FP16Linear(nn.Linear):
 | 
			
		|||
                    self.weight_type = 2
 | 
			
		||||
                return torch.ops.torch_ipex.matmul_bias_out(x, self.weight, self.bias)
 | 
			
		||||
        else:
 | 
			
		||||
            if self.weight_type != 3:
 | 
			
		||||
            if self.in_len == 4096 and self.weight_type != 3 or \
 | 
			
		||||
                    self.in_len == 11008 and self.weight_type != 1:
 | 
			
		||||
                # convert weight first to use esimd fp16 kernel
 | 
			
		||||
                self.convert_weight_for_esimd_kernel()
 | 
			
		||||
            # esimd fp16 kernel for inference
 | 
			
		||||
| 
						 | 
				
			
			@ -591,14 +592,17 @@ class FP16Linear(nn.Linear):
 | 
			
		|||
                invalidInputError(False,
 | 
			
		||||
                                  "Please `pip install bigdl_core_xe_esimd` first.")
 | 
			
		||||
 | 
			
		||||
            if x_2d.shape[0] > 1:
 | 
			
		||||
                # first token or batch size > 1, re-convert weight
 | 
			
		||||
                original_weight = self.weight.data.transpose(1, 2)
 | 
			
		||||
                original_weight = original_weight.reshape(self.out_len, self.in_len)
 | 
			
		||||
                result = F.linear(x_2d, original_weight.contiguous())
 | 
			
		||||
                del original_weight
 | 
			
		||||
            if x_2d.shape[0] > 8:
 | 
			
		||||
                # first token or batch size > 8, re-convert weight
 | 
			
		||||
                if self.weight_type == 3:
 | 
			
		||||
                    original_weight = self.weight.data.transpose(1, 2)
 | 
			
		||||
                    original_weight = original_weight.reshape(self.out_len, self.in_len)
 | 
			
		||||
                    result = F.linear(x_2d, original_weight.contiguous())
 | 
			
		||||
                    del original_weight
 | 
			
		||||
                else:
 | 
			
		||||
                    result = F.linear(x_2d, self.weight)
 | 
			
		||||
            else:
 | 
			
		||||
                # rest token, use esimd optimization
 | 
			
		||||
                # batch size <= 8, use esimd optimization
 | 
			
		||||
                result = linear_fp16_esimd.forward(x_2d, self.weight.data)
 | 
			
		||||
 | 
			
		||||
            new_shape = x_shape[:-1] + (self.out_len,)
 | 
			
		||||
| 
						 | 
				
			
			@ -632,9 +636,8 @@ class FP16Linear(nn.Linear):
 | 
			
		|||
                trans_weight = self.weight.data.transpose(0, 1)
 | 
			
		||||
            else:
 | 
			
		||||
                trans_weight = self.weight.data
 | 
			
		||||
            trans_weight = trans_weight.data.reshape(m//8, 8, n)
 | 
			
		||||
            trans_weight = trans_weight.transpose(1, 2).contiguous()
 | 
			
		||||
            self.weight.data = trans_weight
 | 
			
		||||
            self.weight.data = trans_weight.contiguous()
 | 
			
		||||
            self.weight_type = 1
 | 
			
		||||
        elif self.in_len == 4096:
 | 
			
		||||
            if self.weight_type == 2:
 | 
			
		||||
                trans_weight = self.weight.data.transpose(0, 1)
 | 
			
		||||
| 
						 | 
				
			
			@ -643,7 +646,7 @@ class FP16Linear(nn.Linear):
 | 
			
		|||
            trans_weight = trans_weight.data.reshape(m//16, 16, n)
 | 
			
		||||
            trans_weight = trans_weight.transpose(1, 2).contiguous()
 | 
			
		||||
            self.weight.data = trans_weight
 | 
			
		||||
        self.weight_type = 3
 | 
			
		||||
            self.weight_type = 3
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class BF16Linear(nn.Linear):
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue