LLM: better FP16 support for Intel GPUs (#9791)
* initial support * fix * fix style * fix * limi esimd usage condition * refactor code * fix style * small fix * meet code review * small fix
This commit is contained in:
parent
7d9f6c6efc
commit
99bddd3ab4
4 changed files with 153 additions and 75 deletions
|
|
@ -200,8 +200,8 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
|
|||
bias=has_bias,
|
||||
mp_group=mp_group,
|
||||
)
|
||||
device_type = module.qweight.data.device.type
|
||||
invalidInputError(device_type != "meta",
|
||||
device = module.qweight.data.device
|
||||
invalidInputError(device.type != "meta",
|
||||
"converting from meta device is not supported")
|
||||
# Copy the weights
|
||||
paramsLowBit = FP4Params(data=convert_gptq(module, awq=is_awq),
|
||||
|
|
@ -209,11 +209,11 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
|
|||
quantized=True,
|
||||
_shape=(out_features, in_features),
|
||||
convert_shape_only=convert_shape_only,
|
||||
qtype=qtype).to(device_type)
|
||||
qtype=qtype).to(device)
|
||||
new_linear._parameters['weight'] = paramsLowBit
|
||||
if has_bias:
|
||||
new_linear._parameters['bias'] = nn.Parameter(module.bias.data)\
|
||||
.to(device_type)
|
||||
.to(device)
|
||||
elif qtype not in [ggml_tensor_qtype["fp16"], ggml_tensor_qtype["bf16"]]:
|
||||
new_linear = LowBitLinear(
|
||||
in_features,
|
||||
|
|
@ -223,44 +223,39 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
|
|||
mp_group=mp_group,
|
||||
)
|
||||
|
||||
device_type = module.weight.data.device.type
|
||||
device = module.weight.data.device
|
||||
# Copy the weights
|
||||
paramsLowBit = FP4Params(data=module.weight.data,
|
||||
requires_grad=False,
|
||||
quantized=False,
|
||||
_shape=None,
|
||||
convert_shape_only=convert_shape_only,
|
||||
qtype=qtype).to(device_type)
|
||||
qtype=qtype).to(device)
|
||||
new_linear._parameters['weight'] = paramsLowBit
|
||||
if module.bias is not None:
|
||||
new_linear._parameters['bias'] = nn.Parameter(module.bias.data)\
|
||||
.to(device_type)
|
||||
.to(device)
|
||||
elif qtype == ggml_tensor_qtype["fp16"]:
|
||||
# only support two size now
|
||||
# may generalize to other sizes
|
||||
if module.in_features in [4096, 11008]:
|
||||
# esimd fp16 path
|
||||
new_linear = FP16Linear(
|
||||
in_features,
|
||||
out_features,
|
||||
qtype,
|
||||
module.bias is not None,
|
||||
mp_group=mp_group,
|
||||
)
|
||||
device_type = module.weight.data.device.type
|
||||
|
||||
# convert here
|
||||
m, n = module.weight.data.shape
|
||||
if module.in_features == 11008:
|
||||
trans_weight = module.weight.data.reshape(m//8, 8, n)
|
||||
trans_weight = trans_weight.transpose(1, 2).contiguous()
|
||||
elif module.in_features == 4096:
|
||||
trans_weight = module.weight.data.reshape(m//16, 16, n)
|
||||
trans_weight = trans_weight.transpose(1, 2).contiguous()
|
||||
new_linear._parameters['weight'] = nn.Parameter(trans_weight)
|
||||
if module.bias is not None:
|
||||
new_linear._parameters['bias'] = nn.Parameter(module.bias.data)\
|
||||
.to(device_type)
|
||||
module.to(torch.float16)
|
||||
new_linear = FP16Linear(
|
||||
in_features,
|
||||
out_features,
|
||||
module.bias is not None,
|
||||
mp_group=mp_group,
|
||||
)
|
||||
device = module.weight.data.device
|
||||
from bigdl.llm.transformers.utils import get_ipex_version
|
||||
if get_ipex_version() < "2.1.10+xpu":
|
||||
new_linear._parameters['weight'] = nn.Parameter(module.weight)
|
||||
else:
|
||||
# only from 2.1, ipex provides matmul_bias_out
|
||||
# so we need to transpose weight
|
||||
new_weight = module.weight.transpose(0, 1).contiguous()
|
||||
new_linear._parameters['weight'] = nn.Parameter(new_weight)
|
||||
new_linear.weight_type = 2
|
||||
if module.bias is not None:
|
||||
new_linear._parameters['bias'] = nn.Parameter(module.bias.data)\
|
||||
.to(device)
|
||||
elif qtype == ggml_tensor_qtype["bf16"]:
|
||||
module.to(torch.bfloat16)
|
||||
new_linear = BF16Linear(
|
||||
|
|
@ -269,12 +264,12 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
|
|||
module.bias is not None,
|
||||
mp_group=mp_group,
|
||||
)
|
||||
device_type = module.weight.data.device.type
|
||||
device = module.weight.data.device
|
||||
# convert here
|
||||
new_linear._parameters['weight'] = nn.Parameter(module.weight)
|
||||
if module.bias is not None:
|
||||
new_linear._parameters['bias'] = nn.Parameter(module.bias.data)\
|
||||
.to(device_type)
|
||||
.to(device)
|
||||
|
||||
if new_linear is not None:
|
||||
if not module.training:
|
||||
|
|
|
|||
|
|
@ -50,7 +50,8 @@ from torch import Tensor, device, dtype, nn
|
|||
from operator import mul
|
||||
from functools import reduce
|
||||
from bigdl.llm.transformers.xpu_customize_fwd import custom_fwd, custom_bwd
|
||||
from bigdl.llm.transformers.utils import get_autocast_dtype
|
||||
from bigdl.llm.transformers.utils import get_autocast_dtype, get_xpu_device_type, \
|
||||
get_ipex_version
|
||||
|
||||
T = TypeVar("T", bound="torch.nn.Module")
|
||||
|
||||
|
|
@ -538,57 +539,111 @@ class LowBitLinear(nn.Linear):
|
|||
|
||||
|
||||
class FP16Linear(nn.Linear):
|
||||
def __init__(self, input_features, output_features, qtype, bias=True,
|
||||
conver_to_half=True, mp_group=None):
|
||||
def __init__(self, input_features, output_features, bias=True,
|
||||
mp_group=None, weight_type=1):
|
||||
super().__init__(input_features, output_features, bias)
|
||||
self.in_len = input_features
|
||||
self.out_len = output_features
|
||||
self.weight_shape = (self.out_len, self.in_len)
|
||||
self.weight_length = self.out_len * self.in_len
|
||||
self.qtype = qtype
|
||||
self.conver_to_half = conver_to_half
|
||||
self.qtype = ggml_tensor_qtype["fp16"]
|
||||
self.mp_group = mp_group
|
||||
# weigh_type = 1 means original weight
|
||||
# weigh_type = 2 means weight has been transposed
|
||||
# weigh_type = 3 means weight has been transposed by esimd method
|
||||
self.weight_type = 1
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
if self.bias is not None and self.bias.dtype != x.dtype:
|
||||
self.bias.data = self.bias.data.to(x.dtype)
|
||||
|
||||
x_shape = x.shape
|
||||
x_2d = x.view(-1, x_shape[-1])
|
||||
|
||||
x0 = self.weight.data
|
||||
# only work for GPU
|
||||
invalidInputError(x0.device.type == "xpu",
|
||||
"FP16 only works for GPU")
|
||||
try:
|
||||
import intel_extension_for_pytorch
|
||||
import linear_fp16_esimd
|
||||
except ModuleNotFoundError:
|
||||
invalidInputError(False,
|
||||
"Please `pip install bigdl_core_xe` first.")
|
||||
invalidInputError(x.device.type == "xpu",
|
||||
"FP16Linear only works for Intel GPUs")
|
||||
x = x.to(torch.float16)
|
||||
if self.bias is not None and self.bias.dtype != x.dtype:
|
||||
self.bias.data = self.bias.data.to(x.dtype)
|
||||
if self.weight is not None and self.weight.dtype != x.dtype:
|
||||
self.weight.data = self.weight.data.to(x.dtype)
|
||||
|
||||
if x_2d.is_contiguous() is False:
|
||||
x_2d = x_2d.contiguous()
|
||||
|
||||
if x_2d.shape[0] > 1:
|
||||
# first token or batch size > 1, re-convert weight
|
||||
original_weight = self.weight.data.transpose(1, 2)
|
||||
original_weight = original_weight.reshape(self.out_len, self.in_len)
|
||||
result = F.linear(x_2d, original_weight.contiguous())
|
||||
del original_weight
|
||||
if not self.use_esimd_kernel(x):
|
||||
if get_ipex_version() < "2.1.10+xpu":
|
||||
if self.weight_type == 2:
|
||||
self.weight = self.weight.transpose(0, 1).contiguous()
|
||||
self.weight_type = 1
|
||||
return F.linear(x, self.weight, self.bias)
|
||||
else:
|
||||
if self.weight_type == 1:
|
||||
self.weight = self.weight.transpose(0, 1).contiguous()
|
||||
self.weight_type = 2
|
||||
return torch.ops.torch_ipex.matmul_bias_out(x, self.weight, self.bias)
|
||||
else:
|
||||
# rest token, use esimd optimization
|
||||
result = linear_fp16_esimd.forward(x_2d, self.weight.data)
|
||||
if self.weight_type != 3:
|
||||
# convert weight first to use esimd fp16 kernel
|
||||
self.convert_weight_for_esimd_kernel()
|
||||
# esimd fp16 kernel for inference
|
||||
x_shape = x.shape
|
||||
x_2d = x.view(-1, x_shape[-1])
|
||||
if x_2d.is_contiguous() is False:
|
||||
x_2d = x_2d.contiguous()
|
||||
|
||||
new_shape = x_shape[:-1] + (self.out_len,)
|
||||
result = result.view(new_shape)
|
||||
if self.mp_group is not None:
|
||||
from deepspeed import comm as dist
|
||||
dist.inference_all_reduce(result, group=self.mp_group)
|
||||
if self.bias is not None:
|
||||
result += self.bias
|
||||
try:
|
||||
import intel_extension_for_pytorch
|
||||
import linear_fp16_esimd
|
||||
except ModuleNotFoundError:
|
||||
invalidInputError(False,
|
||||
"Please `pip install bigdl_core_xe_esimd` first.")
|
||||
|
||||
return result.to(x.dtype)
|
||||
if x_2d.shape[0] > 1:
|
||||
# first token or batch size > 1, re-convert weight
|
||||
original_weight = self.weight.data.transpose(1, 2)
|
||||
original_weight = original_weight.reshape(self.out_len, self.in_len)
|
||||
result = F.linear(x_2d, original_weight.contiguous())
|
||||
del original_weight
|
||||
else:
|
||||
# rest token, use esimd optimization
|
||||
result = linear_fp16_esimd.forward(x_2d, self.weight.data)
|
||||
|
||||
new_shape = x_shape[:-1] + (self.out_len,)
|
||||
result = result.view(new_shape)
|
||||
if self.mp_group is not None:
|
||||
from deepspeed import comm as dist
|
||||
dist.inference_all_reduce(result, group=self.mp_group)
|
||||
if self.bias is not None:
|
||||
result += self.bias
|
||||
|
||||
return result.to(x.dtype)
|
||||
|
||||
def use_esimd_kernel(self, x):
|
||||
gpu_type = get_xpu_device_type(x)
|
||||
# esimd kernel can only be used for Arc and Flex
|
||||
if gpu_type not in ["arc", "flex"]:
|
||||
return False
|
||||
# now esimd kernel can only be used for specific cases (llama2-7b shape)
|
||||
if self.in_len == 11008 and self.out_features == 4096:
|
||||
return True
|
||||
if self.in_len == 4096 and self.out_features in [4096, 11008]:
|
||||
# seems has some issue with Mistral,
|
||||
# need a further look to check whether can be used for other out features
|
||||
return True
|
||||
return False
|
||||
|
||||
def convert_weight_for_esimd_kernel(self):
|
||||
m, n = self.out_len, self.in_len
|
||||
if self.in_len == 11008:
|
||||
if self.weight_type == 2:
|
||||
trans_weight = self.weight.data.transpose(0, 1)
|
||||
else:
|
||||
trans_weight = self.weight.data
|
||||
trans_weight = trans_weight.data.reshape(m//8, 8, n)
|
||||
trans_weight = trans_weight.transpose(1, 2).contiguous()
|
||||
self.weight.data = trans_weight
|
||||
elif self.in_len == 4096:
|
||||
if self.weight_type == 2:
|
||||
trans_weight = self.weight.data.transpose(0, 1)
|
||||
else:
|
||||
trans_weight = self.weight.data
|
||||
trans_weight = trans_weight.data.reshape(m//16, 16, n)
|
||||
trans_weight = trans_weight.transpose(1, 2).contiguous()
|
||||
self.weight.data = trans_weight
|
||||
self.weight_type = 3
|
||||
|
||||
|
||||
class BF16Linear(nn.Linear):
|
||||
|
|
|
|||
|
|
@ -100,8 +100,9 @@ def llama_mlp_forward(
|
|||
x: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
x_2d = x.view(-1, x.shape[-1])
|
||||
qtype = getattr(self.gate_proj, "qtype", None)
|
||||
if x_2d.shape[0] == 1 and x.device.type == 'xpu' \
|
||||
and self.gate_proj.qtype == ggml_tensor_qtype["sym_int4"] \
|
||||
and qtype == ggml_tensor_qtype["sym_int4"] \
|
||||
and not (self.training and x.requires_grad):
|
||||
import linear_q4_0
|
||||
if not x_2d.is_contiguous():
|
||||
|
|
@ -147,7 +148,8 @@ def llama_attention_forward_4_31(
|
|||
|
||||
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
|
||||
enough_kv_room = is_enough_kv_cache_room_4_31(past_key_value)
|
||||
is_q4_0 = self.q_proj.qtype == SYM_INT4
|
||||
qtype = getattr(self.q_proj, "qtype", None)
|
||||
is_q4_0 = qtype == SYM_INT4
|
||||
no_tp = not self.config.pretraining_tp > 1
|
||||
decoding_fast_path = (no_tp and is_q4_0 and use_fuse_rope and
|
||||
enough_kv_room and bsz * q_len == 1)
|
||||
|
|
|
|||
|
|
@ -149,3 +149,29 @@ def get_autocast_dtype(x):
|
|||
else:
|
||||
invalidInputError(False,
|
||||
f"Device {x.device} is not supported.")
|
||||
|
||||
|
||||
_ipex_version = None
|
||||
|
||||
|
||||
def get_ipex_version():
|
||||
|
||||
global _ipex_version
|
||||
if _ipex_version is not None:
|
||||
return _ipex_version
|
||||
|
||||
import intel_extension_for_pytorch as ipex
|
||||
_ipex_version = ipex.__version__
|
||||
return _ipex_version
|
||||
|
||||
|
||||
def get_xpu_device_type(x):
|
||||
name = torch.xpu.get_device_name(x.device.index)
|
||||
if name.startswith("Intel(R) Arc(TM) A"):
|
||||
return "arc"
|
||||
elif name.startswith("Intel(R) Data Center GPU Flex"):
|
||||
return "flex"
|
||||
elif name.startswith("Intel(R) Data Center GPU Max"):
|
||||
return "pvc"
|
||||
else:
|
||||
return "others"
|
||||
|
|
|
|||
Loading…
Reference in a new issue