draft mmint4 (#10031)

change to llm.cpp

support transposed format

revert

implement qkv fuse

fix style

change to vertically pack

change to enable_xetla

fix mlp_fusion_check

remove comments

address comments

add some comments

fix style
This commit is contained in:
Yang Wang 2024-02-27 14:55:16 -08:00 committed by GitHub
parent d85f7c78df
commit c581c6db30
8 changed files with 160 additions and 23 deletions

View file

@ -192,7 +192,8 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
current_key_name=None, convert_shape_only=False,
cpu_embedding=False, prefix_name='',
imatrix_data=None, embedding_qtype=None,
model_type=None, torch_dtype=torch.float32):
model_type=None, torch_dtype=torch.float32,
enable_xetla=False):
from bigdl.llm.transformers.low_bit_linear import LowBitLinear, FP4Params, \
FP16Linear, BF16Linear
from bigdl.llm.transformers.embedding import LLMEmbedding, LowBitEmbedding
@ -223,6 +224,7 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
qtype=qtype,
bias=has_bias,
mp_group=mp_group,
enable_xetla=enable_xetla,
)
device = module.qweight.data.device
invalidInputError(device.type != "meta",
@ -234,7 +236,8 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
quantized=True,
_shape=(out_features, in_features),
convert_shape_only=convert_shape_only,
qtype=qtype).to(device)
qtype=qtype,
enable_xetla=enable_xetla).to(device)
new_linear._parameters['weight'] = paramsLowBit
if has_bias:
new_linear._parameters['bias'] = nn.Parameter(module.bias.data)\
@ -249,6 +252,7 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
qtype,
module.bias is not None,
mp_group=mp_group,
enable_xetla=enable_xetla,
)
cur_qtype, cur_imatrix = get_cur_qtype_and_imatrix(qtype,
full_module_name,
@ -263,7 +267,8 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
convert_shape_only=convert_shape_only,
qtype=cur_qtype,
imatrix=cur_imatrix,
in_features=in_features).to(device)
in_features=in_features,
enable_xetla=enable_xetla).to(device)
new_linear._parameters['weight'] = paramsLowBit
if module.bias is not None:
new_linear._parameters['bias'] = nn.Parameter(module.bias.data)\
@ -368,7 +373,8 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
imatrix_data=imatrix_data,
embedding_qtype=embedding_qtype,
model_type=model_type,
torch_dtype=torch_dtype
torch_dtype=torch_dtype,
enable_xetla=enable_xetla,
)
has_been_replaced = _flag or has_been_replaced
return model, has_been_replaced
@ -571,7 +577,9 @@ def ggml_convert_low_bit(model, qtype, optimize_model=True,
convert_shape_only=False, device="cpu",
modules_to_not_convert=None, cpu_embedding=False,
lightweight_bmm=False, torch_dtype="auto",
imatrix_data=None, embedding_qtype=None):
imatrix_data=None,
embedding_qtype=None,
enable_xetla=False):
logger.info(f"Converting the current model to "
f"{list(ggml_tensor_qtype.keys())[list(ggml_tensor_qtype.values()).index(qtype)]} "
f"format......")
@ -601,7 +609,8 @@ def ggml_convert_low_bit(model, qtype, optimize_model=True,
imatrix_data=imatrix_data,
embedding_qtype=embedding_qtype,
model_type=model_type,
torch_dtype=torch_dtype
torch_dtype=torch_dtype,
enable_xetla=enable_xetla,
)
if not has_been_replaced:
warnings.warn(

View file

@ -75,6 +75,56 @@ IQ2_XS = ggml_tensor_qtype["gguf_iq2_xs"]
Q2_K = ggml_tensor_qtype["q2_k"]
# The ggml_weight is col major and packs two rows at a stride of Q4_0//2.
#
# The returning weight is row major and packs two rows at a stride of 16//2.
# 16 is the tile_size_y used in mm_int4, so that we can do something like
# new_weight_tile = concat(weight_tile & 0x0F, weight_tile >> 4).
#
# A more complex packing strategy is to permute the weight so that the
# new_weight_tile is directly VNNI packed, but I did not find significant
# performance improvement.
#
# Note this format cannot be used directly in IPEX's mm_int4, which expects
# row major but packing two consecutive columns.
def q4_0_xpu_transpose(ggml_weight, weight_shape):
from bigdl.llm.transformers.low_bit_linear import get_block_size
Q4_0 = get_block_size("sym_int4")
n, k = weight_shape
ggml_weight_only = ggml_weight[:n*k//2]
ggml_scales = ggml_weight[n*k//2:]
qweight = ggml_weight_only.clone()
scales = ggml_scales.view(torch.float16).clone()
qweight_0 = qweight & 0x0F
qweight_1 = qweight >> 4
qweight_0 = qweight_0.reshape(n, -1, Q4_0//2)
qweight_1 = qweight_1.reshape(n, -1, Q4_0//2)
qweight = torch.cat([qweight_0, qweight_1], dim=-1)
qweight = qweight.reshape(n, k//16, 2, 8)
qweight = qweight.bitwise_left_shift(
torch.tensor([0, 4], dtype=torch.uint8, device=ggml_weight.device).reshape(1, 1, 2, 1))
qweight = torch.bitwise_or(qweight[:, :, 0, :], qweight[:, :, 1, :])
qweight = qweight.reshape(n, k//2)
qweight = qweight.transpose(0, 1).contiguous()
scales = scales.reshape(n, k//Q4_0).transpose(0, 1).contiguous()
# 119 is the value of 0x77
zeros = torch.ones([k//Q4_0, n//2], dtype=torch.uint8, device=ggml_weight.device) * (119)
qweight_bytes = qweight.view(torch.uint8).view(-1)
scales_bytes = scales.view(torch.uint8).view(-1)
zeros_bytes = zeros.view(torch.uint8).view(-1)
weight = torch.concat([qweight_bytes, zeros_bytes, scales_bytes], dim=0)
return weight
def get_block_size(qtype: str):
return ggml.ggml_qk_size(ggml_tensor_qtype[qtype])
@ -208,7 +258,8 @@ class FP4Params(torch.nn.Parameter):
convert_shape_only=False,
qtype=None,
imatrix=None,
in_features=None):
in_features=None,
enable_xetla=False,):
if data is None:
data = torch.empty(0)
@ -220,6 +271,7 @@ class FP4Params(torch.nn.Parameter):
self.convert_shape_only = convert_shape_only
self.imatrix = imatrix
self.in_features = in_features
self.enable_xetla = enable_xetla
return self
def ggml_mse(self, w, ggml_qtype, device):
@ -308,13 +360,20 @@ class FP4Params(torch.nn.Parameter):
self.data = ggml_q_format_convet_cpu2xpu(self.data,
reduce(mul, self._shape, 1),
self.qtype)
if self.enable_xetla:
self.data = q4_0_xpu_transpose(self.data, self._shape)
new_param = FP4Params(super().to(device=device,
dtype=dtype,
non_blocking=non_blocking),
requires_grad=self.requires_grad,
quantized=self.quantized,
_shape=self._shape,
qtype=self.qtype)
qtype=self.qtype,
enable_xetla=self.enable_xetla)
if self.enable_xetla:
device_type = get_xpu_device_type(new_param.data)
invalidInputError(device_type == "pvc",
f"xetla is only supported on PVC, but got {device_type}")
return new_param
elif (device is not None and device.type == "cpu" and self.data.device.type == "xpu"):
new_param = FP4Params(super().to(device=device,
@ -323,7 +382,11 @@ class FP4Params(torch.nn.Parameter):
requires_grad=self.requires_grad,
quantized=self.quantized,
_shape=self._shape,
qtype=self.qtype)
qtype=self.qtype,
enable_xetla=self.enable_xetla)
if self.enable_xetla:
invalidInputError(False,
"xetla is not supported on CPUs but got enable_xetla=True")
new_param.data = ggml_q_format_convet_xpu2cpu(new_param.data,
reduce(mul, new_param._shape, 1),
new_param.qtype)
@ -335,7 +398,8 @@ class FP4Params(torch.nn.Parameter):
requires_grad=self.requires_grad,
quantized=self.quantized,
_shape=self._shape,
qtype=self.qtype)
qtype=self.qtype,
enable_xetla=self.enable_xetla)
return new_param
@ -441,11 +505,12 @@ class MatMulLowBitCPU(torch.autograd.Function):
class LowBitLinear(nn.Linear):
def __init__(self, input_features, output_features, qtype, bias=True,
conver_to_half=True, mp_group=None):
conver_to_half=True, mp_group=None, enable_xetla=False):
super().__init__(input_features, output_features, bias)
self.weight = FP4Params(self.weight.data,
requires_grad=False,
quantized=False, _shape=None, qtype=qtype)
quantized=False, _shape=None, qtype=qtype,
enable_xetla=enable_xetla)
self.in_len = input_features
self.out_len = output_features
self.weight_shape = (self.out_len, self.in_len)
@ -454,6 +519,7 @@ class LowBitLinear(nn.Linear):
self.conver_to_half = conver_to_half
self.mp_group = mp_group
self.compute_dtype = None # only for training
self.enable_xetla = enable_xetla
def forward(self, x: torch.Tensor):
# Due to inconsistent training status in some models like Baichuan-7b-Chat,
@ -510,6 +576,9 @@ class LowBitLinear(nn.Linear):
result = linear_q4_0.forward_new(x_2d, self.weight.data,
self.weight.qtype,
input_seq_size)
elif self.enable_xetla:
x_2d = x_2d.half()
result = linear_q4_0.mm_int4(x_2d, self.weight.data)
else:
# inference path
# current workaround to reduce first token latency of fp32 input

View file

@ -358,6 +358,7 @@ class _BaseAutoModelClass:
embedding_qtype = kwargs.pop("embedding_qtype", None)
if embedding_qtype is not None:
embedding_qtype = ggml_tensor_qtype[embedding_qtype]
enable_xetla = kwargs.pop("enable_xetla", False)
_args = copy.deepcopy(args)
_kwargs = copy.deepcopy(kwargs)
awq_config = None
@ -421,7 +422,8 @@ class _BaseAutoModelClass:
cpu_embedding=cpu_embedding, lightweight_bmm=lightweight_bmm,
torch_dtype=kwargs.get("torch_dtype", 'auto'),
imatrix_data=imatrix_data,
embedding_qtype=embedding_qtype)
embedding_qtype=embedding_qtype,
enable_xetla=enable_xetla,)
model.config.update({"bigdl_transformers_low_bit": q_k})
# enable tie_word_embeddings for MPT

View file

@ -73,7 +73,7 @@ def baichuan_mlp_forward(
) -> torch.Tensor:
x_2d = x.view(-1, x.shape[-1])
qtype = getattr(self.gate_proj, "qtype", None)
if mlp_fusion_check(x_2d, qtype, self.training):
if mlp_fusion_check(x_2d, qtype, self.training) and not self.down_proj.enable_xetla:
import linear_q4_0
if not x_2d.is_contiguous():
x_2d = x_2d.contiguous()

View file

@ -111,7 +111,7 @@ def llama_mlp_forward(
x_2d = x.view(-1, x.shape[-1])
bsz, hidden_size = x_2d.shape
qtype = getattr(self.gate_proj, "qtype", None)
if mlp_fusion_check(x_2d, qtype, self.training):
if mlp_fusion_check(x_2d, qtype, self.training) and not self.down_proj.enable_xetla:
import linear_q4_0
if not x_2d.is_contiguous():
x_2d = x_2d.contiguous()
@ -216,6 +216,35 @@ def llama_decoder_forward(
return outputs
def fuse_qkv_weight(q_proj, k_proj, v_proj):
weight_size = q_proj.out_len * q_proj.in_len // 2
zeros_size = q_proj.in_len * q_proj.out_len // 2 // 64
zeros_end = weight_size + zeros_size
weight_byte_shape = (q_proj.in_len//2, q_proj.out_len)
zeros_byte_shape = (q_proj.in_len//64, q_proj.out_len//2)
scales_byte_shape = (q_proj.in_len//64, q_proj.out_len*2)
qweight = torch.concat([q_proj.weight.data[:weight_size].reshape(weight_byte_shape),
k_proj.weight.data[:weight_size].reshape(weight_byte_shape),
v_proj.weight.data[:weight_size].reshape(weight_byte_shape),
], dim=-1).reshape(-1)
qzeros = torch.concat([q_proj.weight.data[weight_size:zeros_end].reshape(zeros_byte_shape),
k_proj.weight.data[weight_size:zeros_end].reshape(zeros_byte_shape),
v_proj.weight.data[weight_size:zeros_end].reshape(zeros_byte_shape),
], dim=-1).reshape(-1)
qscales = torch.concat([q_proj.weight.data[zeros_end:].reshape(scales_byte_shape),
k_proj.weight.data[zeros_end:].reshape(scales_byte_shape),
v_proj.weight.data[zeros_end:].reshape(scales_byte_shape),
], dim=-1).reshape(-1)
q_proj.weight.data = torch.empty(0)
k_proj.weight.data = torch.empty(0)
v_proj.weight.data = torch.empty(0)
return torch.cat([qweight, qzeros, qscales], dim=0)
def should_use_mm_int4_qkv(self, device):
return device.type == "xpu" and self.q_proj.qtype == SYM_INT4 and self.q_proj.enable_xetla
def llama_attention_forward_4_31(
self,
hidden_states: torch.Tensor,
@ -459,6 +488,7 @@ def llama_attention_forward_4_31_original(
no_tp = not self.config.pretraining_tp > 1
decoding_fast_path = (no_tp and qtype_check and use_fuse_rope and
enough_kv_room and bsz * q_len == 1)
decoding_fast_path = decoding_fast_path and not self.q_proj.enable_xetla
# single batch decoding fast path
# forward_qkv takes will perform QKV projection, rotary position embedding
@ -524,9 +554,20 @@ def llama_attention_forward_4_31_original(
query_states, key_states, value_states
)
else:
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
if should_use_mm_int4_qkv(self, device):
if not hasattr(self, "qkv_proj_qweight"):
self.qkv_proj_qweight = fuse_qkv_weight(self.q_proj,
self.k_proj,
self.v_proj)
import linear_q4_0
qkv_states = linear_q4_0.mm_int4(hidden_states, self.qkv_proj_qweight)
query_states = qkv_states[:, :, :hidden_size]
key_states = qkv_states[:, :, hidden_size:2*hidden_size]
value_states = qkv_states[:, :, 2*hidden_size:]
else:
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len,
self.num_heads, self.head_dim).transpose(1, 2)
@ -682,6 +723,7 @@ def llama_attention_selective_batching_forward_4_31(
no_tp = not self.config.pretraining_tp > 1
decoding_fast_path = (no_tp and is_q4_0 and use_fuse_rope and
bsz * q_len == 1)
decoding_fast_path = decoding_fast_path and not self.q_proj.enable_xetla
updated_past_key_values = []
# single batch decoding fast path
@ -874,6 +916,7 @@ def llama_attention_forward_4_36(
no_tp = not self.config.pretraining_tp > 1
decoding_fast_path = (no_tp and is_q4_0 and use_fuse_rope and
enough_kv_room and bsz * q_len == 1)
decoding_fast_path = decoding_fast_path and not self.q_proj.enable_xetla
# single batch decoding fast path
# forward_qkv takes will perform QKV projection, rotary position embedding
@ -944,9 +987,20 @@ def llama_attention_forward_4_36(
query_states, key_states, value_states
)
else:
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
if should_use_mm_int4_qkv(self, device):
if not hasattr(self, "qkv_proj_qweight"):
self.qkv_proj_qweight = fuse_qkv_weight(self.q_proj,
self.k_proj,
self.v_proj)
import linear_q4_0
qkv_states = linear_q4_0.mm_int4(hidden_states, self.qkv_proj_qweight)
query_states = qkv_states[:, :, :hidden_size]
key_states = qkv_states[:, :, hidden_size:2*hidden_size]
value_states = qkv_states[:, :, 2*hidden_size:]
else:
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len,
self.num_heads, self.head_dim).transpose(1, 2)

View file

@ -140,6 +140,7 @@ def mistral_attention_forward(
use_fuse_rope,
enough_kv_room,
bsz * q_len)
decoding_fast_path = decoding_fast_path and not self.q_proj.enable_xetla
if decoding_fast_path:
hidden_states = hidden_states.view(1, -1)
@ -288,6 +289,7 @@ def mistral_attention_forward_4_36(
use_fuse_rope,
enough_kv_room,
bsz * q_len)
decoding_fast_path = decoding_fast_path and not self.q_proj.enable_xetla
if decoding_fast_path:
hidden_states = hidden_states.view(1, -1)

View file

@ -153,6 +153,7 @@ def mixtral_attention_forward(
use_fuse_rope,
enough_kv_room,
bsz * q_len)
decoding_fast_path = decoding_fast_path and not self.q_proj.enable_xetla
if decoding_fast_path:
hidden_states = hidden_states.view(1, -1)
@ -330,7 +331,7 @@ def mixtral_mlp_forward(
routing_weights
) -> torch.Tensor:
qtype = getattr(self.w1, "qtype", None)
if mlp_fusion_check(x, qtype, self.training):
if mlp_fusion_check(x, qtype, self.training) and not self.w1.enable_xetla:
import linear_q4_0
return self.w2(linear_q4_0.mlp_forward_xpu(
x, self.w1.weight.data, self.w3.weight.data,

View file

@ -285,7 +285,7 @@ def core_attn(self, query, key, value, causal_mask=None, attention_mask=None, he
def qwen_mlp_forward(self, x: torch.Tensor) -> torch.Tensor:
x_2d = x.view(-1, x.shape[-1])
qtype = getattr(self.w1, "qtype", None)
if mlp_fusion_check(x_2d, qtype, self.training):
if mlp_fusion_check(x_2d, qtype, self.training) and not self.w1.enable_xetla:
import linear_q4_0
if not x_2d.is_contiguous():
x_2d = x_2d.contiguous()