draft mmint4 (#10031)
change to llm.cpp support transposed format revert implement qkv fuse fix style change to vertically pack change to enable_xetla fix mlp_fusion_check remove comments address comments add some comments fix style
This commit is contained in:
parent
d85f7c78df
commit
c581c6db30
8 changed files with 160 additions and 23 deletions
|
|
@ -192,7 +192,8 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
|
|||
current_key_name=None, convert_shape_only=False,
|
||||
cpu_embedding=False, prefix_name='',
|
||||
imatrix_data=None, embedding_qtype=None,
|
||||
model_type=None, torch_dtype=torch.float32):
|
||||
model_type=None, torch_dtype=torch.float32,
|
||||
enable_xetla=False):
|
||||
from bigdl.llm.transformers.low_bit_linear import LowBitLinear, FP4Params, \
|
||||
FP16Linear, BF16Linear
|
||||
from bigdl.llm.transformers.embedding import LLMEmbedding, LowBitEmbedding
|
||||
|
|
@ -223,6 +224,7 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
|
|||
qtype=qtype,
|
||||
bias=has_bias,
|
||||
mp_group=mp_group,
|
||||
enable_xetla=enable_xetla,
|
||||
)
|
||||
device = module.qweight.data.device
|
||||
invalidInputError(device.type != "meta",
|
||||
|
|
@ -234,7 +236,8 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
|
|||
quantized=True,
|
||||
_shape=(out_features, in_features),
|
||||
convert_shape_only=convert_shape_only,
|
||||
qtype=qtype).to(device)
|
||||
qtype=qtype,
|
||||
enable_xetla=enable_xetla).to(device)
|
||||
new_linear._parameters['weight'] = paramsLowBit
|
||||
if has_bias:
|
||||
new_linear._parameters['bias'] = nn.Parameter(module.bias.data)\
|
||||
|
|
@ -249,6 +252,7 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
|
|||
qtype,
|
||||
module.bias is not None,
|
||||
mp_group=mp_group,
|
||||
enable_xetla=enable_xetla,
|
||||
)
|
||||
cur_qtype, cur_imatrix = get_cur_qtype_and_imatrix(qtype,
|
||||
full_module_name,
|
||||
|
|
@ -263,7 +267,8 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
|
|||
convert_shape_only=convert_shape_only,
|
||||
qtype=cur_qtype,
|
||||
imatrix=cur_imatrix,
|
||||
in_features=in_features).to(device)
|
||||
in_features=in_features,
|
||||
enable_xetla=enable_xetla).to(device)
|
||||
new_linear._parameters['weight'] = paramsLowBit
|
||||
if module.bias is not None:
|
||||
new_linear._parameters['bias'] = nn.Parameter(module.bias.data)\
|
||||
|
|
@ -368,7 +373,8 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
|
|||
imatrix_data=imatrix_data,
|
||||
embedding_qtype=embedding_qtype,
|
||||
model_type=model_type,
|
||||
torch_dtype=torch_dtype
|
||||
torch_dtype=torch_dtype,
|
||||
enable_xetla=enable_xetla,
|
||||
)
|
||||
has_been_replaced = _flag or has_been_replaced
|
||||
return model, has_been_replaced
|
||||
|
|
@ -571,7 +577,9 @@ def ggml_convert_low_bit(model, qtype, optimize_model=True,
|
|||
convert_shape_only=False, device="cpu",
|
||||
modules_to_not_convert=None, cpu_embedding=False,
|
||||
lightweight_bmm=False, torch_dtype="auto",
|
||||
imatrix_data=None, embedding_qtype=None):
|
||||
imatrix_data=None,
|
||||
embedding_qtype=None,
|
||||
enable_xetla=False):
|
||||
logger.info(f"Converting the current model to "
|
||||
f"{list(ggml_tensor_qtype.keys())[list(ggml_tensor_qtype.values()).index(qtype)]} "
|
||||
f"format......")
|
||||
|
|
@ -601,7 +609,8 @@ def ggml_convert_low_bit(model, qtype, optimize_model=True,
|
|||
imatrix_data=imatrix_data,
|
||||
embedding_qtype=embedding_qtype,
|
||||
model_type=model_type,
|
||||
torch_dtype=torch_dtype
|
||||
torch_dtype=torch_dtype,
|
||||
enable_xetla=enable_xetla,
|
||||
)
|
||||
if not has_been_replaced:
|
||||
warnings.warn(
|
||||
|
|
|
|||
|
|
@ -75,6 +75,56 @@ IQ2_XS = ggml_tensor_qtype["gguf_iq2_xs"]
|
|||
Q2_K = ggml_tensor_qtype["q2_k"]
|
||||
|
||||
|
||||
# The ggml_weight is col major and packs two rows at a stride of Q4_0//2.
|
||||
#
|
||||
# The returning weight is row major and packs two rows at a stride of 16//2.
|
||||
# 16 is the tile_size_y used in mm_int4, so that we can do something like
|
||||
# new_weight_tile = concat(weight_tile & 0x0F, weight_tile >> 4).
|
||||
#
|
||||
# A more complex packing strategy is to permute the weight so that the
|
||||
# new_weight_tile is directly VNNI packed, but I did not find significant
|
||||
# performance improvement.
|
||||
#
|
||||
# Note this format cannot be used directly in IPEX's mm_int4, which expects
|
||||
# row major but packing two consecutive columns.
|
||||
def q4_0_xpu_transpose(ggml_weight, weight_shape):
|
||||
from bigdl.llm.transformers.low_bit_linear import get_block_size
|
||||
Q4_0 = get_block_size("sym_int4")
|
||||
|
||||
n, k = weight_shape
|
||||
ggml_weight_only = ggml_weight[:n*k//2]
|
||||
ggml_scales = ggml_weight[n*k//2:]
|
||||
|
||||
qweight = ggml_weight_only.clone()
|
||||
scales = ggml_scales.view(torch.float16).clone()
|
||||
|
||||
qweight_0 = qweight & 0x0F
|
||||
qweight_1 = qweight >> 4
|
||||
|
||||
qweight_0 = qweight_0.reshape(n, -1, Q4_0//2)
|
||||
qweight_1 = qweight_1.reshape(n, -1, Q4_0//2)
|
||||
qweight = torch.cat([qweight_0, qweight_1], dim=-1)
|
||||
qweight = qweight.reshape(n, k//16, 2, 8)
|
||||
qweight = qweight.bitwise_left_shift(
|
||||
torch.tensor([0, 4], dtype=torch.uint8, device=ggml_weight.device).reshape(1, 1, 2, 1))
|
||||
|
||||
qweight = torch.bitwise_or(qweight[:, :, 0, :], qweight[:, :, 1, :])
|
||||
qweight = qweight.reshape(n, k//2)
|
||||
qweight = qweight.transpose(0, 1).contiguous()
|
||||
|
||||
scales = scales.reshape(n, k//Q4_0).transpose(0, 1).contiguous()
|
||||
|
||||
# 119 is the value of 0x77
|
||||
zeros = torch.ones([k//Q4_0, n//2], dtype=torch.uint8, device=ggml_weight.device) * (119)
|
||||
|
||||
qweight_bytes = qweight.view(torch.uint8).view(-1)
|
||||
scales_bytes = scales.view(torch.uint8).view(-1)
|
||||
zeros_bytes = zeros.view(torch.uint8).view(-1)
|
||||
|
||||
weight = torch.concat([qweight_bytes, zeros_bytes, scales_bytes], dim=0)
|
||||
return weight
|
||||
|
||||
|
||||
def get_block_size(qtype: str):
|
||||
return ggml.ggml_qk_size(ggml_tensor_qtype[qtype])
|
||||
|
||||
|
|
@ -208,7 +258,8 @@ class FP4Params(torch.nn.Parameter):
|
|||
convert_shape_only=False,
|
||||
qtype=None,
|
||||
imatrix=None,
|
||||
in_features=None):
|
||||
in_features=None,
|
||||
enable_xetla=False,):
|
||||
if data is None:
|
||||
data = torch.empty(0)
|
||||
|
||||
|
|
@ -220,6 +271,7 @@ class FP4Params(torch.nn.Parameter):
|
|||
self.convert_shape_only = convert_shape_only
|
||||
self.imatrix = imatrix
|
||||
self.in_features = in_features
|
||||
self.enable_xetla = enable_xetla
|
||||
return self
|
||||
|
||||
def ggml_mse(self, w, ggml_qtype, device):
|
||||
|
|
@ -308,13 +360,20 @@ class FP4Params(torch.nn.Parameter):
|
|||
self.data = ggml_q_format_convet_cpu2xpu(self.data,
|
||||
reduce(mul, self._shape, 1),
|
||||
self.qtype)
|
||||
if self.enable_xetla:
|
||||
self.data = q4_0_xpu_transpose(self.data, self._shape)
|
||||
new_param = FP4Params(super().to(device=device,
|
||||
dtype=dtype,
|
||||
non_blocking=non_blocking),
|
||||
requires_grad=self.requires_grad,
|
||||
quantized=self.quantized,
|
||||
_shape=self._shape,
|
||||
qtype=self.qtype)
|
||||
qtype=self.qtype,
|
||||
enable_xetla=self.enable_xetla)
|
||||
if self.enable_xetla:
|
||||
device_type = get_xpu_device_type(new_param.data)
|
||||
invalidInputError(device_type == "pvc",
|
||||
f"xetla is only supported on PVC, but got {device_type}")
|
||||
return new_param
|
||||
elif (device is not None and device.type == "cpu" and self.data.device.type == "xpu"):
|
||||
new_param = FP4Params(super().to(device=device,
|
||||
|
|
@ -323,7 +382,11 @@ class FP4Params(torch.nn.Parameter):
|
|||
requires_grad=self.requires_grad,
|
||||
quantized=self.quantized,
|
||||
_shape=self._shape,
|
||||
qtype=self.qtype)
|
||||
qtype=self.qtype,
|
||||
enable_xetla=self.enable_xetla)
|
||||
if self.enable_xetla:
|
||||
invalidInputError(False,
|
||||
"xetla is not supported on CPUs but got enable_xetla=True")
|
||||
new_param.data = ggml_q_format_convet_xpu2cpu(new_param.data,
|
||||
reduce(mul, new_param._shape, 1),
|
||||
new_param.qtype)
|
||||
|
|
@ -335,7 +398,8 @@ class FP4Params(torch.nn.Parameter):
|
|||
requires_grad=self.requires_grad,
|
||||
quantized=self.quantized,
|
||||
_shape=self._shape,
|
||||
qtype=self.qtype)
|
||||
qtype=self.qtype,
|
||||
enable_xetla=self.enable_xetla)
|
||||
return new_param
|
||||
|
||||
|
||||
|
|
@ -441,11 +505,12 @@ class MatMulLowBitCPU(torch.autograd.Function):
|
|||
|
||||
class LowBitLinear(nn.Linear):
|
||||
def __init__(self, input_features, output_features, qtype, bias=True,
|
||||
conver_to_half=True, mp_group=None):
|
||||
conver_to_half=True, mp_group=None, enable_xetla=False):
|
||||
super().__init__(input_features, output_features, bias)
|
||||
self.weight = FP4Params(self.weight.data,
|
||||
requires_grad=False,
|
||||
quantized=False, _shape=None, qtype=qtype)
|
||||
quantized=False, _shape=None, qtype=qtype,
|
||||
enable_xetla=enable_xetla)
|
||||
self.in_len = input_features
|
||||
self.out_len = output_features
|
||||
self.weight_shape = (self.out_len, self.in_len)
|
||||
|
|
@ -454,6 +519,7 @@ class LowBitLinear(nn.Linear):
|
|||
self.conver_to_half = conver_to_half
|
||||
self.mp_group = mp_group
|
||||
self.compute_dtype = None # only for training
|
||||
self.enable_xetla = enable_xetla
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
# Due to inconsistent training status in some models like Baichuan-7b-Chat,
|
||||
|
|
@ -510,6 +576,9 @@ class LowBitLinear(nn.Linear):
|
|||
result = linear_q4_0.forward_new(x_2d, self.weight.data,
|
||||
self.weight.qtype,
|
||||
input_seq_size)
|
||||
elif self.enable_xetla:
|
||||
x_2d = x_2d.half()
|
||||
result = linear_q4_0.mm_int4(x_2d, self.weight.data)
|
||||
else:
|
||||
# inference path
|
||||
# current workaround to reduce first token latency of fp32 input
|
||||
|
|
|
|||
|
|
@ -358,6 +358,7 @@ class _BaseAutoModelClass:
|
|||
embedding_qtype = kwargs.pop("embedding_qtype", None)
|
||||
if embedding_qtype is not None:
|
||||
embedding_qtype = ggml_tensor_qtype[embedding_qtype]
|
||||
enable_xetla = kwargs.pop("enable_xetla", False)
|
||||
_args = copy.deepcopy(args)
|
||||
_kwargs = copy.deepcopy(kwargs)
|
||||
awq_config = None
|
||||
|
|
@ -421,7 +422,8 @@ class _BaseAutoModelClass:
|
|||
cpu_embedding=cpu_embedding, lightweight_bmm=lightweight_bmm,
|
||||
torch_dtype=kwargs.get("torch_dtype", 'auto'),
|
||||
imatrix_data=imatrix_data,
|
||||
embedding_qtype=embedding_qtype)
|
||||
embedding_qtype=embedding_qtype,
|
||||
enable_xetla=enable_xetla,)
|
||||
model.config.update({"bigdl_transformers_low_bit": q_k})
|
||||
|
||||
# enable tie_word_embeddings for MPT
|
||||
|
|
|
|||
|
|
@ -73,7 +73,7 @@ def baichuan_mlp_forward(
|
|||
) -> torch.Tensor:
|
||||
x_2d = x.view(-1, x.shape[-1])
|
||||
qtype = getattr(self.gate_proj, "qtype", None)
|
||||
if mlp_fusion_check(x_2d, qtype, self.training):
|
||||
if mlp_fusion_check(x_2d, qtype, self.training) and not self.down_proj.enable_xetla:
|
||||
import linear_q4_0
|
||||
if not x_2d.is_contiguous():
|
||||
x_2d = x_2d.contiguous()
|
||||
|
|
|
|||
|
|
@ -111,7 +111,7 @@ def llama_mlp_forward(
|
|||
x_2d = x.view(-1, x.shape[-1])
|
||||
bsz, hidden_size = x_2d.shape
|
||||
qtype = getattr(self.gate_proj, "qtype", None)
|
||||
if mlp_fusion_check(x_2d, qtype, self.training):
|
||||
if mlp_fusion_check(x_2d, qtype, self.training) and not self.down_proj.enable_xetla:
|
||||
import linear_q4_0
|
||||
if not x_2d.is_contiguous():
|
||||
x_2d = x_2d.contiguous()
|
||||
|
|
@ -216,6 +216,35 @@ def llama_decoder_forward(
|
|||
return outputs
|
||||
|
||||
|
||||
def fuse_qkv_weight(q_proj, k_proj, v_proj):
|
||||
weight_size = q_proj.out_len * q_proj.in_len // 2
|
||||
zeros_size = q_proj.in_len * q_proj.out_len // 2 // 64
|
||||
zeros_end = weight_size + zeros_size
|
||||
weight_byte_shape = (q_proj.in_len//2, q_proj.out_len)
|
||||
zeros_byte_shape = (q_proj.in_len//64, q_proj.out_len//2)
|
||||
scales_byte_shape = (q_proj.in_len//64, q_proj.out_len*2)
|
||||
qweight = torch.concat([q_proj.weight.data[:weight_size].reshape(weight_byte_shape),
|
||||
k_proj.weight.data[:weight_size].reshape(weight_byte_shape),
|
||||
v_proj.weight.data[:weight_size].reshape(weight_byte_shape),
|
||||
], dim=-1).reshape(-1)
|
||||
qzeros = torch.concat([q_proj.weight.data[weight_size:zeros_end].reshape(zeros_byte_shape),
|
||||
k_proj.weight.data[weight_size:zeros_end].reshape(zeros_byte_shape),
|
||||
v_proj.weight.data[weight_size:zeros_end].reshape(zeros_byte_shape),
|
||||
], dim=-1).reshape(-1)
|
||||
qscales = torch.concat([q_proj.weight.data[zeros_end:].reshape(scales_byte_shape),
|
||||
k_proj.weight.data[zeros_end:].reshape(scales_byte_shape),
|
||||
v_proj.weight.data[zeros_end:].reshape(scales_byte_shape),
|
||||
], dim=-1).reshape(-1)
|
||||
q_proj.weight.data = torch.empty(0)
|
||||
k_proj.weight.data = torch.empty(0)
|
||||
v_proj.weight.data = torch.empty(0)
|
||||
return torch.cat([qweight, qzeros, qscales], dim=0)
|
||||
|
||||
|
||||
def should_use_mm_int4_qkv(self, device):
|
||||
return device.type == "xpu" and self.q_proj.qtype == SYM_INT4 and self.q_proj.enable_xetla
|
||||
|
||||
|
||||
def llama_attention_forward_4_31(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
|
|
@ -459,6 +488,7 @@ def llama_attention_forward_4_31_original(
|
|||
no_tp = not self.config.pretraining_tp > 1
|
||||
decoding_fast_path = (no_tp and qtype_check and use_fuse_rope and
|
||||
enough_kv_room and bsz * q_len == 1)
|
||||
decoding_fast_path = decoding_fast_path and not self.q_proj.enable_xetla
|
||||
|
||||
# single batch decoding fast path
|
||||
# forward_qkv takes will perform QKV projection, rotary position embedding
|
||||
|
|
@ -524,9 +554,20 @@ def llama_attention_forward_4_31_original(
|
|||
query_states, key_states, value_states
|
||||
)
|
||||
else:
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
if should_use_mm_int4_qkv(self, device):
|
||||
if not hasattr(self, "qkv_proj_qweight"):
|
||||
self.qkv_proj_qweight = fuse_qkv_weight(self.q_proj,
|
||||
self.k_proj,
|
||||
self.v_proj)
|
||||
import linear_q4_0
|
||||
qkv_states = linear_q4_0.mm_int4(hidden_states, self.qkv_proj_qweight)
|
||||
query_states = qkv_states[:, :, :hidden_size]
|
||||
key_states = qkv_states[:, :, hidden_size:2*hidden_size]
|
||||
value_states = qkv_states[:, :, 2*hidden_size:]
|
||||
else:
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
query_states = query_states.view(bsz, q_len,
|
||||
self.num_heads, self.head_dim).transpose(1, 2)
|
||||
|
|
@ -682,6 +723,7 @@ def llama_attention_selective_batching_forward_4_31(
|
|||
no_tp = not self.config.pretraining_tp > 1
|
||||
decoding_fast_path = (no_tp and is_q4_0 and use_fuse_rope and
|
||||
bsz * q_len == 1)
|
||||
decoding_fast_path = decoding_fast_path and not self.q_proj.enable_xetla
|
||||
|
||||
updated_past_key_values = []
|
||||
# single batch decoding fast path
|
||||
|
|
@ -874,6 +916,7 @@ def llama_attention_forward_4_36(
|
|||
no_tp = not self.config.pretraining_tp > 1
|
||||
decoding_fast_path = (no_tp and is_q4_0 and use_fuse_rope and
|
||||
enough_kv_room and bsz * q_len == 1)
|
||||
decoding_fast_path = decoding_fast_path and not self.q_proj.enable_xetla
|
||||
|
||||
# single batch decoding fast path
|
||||
# forward_qkv takes will perform QKV projection, rotary position embedding
|
||||
|
|
@ -944,9 +987,20 @@ def llama_attention_forward_4_36(
|
|||
query_states, key_states, value_states
|
||||
)
|
||||
else:
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
if should_use_mm_int4_qkv(self, device):
|
||||
if not hasattr(self, "qkv_proj_qweight"):
|
||||
self.qkv_proj_qweight = fuse_qkv_weight(self.q_proj,
|
||||
self.k_proj,
|
||||
self.v_proj)
|
||||
import linear_q4_0
|
||||
qkv_states = linear_q4_0.mm_int4(hidden_states, self.qkv_proj_qweight)
|
||||
query_states = qkv_states[:, :, :hidden_size]
|
||||
key_states = qkv_states[:, :, hidden_size:2*hidden_size]
|
||||
value_states = qkv_states[:, :, 2*hidden_size:]
|
||||
else:
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
query_states = query_states.view(bsz, q_len,
|
||||
self.num_heads, self.head_dim).transpose(1, 2)
|
||||
|
|
|
|||
|
|
@ -140,6 +140,7 @@ def mistral_attention_forward(
|
|||
use_fuse_rope,
|
||||
enough_kv_room,
|
||||
bsz * q_len)
|
||||
decoding_fast_path = decoding_fast_path and not self.q_proj.enable_xetla
|
||||
|
||||
if decoding_fast_path:
|
||||
hidden_states = hidden_states.view(1, -1)
|
||||
|
|
@ -288,6 +289,7 @@ def mistral_attention_forward_4_36(
|
|||
use_fuse_rope,
|
||||
enough_kv_room,
|
||||
bsz * q_len)
|
||||
decoding_fast_path = decoding_fast_path and not self.q_proj.enable_xetla
|
||||
|
||||
if decoding_fast_path:
|
||||
hidden_states = hidden_states.view(1, -1)
|
||||
|
|
|
|||
|
|
@ -153,6 +153,7 @@ def mixtral_attention_forward(
|
|||
use_fuse_rope,
|
||||
enough_kv_room,
|
||||
bsz * q_len)
|
||||
decoding_fast_path = decoding_fast_path and not self.q_proj.enable_xetla
|
||||
|
||||
if decoding_fast_path:
|
||||
hidden_states = hidden_states.view(1, -1)
|
||||
|
|
@ -330,7 +331,7 @@ def mixtral_mlp_forward(
|
|||
routing_weights
|
||||
) -> torch.Tensor:
|
||||
qtype = getattr(self.w1, "qtype", None)
|
||||
if mlp_fusion_check(x, qtype, self.training):
|
||||
if mlp_fusion_check(x, qtype, self.training) and not self.w1.enable_xetla:
|
||||
import linear_q4_0
|
||||
return self.w2(linear_q4_0.mlp_forward_xpu(
|
||||
x, self.w1.weight.data, self.w3.weight.data,
|
||||
|
|
|
|||
|
|
@ -285,7 +285,7 @@ def core_attn(self, query, key, value, causal_mask=None, attention_mask=None, he
|
|||
def qwen_mlp_forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x_2d = x.view(-1, x.shape[-1])
|
||||
qtype = getattr(self.w1, "qtype", None)
|
||||
if mlp_fusion_check(x_2d, qtype, self.training):
|
||||
if mlp_fusion_check(x_2d, qtype, self.training) and not self.w1.enable_xetla:
|
||||
import linear_q4_0
|
||||
if not x_2d.is_contiguous():
|
||||
x_2d = x_2d.contiguous()
|
||||
|
|
|
|||
Loading…
Reference in a new issue