refactor xpu linear forward (#12768)
This commit is contained in:
		
							parent
							
								
									413d6c2b66
								
							
						
					
					
						commit
						0237ffb302
					
				
					 3 changed files with 32 additions and 82 deletions
				
			
		| 
						 | 
				
			
			@ -500,16 +500,16 @@ class MatMulLowBit(torch.autograd.Function):
 | 
			
		|||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    @custom_fwd
 | 
			
		||||
    def forward(ctx, A, weight, input_seq_size):
 | 
			
		||||
    def forward(ctx, A, weight, output_size):
 | 
			
		||||
        ctx.is_empty = False
 | 
			
		||||
        import xe_linear
 | 
			
		||||
        if weight.qtype == NF4:
 | 
			
		||||
            result = xe_linear.forward_new(A,
 | 
			
		||||
                                           weight.data.view(torch.uint8),
 | 
			
		||||
                                           weight.qtype,
 | 
			
		||||
                                           input_seq_size)
 | 
			
		||||
                                           output_size)
 | 
			
		||||
        else:
 | 
			
		||||
            result = xe_linear.forward_new(A, weight.data, weight.qtype, input_seq_size)
 | 
			
		||||
            result = xe_linear.forward_new(A, weight.data, weight.qtype, output_size)
 | 
			
		||||
        if any(ctx.needs_input_grad[:2]):
 | 
			
		||||
            ctx.tensors = (A, weight)
 | 
			
		||||
        else:
 | 
			
		||||
| 
						 | 
				
			
			@ -627,89 +627,50 @@ class LowBitLinear(nn.Linear):
 | 
			
		|||
        if self.optimize_lm_head:
 | 
			
		||||
            x = reshape_lm_head_input(x)
 | 
			
		||||
 | 
			
		||||
        # [batch, input_num, in_len]
 | 
			
		||||
        # input_num == token num for Transformer
 | 
			
		||||
        x_shape = x.shape
 | 
			
		||||
        # Output shape, e.g., [batch, input_num, out_len]
 | 
			
		||||
        new_shape = x_shape[:-1] + (self.out_len,)
 | 
			
		||||
        # [batch, seq_len, in_len] -> [batch, seq_len, out_len]
 | 
			
		||||
        new_shape = x.shape[:-1] + (self.out_len,)
 | 
			
		||||
 | 
			
		||||
        # Activation is empty tensor, e.g., [1, 0, 4096]
 | 
			
		||||
        if 0 in x_shape:
 | 
			
		||||
        if 0 in x.shape:
 | 
			
		||||
            # return empty tensor with output shape, x.dtype and x.device
 | 
			
		||||
            return torch.empty(new_shape, dtype=x.dtype, device=x.device)
 | 
			
		||||
 | 
			
		||||
        x_2d = x.contiguous().view(-1, x_shape[-1])
 | 
			
		||||
 | 
			
		||||
        if self.act_order:
 | 
			
		||||
            x_2d = x_2d[:, self.g_idx_map]
 | 
			
		||||
        # x0 for weight
 | 
			
		||||
        x0 = self.weight.data
 | 
			
		||||
            x = x[..., self.g_idx_map]
 | 
			
		||||
 | 
			
		||||
        if x0.device.type == "xpu":
 | 
			
		||||
            # GPU logic
 | 
			
		||||
            try:
 | 
			
		||||
                import xe_linear
 | 
			
		||||
                from ipex_llm.transformers.models.utils import use_xmx
 | 
			
		||||
            except ModuleNotFoundError:
 | 
			
		||||
                invalidInputError(False,
 | 
			
		||||
                                  "Please `pip install bigdl_core_xe` first.")
 | 
			
		||||
        x_2d = x.contiguous().view(-1, x.shape[-1])
 | 
			
		||||
 | 
			
		||||
            if x_2d.is_contiguous() is False:
 | 
			
		||||
                x_2d = x_2d.contiguous()
 | 
			
		||||
 | 
			
		||||
            if len(x_shape) == 3:
 | 
			
		||||
                input_seq_size = x_shape[1]
 | 
			
		||||
            elif len(x_shape) < 3:
 | 
			
		||||
                input_seq_size = 1
 | 
			
		||||
 | 
			
		||||
            if is_training:
 | 
			
		||||
                # training path
 | 
			
		||||
                if x_2d.requires_grad:
 | 
			
		||||
                    result = MatMulLowBit.apply(x_2d, self.weight, input_seq_size)
 | 
			
		||||
        if self.weight.device.type == "xpu":
 | 
			
		||||
            if is_training and x_2d.requires_grad:
 | 
			
		||||
                result = MatMulLowBit.apply(x_2d, self.weight, self.out_len)
 | 
			
		||||
            else:
 | 
			
		||||
                    if self.weight.qtype == NF4:
 | 
			
		||||
                        result = xe_linear.forward_new(x_2d,
 | 
			
		||||
                                                       self.weight.data.view(torch.uint8),
 | 
			
		||||
                                                       self.weight.qtype,
 | 
			
		||||
                                                       input_seq_size)
 | 
			
		||||
                    else:
 | 
			
		||||
                        result = xe_linear.forward_new(x_2d,
 | 
			
		||||
                                                       self.weight.data,
 | 
			
		||||
                                                       self.weight.qtype,
 | 
			
		||||
                                                       input_seq_size)
 | 
			
		||||
            else:
 | 
			
		||||
                # inference path
 | 
			
		||||
                # current workaround to reduce first token latency of fp32 input
 | 
			
		||||
                # sometimes fp16 cause nan and training instability
 | 
			
		||||
                # disable the conversion when training
 | 
			
		||||
                # TODO: may modify the input length condition for empty cache.
 | 
			
		||||
                do_empty_cache = self.low_memory_mode and x_2d.shape[0] >= 1024
 | 
			
		||||
                if do_empty_cache:
 | 
			
		||||
                    torch.xpu.empty_cache()
 | 
			
		||||
 | 
			
		||||
                if self.qtype == NF4:
 | 
			
		||||
                    w = self.weight.data.view(torch.uint8)
 | 
			
		||||
                else:
 | 
			
		||||
                    w = self.weight.data
 | 
			
		||||
 | 
			
		||||
                if use_batch_forward(x_2d, self.weight.qtype, self.out_len):
 | 
			
		||||
                    import xe_batch
 | 
			
		||||
                    result = xe_batch.batch_forward(x_2d, self.weight.data, self.weight.qtype)
 | 
			
		||||
                elif (
 | 
			
		||||
                    self.conver_to_half
 | 
			
		||||
                    and x_2d.shape[0] > 1
 | 
			
		||||
                    and x_2d.dtype == torch.float32
 | 
			
		||||
                    and not use_xmx(x_2d, self.weight.qtype)
 | 
			
		||||
                ):
 | 
			
		||||
                    result = xe_batch.batch_forward(x_2d, w, self.qtype)
 | 
			
		||||
                elif not is_training and self.conver_to_half \
 | 
			
		||||
                        and x_2d.shape[0] > 1 and x_2d.dtype == torch.float:
 | 
			
		||||
                    import xe_linear
 | 
			
		||||
                    x_2d = x_2d.half()
 | 
			
		||||
                    result = xe_linear.forward_new(x_2d, self.weight.data,
 | 
			
		||||
                                                   self.weight.qtype, input_seq_size)
 | 
			
		||||
                    result = xe_linear.forward_new(x_2d, w, self.qtype, self.out_len)
 | 
			
		||||
                    result = result.to(x.dtype)
 | 
			
		||||
                else:
 | 
			
		||||
                    if self.weight.qtype == NF4:
 | 
			
		||||
                        result = xe_linear.forward_new(x_2d, self.weight.data.view(torch.uint8),
 | 
			
		||||
                                                       self.weight.qtype, input_seq_size)
 | 
			
		||||
                    else:
 | 
			
		||||
                        result = xe_linear.forward_new(x_2d, self.weight.data,
 | 
			
		||||
                                                       self.weight.qtype, input_seq_size)
 | 
			
		||||
                    import xe_linear
 | 
			
		||||
                    result = xe_linear.forward_new(x_2d, w, self.qtype, self.out_len)
 | 
			
		||||
 | 
			
		||||
                if do_empty_cache:
 | 
			
		||||
                    torch.xpu.empty_cache()
 | 
			
		||||
 | 
			
		||||
            result = result.view(new_shape)
 | 
			
		||||
 | 
			
		||||
            if self.mp_group is not None:
 | 
			
		||||
                if get_use_vllm():
 | 
			
		||||
                    result = self.mp_group.all_reduce(result)
 | 
			
		||||
