diff --git a/python/llm/src/ipex_llm/transformers/low_bit_linear.py b/python/llm/src/ipex_llm/transformers/low_bit_linear.py index fef2166c..92d82f6a 100644 --- a/python/llm/src/ipex_llm/transformers/low_bit_linear.py +++ b/python/llm/src/ipex_llm/transformers/low_bit_linear.py @@ -500,16 +500,16 @@ class MatMulLowBit(torch.autograd.Function): @staticmethod @custom_fwd - def forward(ctx, A, weight, input_seq_size): + def forward(ctx, A, weight, output_size): ctx.is_empty = False import xe_linear if weight.qtype == NF4: result = xe_linear.forward_new(A, weight.data.view(torch.uint8), weight.qtype, - input_seq_size) + output_size) else: - result = xe_linear.forward_new(A, weight.data, weight.qtype, input_seq_size) + result = xe_linear.forward_new(A, weight.data, weight.qtype, output_size) if any(ctx.needs_input_grad[:2]): ctx.tensors = (A, weight) else: @@ -627,89 +627,50 @@ class LowBitLinear(nn.Linear): if self.optimize_lm_head: x = reshape_lm_head_input(x) - # [batch, input_num, in_len] - # input_num == token num for Transformer - x_shape = x.shape - # Output shape, e.g., [batch, input_num, out_len] - new_shape = x_shape[:-1] + (self.out_len,) + # [batch, seq_len, in_len] -> [batch, seq_len, out_len] + new_shape = x.shape[:-1] + (self.out_len,) + # Activation is empty tensor, e.g., [1, 0, 4096] - if 0 in x_shape: + if 0 in x.shape: # return empty tensor with output shape, x.dtype and x.device return torch.empty(new_shape, dtype=x.dtype, device=x.device) - x_2d = x.contiguous().view(-1, x_shape[-1]) - if self.act_order: - x_2d = x_2d[:, self.g_idx_map] - # x0 for weight - x0 = self.weight.data + x = x[..., self.g_idx_map] - if x0.device.type == "xpu": - # GPU logic - try: - import xe_linear - from ipex_llm.transformers.models.utils import use_xmx - except ModuleNotFoundError: - invalidInputError(False, - "Please `pip install bigdl_core_xe` first.") + x_2d = x.contiguous().view(-1, x.shape[-1]) - if x_2d.is_contiguous() is False: - x_2d = x_2d.contiguous() - - if len(x_shape) == 3: - input_seq_size = x_shape[1] - elif len(x_shape) < 3: - input_seq_size = 1 - - if is_training: - # training path - if x_2d.requires_grad: - result = MatMulLowBit.apply(x_2d, self.weight, input_seq_size) - else: - if self.weight.qtype == NF4: - result = xe_linear.forward_new(x_2d, - self.weight.data.view(torch.uint8), - self.weight.qtype, - input_seq_size) - else: - result = xe_linear.forward_new(x_2d, - self.weight.data, - self.weight.qtype, - input_seq_size) + if self.weight.device.type == "xpu": + if is_training and x_2d.requires_grad: + result = MatMulLowBit.apply(x_2d, self.weight, self.out_len) else: - # inference path - # current workaround to reduce first token latency of fp32 input - # sometimes fp16 cause nan and training instability - # disable the conversion when training - # TODO: may modify the input length condition for empty cache. do_empty_cache = self.low_memory_mode and x_2d.shape[0] >= 1024 if do_empty_cache: torch.xpu.empty_cache() + if self.qtype == NF4: + w = self.weight.data.view(torch.uint8) + else: + w = self.weight.data + if use_batch_forward(x_2d, self.weight.qtype, self.out_len): import xe_batch - result = xe_batch.batch_forward(x_2d, self.weight.data, self.weight.qtype) - elif ( - self.conver_to_half - and x_2d.shape[0] > 1 - and x_2d.dtype == torch.float32 - and not use_xmx(x_2d, self.weight.qtype) - ): + result = xe_batch.batch_forward(x_2d, w, self.qtype) + elif not is_training and self.conver_to_half \ + and x_2d.shape[0] > 1 and x_2d.dtype == torch.float: + import xe_linear x_2d = x_2d.half() - result = xe_linear.forward_new(x_2d, self.weight.data, - self.weight.qtype, input_seq_size) + result = xe_linear.forward_new(x_2d, w, self.qtype, self.out_len) result = result.to(x.dtype) else: - if self.weight.qtype == NF4: - result = xe_linear.forward_new(x_2d, self.weight.data.view(torch.uint8), - self.weight.qtype, input_seq_size) - else: - result = xe_linear.forward_new(x_2d, self.weight.data, - self.weight.qtype, input_seq_size) + import xe_linear + result = xe_linear.forward_new(x_2d, w, self.qtype, self.out_len) if do_empty_cache: torch.xpu.empty_cache() + result = result.view(new_shape) + if self.mp_group is not None: if get_use_vllm(): result = self.mp_group.all_reduce(result) @@ -718,6 +679,7 @@ class LowBitLinear(nn.Linear): dist.inference_all_reduce(result, group=self.mp_group) else: invalidInputError(False, "mp_group is not None, but no supported backend found") + if self.bias is not None: result += self.bias else: @@ -731,7 +693,7 @@ class LowBitLinear(nn.Linear): result = MatMulLowBitCPU.apply(x, self.weight) else: from ipex_llm.utils.isa_checker import is_server, is_spr - + x0 = self.weight.data # convert if necessary, and compute a linear result if is_server() and (not is_spr()) and \ self.qtype == SYM_INT4 and x_2d.shape[0] >= TORCH_LINEAR_THRESHOLD: diff --git a/python/llm/src/ipex_llm/transformers/models/utils.py b/python/llm/src/ipex_llm/transformers/models/utils.py index 9cde0af4..0dbea2e9 100644 --- a/python/llm/src/ipex_llm/transformers/models/utils.py +++ b/python/llm/src/ipex_llm/transformers/models/utils.py @@ -259,19 +259,6 @@ def mlp_fusion_check(x, qtype, training): return True -def use_xmx(x: torch.Tensor, qtype: int): - device = get_xpu_device_name(x.device) - return ( - device in ["arc", "pvc"] - and qtype in [SYM_INT4, SYM_INT8, FP8E4, FP8E5, WOQ_INT4] - and ( - (device == "pvc" and 1 < x.size(0) <= 16) - or - (device != "pvc" and 1 < x.size(0) <= 64) - ) - ) - - def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: batch, num_key_value_heads, slen, head_dim = hidden_states.shape if n_rep == 1: diff --git a/python/llm/src/ipex_llm/transformers/xpu_ops.py b/python/llm/src/ipex_llm/transformers/xpu_ops.py index 9b9c740e..6369c594 100644 --- a/python/llm/src/ipex_llm/transformers/xpu_ops.py +++ b/python/llm/src/ipex_llm/transformers/xpu_ops.py @@ -20,9 +20,10 @@ import xe_batch import xe_addons -# @torch.library.register_fake("ipex_llm::forward_new") -# def _(x, weight, qtype, input_size): -# return ??? +@torch.library.register_fake("ipex_llm::forward_new") +def _(x, weight, qtype, output_size): + return torch.empty([x.size(0), output_size], + dtype=x.dtype, device=x.device) # @torch.library.register_fake("ipex_llm::dequant")