remove unused code again (#12624)

This commit is contained in:
Yishuo Wang 2024-12-27 14:17:11 +08:00 committed by GitHub
parent 46eeab4479
commit c72a5db757
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 19 additions and 171 deletions

View file

@ -92,8 +92,7 @@ def train(
load_in_low_bit="bf16", load_in_low_bit="bf16",
optimize_model=True, optimize_model=True,
torch_dtype=torch.bfloat16, torch_dtype=torch.bfloat16,
trust_remote_code=True, trust_remote_code=True
enable_xetla=False
) )
model = model.to("xpu") model = model.to("xpu")

View file

@ -257,8 +257,7 @@ def optimize_model(model, low_bit='sym_int4', optimize_llm=True, modules_to_not_
optimize_model=optimize_llm, optimize_model=optimize_llm,
modules_to_not_convert=modules_to_not_convert, modules_to_not_convert=modules_to_not_convert,
cpu_embedding=cpu_embedding, cpu_embedding=cpu_embedding,
lightweight_bmm=lightweight_bmm, lightweight_bmm=lightweight_bmm)
enable_xetla=kwargs.pop("enable_xetla", False))
# add save_low_bit to pretrained model dynamically # add save_low_bit to pretrained model dynamically
import types import types
model._bigdl_config = dict() model._bigdl_config = dict()

View file

@ -232,7 +232,7 @@ def is_linear_module(module):
def convert_vllm(module, qtype, in_features, out_features, mp_group, cur_qtype, def convert_vllm(module, qtype, in_features, out_features, mp_group, cur_qtype,
enable_xetla, optimize_lm_head, enable_scale_search): optimize_lm_head, enable_scale_search):
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from ipex_llm.transformers.low_bit_linear import LowBitLinear, \ from ipex_llm.transformers.low_bit_linear import LowBitLinear, \
FP16Linear, BF16Linear, vLLMLowBitLinear, vLLMFP16Linear, vLLMBF16Linear FP16Linear, BF16Linear, vLLMLowBitLinear, vLLMFP16Linear, vLLMBF16Linear
@ -261,7 +261,6 @@ def convert_vllm(module, qtype, in_features, out_features, mp_group, cur_qtype,
cur_qtype, cur_qtype,
module.bias is not None, module.bias is not None,
mp_group=mp_group, mp_group=mp_group,
enable_xetla=enable_xetla,
optimize_lm_head=optimize_lm_head, optimize_lm_head=optimize_lm_head,
enable_scale_search=enable_scale_search, enable_scale_search=enable_scale_search,
) )
@ -289,7 +288,6 @@ def convert_vllm(module, qtype, in_features, out_features, mp_group, cur_qtype,
cur_qtype, cur_qtype,
module.bias is not None, module.bias is not None,
mp_group=mp_group, mp_group=mp_group,
enable_xetla=enable_xetla,
optimize_lm_head=optimize_lm_head, optimize_lm_head=optimize_lm_head,
enable_scale_search=enable_scale_search, enable_scale_search=enable_scale_search,
) )
@ -473,7 +471,6 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
prefix_name='', prefix_name='',
imatrix_data=None, embedding_qtype=None, imatrix_data=None, embedding_qtype=None,
model_config=None, torch_dtype=torch.float32, model_config=None, torch_dtype=torch.float32,
enable_xetla=False,
mixed_precision=False, mixed_precision=False,
act_order=False, act_order=False,
enable_scale_search=False, enable_scale_search=False,
@ -523,7 +520,6 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
qtype=qtype, qtype=qtype,
bias=has_bias, bias=has_bias,
mp_group=mp_group, mp_group=mp_group,
enable_xetla=enable_xetla,
optimize_lm_head=optimize_lm_head, optimize_lm_head=optimize_lm_head,
act_order=act_order, act_order=act_order,
enable_scale_search=enable_scale_search, enable_scale_search=enable_scale_search,
@ -544,7 +540,6 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
_shape=(out_features, in_features), _shape=(out_features, in_features),
convert_shape_only=convert_shape_only, convert_shape_only=convert_shape_only,
qtype=qtype, qtype=qtype,
enable_xetla=enable_xetla,
enable_scale_search=enable_scale_search).to(device) enable_scale_search=enable_scale_search).to(device)
new_linear._parameters['weight'] = paramsLowBit new_linear._parameters['weight'] = paramsLowBit
if has_bias: if has_bias:
@ -562,7 +557,6 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
qtype=qtype, qtype=qtype,
bias=has_bias, bias=has_bias,
mp_group=mp_group, mp_group=mp_group,
enable_xetla=enable_xetla,
optimize_lm_head=False, optimize_lm_head=False,
act_order=act_order, act_order=act_order,
enable_scale_search=enable_scale_search, enable_scale_search=enable_scale_search,
@ -581,7 +575,6 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
qtype=cur_qtype, qtype=cur_qtype,
imatrix=cur_imatrix, imatrix=cur_imatrix,
in_features=in_features, in_features=in_features,
enable_xetla=enable_xetla,
enable_scale_search=enable_scale_search).to(device) enable_scale_search=enable_scale_search).to(device)
else: else:
new_linear = vLLMLowBitLinear( new_linear = vLLMLowBitLinear(
@ -590,7 +583,6 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
qtype=qtype, qtype=qtype,
bias=has_bias, bias=has_bias,
mp_group=mp_group, mp_group=mp_group,
enable_xetla=enable_xetla,
optimize_lm_head=False, optimize_lm_head=False,
act_order=act_order, act_order=act_order,
enable_scale_search=enable_scale_search, enable_scale_search=enable_scale_search,
@ -609,7 +601,6 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
_shape=(out_features, in_features), _shape=(out_features, in_features),
convert_shape_only=convert_shape_only, convert_shape_only=convert_shape_only,
qtype=qtype, qtype=qtype,
enable_xetla=enable_xetla,
enable_scale_search=enable_scale_search).to(device) enable_scale_search=enable_scale_search).to(device)
new_linear._parameters['weight'] = paramsLowBit new_linear._parameters['weight'] = paramsLowBit
if has_bias: if has_bias:
@ -639,7 +630,6 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
out_features, out_features,
mp_group, mp_group,
cur_qtype, cur_qtype,
enable_xetla,
optimize_lm_head, optimize_lm_head,
enable_scale_search) enable_scale_search)
else: else:
@ -649,7 +639,6 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
cur_qtype, cur_qtype,
module.bias is not None, module.bias is not None,
mp_group=mp_group, mp_group=mp_group,
enable_xetla=enable_xetla,
optimize_lm_head=optimize_lm_head, optimize_lm_head=optimize_lm_head,
enable_scale_search=enable_scale_search, enable_scale_search=enable_scale_search,
) )
@ -663,7 +652,6 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
qtype=cur_qtype, qtype=cur_qtype,
imatrix=cur_imatrix, imatrix=cur_imatrix,
in_features=in_features, in_features=in_features,
enable_xetla=enable_xetla,
enable_scale_search=enable_scale_search).to(device) enable_scale_search=enable_scale_search).to(device)
new_linear._parameters['weight'] = paramsLowBit new_linear._parameters['weight'] = paramsLowBit
if module.bias is not None: if module.bias is not None:
@ -762,7 +750,6 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
embedding_qtype=embedding_qtype, embedding_qtype=embedding_qtype,
model_config=model_config, model_config=model_config,
torch_dtype=torch_dtype, torch_dtype=torch_dtype,
enable_xetla=enable_xetla,
mixed_precision=mixed_precision, mixed_precision=mixed_precision,
act_order=act_order, act_order=act_order,
enable_scale_search=enable_scale_search, enable_scale_search=enable_scale_search,
@ -1094,7 +1081,6 @@ def ggml_convert_low_bit(model, qtype, optimize_model=True,
lightweight_bmm=False, torch_dtype="auto", lightweight_bmm=False, torch_dtype="auto",
imatrix_data=None, imatrix_data=None,
embedding_qtype=None, embedding_qtype=None,
enable_xetla=False,
mixed_precision=False): mixed_precision=False):
if qtype in ggml_tensor_qtype.values(): if qtype in ggml_tensor_qtype.values():
index = list(ggml_tensor_qtype.values()).index(qtype) index = list(ggml_tensor_qtype.values()).index(qtype)
@ -1138,7 +1124,6 @@ def ggml_convert_low_bit(model, qtype, optimize_model=True,
embedding_qtype=embedding_qtype, embedding_qtype=embedding_qtype,
model_config=model_config, model_config=model_config,
torch_dtype=torch_dtype, torch_dtype=torch_dtype,
enable_xetla=enable_xetla,
mixed_precision=mixed_precision, mixed_precision=mixed_precision,
act_order=act_order, act_order=act_order,
enable_scale_search=enable_scale_search, enable_scale_search=enable_scale_search,

View file

@ -92,113 +92,6 @@ RTN_DTYPE = {
} }
# For sym_int4
# The ggml_weight is col major and packs two rows at a stride of Q4_0//2.
#
# The returning weight is row major and packs two rows at a stride of 16//2.
# 16 is the tile_size_y used in mm_xetla, so that we can do something like
# new_weight_tile = concat(weight_tile & 0x0F, weight_tile >> 4).
#
# A more complex packing strategy is to permute the weight so that the
# new_weight_tile is directly VNNI packed, but I did not find significant
# performance improvement.
#
# Note this format cannot be used directly in IPEX-LLM's mm_int4, which expects
# row major but packing two consecutive columns.
#
# For fp8, just remove the scales (which are all ones) and transpose
def ggml_xpu_to_ipex_llm_xetla(ggml_weight, weight_shape, qtype):
if qtype == ggml_tensor_qtype["sym_int4"]:
from ipex_llm.transformers.low_bit_linear import get_block_size
Q4_0 = get_block_size("sym_int4")
n, k = weight_shape
ggml_weight_only = ggml_weight[:n*k//2]
ggml_scales = ggml_weight[n*k//2:]
qweight = ggml_weight_only.clone()
scales = ggml_scales.view(torch.float16).clone()
qweight_0 = qweight & 0x0F
qweight_1 = qweight >> 4
qweight_0 = qweight_0.reshape(n, -1, Q4_0//2)
qweight_1 = qweight_1.reshape(n, -1, Q4_0//2)
qweight = torch.cat([qweight_0, qweight_1], dim=-1)
qweight = qweight.reshape(n, k//16, 2, 8)
qweight = qweight.bitwise_left_shift(
torch.tensor([0, 4], dtype=torch.uint8, device=ggml_weight.device).reshape(1, 1, 2, 1))
qweight = torch.bitwise_or(qweight[:, :, 0, :], qweight[:, :, 1, :])
qweight = qweight.reshape(n, k//2)
qweight = qweight.transpose(0, 1).contiguous()
scales = scales.reshape(n, k//Q4_0).transpose(0, 1).contiguous()
# 119 is the value of 0x77
zeros = torch.ones([k//Q4_0, n//2], dtype=torch.uint8, device=ggml_weight.device) * (119)
qweight_bytes = qweight.view(torch.uint8).view(-1)
scales_bytes = scales.view(torch.uint8).view(-1)
zeros_bytes = zeros.view(torch.uint8).view(-1)
weight = torch.concat([qweight_bytes, zeros_bytes, scales_bytes], dim=0)
elif qtype == ggml_tensor_qtype["fp8_e5m2"]:
n, k = weight_shape
weight = ggml_weight[:n*k].view(n, k).transpose(0, 1).contiguous()
else:
invalidInputError(False, f"Unsupported qtype {qtype}")
return weight
def ipex_llm_xetla_to_ggml_xpu(xetla_weight, weight_shape, qtype):
from ipex_llm.transformers.low_bit_linear import get_block_size
if qtype == ggml_tensor_qtype["sym_int4"]:
Q4_0 = get_block_size("sym_int4")
n, k = weight_shape
weight_size = n*k//2
zeros_size = n*k//Q4_0//2
scales_size = n*k//Q4_0 * 2
xetla_weight_only = xetla_weight[:weight_size]
scales_start = weight_size + zeros_size
xetla_scales = xetla_weight[scales_start:scales_start+scales_size]
qweight = xetla_weight_only.clone()
scales = xetla_scales.view(torch.float16).clone()
qweight_0 = qweight & 0x0F
qweight_1 = qweight >> 4
qweight_0 = qweight_0.reshape(-1, 8, n)
qweight_1 = qweight_1.reshape(-1, 8, n)
qweight = torch.cat([qweight_0, qweight_1], dim=1)
qweight = qweight.reshape(k, n).transpose(0, 1).contiguous().reshape(n, k//Q4_0,
2, Q4_0//2)
qweight = qweight.bitwise_left_shift(
torch.tensor([0, 4], dtype=torch.uint8,
device=xetla_weight_only.device).reshape(1, 1, 2, 1))
qweight = torch.bitwise_or(qweight[:, :, 0, :], qweight[:, :, 1, :])
qweight = qweight.reshape(n, k//2)
scales = scales.reshape(k//Q4_0, n).transpose(0, 1).contiguous()
qweight_bytes = qweight.view(torch.uint8).view(-1)
scales_bytes = scales.view(torch.uint8).view(-1)
weight = torch.concat([qweight_bytes, scales_bytes], dim=0)
elif qtype == ggml_tensor_qtype["fp8_e5m2"]:
Q8_0 = get_block_size("fp8_e5m2")
n, k = weight_shape
qweight = xetla_weight[:n*k].transpose(0, 1).contiguous()
scales = torch.ones([n*k//Q8_0], dtype=torch.float, device=xetla_weight.device)
qweight_bytes = qweight.view(torch.uint8).view(-1)
scales_bytes = scales.view(torch.uint8).view(-1)
weight = torch.concat([qweight_bytes, scales_bytes], dim=0)
else:
invalidInputError(False, f"Unsupported qtype {qtype}")
return weight
def get_block_size(qtype: str): def get_block_size(qtype: str):
return ggml.ggml_qk_size(ggml_tensor_qtype[qtype]) return ggml.ggml_qk_size(ggml_tensor_qtype[qtype])
@ -422,7 +315,6 @@ class FP4Params(torch.nn.Parameter):
qtype=None, qtype=None,
imatrix=None, imatrix=None,
in_features=None, in_features=None,
enable_xetla=False,
enable_scale_search=False): enable_scale_search=False):
if data is None: if data is None:
data = torch.empty(0) data = torch.empty(0)
@ -435,7 +327,6 @@ class FP4Params(torch.nn.Parameter):
self.convert_shape_only = convert_shape_only self.convert_shape_only = convert_shape_only
self.imatrix = imatrix self.imatrix = imatrix
self.in_features = in_features self.in_features = in_features
self.enable_xetla = enable_xetla
self.enable_scale_search = enable_scale_search self.enable_scale_search = enable_scale_search
return self return self
@ -529,8 +420,6 @@ class FP4Params(torch.nn.Parameter):
self.data = ggml_q_format_convet_cpu2xpu(self.data, self.data = ggml_q_format_convet_cpu2xpu(self.data,
reduce(mul, self._shape, 1), reduce(mul, self._shape, 1),
self.qtype) self.qtype)
if self.enable_xetla:
self.data = ggml_xpu_to_ipex_llm_xetla(self.data, self._shape, self.qtype)
new_param = FP4Params(super().to(device=device, new_param = FP4Params(super().to(device=device,
dtype=dtype, dtype=dtype,
non_blocking=non_blocking), non_blocking=non_blocking),
@ -538,12 +427,7 @@ class FP4Params(torch.nn.Parameter):
quantized=self.quantized, quantized=self.quantized,
_shape=self._shape, _shape=self._shape,
qtype=self.qtype, qtype=self.qtype,
enable_xetla=self.enable_xetla,
enable_scale_search=self.enable_scale_search) enable_scale_search=self.enable_scale_search)
if self.enable_xetla:
device_type = get_xpu_device_type(new_param.data)
invalidInputError(device_type == "pvc",
f"xetla is only supported on PVC, but got {device_type}")
return new_param return new_param
elif (device is not None and device.type == "cpu" and self.data.device.type == "xpu"): elif (device is not None and device.type == "cpu" and self.data.device.type == "xpu"):
new_param = FP4Params(super().to(device=device, new_param = FP4Params(super().to(device=device,
@ -553,13 +437,7 @@ class FP4Params(torch.nn.Parameter):
quantized=self.quantized, quantized=self.quantized,
_shape=self._shape, _shape=self._shape,
qtype=self.qtype, qtype=self.qtype,
enable_xetla=self.enable_xetla,
enable_scale_search=self.enable_scale_search) enable_scale_search=self.enable_scale_search)
if self.enable_xetla:
ggml_xpu = ipex_llm_xetla_to_ggml_xpu(new_param.data,
new_param._shape,
new_param.qtype)
else:
ggml_xpu = new_param.data ggml_xpu = new_param.data
new_param.data = ggml_q_format_convet_xpu2cpu(ggml_xpu, new_param.data = ggml_q_format_convet_xpu2cpu(ggml_xpu,
reduce(mul, new_param._shape, 1), reduce(mul, new_param._shape, 1),
@ -573,7 +451,6 @@ class FP4Params(torch.nn.Parameter):
quantized=self.quantized, quantized=self.quantized,
_shape=self._shape, _shape=self._shape,
qtype=self.qtype, qtype=self.qtype,
enable_xetla=self.enable_xetla,
enable_scale_search=self.enable_scale_search) enable_scale_search=self.enable_scale_search)
return new_param return new_param
@ -691,14 +568,13 @@ class MatMulLowBitCPU(torch.autograd.Function):
class LowBitLinear(nn.Linear): class LowBitLinear(nn.Linear):
def __init__(self, input_features, output_features, qtype, bias=True, def __init__(self, input_features, output_features, qtype, bias=True,
conver_to_half=True, mp_group=None, enable_xetla=False, conver_to_half=True, mp_group=None,
optimize_lm_head=False, act_order=False, optimize_lm_head=False, act_order=False,
enable_scale_search=False): enable_scale_search=False):
super().__init__(input_features, output_features, bias) super().__init__(input_features, output_features, bias)
self.weight = FP4Params(self.weight.data, self.weight = FP4Params(self.weight.data,
requires_grad=False, requires_grad=False,
quantized=False, _shape=None, qtype=qtype, quantized=False, _shape=None, qtype=qtype,
enable_xetla=enable_xetla,
enable_scale_search=enable_scale_search) enable_scale_search=enable_scale_search)
self.in_len = input_features self.in_len = input_features
self.out_len = output_features self.out_len = output_features
@ -708,7 +584,6 @@ class LowBitLinear(nn.Linear):
self.conver_to_half = conver_to_half self.conver_to_half = conver_to_half
self.mp_group = mp_group self.mp_group = mp_group
self.compute_dtype = None # only for training self.compute_dtype = None # only for training
self.enable_xetla = enable_xetla
self.optimize_lm_head = optimize_lm_head self.optimize_lm_head = optimize_lm_head
self.device = None # detected only once in the first forward self.device = None # detected only once in the first forward
# empty cache before and after lm_head at first token (by default on arc) # empty cache before and after lm_head at first token (by default on arc)
@ -799,9 +674,6 @@ class LowBitLinear(nn.Linear):
self.weight.data, self.weight.data,
self.weight.qtype, self.weight.qtype,
input_seq_size) input_seq_size)
elif self.enable_xetla:
x_2d = x_2d.half()
result = xe_linear.mm_xetla(x_2d, self.weight.data, self.qtype)
else: else:
# inference path # inference path
# current workaround to reduce first token latency of fp32 input # current workaround to reduce first token latency of fp32 input
@ -880,8 +752,7 @@ class LowBitLinear(nn.Linear):
class FP16Linear(nn.Linear): class FP16Linear(nn.Linear):
def __init__(self, input_features, output_features, bias=True, def __init__(self, input_features, output_features, bias=True,
mp_group=None, weight_type=1, enable_xetla=False, mp_group=None, weight_type=1, optimize_lm_head=False):
optimize_lm_head=False):
super().__init__(input_features, output_features, bias) super().__init__(input_features, output_features, bias)
self.in_len = input_features self.in_len = input_features
self.out_len = output_features self.out_len = output_features
@ -894,7 +765,6 @@ class FP16Linear(nn.Linear):
# weigh_type = 3 means weight has been transposed by esimd method # weigh_type = 3 means weight has been transposed by esimd method
self.weight_type = 1 self.weight_type = 1
self.optimize_lm_head = optimize_lm_head self.optimize_lm_head = optimize_lm_head
self.enable_xetla = enable_xetla
def forward(self, x: torch.Tensor): def forward(self, x: torch.Tensor):
# only work for GPU # only work for GPU
@ -1010,8 +880,7 @@ class FP16Linear(nn.Linear):
class BF16Linear(nn.Linear): class BF16Linear(nn.Linear):
def __init__(self, input_features, output_features, bias=True, def __init__(self, input_features, output_features, bias=True,
mp_group=None, compute_dtype=None, enable_xetla=False, mp_group=None, compute_dtype=None, optimize_lm_head=False):
optimize_lm_head=False):
super().__init__(input_features, output_features, bias) super().__init__(input_features, output_features, bias)
self.in_len = input_features self.in_len = input_features
self.out_len = output_features self.out_len = output_features
@ -1021,7 +890,6 @@ class BF16Linear(nn.Linear):
self.mp_group = mp_group self.mp_group = mp_group
self.compute_dtype = compute_dtype self.compute_dtype = compute_dtype
self.optimize_lm_head = optimize_lm_head self.optimize_lm_head = optimize_lm_head
self.enable_xetla = enable_xetla
def forward(self, x: torch.Tensor): def forward(self, x: torch.Tensor):
if self.optimize_lm_head: if self.optimize_lm_head:
@ -1050,11 +918,11 @@ class BF16Linear(nn.Linear):
class vLLMLowBitLinear(LowBitLinear): class vLLMLowBitLinear(LowBitLinear):
def __init__(self, input_features, output_features, qtype, bias=True, def __init__(self, input_features, output_features, qtype, bias=True,
conver_to_half=True, mp_group=None, enable_xetla=False, conver_to_half=True, mp_group=None,
optimize_lm_head=False, act_order=False, optimize_lm_head=False, act_order=False,
enable_scale_search=False): enable_scale_search=False):
super().__init__(input_features, output_features, qtype, bias, conver_to_half, mp_group, super().__init__(input_features, output_features, qtype, bias, conver_to_half, mp_group,
enable_xetla, optimize_lm_head, act_order, enable_scale_search) optimize_lm_head, act_order, enable_scale_search)
def forward(self, x: torch.Tensor): def forward(self, x: torch.Tensor):
result = super().forward(x) result = super().forward(x)
@ -1063,9 +931,9 @@ class vLLMLowBitLinear(LowBitLinear):
class vLLMFP16Linear(FP16Linear): class vLLMFP16Linear(FP16Linear):
def __init__(self, input_features, output_features, bias=True, mp_group=None, weight_type=1, def __init__(self, input_features, output_features, bias=True, mp_group=None, weight_type=1,
enable_xetla=False, optimize_lm_head=False): optimize_lm_head=False):
super().__init__(input_features, output_features, bias, mp_group, weight_type, super().__init__(input_features, output_features, bias, mp_group, weight_type,
enable_xetla, optimize_lm_head) optimize_lm_head)
def forward(self, x: torch.Tensor): def forward(self, x: torch.Tensor):
result = super().forward(x) result = super().forward(x)
@ -1074,9 +942,9 @@ class vLLMFP16Linear(FP16Linear):
class vLLMBF16Linear(BF16Linear): class vLLMBF16Linear(BF16Linear):
def __init__(self, input_features, output_features, bias=True, mp_group=None, def __init__(self, input_features, output_features, bias=True, mp_group=None,
compute_dtype=None, enable_xetla=False, optimize_lm_head=False): compute_dtype=None, optimize_lm_head=False):
super().__init__(input_features, output_features, bias, mp_group, compute_dtype, super().__init__(input_features, output_features, bias, mp_group, compute_dtype,
enable_xetla, optimize_lm_head) optimize_lm_head)
def forward(self, x: torch.Tensor): def forward(self, x: torch.Tensor):
result = super().forward(x) result = super().forward(x)

View file

@ -448,7 +448,6 @@ class _BaseAutoModelClass:
mixed_precision = kwargs.pop("mixed_precision", False) mixed_precision = kwargs.pop("mixed_precision", False)
if embedding_qtype is not None: if embedding_qtype is not None:
embedding_qtype = ggml_tensor_qtype[embedding_qtype] embedding_qtype = ggml_tensor_qtype[embedding_qtype]
enable_xetla = kwargs.pop("enable_xetla", False)
_args = copy.deepcopy(args) _args = copy.deepcopy(args)
_kwargs = copy.deepcopy(kwargs) _kwargs = copy.deepcopy(kwargs)
awq_config = None awq_config = None
@ -518,7 +517,6 @@ class _BaseAutoModelClass:
torch_dtype=kwargs.get("torch_dtype", 'auto'), torch_dtype=kwargs.get("torch_dtype", 'auto'),
imatrix_data=imatrix_data, imatrix_data=imatrix_data,
embedding_qtype=embedding_qtype, embedding_qtype=embedding_qtype,
enable_xetla=enable_xetla,
mixed_precision=mixed_precision) mixed_precision=mixed_precision)
if disk_embedding: if disk_embedding:

View file

@ -67,7 +67,7 @@ def baichuan_mlp_forward(
) -> torch.Tensor: ) -> torch.Tensor:
x_2d = x.view(-1, x.shape[-1]) x_2d = x.view(-1, x.shape[-1])
qtype = getattr(self.gate_proj, "qtype", None) qtype = getattr(self.gate_proj, "qtype", None)
if mlp_fusion_check(x_2d, qtype, self.training) and not self.down_proj.enable_xetla: if mlp_fusion_check(x_2d, qtype, self.training):
import xe_linear import xe_linear
if not x_2d.is_contiguous(): if not x_2d.is_contiguous():
x_2d = x_2d.contiguous() x_2d = x_2d.contiguous()

View file

@ -380,7 +380,7 @@ def mixtral_mlp_forward(
routing_weights routing_weights
) -> torch.Tensor: ) -> torch.Tensor:
qtype = getattr(self.w1, "qtype", None) qtype = getattr(self.w1, "qtype", None)
if mlp_fusion_check(x, qtype, self.training) and not self.w1.enable_xetla: if mlp_fusion_check(x, qtype, self.training):
import xe_linear import xe_linear
return self.w2(xe_linear.mlp_forward_xpu( return self.w2(xe_linear.mlp_forward_xpu(
x, self.w1.weight.data, self.w3.weight.data, x, self.w1.weight.data, self.w3.weight.data,

View file

@ -259,7 +259,7 @@ def qwen_attention_forward_registered(
def qwen_mlp_forward(self, x: torch.Tensor) -> torch.Tensor: def qwen_mlp_forward(self, x: torch.Tensor) -> torch.Tensor:
x_2d = x.view(-1, x.shape[-1]) x_2d = x.view(-1, x.shape[-1])
qtype = getattr(self.w1, "qtype", None) qtype = getattr(self.w1, "qtype", None)
if mlp_fusion_check(x_2d, qtype, self.training) and not self.w1.enable_xetla: if mlp_fusion_check(x_2d, qtype, self.training):
import xe_linear import xe_linear
if not x_2d.is_contiguous(): if not x_2d.is_contiguous():
x_2d = x_2d.contiguous() x_2d = x_2d.contiguous()

View file

@ -612,7 +612,7 @@ def qwen2_mlp_forward(
) -> torch.Tensor: ) -> torch.Tensor:
x_2d = x.view(-1, x.shape[-1]) x_2d = x.view(-1, x.shape[-1])
qtype = getattr(self.gate_proj, "qtype", None) qtype = getattr(self.gate_proj, "qtype", None)
if mlp_fusion_check(x_2d, qtype, self.training) and not self.down_proj.enable_xetla: if mlp_fusion_check(x_2d, qtype, self.training):
import xe_linear import xe_linear
return self.down_proj(xe_linear.mlp_forward_xpu( return self.down_proj(xe_linear.mlp_forward_xpu(
x_2d, self.gate_proj.weight.data, self.up_proj.weight.data, x_2d, self.gate_proj.weight.data, self.up_proj.weight.data,

View file

@ -337,8 +337,7 @@ def use_decoding_fast_path(proj,
return False return False
if bs != 1: if bs != 1:
return False return False
if proj.enable_xetla:
return False
if device in ["uhd"]: if device in ["uhd"]:
return False return False
return True return True