LLM: better FP16 support for Intel GPUs (#9791)

* initial support

* fix

* fix style

* fix

* limi esimd usage condition

* refactor code

* fix style

* small fix

* meet code review

* small fix
This commit is contained in:
Ruonan Wang 2023-12-28 13:30:13 +08:00 committed by GitHub
parent 7d9f6c6efc
commit 99bddd3ab4
4 changed files with 153 additions and 75 deletions

View file

@ -200,8 +200,8 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
bias=has_bias,
mp_group=mp_group,
)
device_type = module.qweight.data.device.type
invalidInputError(device_type != "meta",
device = module.qweight.data.device
invalidInputError(device.type != "meta",
"converting from meta device is not supported")
# Copy the weights
paramsLowBit = FP4Params(data=convert_gptq(module, awq=is_awq),
@ -209,11 +209,11 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
quantized=True,
_shape=(out_features, in_features),
convert_shape_only=convert_shape_only,
qtype=qtype).to(device_type)
qtype=qtype).to(device)
new_linear._parameters['weight'] = paramsLowBit
if has_bias:
new_linear._parameters['bias'] = nn.Parameter(module.bias.data)\
.to(device_type)
.to(device)
elif qtype not in [ggml_tensor_qtype["fp16"], ggml_tensor_qtype["bf16"]]:
new_linear = LowBitLinear(
in_features,
@ -223,44 +223,39 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
mp_group=mp_group,
)
device_type = module.weight.data.device.type
device = module.weight.data.device
# Copy the weights
paramsLowBit = FP4Params(data=module.weight.data,
requires_grad=False,
quantized=False,
_shape=None,
convert_shape_only=convert_shape_only,
qtype=qtype).to(device_type)
qtype=qtype).to(device)
new_linear._parameters['weight'] = paramsLowBit
if module.bias is not None:
new_linear._parameters['bias'] = nn.Parameter(module.bias.data)\
.to(device_type)
.to(device)
elif qtype == ggml_tensor_qtype["fp16"]:
# only support two size now
# may generalize to other sizes
if module.in_features in [4096, 11008]:
# esimd fp16 path
new_linear = FP16Linear(
in_features,
out_features,
qtype,
module.bias is not None,
mp_group=mp_group,
)
device_type = module.weight.data.device.type
# convert here
m, n = module.weight.data.shape
if module.in_features == 11008:
trans_weight = module.weight.data.reshape(m//8, 8, n)
trans_weight = trans_weight.transpose(1, 2).contiguous()
elif module.in_features == 4096:
trans_weight = module.weight.data.reshape(m//16, 16, n)
trans_weight = trans_weight.transpose(1, 2).contiguous()
new_linear._parameters['weight'] = nn.Parameter(trans_weight)
if module.bias is not None:
new_linear._parameters['bias'] = nn.Parameter(module.bias.data)\
.to(device_type)
module.to(torch.float16)
new_linear = FP16Linear(
in_features,
out_features,
module.bias is not None,
mp_group=mp_group,
)
device = module.weight.data.device
from bigdl.llm.transformers.utils import get_ipex_version
if get_ipex_version() < "2.1.10+xpu":
new_linear._parameters['weight'] = nn.Parameter(module.weight)
else:
# only from 2.1, ipex provides matmul_bias_out
# so we need to transpose weight
new_weight = module.weight.transpose(0, 1).contiguous()
new_linear._parameters['weight'] = nn.Parameter(new_weight)
new_linear.weight_type = 2
if module.bias is not None:
new_linear._parameters['bias'] = nn.Parameter(module.bias.data)\
.to(device)
elif qtype == ggml_tensor_qtype["bf16"]:
module.to(torch.bfloat16)
new_linear = BF16Linear(
@ -269,12 +264,12 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
module.bias is not None,
mp_group=mp_group,
)
device_type = module.weight.data.device.type
device = module.weight.data.device
# convert here
new_linear._parameters['weight'] = nn.Parameter(module.weight)
if module.bias is not None:
new_linear._parameters['bias'] = nn.Parameter(module.bias.data)\
.to(device_type)
.to(device)
if new_linear is not None:
if not module.training:

View file

@ -50,7 +50,8 @@ from torch import Tensor, device, dtype, nn
from operator import mul
from functools import reduce
from bigdl.llm.transformers.xpu_customize_fwd import custom_fwd, custom_bwd
from bigdl.llm.transformers.utils import get_autocast_dtype
from bigdl.llm.transformers.utils import get_autocast_dtype, get_xpu_device_type, \
get_ipex_version
T = TypeVar("T", bound="torch.nn.Module")
@ -538,57 +539,111 @@ class LowBitLinear(nn.Linear):
class FP16Linear(nn.Linear):
def __init__(self, input_features, output_features, qtype, bias=True,
conver_to_half=True, mp_group=None):
def __init__(self, input_features, output_features, bias=True,
mp_group=None, weight_type=1):
super().__init__(input_features, output_features, bias)
self.in_len = input_features
self.out_len = output_features
self.weight_shape = (self.out_len, self.in_len)
self.weight_length = self.out_len * self.in_len
self.qtype = qtype
self.conver_to_half = conver_to_half
self.qtype = ggml_tensor_qtype["fp16"]
self.mp_group = mp_group
# weigh_type = 1 means original weight
# weigh_type = 2 means weight has been transposed
# weigh_type = 3 means weight has been transposed by esimd method
self.weight_type = 1
def forward(self, x: torch.Tensor):
if self.bias is not None and self.bias.dtype != x.dtype:
self.bias.data = self.bias.data.to(x.dtype)
x_shape = x.shape
x_2d = x.view(-1, x_shape[-1])
x0 = self.weight.data
# only work for GPU
invalidInputError(x0.device.type == "xpu",
"FP16 only works for GPU")
try:
import intel_extension_for_pytorch
import linear_fp16_esimd
except ModuleNotFoundError:
invalidInputError(False,
"Please `pip install bigdl_core_xe` first.")
invalidInputError(x.device.type == "xpu",
"FP16Linear only works for Intel GPUs")
x = x.to(torch.float16)
if self.bias is not None and self.bias.dtype != x.dtype:
self.bias.data = self.bias.data.to(x.dtype)
if self.weight is not None and self.weight.dtype != x.dtype:
self.weight.data = self.weight.data.to(x.dtype)
if x_2d.is_contiguous() is False:
x_2d = x_2d.contiguous()
if x_2d.shape[0] > 1:
# first token or batch size > 1, re-convert weight
original_weight = self.weight.data.transpose(1, 2)
original_weight = original_weight.reshape(self.out_len, self.in_len)
result = F.linear(x_2d, original_weight.contiguous())
del original_weight
if not self.use_esimd_kernel(x):
if get_ipex_version() < "2.1.10+xpu":
if self.weight_type == 2:
self.weight = self.weight.transpose(0, 1).contiguous()
self.weight_type = 1
return F.linear(x, self.weight, self.bias)
else:
if self.weight_type == 1:
self.weight = self.weight.transpose(0, 1).contiguous()
self.weight_type = 2
return torch.ops.torch_ipex.matmul_bias_out(x, self.weight, self.bias)
else:
# rest token, use esimd optimization
result = linear_fp16_esimd.forward(x_2d, self.weight.data)
if self.weight_type != 3:
# convert weight first to use esimd fp16 kernel
self.convert_weight_for_esimd_kernel()
# esimd fp16 kernel for inference
x_shape = x.shape
x_2d = x.view(-1, x_shape[-1])
if x_2d.is_contiguous() is False:
x_2d = x_2d.contiguous()
new_shape = x_shape[:-1] + (self.out_len,)
result = result.view(new_shape)
if self.mp_group is not None:
from deepspeed import comm as dist
dist.inference_all_reduce(result, group=self.mp_group)
if self.bias is not None:
result += self.bias
try:
import intel_extension_for_pytorch
import linear_fp16_esimd
except ModuleNotFoundError:
invalidInputError(False,
"Please `pip install bigdl_core_xe_esimd` first.")
return result.to(x.dtype)
if x_2d.shape[0] > 1:
# first token or batch size > 1, re-convert weight
original_weight = self.weight.data.transpose(1, 2)
original_weight = original_weight.reshape(self.out_len, self.in_len)
result = F.linear(x_2d, original_weight.contiguous())
del original_weight
else:
# rest token, use esimd optimization
result = linear_fp16_esimd.forward(x_2d, self.weight.data)
new_shape = x_shape[:-1] + (self.out_len,)
result = result.view(new_shape)
if self.mp_group is not None:
from deepspeed import comm as dist
dist.inference_all_reduce(result, group=self.mp_group)
if self.bias is not None:
result += self.bias
return result.to(x.dtype)
def use_esimd_kernel(self, x):
gpu_type = get_xpu_device_type(x)
# esimd kernel can only be used for Arc and Flex
if gpu_type not in ["arc", "flex"]:
return False
# now esimd kernel can only be used for specific cases (llama2-7b shape)
if self.in_len == 11008 and self.out_features == 4096:
return True
if self.in_len == 4096 and self.out_features in [4096, 11008]:
# seems has some issue with Mistral,
# need a further look to check whether can be used for other out features
return True
return False
def convert_weight_for_esimd_kernel(self):
m, n = self.out_len, self.in_len
if self.in_len == 11008:
if self.weight_type == 2:
trans_weight = self.weight.data.transpose(0, 1)
else:
trans_weight = self.weight.data
trans_weight = trans_weight.data.reshape(m//8, 8, n)
trans_weight = trans_weight.transpose(1, 2).contiguous()
self.weight.data = trans_weight
elif self.in_len == 4096:
if self.weight_type == 2:
trans_weight = self.weight.data.transpose(0, 1)
else:
trans_weight = self.weight.data
trans_weight = trans_weight.data.reshape(m//16, 16, n)
trans_weight = trans_weight.transpose(1, 2).contiguous()
self.weight.data = trans_weight
self.weight_type = 3
class BF16Linear(nn.Linear):

View file

@ -100,8 +100,9 @@ def llama_mlp_forward(
x: torch.Tensor,
) -> torch.Tensor:
x_2d = x.view(-1, x.shape[-1])
qtype = getattr(self.gate_proj, "qtype", None)
if x_2d.shape[0] == 1 and x.device.type == 'xpu' \
and self.gate_proj.qtype == ggml_tensor_qtype["sym_int4"] \
and qtype == ggml_tensor_qtype["sym_int4"] \
and not (self.training and x.requires_grad):
import linear_q4_0
if not x_2d.is_contiguous():
@ -147,7 +148,8 @@ def llama_attention_forward_4_31(
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
enough_kv_room = is_enough_kv_cache_room_4_31(past_key_value)
is_q4_0 = self.q_proj.qtype == SYM_INT4
qtype = getattr(self.q_proj, "qtype", None)
is_q4_0 = qtype == SYM_INT4
no_tp = not self.config.pretraining_tp > 1
decoding_fast_path = (no_tp and is_q4_0 and use_fuse_rope and
enough_kv_room and bsz * q_len == 1)

View file

@ -149,3 +149,29 @@ def get_autocast_dtype(x):
else:
invalidInputError(False,
f"Device {x.device} is not supported.")
_ipex_version = None
def get_ipex_version():
global _ipex_version
if _ipex_version is not None:
return _ipex_version
import intel_extension_for_pytorch as ipex
_ipex_version = ipex.__version__
return _ipex_version
def get_xpu_device_type(x):
name = torch.xpu.get_device_name(x.device.index)
if name.startswith("Intel(R) Arc(TM) A"):
return "arc"
elif name.startswith("Intel(R) Data Center GPU Flex"):
return "flex"
elif name.startswith("Intel(R) Data Center GPU Max"):
return "pvc"
else:
return "others"