| 
						 | 
				
			
			@ -718,6 +679,7 @@ class LowBitLinear(nn.Linear):
 | 
			
		|||
                    dist.inference_all_reduce(result, group=self.mp_group)
 | 
			
		||||
                else:
 | 
			
		||||
                    invalidInputError(False, "mp_group is not None, but no supported backend found")
 | 
			
		||||
 | 
			
		||||
            if self.bias is not None:
 | 
			
		||||
                result += self.bias
 | 
			
		||||
        else:
 | 
			
		||||
| 
						 | 
				
			
			@ -731,7 +693,7 @@ class LowBitLinear(nn.Linear):
 | 
			
		|||
                result = MatMulLowBitCPU.apply(x, self.weight)
 | 
			
		||||
            else:
 | 
			
		||||
                from ipex_llm.utils.isa_checker import is_server, is_spr
 | 
			
		||||
 | 
			
		||||
                x0 = self.weight.data
 | 
			
		||||
                # convert if necessary, and compute a linear result
 | 
			
		||||
                if is_server() and (not is_spr()) and \
 | 
			
		||||
                        self.qtype == SYM_INT4 and x_2d.shape[0] >= TORCH_LINEAR_THRESHOLD:
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -259,19 +259,6 @@ def mlp_fusion_check(x, qtype, training):
 | 
			
		|||
    return True
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def use_xmx(x: torch.Tensor, qtype: int):
 | 
			
		||||
    device = get_xpu_device_name(x.device)
 | 
			
		||||
    return (
 | 
			
		||||
        device in ["arc", "pvc"]
 | 
			
		||||
        and qtype in [SYM_INT4, SYM_INT8, FP8E4, FP8E5, WOQ_INT4]
 | 
			
		||||
        and (
 | 
			
		||||
            (device == "pvc" and 1 < x.size(0) <= 16)
 | 
			
		||||
            or
 | 
			
		||||
            (device != "pvc" and 1 < x.size(0) <= 64)
 | 
			
		||||
        )
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
 | 
			
		||||
    batch, num_key_value_heads, slen, head_dim = hidden_states.shape
 | 
			
		||||
    if n_rep == 1:
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -20,9 +20,10 @@ import xe_batch
 | 
			
		|||
import xe_addons
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# @torch.library.register_fake("ipex_llm::forward_new")
 | 
			
		||||
# def _(x, weight, qtype, input_size):
 | 
			
		||||
#     return ???
 | 
			
		||||
@torch.library.register_fake("ipex_llm::forward_new")
 | 
			
		||||
def _(x, weight, qtype, output_size):
 | 
			
		||||
    return torch.empty([x.size(0), output_size],
 | 
			
		||||
                       dtype=x.dtype, device=x.device)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# @torch.library.register_fake("ipex_llm::dequant")
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue