remove unused code again (#12624)
This commit is contained in:
parent
46eeab4479
commit
c72a5db757
10 changed files with 19 additions and 171 deletions
|
|
@ -92,8 +92,7 @@ def train(
|
|||
load_in_low_bit="bf16",
|
||||
optimize_model=True,
|
||||
torch_dtype=torch.bfloat16,
|
||||
trust_remote_code=True,
|
||||
enable_xetla=False
|
||||
trust_remote_code=True
|
||||
)
|
||||
|
||||
model = model.to("xpu")
|
||||
|
|
@ -156,7 +155,7 @@ def train(
|
|||
callbacks=trainer_callbacks
|
||||
)
|
||||
model.config.use_cache = False
|
||||
|
||||
|
||||
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
||||
|
||||
# model.save_pretrained(output_dir)
|
||||
|
|
|
|||
|
|
@ -257,8 +257,7 @@ def optimize_model(model, low_bit='sym_int4', optimize_llm=True, modules_to_not_
|
|||
optimize_model=optimize_llm,
|
||||
modules_to_not_convert=modules_to_not_convert,
|
||||
cpu_embedding=cpu_embedding,
|
||||
lightweight_bmm=lightweight_bmm,
|
||||
enable_xetla=kwargs.pop("enable_xetla", False))
|
||||
lightweight_bmm=lightweight_bmm)
|
||||
# add save_low_bit to pretrained model dynamically
|
||||
import types
|
||||
model._bigdl_config = dict()
|
||||
|
|
|
|||
|
|
@ -232,7 +232,7 @@ def is_linear_module(module):
|
|||
|
||||
|
||||
def convert_vllm(module, qtype, in_features, out_features, mp_group, cur_qtype,
|
||||
enable_xetla, optimize_lm_head, enable_scale_search):
|
||||
optimize_lm_head, enable_scale_search):
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from ipex_llm.transformers.low_bit_linear import LowBitLinear, \
|
||||
FP16Linear, BF16Linear, vLLMLowBitLinear, vLLMFP16Linear, vLLMBF16Linear
|
||||
|
|
@ -261,7 +261,6 @@ def convert_vllm(module, qtype, in_features, out_features, mp_group, cur_qtype,
|
|||
cur_qtype,
|
||||
module.bias is not None,
|
||||
mp_group=mp_group,
|
||||
enable_xetla=enable_xetla,
|
||||
optimize_lm_head=optimize_lm_head,
|
||||
enable_scale_search=enable_scale_search,
|
||||
)
|
||||
|
|
@ -289,7 +288,6 @@ def convert_vllm(module, qtype, in_features, out_features, mp_group, cur_qtype,
|
|||
cur_qtype,
|
||||
module.bias is not None,
|
||||
mp_group=mp_group,
|
||||
enable_xetla=enable_xetla,
|
||||
optimize_lm_head=optimize_lm_head,
|
||||
enable_scale_search=enable_scale_search,
|
||||
)
|
||||
|
|
@ -473,7 +471,6 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
|
|||
prefix_name='',
|
||||
imatrix_data=None, embedding_qtype=None,
|
||||
model_config=None, torch_dtype=torch.float32,
|
||||
enable_xetla=False,
|
||||
mixed_precision=False,
|
||||
act_order=False,
|
||||
enable_scale_search=False,
|
||||
|
|
@ -523,7 +520,6 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
|
|||
qtype=qtype,
|
||||
bias=has_bias,
|
||||
mp_group=mp_group,
|
||||
enable_xetla=enable_xetla,
|
||||
optimize_lm_head=optimize_lm_head,
|
||||
act_order=act_order,
|
||||
enable_scale_search=enable_scale_search,
|
||||
|
|
@ -544,7 +540,6 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
|
|||
_shape=(out_features, in_features),
|
||||
convert_shape_only=convert_shape_only,
|
||||
qtype=qtype,
|
||||
enable_xetla=enable_xetla,
|
||||
enable_scale_search=enable_scale_search).to(device)
|
||||
new_linear._parameters['weight'] = paramsLowBit
|
||||
if has_bias:
|
||||
|
|
@ -562,7 +557,6 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
|
|||
qtype=qtype,
|
||||
bias=has_bias,
|
||||
mp_group=mp_group,
|
||||
enable_xetla=enable_xetla,
|
||||
optimize_lm_head=False,
|
||||
act_order=act_order,
|
||||
enable_scale_search=enable_scale_search,
|
||||
|
|
@ -581,7 +575,6 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
|
|||
qtype=cur_qtype,
|
||||
imatrix=cur_imatrix,
|
||||
in_features=in_features,
|
||||
enable_xetla=enable_xetla,
|
||||
enable_scale_search=enable_scale_search).to(device)
|
||||
else:
|
||||
new_linear = vLLMLowBitLinear(
|
||||
|
|
@ -590,7 +583,6 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
|
|||
qtype=qtype,
|
||||
bias=has_bias,
|
||||
mp_group=mp_group,
|
||||
enable_xetla=enable_xetla,
|
||||
optimize_lm_head=False,
|
||||
act_order=act_order,
|
||||
enable_scale_search=enable_scale_search,
|
||||
|
|
@ -609,7 +601,6 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
|
|||
_shape=(out_features, in_features),
|
||||
convert_shape_only=convert_shape_only,
|
||||
qtype=qtype,
|
||||
enable_xetla=enable_xetla,
|
||||
enable_scale_search=enable_scale_search).to(device)
|
||||
new_linear._parameters['weight'] = paramsLowBit
|
||||
if has_bias:
|
||||
|
|
@ -639,7 +630,6 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
|
|||
out_features,
|
||||
mp_group,
|
||||
cur_qtype,
|
||||
enable_xetla,
|
||||
optimize_lm_head,
|
||||
enable_scale_search)
|
||||
else:
|
||||
|
|
@ -649,7 +639,6 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
|
|||
cur_qtype,
|
||||
module.bias is not None,
|
||||
mp_group=mp_group,
|
||||
enable_xetla=enable_xetla,
|
||||
optimize_lm_head=optimize_lm_head,
|
||||
enable_scale_search=enable_scale_search,
|
||||
)
|
||||
|
|
@ -663,7 +652,6 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
|
|||
qtype=cur_qtype,
|
||||
imatrix=cur_imatrix,
|
||||
in_features=in_features,
|
||||
enable_xetla=enable_xetla,
|
||||
enable_scale_search=enable_scale_search).to(device)
|
||||
new_linear._parameters['weight'] = paramsLowBit
|
||||
if module.bias is not None:
|
||||
|
|
@ -762,7 +750,6 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
|
|||
embedding_qtype=embedding_qtype,
|
||||
model_config=model_config,
|
||||
torch_dtype=torch_dtype,
|
||||
enable_xetla=enable_xetla,
|
||||
mixed_precision=mixed_precision,
|
||||
act_order=act_order,
|
||||
enable_scale_search=enable_scale_search,
|
||||
|
|
@ -1094,7 +1081,6 @@ def ggml_convert_low_bit(model, qtype, optimize_model=True,
|
|||
lightweight_bmm=False, torch_dtype="auto",
|
||||
imatrix_data=None,
|
||||
embedding_qtype=None,
|
||||
enable_xetla=False,
|
||||
mixed_precision=False):
|
||||
if qtype in ggml_tensor_qtype.values():
|
||||
index = list(ggml_tensor_qtype.values()).index(qtype)
|
||||
|
|
@ -1138,7 +1124,6 @@ def ggml_convert_low_bit(model, qtype, optimize_model=True,
|
|||
embedding_qtype=embedding_qtype,
|
||||
model_config=model_config,
|
||||
torch_dtype=torch_dtype,
|
||||
enable_xetla=enable_xetla,
|
||||
mixed_precision=mixed_precision,
|
||||
act_order=act_order,
|
||||
enable_scale_search=enable_scale_search,
|
||||
|
|
|
|||
|
|
@ -92,113 +92,6 @@ RTN_DTYPE = {
|
|||
}
|
||||
|
||||
|
||||
# For sym_int4
|
||||
# The ggml_weight is col major and packs two rows at a stride of Q4_0//2.
|
||||
#
|
||||
# The returning weight is row major and packs two rows at a stride of 16//2.
|
||||
# 16 is the tile_size_y used in mm_xetla, so that we can do something like
|
||||
# new_weight_tile = concat(weight_tile & 0x0F, weight_tile >> 4).
|
||||
#
|
||||
# A more complex packing strategy is to permute the weight so that the
|
||||
# new_weight_tile is directly VNNI packed, but I did not find significant
|
||||
# performance improvement.
|
||||
#
|
||||
# Note this format cannot be used directly in IPEX-LLM's mm_int4, which expects
|
||||
# row major but packing two consecutive columns.
|
||||
#
|
||||
# For fp8, just remove the scales (which are all ones) and transpose
|
||||
def ggml_xpu_to_ipex_llm_xetla(ggml_weight, weight_shape, qtype):
|
||||
if qtype == ggml_tensor_qtype["sym_int4"]:
|
||||
from ipex_llm.transformers.low_bit_linear import get_block_size
|
||||
Q4_0 = get_block_size("sym_int4")
|
||||
|
||||
n, k = weight_shape
|
||||
ggml_weight_only = ggml_weight[:n*k//2]
|
||||
ggml_scales = ggml_weight[n*k//2:]
|
||||
|
||||
qweight = ggml_weight_only.clone()
|
||||
scales = ggml_scales.view(torch.float16).clone()
|
||||
|
||||
qweight_0 = qweight & 0x0F
|
||||
qweight_1 = qweight >> 4
|
||||
|
||||
qweight_0 = qweight_0.reshape(n, -1, Q4_0//2)
|
||||
qweight_1 = qweight_1.reshape(n, -1, Q4_0//2)
|
||||
qweight = torch.cat([qweight_0, qweight_1], dim=-1)
|
||||
qweight = qweight.reshape(n, k//16, 2, 8)
|
||||
qweight = qweight.bitwise_left_shift(
|
||||
torch.tensor([0, 4], dtype=torch.uint8, device=ggml_weight.device).reshape(1, 1, 2, 1))
|
||||
|
||||
qweight = torch.bitwise_or(qweight[:, :, 0, :], qweight[:, :, 1, :])
|
||||
qweight = qweight.reshape(n, k//2)
|
||||
qweight = qweight.transpose(0, 1).contiguous()
|
||||
|
||||
scales = scales.reshape(n, k//Q4_0).transpose(0, 1).contiguous()
|
||||
|
||||
# 119 is the value of 0x77
|
||||
zeros = torch.ones([k//Q4_0, n//2], dtype=torch.uint8, device=ggml_weight.device) * (119)
|
||||
|
||||
qweight_bytes = qweight.view(torch.uint8).view(-1)
|
||||
scales_bytes = scales.view(torch.uint8).view(-1)
|
||||
zeros_bytes = zeros.view(torch.uint8).view(-1)
|
||||
|
||||
weight = torch.concat([qweight_bytes, zeros_bytes, scales_bytes], dim=0)
|
||||
elif qtype == ggml_tensor_qtype["fp8_e5m2"]:
|
||||
n, k = weight_shape
|
||||
weight = ggml_weight[:n*k].view(n, k).transpose(0, 1).contiguous()
|
||||
else:
|
||||
invalidInputError(False, f"Unsupported qtype {qtype}")
|
||||
return weight
|
||||
|
||||
|
||||
def ipex_llm_xetla_to_ggml_xpu(xetla_weight, weight_shape, qtype):
|
||||
from ipex_llm.transformers.low_bit_linear import get_block_size
|
||||
if qtype == ggml_tensor_qtype["sym_int4"]:
|
||||
Q4_0 = get_block_size("sym_int4")
|
||||
n, k = weight_shape
|
||||
weight_size = n*k//2
|
||||
zeros_size = n*k//Q4_0//2
|
||||
scales_size = n*k//Q4_0 * 2
|
||||
xetla_weight_only = xetla_weight[:weight_size]
|
||||
scales_start = weight_size + zeros_size
|
||||
xetla_scales = xetla_weight[scales_start:scales_start+scales_size]
|
||||
|
||||
qweight = xetla_weight_only.clone()
|
||||
scales = xetla_scales.view(torch.float16).clone()
|
||||
|
||||
qweight_0 = qweight & 0x0F
|
||||
qweight_1 = qweight >> 4
|
||||
qweight_0 = qweight_0.reshape(-1, 8, n)
|
||||
qweight_1 = qweight_1.reshape(-1, 8, n)
|
||||
qweight = torch.cat([qweight_0, qweight_1], dim=1)
|
||||
|
||||
qweight = qweight.reshape(k, n).transpose(0, 1).contiguous().reshape(n, k//Q4_0,
|
||||
2, Q4_0//2)
|
||||
qweight = qweight.bitwise_left_shift(
|
||||
torch.tensor([0, 4], dtype=torch.uint8,
|
||||
device=xetla_weight_only.device).reshape(1, 1, 2, 1))
|
||||
|
||||
qweight = torch.bitwise_or(qweight[:, :, 0, :], qweight[:, :, 1, :])
|
||||
qweight = qweight.reshape(n, k//2)
|
||||
|
||||
scales = scales.reshape(k//Q4_0, n).transpose(0, 1).contiguous()
|
||||
|
||||
qweight_bytes = qweight.view(torch.uint8).view(-1)
|
||||
scales_bytes = scales.view(torch.uint8).view(-1)
|
||||
weight = torch.concat([qweight_bytes, scales_bytes], dim=0)
|
||||
elif qtype == ggml_tensor_qtype["fp8_e5m2"]:
|
||||
Q8_0 = get_block_size("fp8_e5m2")
|
||||
n, k = weight_shape
|
||||
qweight = xetla_weight[:n*k].transpose(0, 1).contiguous()
|
||||
scales = torch.ones([n*k//Q8_0], dtype=torch.float, device=xetla_weight.device)
|
||||
qweight_bytes = qweight.view(torch.uint8).view(-1)
|
||||
scales_bytes = scales.view(torch.uint8).view(-1)
|
||||
weight = torch.concat([qweight_bytes, scales_bytes], dim=0)
|
||||
else:
|
||||
invalidInputError(False, f"Unsupported qtype {qtype}")
|
||||
return weight
|
||||
|
||||
|
||||
def get_block_size(qtype: str):
|
||||
return ggml.ggml_qk_size(ggml_tensor_qtype[qtype])
|
||||
|
||||
|
|
@ -422,7 +315,6 @@ class FP4Params(torch.nn.Parameter):
|
|||
qtype=None,
|
||||
imatrix=None,
|
||||
in_features=None,
|
||||
enable_xetla=False,
|
||||
enable_scale_search=False):
|
||||
if data is None:
|
||||
data = torch.empty(0)
|
||||
|
|
@ -435,7 +327,6 @@ class FP4Params(torch.nn.Parameter):
|
|||
self.convert_shape_only = convert_shape_only
|
||||
self.imatrix = imatrix
|
||||
self.in_features = in_features
|
||||
self.enable_xetla = enable_xetla
|
||||
self.enable_scale_search = enable_scale_search
|
||||
return self
|
||||
|
||||
|
|
@ -529,8 +420,6 @@ class FP4Params(torch.nn.Parameter):
|
|||
self.data = ggml_q_format_convet_cpu2xpu(self.data,
|
||||
reduce(mul, self._shape, 1),
|
||||
self.qtype)
|
||||
if self.enable_xetla:
|
||||
self.data = ggml_xpu_to_ipex_llm_xetla(self.data, self._shape, self.qtype)
|
||||
new_param = FP4Params(super().to(device=device,
|
||||
dtype=dtype,
|
||||
non_blocking=non_blocking),
|
||||
|
|
@ -538,12 +427,7 @@ class FP4Params(torch.nn.Parameter):
|
|||
quantized=self.quantized,
|
||||
_shape=self._shape,
|
||||
qtype=self.qtype,
|
||||
enable_xetla=self.enable_xetla,
|
||||
enable_scale_search=self.enable_scale_search)
|
||||
if self.enable_xetla:
|
||||
device_type = get_xpu_device_type(new_param.data)
|
||||
invalidInputError(device_type == "pvc",
|
||||
f"xetla is only supported on PVC, but got {device_type}")
|
||||
return new_param
|
||||
elif (device is not None and device.type == "cpu" and self.data.device.type == "xpu"):
|
||||
new_param = FP4Params(super().to(device=device,
|
||||
|
|
@ -553,14 +437,8 @@ class FP4Params(torch.nn.Parameter):
|
|||
quantized=self.quantized,
|
||||
_shape=self._shape,
|
||||
qtype=self.qtype,
|
||||
enable_xetla=self.enable_xetla,
|
||||
enable_scale_search=self.enable_scale_search)
|
||||
if self.enable_xetla:
|
||||
ggml_xpu = ipex_llm_xetla_to_ggml_xpu(new_param.data,
|
||||
new_param._shape,
|
||||
new_param.qtype)
|
||||
else:
|
||||
ggml_xpu = new_param.data
|
||||
ggml_xpu = new_param.data
|
||||
new_param.data = ggml_q_format_convet_xpu2cpu(ggml_xpu,
|
||||
reduce(mul, new_param._shape, 1),
|
||||
new_param.qtype)
|
||||
|
|
@ -573,7 +451,6 @@ class FP4Params(torch.nn.Parameter):
|
|||
quantized=self.quantized,
|
||||
_shape=self._shape,
|
||||
qtype=self.qtype,
|
||||
enable_xetla=self.enable_xetla,
|
||||
enable_scale_search=self.enable_scale_search)
|
||||
return new_param
|
||||
|
||||
|
|
@ -691,14 +568,13 @@ class MatMulLowBitCPU(torch.autograd.Function):
|
|||
|
||||
class LowBitLinear(nn.Linear):
|
||||
def __init__(self, input_features, output_features, qtype, bias=True,
|
||||
conver_to_half=True, mp_group=None, enable_xetla=False,
|
||||
conver_to_half=True, mp_group=None,
|
||||
optimize_lm_head=False, act_order=False,
|
||||
enable_scale_search=False):
|
||||
super().__init__(input_features, output_features, bias)
|
||||
self.weight = FP4Params(self.weight.data,
|
||||
requires_grad=False,
|
||||
quantized=False, _shape=None, qtype=qtype,
|
||||
enable_xetla=enable_xetla,
|
||||
enable_scale_search=enable_scale_search)
|
||||
self.in_len = input_features
|
||||
self.out_len = output_features
|
||||
|
|
@ -708,7 +584,6 @@ class LowBitLinear(nn.Linear):
|
|||
self.conver_to_half = conver_to_half
|
||||
self.mp_group = mp_group
|
||||
self.compute_dtype = None # only for training
|
||||
self.enable_xetla = enable_xetla
|
||||
self.optimize_lm_head = optimize_lm_head
|
||||
self.device = None # detected only once in the first forward
|
||||
# empty cache before and after lm_head at first token (by default on arc)
|
||||
|
|
@ -799,9 +674,6 @@ class LowBitLinear(nn.Linear):
|
|||
self.weight.data,
|
||||
self.weight.qtype,
|
||||
input_seq_size)
|
||||
elif self.enable_xetla:
|
||||
x_2d = x_2d.half()
|
||||
result = xe_linear.mm_xetla(x_2d, self.weight.data, self.qtype)
|
||||
else:
|
||||
# inference path
|
||||
# current workaround to reduce first token latency of fp32 input
|
||||
|
|
@ -880,8 +752,7 @@ class LowBitLinear(nn.Linear):
|
|||
|
||||
class FP16Linear(nn.Linear):
|
||||
def __init__(self, input_features, output_features, bias=True,
|
||||
mp_group=None, weight_type=1, enable_xetla=False,
|
||||
optimize_lm_head=False):
|
||||
mp_group=None, weight_type=1, optimize_lm_head=False):
|
||||
super().__init__(input_features, output_features, bias)
|
||||
self.in_len = input_features
|
||||
self.out_len = output_features
|
||||
|
|
@ -894,7 +765,6 @@ class FP16Linear(nn.Linear):
|
|||
# weigh_type = 3 means weight has been transposed by esimd method
|
||||
self.weight_type = 1
|
||||
self.optimize_lm_head = optimize_lm_head
|
||||
self.enable_xetla = enable_xetla
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
# only work for GPU
|
||||
|
|
@ -1010,8 +880,7 @@ class FP16Linear(nn.Linear):
|
|||
|
||||
class BF16Linear(nn.Linear):
|
||||
def __init__(self, input_features, output_features, bias=True,
|
||||
mp_group=None, compute_dtype=None, enable_xetla=False,
|
||||
optimize_lm_head=False):
|
||||
mp_group=None, compute_dtype=None, optimize_lm_head=False):
|
||||
super().__init__(input_features, output_features, bias)
|
||||
self.in_len = input_features
|
||||
self.out_len = output_features
|
||||
|
|
@ -1021,7 +890,6 @@ class BF16Linear(nn.Linear):
|
|||
self.mp_group = mp_group
|
||||
self.compute_dtype = compute_dtype
|
||||
self.optimize_lm_head = optimize_lm_head
|
||||
self.enable_xetla = enable_xetla
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
if self.optimize_lm_head:
|
||||
|
|
@ -1050,11 +918,11 @@ class BF16Linear(nn.Linear):
|
|||
|
||||
class vLLMLowBitLinear(LowBitLinear):
|
||||
def __init__(self, input_features, output_features, qtype, bias=True,
|
||||
conver_to_half=True, mp_group=None, enable_xetla=False,
|
||||
conver_to_half=True, mp_group=None,
|
||||
optimize_lm_head=False, act_order=False,
|
||||
enable_scale_search=False):
|
||||
super().__init__(input_features, output_features, qtype, bias, conver_to_half, mp_group,
|
||||
enable_xetla, optimize_lm_head, act_order, enable_scale_search)
|
||||
optimize_lm_head, act_order, enable_scale_search)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
result = super().forward(x)
|
||||
|
|
@ -1063,9 +931,9 @@ class vLLMLowBitLinear(LowBitLinear):
|
|||
|
||||
class vLLMFP16Linear(FP16Linear):
|
||||
def __init__(self, input_features, output_features, bias=True, mp_group=None, weight_type=1,
|
||||
enable_xetla=False, optimize_lm_head=False):
|
||||
optimize_lm_head=False):
|
||||
super().__init__(input_features, output_features, bias, mp_group, weight_type,
|
||||
enable_xetla, optimize_lm_head)
|
||||
optimize_lm_head)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
result = super().forward(x)
|
||||
|
|
@ -1074,9 +942,9 @@ class vLLMFP16Linear(FP16Linear):
|
|||
|
||||
class vLLMBF16Linear(BF16Linear):
|
||||
def __init__(self, input_features, output_features, bias=True, mp_group=None,
|
||||
compute_dtype=None, enable_xetla=False, optimize_lm_head=False):
|
||||
compute_dtype=None, optimize_lm_head=False):
|
||||
super().__init__(input_features, output_features, bias, mp_group, compute_dtype,
|
||||
enable_xetla, optimize_lm_head)
|
||||
optimize_lm_head)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
result = super().forward(x)
|
||||
|
|
|
|||
|
|
@ -448,7 +448,6 @@ class _BaseAutoModelClass:
|
|||
mixed_precision = kwargs.pop("mixed_precision", False)
|
||||
if embedding_qtype is not None:
|
||||
embedding_qtype = ggml_tensor_qtype[embedding_qtype]
|
||||
enable_xetla = kwargs.pop("enable_xetla", False)
|
||||
_args = copy.deepcopy(args)
|
||||
_kwargs = copy.deepcopy(kwargs)
|
||||
awq_config = None
|
||||
|
|
@ -518,7 +517,6 @@ class _BaseAutoModelClass:
|
|||
torch_dtype=kwargs.get("torch_dtype", 'auto'),
|
||||
imatrix_data=imatrix_data,
|
||||
embedding_qtype=embedding_qtype,
|
||||
enable_xetla=enable_xetla,
|
||||
mixed_precision=mixed_precision)
|
||||
|
||||
if disk_embedding:
|
||||
|
|
|
|||
|
|
@ -67,7 +67,7 @@ def baichuan_mlp_forward(
|
|||
) -> torch.Tensor:
|
||||
x_2d = x.view(-1, x.shape[-1])
|
||||
qtype = getattr(self.gate_proj, "qtype", None)
|
||||
if mlp_fusion_check(x_2d, qtype, self.training) and not self.down_proj.enable_xetla:
|
||||
if mlp_fusion_check(x_2d, qtype, self.training):
|
||||
import xe_linear
|
||||
if not x_2d.is_contiguous():
|
||||
x_2d = x_2d.contiguous()
|
||||
|
|
|
|||
|
|
@ -380,7 +380,7 @@ def mixtral_mlp_forward(
|
|||
routing_weights
|
||||
) -> torch.Tensor:
|
||||
qtype = getattr(self.w1, "qtype", None)
|
||||
if mlp_fusion_check(x, qtype, self.training) and not self.w1.enable_xetla:
|
||||
if mlp_fusion_check(x, qtype, self.training):
|
||||
import xe_linear
|
||||
return self.w2(xe_linear.mlp_forward_xpu(
|
||||
x, self.w1.weight.data, self.w3.weight.data,
|
||||
|
|
|
|||
|
|
@ -259,7 +259,7 @@ def qwen_attention_forward_registered(
|
|||
def qwen_mlp_forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x_2d = x.view(-1, x.shape[-1])
|
||||
qtype = getattr(self.w1, "qtype", None)
|
||||
if mlp_fusion_check(x_2d, qtype, self.training) and not self.w1.enable_xetla:
|
||||
if mlp_fusion_check(x_2d, qtype, self.training):
|
||||
import xe_linear
|
||||
if not x_2d.is_contiguous():
|
||||
x_2d = x_2d.contiguous()
|
||||
|
|
|
|||
|
|
@ -612,7 +612,7 @@ def qwen2_mlp_forward(
|
|||
) -> torch.Tensor:
|
||||
x_2d = x.view(-1, x.shape[-1])
|
||||
qtype = getattr(self.gate_proj, "qtype", None)
|
||||
if mlp_fusion_check(x_2d, qtype, self.training) and not self.down_proj.enable_xetla:
|
||||
if mlp_fusion_check(x_2d, qtype, self.training):
|
||||
import xe_linear
|
||||
return self.down_proj(xe_linear.mlp_forward_xpu(
|
||||
x_2d, self.gate_proj.weight.data, self.up_proj.weight.data,
|
||||
|
|
|
|||
|
|
@ -337,8 +337,7 @@ def use_decoding_fast_path(proj,
|
|||
return False
|
||||
if bs != 1:
|
||||
return False
|
||||
if proj.enable_xetla:
|
||||
return False
|
||||
|
||||
if device in ["uhd"]:
|
||||
return False
|
||||
return True
|
||||
|
|
|
|||
Loading…
Reference in a new issue