update api usage of xe_batch & fp16 (#11164)
* update api usage * update setup.py
This commit is contained in:
		
							parent
							
								
									e29e2f1c78
								
							
						
					
					
						commit
						9bfbf78bf4
					
				
					 2 changed files with 6 additions and 14 deletions
				
			
		| 
						 | 
				
			
			@ -299,8 +299,7 @@ def setup_package():
 | 
			
		|||
                        "intel_extension_for_pytorch==2.1.10+xpu",
 | 
			
		||||
                        "bigdl-core-xe-21==" + CORE_XE_VERSION,
 | 
			
		||||
                        "bigdl-core-xe-batch-21==" + CORE_XE_VERSION,
 | 
			
		||||
                        "bigdl-core-xe-addons-21==" + CORE_XE_VERSION,
 | 
			
		||||
                        "bigdl-core-xe-esimd-21==" + CORE_XE_VERSION]
 | 
			
		||||
                        "bigdl-core-xe-addons-21==" + CORE_XE_VERSION]
 | 
			
		||||
    xpu_21_requires += oneapi_2024_0_requires
 | 
			
		||||
    # default to ipex 2.1 for linux and windows
 | 
			
		||||
    xpu_requires = copy.deepcopy(xpu_21_requires)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -720,8 +720,7 @@ class LowBitLinear(nn.Linear):
 | 
			
		|||
                    if use_batch_forward(x_2d, self.weight.qtype, self.out_len):
 | 
			
		||||
                        import xe_batch
 | 
			
		||||
                        result = xe_batch.batch_forward(x_2d, self.weight.data,
 | 
			
		||||
                                                        self.weight.qtype,
 | 
			
		||||
                                                        input_seq_size)
 | 
			
		||||
                                                        self.weight.qtype)
 | 
			
		||||
                    else:
 | 
			
		||||
                        result = xe_linear.forward_new(x_2d, self.weight.data, self.weight.qtype,
 | 
			
		||||
                                                       input_seq_size)
 | 
			
		||||
| 
						 | 
				
			
			@ -730,8 +729,7 @@ class LowBitLinear(nn.Linear):
 | 
			
		|||
                    if use_batch_forward(x_2d, self.weight.qtype, self.out_len):
 | 
			
		||||
                        import xe_batch
 | 
			
		||||
                        result = xe_batch.batch_forward(x_2d, self.weight.data,
 | 
			
		||||
                                                        self.weight.qtype,
 | 
			
		||||
                                                        input_seq_size)
 | 
			
		||||
                                                        self.weight.qtype)
 | 
			
		||||
                    else:
 | 
			
		||||
                        result = xe_linear.forward_new(x_2d, self.weight.data, self.weight.qtype,
 | 
			
		||||
                                                       input_seq_size)
 | 
			
		||||
| 
						 | 
				
			
			@ -843,13 +841,6 @@ class FP16Linear(nn.Linear):
 | 
			
		|||
            if x_2d.is_contiguous() is False:
 | 
			
		||||
                x_2d = x_2d.contiguous()
 | 
			
		||||
 | 
			
		||||
            try:
 | 
			
		||||
                import intel_extension_for_pytorch
 | 
			
		||||
                import linear_fp16_esimd
 | 
			
		||||
            except ModuleNotFoundError:
 | 
			
		||||
                invalidInputError(False,
 | 
			
		||||
                                  "Please `pip install bigdl_core_xe_esimd` first.")
 | 
			
		||||
 | 
			
		||||
            if x_2d.shape[0] > 8:
 | 
			
		||||
                # first token or batch size > 8, re-convert weight
 | 
			
		||||
                if self.weight_type == 3:
 | 
			
		||||
| 
						 | 
				
			
			@ -861,7 +852,9 @@ class FP16Linear(nn.Linear):
 | 
			
		|||
                    result = F.linear(x_2d, self.weight)
 | 
			
		||||
            else:
 | 
			
		||||
                # batch size <= 8, use esimd optimization
 | 
			
		||||
                result = linear_fp16_esimd.forward(x_2d, self.weight.data)
 | 
			
		||||
                import xe_batch
 | 
			
		||||
                result = xe_batch.batch_forward(x_2d, self.weight.data,
 | 
			
		||||
                                                self.qtype)
 | 
			
		||||
 | 
			
		||||
            new_shape = x_shape[:-1] + (self.out_len,)
 | 
			
		||||
            result = result.view(new_shape)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue