remove unused code again (#12624)
This commit is contained in:
parent
46eeab4479
commit
c72a5db757
10 changed files with 19 additions and 171 deletions
|
|
@ -92,8 +92,7 @@ def train(
|
||||||
load_in_low_bit="bf16",
|
load_in_low_bit="bf16",
|
||||||
optimize_model=True,
|
optimize_model=True,
|
||||||
torch_dtype=torch.bfloat16,
|
torch_dtype=torch.bfloat16,
|
||||||
trust_remote_code=True,
|
trust_remote_code=True
|
||||||
enable_xetla=False
|
|
||||||
)
|
)
|
||||||
|
|
||||||
model = model.to("xpu")
|
model = model.to("xpu")
|
||||||
|
|
@ -156,7 +155,7 @@ def train(
|
||||||
callbacks=trainer_callbacks
|
callbacks=trainer_callbacks
|
||||||
)
|
)
|
||||||
model.config.use_cache = False
|
model.config.use_cache = False
|
||||||
|
|
||||||
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
||||||
|
|
||||||
# model.save_pretrained(output_dir)
|
# model.save_pretrained(output_dir)
|
||||||
|
|
|
||||||
|
|
@ -257,8 +257,7 @@ def optimize_model(model, low_bit='sym_int4', optimize_llm=True, modules_to_not_
|
||||||
optimize_model=optimize_llm,
|
optimize_model=optimize_llm,
|
||||||
modules_to_not_convert=modules_to_not_convert,
|
modules_to_not_convert=modules_to_not_convert,
|
||||||
cpu_embedding=cpu_embedding,
|
cpu_embedding=cpu_embedding,
|
||||||
lightweight_bmm=lightweight_bmm,
|
lightweight_bmm=lightweight_bmm)
|
||||||
enable_xetla=kwargs.pop("enable_xetla", False))
|
|
||||||
# add save_low_bit to pretrained model dynamically
|
# add save_low_bit to pretrained model dynamically
|
||||||
import types
|
import types
|
||||||
model._bigdl_config = dict()
|
model._bigdl_config = dict()
|
||||||
|
|
|
||||||
|
|
@ -232,7 +232,7 @@ def is_linear_module(module):
|
||||||
|
|
||||||
|
|
||||||
def convert_vllm(module, qtype, in_features, out_features, mp_group, cur_qtype,
|
def convert_vllm(module, qtype, in_features, out_features, mp_group, cur_qtype,
|
||||||
enable_xetla, optimize_lm_head, enable_scale_search):
|
optimize_lm_head, enable_scale_search):
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||||
from ipex_llm.transformers.low_bit_linear import LowBitLinear, \
|
from ipex_llm.transformers.low_bit_linear import LowBitLinear, \
|
||||||
FP16Linear, BF16Linear, vLLMLowBitLinear, vLLMFP16Linear, vLLMBF16Linear
|
FP16Linear, BF16Linear, vLLMLowBitLinear, vLLMFP16Linear, vLLMBF16Linear
|
||||||
|
|
@ -261,7 +261,6 @@ def convert_vllm(module, qtype, in_features, out_features, mp_group, cur_qtype,
|
||||||
cur_qtype,
|
cur_qtype,
|
||||||
module.bias is not None,
|
module.bias is not None,
|
||||||
mp_group=mp_group,
|
mp_group=mp_group,
|
||||||
enable_xetla=enable_xetla,
|
|
||||||
optimize_lm_head=optimize_lm_head,
|
optimize_lm_head=optimize_lm_head,
|
||||||
enable_scale_search=enable_scale_search,
|
enable_scale_search=enable_scale_search,
|
||||||
)
|
)
|
||||||
|
|
@ -289,7 +288,6 @@ def convert_vllm(module, qtype, in_features, out_features, mp_group, cur_qtype,
|
||||||
cur_qtype,
|
cur_qtype,
|
||||||
module.bias is not None,
|
module.bias is not None,
|
||||||
mp_group=mp_group,
|
mp_group=mp_group,
|
||||||
enable_xetla=enable_xetla,
|
|
||||||
optimize_lm_head=optimize_lm_head,
|
optimize_lm_head=optimize_lm_head,
|
||||||
enable_scale_search=enable_scale_search,
|
enable_scale_search=enable_scale_search,
|
||||||
)
|
)
|
||||||
|
|
@ -473,7 +471,6 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
|
||||||
prefix_name='',
|
prefix_name='',
|
||||||
imatrix_data=None, embedding_qtype=None,
|
imatrix_data=None, embedding_qtype=None,
|
||||||
model_config=None, torch_dtype=torch.float32,
|
model_config=None, torch_dtype=torch.float32,
|
||||||
enable_xetla=False,
|
|
||||||
mixed_precision=False,
|
mixed_precision=False,
|
||||||
act_order=False,
|
act_order=False,
|
||||||
enable_scale_search=False,
|
enable_scale_search=False,
|
||||||
|
|
@ -523,7 +520,6 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
|
||||||
qtype=qtype,
|
qtype=qtype,
|
||||||
bias=has_bias,
|
bias=has_bias,
|
||||||
mp_group=mp_group,
|
mp_group=mp_group,
|
||||||
enable_xetla=enable_xetla,
|
|
||||||
optimize_lm_head=optimize_lm_head,
|
optimize_lm_head=optimize_lm_head,
|
||||||
act_order=act_order,
|
act_order=act_order,
|
||||||
enable_scale_search=enable_scale_search,
|
enable_scale_search=enable_scale_search,
|
||||||
|
|
@ -544,7 +540,6 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
|
||||||
_shape=(out_features, in_features),
|
_shape=(out_features, in_features),
|
||||||
convert_shape_only=convert_shape_only,
|
convert_shape_only=convert_shape_only,
|
||||||
qtype=qtype,
|
qtype=qtype,
|
||||||
enable_xetla=enable_xetla,
|
|
||||||
enable_scale_search=enable_scale_search).to(device)
|
enable_scale_search=enable_scale_search).to(device)
|
||||||
new_linear._parameters['weight'] = paramsLowBit
|
new_linear._parameters['weight'] = paramsLowBit
|
||||||
if has_bias:
|
if has_bias:
|
||||||
|
|
@ -562,7 +557,6 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
|
||||||
qtype=qtype,
|
qtype=qtype,
|
||||||
bias=has_bias,
|
bias=has_bias,
|
||||||
mp_group=mp_group,
|
mp_group=mp_group,
|
||||||
enable_xetla=enable_xetla,
|
|
||||||
optimize_lm_head=False,
|
optimize_lm_head=False,
|
||||||
act_order=act_order,
|
act_order=act_order,
|
||||||
enable_scale_search=enable_scale_search,
|
enable_scale_search=enable_scale_search,
|
||||||
|
|
@ -581,7 +575,6 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
|
||||||
qtype=cur_qtype,
|
qtype=cur_qtype,
|
||||||
imatrix=cur_imatrix,
|
imatrix=cur_imatrix,
|
||||||
in_features=in_features,
|
in_features=in_features,
|
||||||
enable_xetla=enable_xetla,
|
|
||||||
enable_scale_search=enable_scale_search).to(device)
|
enable_scale_search=enable_scale_search).to(device)
|
||||||
else:
|
else:
|
||||||
new_linear = vLLMLowBitLinear(
|
new_linear = vLLMLowBitLinear(
|
||||||
|
|
@ -590,7 +583,6 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
|
||||||
qtype=qtype,
|
qtype=qtype,
|
||||||
bias=has_bias,
|
bias=has_bias,
|
||||||
mp_group=mp_group,
|
mp_group=mp_group,
|
||||||
enable_xetla=enable_xetla,
|
|
||||||
optimize_lm_head=False,
|
optimize_lm_head=False,
|
||||||
act_order=act_order,
|
act_order=act_order,
|
||||||
enable_scale_search=enable_scale_search,
|
enable_scale_search=enable_scale_search,
|
||||||
|
|
@ -609,7 +601,6 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
|
||||||
_shape=(out_features, in_features),
|
_shape=(out_features, in_features),
|
||||||
convert_shape_only=convert_shape_only,
|
convert_shape_only=convert_shape_only,
|
||||||
qtype=qtype,
|
qtype=qtype,
|
||||||
enable_xetla=enable_xetla,
|
|
||||||
enable_scale_search=enable_scale_search).to(device)
|
enable_scale_search=enable_scale_search).to(device)
|
||||||
new_linear._parameters['weight'] = paramsLowBit
|
new_linear._parameters['weight'] = paramsLowBit
|
||||||
if has_bias:
|
if has_bias:
|
||||||
|
|
@ -639,7 +630,6 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
|
||||||
out_features,
|
out_features,
|
||||||
mp_group,
|
mp_group,
|
||||||
cur_qtype,
|
cur_qtype,
|
||||||
enable_xetla,
|
|
||||||
optimize_lm_head,
|
optimize_lm_head,
|
||||||
enable_scale_search)
|
enable_scale_search)
|
||||||
else:
|
else:
|
||||||
|
|
@ -649,7 +639,6 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
|
||||||
cur_qtype,
|
cur_qtype,
|
||||||
module.bias is not None,
|
module.bias is not None,
|
||||||
mp_group=mp_group,
|
mp_group=mp_group,
|
||||||
enable_xetla=enable_xetla,
|
|
||||||
optimize_lm_head=optimize_lm_head,
|
optimize_lm_head=optimize_lm_head,
|
||||||
enable_scale_search=enable_scale_search,
|
enable_scale_search=enable_scale_search,
|
||||||
)
|
)
|
||||||
|
|
@ -663,7 +652,6 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
|
||||||
qtype=cur_qtype,
|
qtype=cur_qtype,
|
||||||
imatrix=cur_imatrix,
|
imatrix=cur_imatrix,
|
||||||
in_features=in_features,
|
in_features=in_features,
|
||||||
enable_xetla=enable_xetla,
|
|
||||||
enable_scale_search=enable_scale_search).to(device)
|
enable_scale_search=enable_scale_search).to(device)
|
||||||
new_linear._parameters['weight'] = paramsLowBit
|
new_linear._parameters['weight'] = paramsLowBit
|
||||||
if module.bias is not None:
|
if module.bias is not None:
|
||||||
|
|
@ -762,7 +750,6 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
|
||||||
embedding_qtype=embedding_qtype,
|
embedding_qtype=embedding_qtype,
|
||||||
model_config=model_config,
|
model_config=model_config,
|
||||||
torch_dtype=torch_dtype,
|
torch_dtype=torch_dtype,
|
||||||
enable_xetla=enable_xetla,
|
|
||||||
mixed_precision=mixed_precision,
|
mixed_precision=mixed_precision,
|
||||||
act_order=act_order,
|
act_order=act_order,
|
||||||
enable_scale_search=enable_scale_search,
|
enable_scale_search=enable_scale_search,
|
||||||
|
|
@ -1094,7 +1081,6 @@ def ggml_convert_low_bit(model, qtype, optimize_model=True,
|
||||||
lightweight_bmm=False, torch_dtype="auto",
|
lightweight_bmm=False, torch_dtype="auto",
|
||||||
imatrix_data=None,
|
imatrix_data=None,
|
||||||
embedding_qtype=None,
|
embedding_qtype=None,
|
||||||
enable_xetla=False,
|
|
||||||
mixed_precision=False):
|
mixed_precision=False):
|
||||||
if qtype in ggml_tensor_qtype.values():
|
if qtype in ggml_tensor_qtype.values():
|
||||||
index = list(ggml_tensor_qtype.values()).index(qtype)
|
index = list(ggml_tensor_qtype.values()).index(qtype)
|
||||||
|
|
@ -1138,7 +1124,6 @@ def ggml_convert_low_bit(model, qtype, optimize_model=True,
|
||||||
embedding_qtype=embedding_qtype,
|
embedding_qtype=embedding_qtype,
|
||||||
model_config=model_config,
|
model_config=model_config,
|
||||||
torch_dtype=torch_dtype,
|
torch_dtype=torch_dtype,
|
||||||
enable_xetla=enable_xetla,
|
|
||||||
mixed_precision=mixed_precision,
|
mixed_precision=mixed_precision,
|
||||||
act_order=act_order,
|
act_order=act_order,
|
||||||
enable_scale_search=enable_scale_search,
|
enable_scale_search=enable_scale_search,
|
||||||
|
|
|
||||||
|
|
@ -92,113 +92,6 @@ RTN_DTYPE = {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
# For sym_int4
|
|
||||||
# The ggml_weight is col major and packs two rows at a stride of Q4_0//2.
|
|
||||||
#
|
|
||||||
# The returning weight is row major and packs two rows at a stride of 16//2.
|
|
||||||
# 16 is the tile_size_y used in mm_xetla, so that we can do something like
|
|
||||||
# new_weight_tile = concat(weight_tile & 0x0F, weight_tile >> 4).
|
|
||||||
#
|
|
||||||
# A more complex packing strategy is to permute the weight so that the
|
|
||||||
# new_weight_tile is directly VNNI packed, but I did not find significant
|
|
||||||
# performance improvement.
|
|
||||||
#
|
|
||||||
# Note this format cannot be used directly in IPEX-LLM's mm_int4, which expects
|
|
||||||
# row major but packing two consecutive columns.
|
|
||||||
#
|
|
||||||
# For fp8, just remove the scales (which are all ones) and transpose
|
|
||||||
def ggml_xpu_to_ipex_llm_xetla(ggml_weight, weight_shape, qtype):
|
|
||||||
if qtype == ggml_tensor_qtype["sym_int4"]:
|
|
||||||
from ipex_llm.transformers.low_bit_linear import get_block_size
|
|
||||||
Q4_0 = get_block_size("sym_int4")
|
|
||||||
|
|
||||||
n, k = weight_shape
|
|
||||||
ggml_weight_only = ggml_weight[:n*k//2]
|
|
||||||
ggml_scales = ggml_weight[n*k//2:]
|
|
||||||
|
|
||||||
qweight = ggml_weight_only.clone()
|
|
||||||
scales = ggml_scales.view(torch.float16).clone()
|
|
||||||
|
|
||||||
qweight_0 = qweight & 0x0F
|
|
||||||
qweight_1 = qweight >> 4
|
|
||||||
|
|
||||||
qweight_0 = qweight_0.reshape(n, -1, Q4_0//2)
|
|
||||||
qweight_1 = qweight_1.reshape(n, -1, Q4_0//2)
|
|
||||||
qweight = torch.cat([qweight_0, qweight_1], dim=-1)
|
|
||||||
qweight = qweight.reshape(n, k//16, 2, 8)
|
|
||||||
qweight = qweight.bitwise_left_shift(
|
|
||||||
torch.tensor([0, 4], dtype=torch.uint8, device=ggml_weight.device).reshape(1, 1, 2, 1))
|
|
||||||
|
|
||||||
qweight = torch.bitwise_or(qweight[:, :, 0, :], qweight[:, :, 1, :])
|
|
||||||
qweight = qweight.reshape(n, k//2)
|
|
||||||
qweight = qweight.transpose(0, 1).contiguous()
|
|
||||||
|
|
||||||
scales = scales.reshape(n, k//Q4_0).transpose(0, 1).contiguous()
|
|
||||||
|
|
||||||
# 119 is the value of 0x77
|
|
||||||
zeros = torch.ones([k//Q4_0, n//2], dtype=torch.uint8, device=ggml_weight.device) * (119)
|
|
||||||
|
|
||||||
qweight_bytes = qweight.view(torch.uint8).view(-1)
|
|
||||||
scales_bytes = scales.view(torch.uint8).view(-1)
|
|
||||||
zeros_bytes = zeros.view(torch.uint8).view(-1)
|
|
||||||
|
|
||||||
weight = torch.concat([qweight_bytes, zeros_bytes, scales_bytes], dim=0)
|
|
||||||
elif qtype == ggml_tensor_qtype["fp8_e5m2"]:
|
|
||||||
n, k = weight_shape
|
|
||||||
weight = ggml_weight[:n*k].view(n, k).transpose(0, 1).contiguous()
|
|
||||||
else:
|
|
||||||
invalidInputError(False, f"Unsupported qtype {qtype}")
|
|
||||||
return weight
|
|
||||||
|
|
||||||
|
|
||||||
def ipex_llm_xetla_to_ggml_xpu(xetla_weight, weight_shape, qtype):
|
|
||||||
from ipex_llm.transformers.low_bit_linear import get_block_size
|
|
||||||
if qtype == ggml_tensor_qtype["sym_int4"]:
|
|
||||||
Q4_0 = get_block_size("sym_int4")
|
|
||||||
n, k = weight_shape
|
|
||||||
weight_size = n*k//2
|
|
||||||
zeros_size = n*k//Q4_0//2
|
|
||||||
scales_size = n*k//Q4_0 * 2
|
|
||||||
xetla_weight_only = xetla_weight[:weight_size]
|
|
||||||
scales_start = weight_size + zeros_size
|
|
||||||
xetla_scales = xetla_weight[scales_start:scales_start+scales_size]
|
|
||||||
|
|
||||||
qweight = xetla_weight_only.clone()
|
|
||||||
scales = xetla_scales.view(torch.float16).clone()
|
|
||||||
|
|
||||||
qweight_0 = qweight & 0x0F
|
|
||||||
qweight_1 = qweight >> 4
|
|
||||||
qweight_0 = qweight_0.reshape(-1, 8, n)
|
|
||||||
qweight_1 = qweight_1.reshape(-1, 8, n)
|
|
||||||
qweight = torch.cat([qweight_0, qweight_1], dim=1)
|
|
||||||
|
|
||||||
qweight = qweight.reshape(k, n).transpose(0, 1).contiguous().reshape(n, k//Q4_0,
|
|
||||||
2, Q4_0//2)
|
|
||||||
qweight = qweight.bitwise_left_shift(
|
|
||||||
torch.tensor([0, 4], dtype=torch.uint8,
|
|
||||||
device=xetla_weight_only.device).reshape(1, 1, 2, 1))
|
|
||||||
|
|
||||||
qweight = torch.bitwise_or(qweight[:, :, 0, :], qweight[:, :, 1, :])
|
|
||||||
qweight = qweight.reshape(n, k//2)
|
|
||||||
|
|
||||||
scales = scales.reshape(k//Q4_0, n).transpose(0, 1).contiguous()
|
|
||||||
|
|
||||||
qweight_bytes = qweight.view(torch.uint8).view(-1)
|
|
||||||
scales_bytes = scales.view(torch.uint8).view(-1)
|
|
||||||
weight = torch.concat([qweight_bytes, scales_bytes], dim=0)
|
|
||||||
elif qtype == ggml_tensor_qtype["fp8_e5m2"]:
|
|
||||||
Q8_0 = get_block_size("fp8_e5m2")
|
|
||||||
n, k = weight_shape
|
|
||||||
qweight = xetla_weight[:n*k].transpose(0, 1).contiguous()
|
|
||||||
scales = torch.ones([n*k//Q8_0], dtype=torch.float, device=xetla_weight.device)
|
|
||||||
qweight_bytes = qweight.view(torch.uint8).view(-1)
|
|
||||||
scales_bytes = scales.view(torch.uint8).view(-1)
|
|
||||||
weight = torch.concat([qweight_bytes, scales_bytes], dim=0)
|
|
||||||
else:
|
|
||||||
invalidInputError(False, f"Unsupported qtype {qtype}")
|
|
||||||
return weight
|
|
||||||
|
|
||||||
|
|
||||||
def get_block_size(qtype: str):
|
def get_block_size(qtype: str):
|
||||||
return ggml.ggml_qk_size(ggml_tensor_qtype[qtype])
|
return ggml.ggml_qk_size(ggml_tensor_qtype[qtype])
|
||||||
|
|
||||||
|
|
@ -422,7 +315,6 @@ class FP4Params(torch.nn.Parameter):
|
||||||
qtype=None,
|
qtype=None,
|
||||||
imatrix=None,
|
imatrix=None,
|
||||||
in_features=None,
|
in_features=None,
|
||||||
enable_xetla=False,
|
|
||||||
enable_scale_search=False):
|
enable_scale_search=False):
|
||||||
if data is None:
|
if data is None:
|
||||||
data = torch.empty(0)
|
data = torch.empty(0)
|
||||||
|
|
@ -435,7 +327,6 @@ class FP4Params(torch.nn.Parameter):
|
||||||
self.convert_shape_only = convert_shape_only
|
self.convert_shape_only = convert_shape_only
|
||||||
self.imatrix = imatrix
|
self.imatrix = imatrix
|
||||||
self.in_features = in_features
|
self.in_features = in_features
|
||||||
self.enable_xetla = enable_xetla
|
|
||||||
self.enable_scale_search = enable_scale_search
|
self.enable_scale_search = enable_scale_search
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
@ -529,8 +420,6 @@ class FP4Params(torch.nn.Parameter):
|
||||||
self.data = ggml_q_format_convet_cpu2xpu(self.data,
|
self.data = ggml_q_format_convet_cpu2xpu(self.data,
|
||||||
reduce(mul, self._shape, 1),
|
reduce(mul, self._shape, 1),
|
||||||
self.qtype)
|
self.qtype)
|
||||||
if self.enable_xetla:
|
|
||||||
self.data = ggml_xpu_to_ipex_llm_xetla(self.data, self._shape, self.qtype)
|
|
||||||
new_param = FP4Params(super().to(device=device,
|
new_param = FP4Params(super().to(device=device,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
non_blocking=non_blocking),
|
non_blocking=non_blocking),
|
||||||
|
|
@ -538,12 +427,7 @@ class FP4Params(torch.nn.Parameter):
|
||||||
quantized=self.quantized,
|
quantized=self.quantized,
|
||||||
_shape=self._shape,
|
_shape=self._shape,
|
||||||
qtype=self.qtype,
|
qtype=self.qtype,
|
||||||
enable_xetla=self.enable_xetla,
|
|
||||||
enable_scale_search=self.enable_scale_search)
|
enable_scale_search=self.enable_scale_search)
|
||||||
if self.enable_xetla:
|
|
||||||
device_type = get_xpu_device_type(new_param.data)
|
|
||||||
invalidInputError(device_type == "pvc",
|
|
||||||
f"xetla is only supported on PVC, but got {device_type}")
|
|
||||||
return new_param
|
return new_param
|
||||||
elif (device is not None and device.type == "cpu" and self.data.device.type == "xpu"):
|
elif (device is not None and device.type == "cpu" and self.data.device.type == "xpu"):
|
||||||
new_param = FP4Params(super().to(device=device,
|
new_param = FP4Params(super().to(device=device,
|
||||||
|
|
@ -553,14 +437,8 @@ class FP4Params(torch.nn.Parameter):
|
||||||
quantized=self.quantized,
|
quantized=self.quantized,
|
||||||
_shape=self._shape,
|
_shape=self._shape,
|
||||||
qtype=self.qtype,
|
qtype=self.qtype,
|
||||||
enable_xetla=self.enable_xetla,
|
|
||||||
enable_scale_search=self.enable_scale_search)
|
enable_scale_search=self.enable_scale_search)
|
||||||
if self.enable_xetla:
|
ggml_xpu = new_param.data
|
||||||
ggml_xpu = ipex_llm_xetla_to_ggml_xpu(new_param.data,
|
|
||||||
new_param._shape,
|
|
||||||
new_param.qtype)
|
|
||||||
else:
|
|
||||||
ggml_xpu = new_param.data
|
|
||||||
new_param.data = ggml_q_format_convet_xpu2cpu(ggml_xpu,
|
new_param.data = ggml_q_format_convet_xpu2cpu(ggml_xpu,
|
||||||
reduce(mul, new_param._shape, 1),
|
reduce(mul, new_param._shape, 1),
|
||||||
new_param.qtype)
|
new_param.qtype)
|
||||||
|
|
@ -573,7 +451,6 @@ class FP4Params(torch.nn.Parameter):
|
||||||
quantized=self.quantized,
|
quantized=self.quantized,
|
||||||
_shape=self._shape,
|
_shape=self._shape,
|
||||||
qtype=self.qtype,
|
qtype=self.qtype,
|
||||||
enable_xetla=self.enable_xetla,
|
|
||||||
enable_scale_search=self.enable_scale_search)
|
enable_scale_search=self.enable_scale_search)
|
||||||
return new_param
|
return new_param
|
||||||
|
|
||||||
|
|
@ -691,14 +568,13 @@ class MatMulLowBitCPU(torch.autograd.Function):
|
||||||
|
|
||||||
class LowBitLinear(nn.Linear):
|
class LowBitLinear(nn.Linear):
|
||||||
def __init__(self, input_features, output_features, qtype, bias=True,
|
def __init__(self, input_features, output_features, qtype, bias=True,
|
||||||
conver_to_half=True, mp_group=None, enable_xetla=False,
|
conver_to_half=True, mp_group=None,
|
||||||
optimize_lm_head=False, act_order=False,
|
optimize_lm_head=False, act_order=False,
|
||||||
enable_scale_search=False):
|
enable_scale_search=False):
|
||||||
super().__init__(input_features, output_features, bias)
|
super().__init__(input_features, output_features, bias)
|
||||||
self.weight = FP4Params(self.weight.data,
|
self.weight = FP4Params(self.weight.data,
|
||||||
requires_grad=False,
|
requires_grad=False,
|
||||||
quantized=False, _shape=None, qtype=qtype,
|
quantized=False, _shape=None, qtype=qtype,
|
||||||
enable_xetla=enable_xetla,
|
|
||||||
enable_scale_search=enable_scale_search)
|
enable_scale_search=enable_scale_search)
|
||||||
self.in_len = input_features
|
self.in_len = input_features
|
||||||
self.out_len = output_features
|
self.out_len = output_features
|
||||||
|
|
@ -708,7 +584,6 @@ class LowBitLinear(nn.Linear):
|
||||||
self.conver_to_half = conver_to_half
|
self.conver_to_half = conver_to_half
|
||||||
self.mp_group = mp_group
|
self.mp_group = mp_group
|
||||||
self.compute_dtype = None # only for training
|
self.compute_dtype = None # only for training
|
||||||
self.enable_xetla = enable_xetla
|
|
||||||
self.optimize_lm_head = optimize_lm_head
|
self.optimize_lm_head = optimize_lm_head
|
||||||
self.device = None # detected only once in the first forward
|
self.device = None # detected only once in the first forward
|
||||||
# empty cache before and after lm_head at first token (by default on arc)
|
# empty cache before and after lm_head at first token (by default on arc)
|
||||||
|
|
@ -799,9 +674,6 @@ class LowBitLinear(nn.Linear):
|
||||||
self.weight.data,
|
self.weight.data,
|
||||||
self.weight.qtype,
|
self.weight.qtype,
|
||||||
input_seq_size)
|
input_seq_size)
|
||||||
elif self.enable_xetla:
|
|
||||||
x_2d = x_2d.half()
|
|
||||||
result = xe_linear.mm_xetla(x_2d, self.weight.data, self.qtype)
|
|
||||||
else:
|
else:
|
||||||
# inference path
|
# inference path
|
||||||
# current workaround to reduce first token latency of fp32 input
|
# current workaround to reduce first token latency of fp32 input
|
||||||
|
|
@ -880,8 +752,7 @@ class LowBitLinear(nn.Linear):
|
||||||
|
|
||||||
class FP16Linear(nn.Linear):
|
class FP16Linear(nn.Linear):
|
||||||
def __init__(self, input_features, output_features, bias=True,
|
def __init__(self, input_features, output_features, bias=True,
|
||||||
mp_group=None, weight_type=1, enable_xetla=False,
|
mp_group=None, weight_type=1, optimize_lm_head=False):
|
||||||
optimize_lm_head=False):
|
|
||||||
super().__init__(input_features, output_features, bias)
|
super().__init__(input_features, output_features, bias)
|
||||||
self.in_len = input_features
|
self.in_len = input_features
|
||||||
self.out_len = output_features
|
self.out_len = output_features
|
||||||
|
|
@ -894,7 +765,6 @@ class FP16Linear(nn.Linear):
|
||||||
# weigh_type = 3 means weight has been transposed by esimd method
|
# weigh_type = 3 means weight has been transposed by esimd method
|
||||||
self.weight_type = 1
|
self.weight_type = 1
|
||||||
self.optimize_lm_head = optimize_lm_head
|
self.optimize_lm_head = optimize_lm_head
|
||||||
self.enable_xetla = enable_xetla
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor):
|
def forward(self, x: torch.Tensor):
|
||||||
# only work for GPU
|
# only work for GPU
|
||||||
|
|
@ -1010,8 +880,7 @@ class FP16Linear(nn.Linear):
|
||||||
|
|
||||||
class BF16Linear(nn.Linear):
|
class BF16Linear(nn.Linear):
|
||||||
def __init__(self, input_features, output_features, bias=True,
|
def __init__(self, input_features, output_features, bias=True,
|
||||||
mp_group=None, compute_dtype=None, enable_xetla=False,
|
mp_group=None, compute_dtype=None, optimize_lm_head=False):
|
||||||
optimize_lm_head=False):
|
|
||||||
super().__init__(input_features, output_features, bias)
|
super().__init__(input_features, output_features, bias)
|
||||||
self.in_len = input_features
|
self.in_len = input_features
|
||||||
self.out_len = output_features
|
self.out_len = output_features
|
||||||
|
|
@ -1021,7 +890,6 @@ class BF16Linear(nn.Linear):
|
||||||
self.mp_group = mp_group
|
self.mp_group = mp_group
|
||||||
self.compute_dtype = compute_dtype
|
self.compute_dtype = compute_dtype
|
||||||
self.optimize_lm_head = optimize_lm_head
|
self.optimize_lm_head = optimize_lm_head
|
||||||
self.enable_xetla = enable_xetla
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor):
|
def forward(self, x: torch.Tensor):
|
||||||
if self.optimize_lm_head:
|
if self.optimize_lm_head:
|
||||||
|
|
@ -1050,11 +918,11 @@ class BF16Linear(nn.Linear):
|
||||||
|
|
||||||
class vLLMLowBitLinear(LowBitLinear):
|
class vLLMLowBitLinear(LowBitLinear):
|
||||||
def __init__(self, input_features, output_features, qtype, bias=True,
|
def __init__(self, input_features, output_features, qtype, bias=True,
|
||||||
conver_to_half=True, mp_group=None, enable_xetla=False,
|
conver_to_half=True, mp_group=None,
|
||||||
optimize_lm_head=False, act_order=False,
|
optimize_lm_head=False, act_order=False,
|
||||||
enable_scale_search=False):
|
enable_scale_search=False):
|
||||||
super().__init__(input_features, output_features, qtype, bias, conver_to_half, mp_group,
|
super().__init__(input_features, output_features, qtype, bias, conver_to_half, mp_group,
|
||||||
enable_xetla, optimize_lm_head, act_order, enable_scale_search)
|
optimize_lm_head, act_order, enable_scale_search)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor):
|
def forward(self, x: torch.Tensor):
|
||||||
result = super().forward(x)
|
result = super().forward(x)
|
||||||
|
|
@ -1063,9 +931,9 @@ class vLLMLowBitLinear(LowBitLinear):
|
||||||
|
|
||||||
class vLLMFP16Linear(FP16Linear):
|
class vLLMFP16Linear(FP16Linear):
|
||||||
def __init__(self, input_features, output_features, bias=True, mp_group=None, weight_type=1,
|
def __init__(self, input_features, output_features, bias=True, mp_group=None, weight_type=1,
|
||||||
enable_xetla=False, optimize_lm_head=False):
|
optimize_lm_head=False):
|
||||||
super().__init__(input_features, output_features, bias, mp_group, weight_type,
|
super().__init__(input_features, output_features, bias, mp_group, weight_type,
|
||||||
enable_xetla, optimize_lm_head)
|
optimize_lm_head)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor):
|
def forward(self, x: torch.Tensor):
|
||||||
result = super().forward(x)
|
result = super().forward(x)
|
||||||
|
|
@ -1074,9 +942,9 @@ class vLLMFP16Linear(FP16Linear):
|
||||||
|
|
||||||
class vLLMBF16Linear(BF16Linear):
|
class vLLMBF16Linear(BF16Linear):
|
||||||
def __init__(self, input_features, output_features, bias=True, mp_group=None,
|
def __init__(self, input_features, output_features, bias=True, mp_group=None,
|
||||||
compute_dtype=None, enable_xetla=False, optimize_lm_head=False):
|
compute_dtype=None, optimize_lm_head=False):
|
||||||
super().__init__(input_features, output_features, bias, mp_group, compute_dtype,
|
super().__init__(input_features, output_features, bias, mp_group, compute_dtype,
|
||||||
enable_xetla, optimize_lm_head)
|
optimize_lm_head)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor):
|
def forward(self, x: torch.Tensor):
|
||||||
result = super().forward(x)
|
result = super().forward(x)
|
||||||
|
|
|
||||||
|
|
@ -448,7 +448,6 @@ class _BaseAutoModelClass:
|
||||||
mixed_precision = kwargs.pop("mixed_precision", False)
|
mixed_precision = kwargs.pop("mixed_precision", False)
|
||||||
if embedding_qtype is not None:
|
if embedding_qtype is not None:
|
||||||
embedding_qtype = ggml_tensor_qtype[embedding_qtype]
|
embedding_qtype = ggml_tensor_qtype[embedding_qtype]
|
||||||
enable_xetla = kwargs.pop("enable_xetla", False)
|
|
||||||
_args = copy.deepcopy(args)
|
_args = copy.deepcopy(args)
|
||||||
_kwargs = copy.deepcopy(kwargs)
|
_kwargs = copy.deepcopy(kwargs)
|
||||||
awq_config = None
|
awq_config = None
|
||||||
|
|
@ -518,7 +517,6 @@ class _BaseAutoModelClass:
|
||||||
torch_dtype=kwargs.get("torch_dtype", 'auto'),
|
torch_dtype=kwargs.get("torch_dtype", 'auto'),
|
||||||
imatrix_data=imatrix_data,
|
imatrix_data=imatrix_data,
|
||||||
embedding_qtype=embedding_qtype,
|
embedding_qtype=embedding_qtype,
|
||||||
enable_xetla=enable_xetla,
|
|
||||||
mixed_precision=mixed_precision)
|
mixed_precision=mixed_precision)
|
||||||
|
|
||||||
if disk_embedding:
|
if disk_embedding:
|
||||||
|
|
|
||||||
|
|
@ -67,7 +67,7 @@ def baichuan_mlp_forward(
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
x_2d = x.view(-1, x.shape[-1])
|
x_2d = x.view(-1, x.shape[-1])
|
||||||
qtype = getattr(self.gate_proj, "qtype", None)
|
qtype = getattr(self.gate_proj, "qtype", None)
|
||||||
if mlp_fusion_check(x_2d, qtype, self.training) and not self.down_proj.enable_xetla:
|
if mlp_fusion_check(x_2d, qtype, self.training):
|
||||||
import xe_linear
|
import xe_linear
|
||||||
if not x_2d.is_contiguous():
|
if not x_2d.is_contiguous():
|
||||||
x_2d = x_2d.contiguous()
|
x_2d = x_2d.contiguous()
|
||||||
|
|
|
||||||
|
|
@ -380,7 +380,7 @@ def mixtral_mlp_forward(
|
||||||
routing_weights
|
routing_weights
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
qtype = getattr(self.w1, "qtype", None)
|
qtype = getattr(self.w1, "qtype", None)
|
||||||
if mlp_fusion_check(x, qtype, self.training) and not self.w1.enable_xetla:
|
if mlp_fusion_check(x, qtype, self.training):
|
||||||
import xe_linear
|
import xe_linear
|
||||||
return self.w2(xe_linear.mlp_forward_xpu(
|
return self.w2(xe_linear.mlp_forward_xpu(
|
||||||
x, self.w1.weight.data, self.w3.weight.data,
|
x, self.w1.weight.data, self.w3.weight.data,
|
||||||
|
|
|
||||||
|
|
@ -259,7 +259,7 @@ def qwen_attention_forward_registered(
|
||||||
def qwen_mlp_forward(self, x: torch.Tensor) -> torch.Tensor:
|
def qwen_mlp_forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
x_2d = x.view(-1, x.shape[-1])
|
x_2d = x.view(-1, x.shape[-1])
|
||||||
qtype = getattr(self.w1, "qtype", None)
|
qtype = getattr(self.w1, "qtype", None)
|
||||||
if mlp_fusion_check(x_2d, qtype, self.training) and not self.w1.enable_xetla:
|
if mlp_fusion_check(x_2d, qtype, self.training):
|
||||||
import xe_linear
|
import xe_linear
|
||||||
if not x_2d.is_contiguous():
|
if not x_2d.is_contiguous():
|
||||||
x_2d = x_2d.contiguous()
|
x_2d = x_2d.contiguous()
|
||||||
|
|
|
||||||
|
|
@ -612,7 +612,7 @@ def qwen2_mlp_forward(
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
x_2d = x.view(-1, x.shape[-1])
|
x_2d = x.view(-1, x.shape[-1])
|
||||||
qtype = getattr(self.gate_proj, "qtype", None)
|
qtype = getattr(self.gate_proj, "qtype", None)
|
||||||
if mlp_fusion_check(x_2d, qtype, self.training) and not self.down_proj.enable_xetla:
|
if mlp_fusion_check(x_2d, qtype, self.training):
|
||||||
import xe_linear
|
import xe_linear
|
||||||
return self.down_proj(xe_linear.mlp_forward_xpu(
|
return self.down_proj(xe_linear.mlp_forward_xpu(
|
||||||
x_2d, self.gate_proj.weight.data, self.up_proj.weight.data,
|
x_2d, self.gate_proj.weight.data, self.up_proj.weight.data,
|
||||||
|
|
|
||||||
|
|
@ -337,8 +337,7 @@ def use_decoding_fast_path(proj,
|
||||||
return False
|
return False
|
||||||
if bs != 1:
|
if bs != 1:
|
||||||
return False
|
return False
|
||||||
if proj.enable_xetla:
|
|
||||||
return False
|
|
||||||
if device in ["uhd"]:
|
if device in ["uhd"]:
|
||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue