[LLM] Add support for empty activation (#9664)
* Add support for empty activation, e.g., [0, 4096]. Empty activation is allowed by PyTorch. * Add comments.
This commit is contained in:
		
							parent
							
								
									284e7697b1
								
							
						
					
					
						commit
						5b0e7e308c
					
				
					 1 changed files with 10 additions and 3 deletions
				
			
		| 
						 | 
				
			
			@ -448,9 +448,18 @@ class LowBitLinear(nn.Linear):
 | 
			
		|||
        if self.bias is not None and self.bias.dtype != x.dtype:
 | 
			
		||||
            self.bias.data = self.bias.data.to(x.dtype)
 | 
			
		||||
 | 
			
		||||
        # [batch, input_num, in_len]
 | 
			
		||||
        # input_num == token num for Transformer
 | 
			
		||||
        x_shape = x.shape
 | 
			
		||||
        x_2d = x.view(-1, x_shape[-1])
 | 
			
		||||
        # Output shape, e.g., [batch, input_num, out_len]
 | 
			
		||||
        new_shape = x_shape[:-1] + (self.out_len,)
 | 
			
		||||
        # Activation is empty tensor, e.g., [1, 0, 4096]
 | 
			
		||||
        if 0 in x_shape:
 | 
			
		||||
            # return empty tensor with output shape, x.dtype and x.device
 | 
			
		||||
            return torch.empty(new_shape, dtype=x.dtype, device=x.device)
 | 
			
		||||
 | 
			
		||||
        x_2d = x.view(-1, x_shape[-1])
 | 
			
		||||
        # x0 for weight
 | 
			
		||||
        x0 = self.weight.data
 | 
			
		||||
 | 
			
		||||
        if x0.device.type == "xpu":
 | 
			
		||||
| 
						 | 
				
			
			@ -489,7 +498,6 @@ class LowBitLinear(nn.Linear):
 | 
			
		|||
                else:
 | 
			
		||||
                    result = linear_q4_0.forward_new(x_2d, self.weight.data, self.weight.qtype,
 | 
			
		||||
                                                     input_seq_size)
 | 
			
		||||
            new_shape = x_shape[:-1] + (self.out_len,)
 | 
			
		||||
            result = result.view(new_shape)
 | 
			
		||||
            if self.mp_group is not None:
 | 
			
		||||
                from deepspeed import comm as dist
 | 
			
		||||
| 
						 | 
				
			
			@ -514,7 +522,6 @@ class LowBitLinear(nn.Linear):
 | 
			
		|||
                else:
 | 
			
		||||
                    # Weight does not need a convert
 | 
			
		||||
                    result = ggml_matmul_src1_x_src0_t(x0, x_2d, self.weight_shape, self.qtype)
 | 
			
		||||
                    new_shape = x_shape[:-1] + (self.out_len,)
 | 
			
		||||
                    result = result.view(new_shape)
 | 
			
		||||
            # allreduce to combine partial results and add bias if necessary
 | 
			
		||||
            if self.mp_group is not None:
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue