From 99bddd3ab483cadfd85b3be7142857cf0eeeb29b Mon Sep 17 00:00:00 2001 From: Ruonan Wang Date: Thu, 28 Dec 2023 13:30:13 +0800 Subject: [PATCH] LLM: better FP16 support for Intel GPUs (#9791) * initial support * fix * fix style * fix * limi esimd usage condition * refactor code * fix style * small fix * meet code review * small fix --- .../llm/src/bigdl/llm/transformers/convert.py | 63 ++++----- .../bigdl/llm/transformers/low_bit_linear.py | 133 +++++++++++++----- .../bigdl/llm/transformers/models/llama.py | 6 +- .../llm/src/bigdl/llm/transformers/utils.py | 26 ++++ 4 files changed, 153 insertions(+), 75 deletions(-) diff --git a/python/llm/src/bigdl/llm/transformers/convert.py b/python/llm/src/bigdl/llm/transformers/convert.py index aed54457..a270905b 100644 --- a/python/llm/src/bigdl/llm/transformers/convert.py +++ b/python/llm/src/bigdl/llm/transformers/convert.py @@ -200,8 +200,8 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None, bias=has_bias, mp_group=mp_group, ) - device_type = module.qweight.data.device.type - invalidInputError(device_type != "meta", + device = module.qweight.data.device + invalidInputError(device.type != "meta", "converting from meta device is not supported") # Copy the weights paramsLowBit = FP4Params(data=convert_gptq(module, awq=is_awq), @@ -209,11 +209,11 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None, quantized=True, _shape=(out_features, in_features), convert_shape_only=convert_shape_only, - qtype=qtype).to(device_type) + qtype=qtype).to(device) new_linear._parameters['weight'] = paramsLowBit if has_bias: new_linear._parameters['bias'] = nn.Parameter(module.bias.data)\ - .to(device_type) + .to(device) elif qtype not in [ggml_tensor_qtype["fp16"], ggml_tensor_qtype["bf16"]]: new_linear = LowBitLinear( in_features, @@ -223,44 +223,39 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None, mp_group=mp_group, ) - device_type = module.weight.data.device.type + device = module.weight.data.device # Copy the weights paramsLowBit = FP4Params(data=module.weight.data, requires_grad=False, quantized=False, _shape=None, convert_shape_only=convert_shape_only, - qtype=qtype).to(device_type) + qtype=qtype).to(device) new_linear._parameters['weight'] = paramsLowBit if module.bias is not None: new_linear._parameters['bias'] = nn.Parameter(module.bias.data)\ - .to(device_type) + .to(device) elif qtype == ggml_tensor_qtype["fp16"]: - # only support two size now - # may generalize to other sizes - if module.in_features in [4096, 11008]: - # esimd fp16 path - new_linear = FP16Linear( - in_features, - out_features, - qtype, - module.bias is not None, - mp_group=mp_group, - ) - device_type = module.weight.data.device.type - - # convert here - m, n = module.weight.data.shape - if module.in_features == 11008: - trans_weight = module.weight.data.reshape(m//8, 8, n) - trans_weight = trans_weight.transpose(1, 2).contiguous() - elif module.in_features == 4096: - trans_weight = module.weight.data.reshape(m//16, 16, n) - trans_weight = trans_weight.transpose(1, 2).contiguous() - new_linear._parameters['weight'] = nn.Parameter(trans_weight) - if module.bias is not None: - new_linear._parameters['bias'] = nn.Parameter(module.bias.data)\ - .to(device_type) + module.to(torch.float16) + new_linear = FP16Linear( + in_features, + out_features, + module.bias is not None, + mp_group=mp_group, + ) + device = module.weight.data.device + from bigdl.llm.transformers.utils import get_ipex_version + if get_ipex_version() < "2.1.10+xpu": + new_linear._parameters['weight'] = nn.Parameter(module.weight) + else: + # only from 2.1, ipex provides matmul_bias_out + # so we need to transpose weight + new_weight = module.weight.transpose(0, 1).contiguous() + new_linear._parameters['weight'] = nn.Parameter(new_weight) + new_linear.weight_type = 2 + if module.bias is not None: + new_linear._parameters['bias'] = nn.Parameter(module.bias.data)\ + .to(device) elif qtype == ggml_tensor_qtype["bf16"]: module.to(torch.bfloat16) new_linear = BF16Linear( @@ -269,12 +264,12 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None, module.bias is not None, mp_group=mp_group, ) - device_type = module.weight.data.device.type + device = module.weight.data.device # convert here new_linear._parameters['weight'] = nn.Parameter(module.weight) if module.bias is not None: new_linear._parameters['bias'] = nn.Parameter(module.bias.data)\ - .to(device_type) + .to(device) if new_linear is not None: if not module.training: diff --git a/python/llm/src/bigdl/llm/transformers/low_bit_linear.py b/python/llm/src/bigdl/llm/transformers/low_bit_linear.py index 1aaf9e61..c63fd1e9 100644 --- a/python/llm/src/bigdl/llm/transformers/low_bit_linear.py +++ b/python/llm/src/bigdl/llm/transformers/low_bit_linear.py @@ -50,7 +50,8 @@ from torch import Tensor, device, dtype, nn from operator import mul from functools import reduce from bigdl.llm.transformers.xpu_customize_fwd import custom_fwd, custom_bwd -from bigdl.llm.transformers.utils import get_autocast_dtype +from bigdl.llm.transformers.utils import get_autocast_dtype, get_xpu_device_type, \ + get_ipex_version T = TypeVar("T", bound="torch.nn.Module") @@ -538,57 +539,111 @@ class LowBitLinear(nn.Linear): class FP16Linear(nn.Linear): - def __init__(self, input_features, output_features, qtype, bias=True, - conver_to_half=True, mp_group=None): + def __init__(self, input_features, output_features, bias=True, + mp_group=None, weight_type=1): super().__init__(input_features, output_features, bias) self.in_len = input_features self.out_len = output_features self.weight_shape = (self.out_len, self.in_len) self.weight_length = self.out_len * self.in_len - self.qtype = qtype - self.conver_to_half = conver_to_half + self.qtype = ggml_tensor_qtype["fp16"] self.mp_group = mp_group + # weigh_type = 1 means original weight + # weigh_type = 2 means weight has been transposed + # weigh_type = 3 means weight has been transposed by esimd method + self.weight_type = 1 def forward(self, x: torch.Tensor): - if self.bias is not None and self.bias.dtype != x.dtype: - self.bias.data = self.bias.data.to(x.dtype) - - x_shape = x.shape - x_2d = x.view(-1, x_shape[-1]) - - x0 = self.weight.data # only work for GPU - invalidInputError(x0.device.type == "xpu", - "FP16 only works for GPU") - try: - import intel_extension_for_pytorch - import linear_fp16_esimd - except ModuleNotFoundError: - invalidInputError(False, - "Please `pip install bigdl_core_xe` first.") + invalidInputError(x.device.type == "xpu", + "FP16Linear only works for Intel GPUs") + x = x.to(torch.float16) + if self.bias is not None and self.bias.dtype != x.dtype: + self.bias.data = self.bias.data.to(x.dtype) + if self.weight is not None and self.weight.dtype != x.dtype: + self.weight.data = self.weight.data.to(x.dtype) - if x_2d.is_contiguous() is False: - x_2d = x_2d.contiguous() - - if x_2d.shape[0] > 1: - # first token or batch size > 1, re-convert weight - original_weight = self.weight.data.transpose(1, 2) - original_weight = original_weight.reshape(self.out_len, self.in_len) - result = F.linear(x_2d, original_weight.contiguous()) - del original_weight + if not self.use_esimd_kernel(x): + if get_ipex_version() < "2.1.10+xpu": + if self.weight_type == 2: + self.weight = self.weight.transpose(0, 1).contiguous() + self.weight_type = 1 + return F.linear(x, self.weight, self.bias) + else: + if self.weight_type == 1: + self.weight = self.weight.transpose(0, 1).contiguous() + self.weight_type = 2 + return torch.ops.torch_ipex.matmul_bias_out(x, self.weight, self.bias) else: - # rest token, use esimd optimization - result = linear_fp16_esimd.forward(x_2d, self.weight.data) + if self.weight_type != 3: + # convert weight first to use esimd fp16 kernel + self.convert_weight_for_esimd_kernel() + # esimd fp16 kernel for inference + x_shape = x.shape + x_2d = x.view(-1, x_shape[-1]) + if x_2d.is_contiguous() is False: + x_2d = x_2d.contiguous() - new_shape = x_shape[:-1] + (self.out_len,) - result = result.view(new_shape) - if self.mp_group is not None: - from deepspeed import comm as dist - dist.inference_all_reduce(result, group=self.mp_group) - if self.bias is not None: - result += self.bias + try: + import intel_extension_for_pytorch + import linear_fp16_esimd + except ModuleNotFoundError: + invalidInputError(False, + "Please `pip install bigdl_core_xe_esimd` first.") - return result.to(x.dtype) + if x_2d.shape[0] > 1: + # first token or batch size > 1, re-convert weight + original_weight = self.weight.data.transpose(1, 2) + original_weight = original_weight.reshape(self.out_len, self.in_len) + result = F.linear(x_2d, original_weight.contiguous()) + del original_weight + else: + # rest token, use esimd optimization + result = linear_fp16_esimd.forward(x_2d, self.weight.data) + + new_shape = x_shape[:-1] + (self.out_len,) + result = result.view(new_shape) + if self.mp_group is not None: + from deepspeed import comm as dist + dist.inference_all_reduce(result, group=self.mp_group) + if self.bias is not None: + result += self.bias + + return result.to(x.dtype) + + def use_esimd_kernel(self, x): + gpu_type = get_xpu_device_type(x) + # esimd kernel can only be used for Arc and Flex + if gpu_type not in ["arc", "flex"]: + return False + # now esimd kernel can only be used for specific cases (llama2-7b shape) + if self.in_len == 11008 and self.out_features == 4096: + return True + if self.in_len == 4096 and self.out_features in [4096, 11008]: + # seems has some issue with Mistral, + # need a further look to check whether can be used for other out features + return True + return False + + def convert_weight_for_esimd_kernel(self): + m, n = self.out_len, self.in_len + if self.in_len == 11008: + if self.weight_type == 2: + trans_weight = self.weight.data.transpose(0, 1) + else: + trans_weight = self.weight.data + trans_weight = trans_weight.data.reshape(m//8, 8, n) + trans_weight = trans_weight.transpose(1, 2).contiguous() + self.weight.data = trans_weight + elif self.in_len == 4096: + if self.weight_type == 2: + trans_weight = self.weight.data.transpose(0, 1) + else: + trans_weight = self.weight.data + trans_weight = trans_weight.data.reshape(m//16, 16, n) + trans_weight = trans_weight.transpose(1, 2).contiguous() + self.weight.data = trans_weight + self.weight_type = 3 class BF16Linear(nn.Linear): diff --git a/python/llm/src/bigdl/llm/transformers/models/llama.py b/python/llm/src/bigdl/llm/transformers/models/llama.py index 538cf1b0..2300ea18 100644 --- a/python/llm/src/bigdl/llm/transformers/models/llama.py +++ b/python/llm/src/bigdl/llm/transformers/models/llama.py @@ -100,8 +100,9 @@ def llama_mlp_forward( x: torch.Tensor, ) -> torch.Tensor: x_2d = x.view(-1, x.shape[-1]) + qtype = getattr(self.gate_proj, "qtype", None) if x_2d.shape[0] == 1 and x.device.type == 'xpu' \ - and self.gate_proj.qtype == ggml_tensor_qtype["sym_int4"] \ + and qtype == ggml_tensor_qtype["sym_int4"] \ and not (self.training and x.requires_grad): import linear_q4_0 if not x_2d.is_contiguous(): @@ -147,7 +148,8 @@ def llama_attention_forward_4_31( use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids) enough_kv_room = is_enough_kv_cache_room_4_31(past_key_value) - is_q4_0 = self.q_proj.qtype == SYM_INT4 + qtype = getattr(self.q_proj, "qtype", None) + is_q4_0 = qtype == SYM_INT4 no_tp = not self.config.pretraining_tp > 1 decoding_fast_path = (no_tp and is_q4_0 and use_fuse_rope and enough_kv_room and bsz * q_len == 1) diff --git a/python/llm/src/bigdl/llm/transformers/utils.py b/python/llm/src/bigdl/llm/transformers/utils.py index e50f196f..4aabb485 100644 --- a/python/llm/src/bigdl/llm/transformers/utils.py +++ b/python/llm/src/bigdl/llm/transformers/utils.py @@ -149,3 +149,29 @@ def get_autocast_dtype(x): else: invalidInputError(False, f"Device {x.device} is not supported.") + + +_ipex_version = None + + +def get_ipex_version(): + + global _ipex_version + if _ipex_version is not None: + return _ipex_version + + import intel_extension_for_pytorch as ipex + _ipex_version = ipex.__version__ + return _ipex_version + + +def get_xpu_device_type(x): + name = torch.xpu.get_device_name(x.device.index) + if name.startswith("Intel(R) Arc(TM) A"): + return "arc" + elif name.startswith("Intel(R) Data Center GPU Flex"): + return "flex" + elif name.startswith("Intel(R) Data Center GPU Max"): + return "pvc" + else: + return "others"