[NPU] initial support of asym_int4_rtn (#12484)
* initiail support of q4_1 * fix * fix * update * update min to Z1 * update * fix * update * fix style * fix * support qwen2 optimize_model=True mp version * temp save * fix * fix style * replace min with zero * support split linear for q4_1 * fix lm_head with mixed_precision=True * fix style * revert test code * add down proj back for q4_0 * remove print
This commit is contained in:
parent
60bafab855
commit
49ab8974fa
12 changed files with 264 additions and 81 deletions
|
|
@ -52,6 +52,7 @@ ggml_tensor_qtype = {"sym_int4": 2, # q4_0 in ggml
|
||||||
"fp6_k": 30,
|
"fp6_k": 30,
|
||||||
"sym_int4_rtn": 31,
|
"sym_int4_rtn": 31,
|
||||||
"sym_int8_rtn": 32,
|
"sym_int8_rtn": 32,
|
||||||
|
"asym_int4_rtn": 33,
|
||||||
}
|
}
|
||||||
|
|
||||||
# mixed precison from llama.cpp
|
# mixed precison from llama.cpp
|
||||||
|
|
|
||||||
|
|
@ -84,8 +84,10 @@ Q5_K = ggml_tensor_qtype["q5_k"]
|
||||||
FP6_K = ggml_tensor_qtype["fp6_k"]
|
FP6_K = ggml_tensor_qtype["fp6_k"]
|
||||||
SYM_INT4_RTN = ggml_tensor_qtype["sym_int4_rtn"]
|
SYM_INT4_RTN = ggml_tensor_qtype["sym_int4_rtn"]
|
||||||
SYM_INT8_RTN = ggml_tensor_qtype["sym_int8_rtn"]
|
SYM_INT8_RTN = ggml_tensor_qtype["sym_int8_rtn"]
|
||||||
|
ASYM_INT4_RTN = ggml_tensor_qtype["asym_int4_rtn"]
|
||||||
RTN_DTYPE = {
|
RTN_DTYPE = {
|
||||||
SYM_INT4_RTN: torch.uint8,
|
SYM_INT4_RTN: torch.uint8,
|
||||||
|
ASYM_INT4_RTN: torch.uint8,
|
||||||
SYM_INT8_RTN: torch.int8,
|
SYM_INT8_RTN: torch.int8,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -223,10 +225,14 @@ def ggml_convert_qtype(tensor: torch.Tensor, qtype: int,
|
||||||
f"Last dim of input tensor must be multiple of {QK}")
|
f"Last dim of input tensor must be multiple of {QK}")
|
||||||
|
|
||||||
dst_size = (n // QK) * block_size_in_bytes
|
dst_size = (n // QK) * block_size_in_bytes
|
||||||
if qtype in [SYM_INT8_RTN, SYM_INT4_RTN]:
|
if qtype in [SYM_INT8_RTN, SYM_INT4_RTN, ASYM_INT4_RTN]:
|
||||||
dst_tensor = torch.empty(dst_size, dtype=RTN_DTYPE[qtype],
|
dst_tensor = torch.empty(dst_size, dtype=RTN_DTYPE[qtype],
|
||||||
device=device)
|
device=device)
|
||||||
dst_tensor = dst_tensor.reshape(tensor.shape[0], tensor.shape[-1] // QK)
|
dst_tensor = dst_tensor.reshape(tensor.shape[0], tensor.shape[-1] // QK)
|
||||||
|
if qtype == ASYM_INT4_RTN:
|
||||||
|
scale = torch.empty((n // k) * 2, dtype=torch.float32,
|
||||||
|
device=device)
|
||||||
|
else:
|
||||||
scale = torch.empty(n // k, dtype=torch.float32,
|
scale = torch.empty(n // k, dtype=torch.float32,
|
||||||
device=device)
|
device=device)
|
||||||
elif qtype == NF4:
|
elif qtype == NF4:
|
||||||
|
|
@ -244,7 +250,7 @@ def ggml_convert_qtype(tensor: torch.Tensor, qtype: int,
|
||||||
dst = ctypes.c_void_p(dst_tensor.data.data_ptr())
|
dst = ctypes.c_void_p(dst_tensor.data.data_ptr())
|
||||||
hist = (ctypes.c_int64 * 16)()
|
hist = (ctypes.c_int64 * 16)()
|
||||||
if qtype not in [IQ2_XXS, IQ2_XS, Q2_K, IQ1_S, Q4_K, Q6_K, Q5_K, FP6_K]:
|
if qtype not in [IQ2_XXS, IQ2_XS, Q2_K, IQ1_S, Q4_K, Q6_K, Q5_K, FP6_K]:
|
||||||
if qtype in [SYM_INT8_RTN, SYM_INT4_RTN]:
|
if qtype in [SYM_INT8_RTN, SYM_INT4_RTN, ASYM_INT4_RTN]:
|
||||||
scale_ptr = ctypes.cast(scale.data.data_ptr(), ctypes.POINTER(ctypes.c_float))
|
scale_ptr = ctypes.cast(scale.data.data_ptr(), ctypes.POINTER(ctypes.c_float))
|
||||||
if imatrix is None:
|
if imatrix is None:
|
||||||
ggml.ggml_quantize_tensor_rtn(src, dst, scale_ptr, qtype, n,
|
ggml.ggml_quantize_tensor_rtn(src, dst, scale_ptr, qtype, n,
|
||||||
|
|
@ -269,7 +275,7 @@ def ggml_convert_qtype(tensor: torch.Tensor, qtype: int,
|
||||||
ggml.ggml_quantize_tensor_with_weights(src, dst, qtype,
|
ggml.ggml_quantize_tensor_with_weights(src, dst, qtype,
|
||||||
n // in_features, in_features,
|
n // in_features, in_features,
|
||||||
hist, imatrix)
|
hist, imatrix)
|
||||||
if qtype in [SYM_INT8_RTN, SYM_INT4_RTN]:
|
if qtype in [SYM_INT8_RTN, SYM_INT4_RTN, ASYM_INT4_RTN]:
|
||||||
return dst_tensor, scale.type(torch.float16)
|
return dst_tensor, scale.type(torch.float16)
|
||||||
else:
|
else:
|
||||||
return dst_tensor
|
return dst_tensor
|
||||||
|
|
|
||||||
|
|
@ -103,6 +103,7 @@ class _BaseAutoModelClass:
|
||||||
qtype_map = {
|
qtype_map = {
|
||||||
"sym_int4": "sym_int4_rtn",
|
"sym_int4": "sym_int4_rtn",
|
||||||
"sym_int8": "sym_int8_rtn",
|
"sym_int8": "sym_int8_rtn",
|
||||||
|
"asym_int4": "asym_int4_rtn",
|
||||||
}
|
}
|
||||||
|
|
||||||
invalidInputError(
|
invalidInputError(
|
||||||
|
|
@ -154,7 +155,7 @@ class _BaseAutoModelClass:
|
||||||
f"but got {quantization_group_size}"
|
f"but got {quantization_group_size}"
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
_args = copy.deepcopy(args)
|
|
||||||
_kwargs = copy.deepcopy(kwargs)
|
_kwargs = copy.deepcopy(kwargs)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
@ -270,6 +271,7 @@ class _BaseAutoModelClass:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
model.config.update({"mixed_precision": mixed_precision})
|
model.config.update({"mixed_precision": mixed_precision})
|
||||||
model.config.update({"group_size": quantization_group_size})
|
model.config.update({"group_size": quantization_group_size})
|
||||||
|
model.config.update({"asym": qtype == "asym_int4_rtn"})
|
||||||
optimize_llm_pre(model, qtype, mixed_precision,
|
optimize_llm_pre(model, qtype, mixed_precision,
|
||||||
quantization_group_size=quantization_group_size)
|
quantization_group_size=quantization_group_size)
|
||||||
cls.load_convert(qtype, model, "cpu", modules_to_not_convert,
|
cls.load_convert(qtype, model, "cpu", modules_to_not_convert,
|
||||||
|
|
@ -416,9 +418,9 @@ class _BaseAutoModelClass:
|
||||||
)
|
)
|
||||||
|
|
||||||
invalidInputError(
|
invalidInputError(
|
||||||
qtype in ["sym_int8_rtn", "sym_int4_rtn"],
|
qtype in ["sym_int8_rtn", "sym_int4_rtn", "asym_int4_rtn"],
|
||||||
f"Unknown bigdl_transformers_low_bit value: {qtype},"
|
f"Unknown bigdl_transformers_low_bit value: {qtype},"
|
||||||
f" expected: sym_int8_rtn, sym_int4_rtn. "
|
f" expected: sym_int8_rtn, sym_int4_rtn, asym_int4_rtn. "
|
||||||
)
|
)
|
||||||
|
|
||||||
if enable_cpp_backend:
|
if enable_cpp_backend:
|
||||||
|
|
|
||||||
|
|
@ -88,10 +88,13 @@ def replace_with_QuantizedLinear(layer, qtype, device, modules_to_not_convert,
|
||||||
from ipex_llm.ggml.quantize import ggml_tensor_qtype
|
from ipex_llm.ggml.quantize import ggml_tensor_qtype
|
||||||
iqtype = ggml_tensor_qtype[qtype]
|
iqtype = ggml_tensor_qtype[qtype]
|
||||||
if isinstance(layer, torch.nn.Linear) and not hasattr(layer, "qtype"):
|
if isinstance(layer, torch.nn.Linear) and not hasattr(layer, "qtype"):
|
||||||
if qtype == "sym_int4_rtn":
|
if qtype in ["sym_int4_rtn", "asym_int4_rtn"]:
|
||||||
# workaround for qwen2-7B & int4
|
# workaround for qwen2-7B & int4
|
||||||
if (layer.in_features == 3584 and layer.out_features == 152064) or \
|
if (layer.in_features == 3584 and layer.out_features == 152064):
|
||||||
(layer.in_features == 18944 and layer.out_features == 3584):
|
qtype = "sym_int8_rtn"
|
||||||
|
iqtype = ggml_tensor_qtype[qtype]
|
||||||
|
if qtype == "sym_int4_rtn":
|
||||||
|
if (layer.in_features == 18944 and layer.out_features == 3584):
|
||||||
qtype = "sym_int8_rtn"
|
qtype = "sym_int8_rtn"
|
||||||
iqtype = ggml_tensor_qtype[qtype]
|
iqtype = ggml_tensor_qtype[qtype]
|
||||||
enable_scale_search = os.environ.get("IPEX_LLM_NPU_QUANTIZATION_OPT", "0") != "0"
|
enable_scale_search = os.environ.get("IPEX_LLM_NPU_QUANTIZATION_OPT", "0") != "0"
|
||||||
|
|
@ -99,8 +102,12 @@ def replace_with_QuantizedLinear(layer, qtype, device, modules_to_not_convert,
|
||||||
iqtype, device=device,
|
iqtype, device=device,
|
||||||
enable_scale_search=enable_scale_search,
|
enable_scale_search=enable_scale_search,
|
||||||
imatrix=imatrix)
|
imatrix=imatrix)
|
||||||
return QuantizedLinear(qweights, scale, layer.bias,
|
zero = None
|
||||||
group_size=group_size)
|
# split scale to scale & zero
|
||||||
|
if qtype == "asym_int4_rtn":
|
||||||
|
scale, zero = torch.split(scale, scale.shape[0] // 2)
|
||||||
|
return QuantizedLinear(qweights, scale, zero, layer.bias,
|
||||||
|
group_size=group_size, qtype=qtype)
|
||||||
|
|
||||||
|
|
||||||
@module_optimization
|
@module_optimization
|
||||||
|
|
@ -111,12 +118,21 @@ def replace_with_DequantizedLinear(layer, qtype, device, modules_to_not_convert,
|
||||||
from ipex_llm.ggml.quantize import ggml_tensor_qtype
|
from ipex_llm.ggml.quantize import ggml_tensor_qtype
|
||||||
iqtype = ggml_tensor_qtype[qtype]
|
iqtype = ggml_tensor_qtype[qtype]
|
||||||
if isinstance(layer, torch.nn.Linear) and not hasattr(layer, "qtype"):
|
if isinstance(layer, torch.nn.Linear) and not hasattr(layer, "qtype"):
|
||||||
|
if qtype in ["sym_int4_rtn", "asym_int4_rtn"]:
|
||||||
|
# workaround for qwen2-7B & int4
|
||||||
|
if (layer.in_features == 3584 and layer.out_features == 152064):
|
||||||
|
qtype = "sym_int8_rtn"
|
||||||
|
iqtype = ggml_tensor_qtype[qtype]
|
||||||
enable_scale_search = os.environ.get("IPEX_LLM_NPU_QUANTIZATION_OPT", "0") != "0"
|
enable_scale_search = os.environ.get("IPEX_LLM_NPU_QUANTIZATION_OPT", "0") != "0"
|
||||||
qweights, scale = ggml_convert_qtype(layer.weight.data.to(torch.float32),
|
qweights, scale = ggml_convert_qtype(layer.weight.data.to(torch.float32),
|
||||||
iqtype, device=device,
|
iqtype, device=device,
|
||||||
enable_scale_search=enable_scale_search,
|
enable_scale_search=enable_scale_search,
|
||||||
imatrix=imatrix)
|
imatrix=imatrix)
|
||||||
return DequantizedLinear(qweights, scale, layer.bias)
|
zero = None
|
||||||
|
# split scale to scale & zero
|
||||||
|
if qtype == "asym_int4_rtn":
|
||||||
|
scale, zero = torch.split(scale, scale.shape[0] // 2)
|
||||||
|
return DequantizedLinear(qweights, scale, zero, layer.bias, qtype)
|
||||||
|
|
||||||
|
|
||||||
@module_optimization
|
@module_optimization
|
||||||
|
|
|
||||||
|
|
@ -128,7 +128,7 @@ def optimize_llm_pre(model: torch.nn.Module, qtype, mixed_precision,
|
||||||
from ipex_llm.transformers.npu_models.common import split_linears
|
from ipex_llm.transformers.npu_models.common import split_linears
|
||||||
if quantization_group_size == 0:
|
if quantization_group_size == 0:
|
||||||
n_splits_linear = 1
|
n_splits_linear = 1
|
||||||
if qtype == "sym_int8_rtn":
|
if qtype in ["sym_int8_rtn", "asym_int4_rtn"]:
|
||||||
# do not split mlp down_proj for Qwen2-7B & sym_int8
|
# do not split mlp down_proj for Qwen2-7B & sym_int8
|
||||||
n_splits_down_proj = 1
|
n_splits_down_proj = 1
|
||||||
else:
|
else:
|
||||||
|
|
@ -154,18 +154,21 @@ def optimize_llm_pre(model: torch.nn.Module, qtype, mixed_precision,
|
||||||
# workaround for MiniCPM-2B
|
# workaround for MiniCPM-2B
|
||||||
new_lm_head_0 = SlicedLMHead(model.lm_head_0.weight, split_num=split_num,
|
new_lm_head_0 = SlicedLMHead(model.lm_head_0.weight, split_num=split_num,
|
||||||
bias=model.lm_head_0.bias, use_split=True,
|
bias=model.lm_head_0.bias, use_split=True,
|
||||||
group_size=quantization_group_size)
|
group_size=quantization_group_size,
|
||||||
|
asym=(qtype == "asym_int4_rtn"))
|
||||||
del model.lm_head_0
|
del model.lm_head_0
|
||||||
model.lm_head_0 = new_lm_head_0
|
model.lm_head_0 = new_lm_head_0
|
||||||
new_lm_head_1 = SlicedLMHead(model.lm_head_1.weight, split_num=split_num,
|
new_lm_head_1 = SlicedLMHead(model.lm_head_1.weight, split_num=split_num,
|
||||||
bias=model.lm_head_1.bias, use_split=True,
|
bias=model.lm_head_1.bias, use_split=True,
|
||||||
group_size=quantization_group_size)
|
group_size=quantization_group_size,
|
||||||
|
asym=(qtype == "asym_int4_rtn"))
|
||||||
del model.lm_head_1
|
del model.lm_head_1
|
||||||
model.lm_head_1 = new_lm_head_1
|
model.lm_head_1 = new_lm_head_1
|
||||||
else:
|
else:
|
||||||
new_lm_head = SlicedLMHead(model.lm_head.weight, split_num=split_num,
|
new_lm_head = SlicedLMHead(model.lm_head.weight, split_num=split_num,
|
||||||
bias=model.lm_head.bias, use_split=True,
|
bias=model.lm_head.bias, use_split=True,
|
||||||
group_size=quantization_group_size)
|
group_size=quantization_group_size,
|
||||||
|
asym=(qtype == "asym_int4_rtn"))
|
||||||
del model.lm_head
|
del model.lm_head
|
||||||
model.lm_head = new_lm_head
|
model.lm_head = new_lm_head
|
||||||
|
|
||||||
|
|
@ -176,11 +179,13 @@ def optimize_llm_pre(model: torch.nn.Module, qtype, mixed_precision,
|
||||||
# Do not split lm_head and use sym_int8 instead when mixed_precison is True
|
# Do not split lm_head and use sym_int8 instead when mixed_precison is True
|
||||||
if quantization_group_size == 0:
|
if quantization_group_size == 0:
|
||||||
# Do not split lm_head and use sym_int8 instead when mixed_precison is True
|
# Do not split lm_head and use sym_int8 instead when mixed_precison is True
|
||||||
is_split = (not mixed_precision) and qtype == "sym_int4_rtn"
|
is_split = (not mixed_precision) and qtype in ["sym_int4_rtn", "asym_int4_rtn"]
|
||||||
split_num = 14 if is_split else 1
|
split_num = 14 if is_split else 1
|
||||||
new_lm_head = SlicedLMHead(model.lm_head.weight, split_num=split_num,
|
new_lm_head = SlicedLMHead(model.lm_head.weight, split_num=split_num,
|
||||||
bias=model.lm_head.bias, use_split=True,
|
bias=model.lm_head.bias, use_split=True,
|
||||||
group_size=quantization_group_size)
|
group_size=quantization_group_size,
|
||||||
|
asym=((qtype == "asym_int4_rtn") and
|
||||||
|
(not mixed_precision)))
|
||||||
del model.lm_head
|
del model.lm_head
|
||||||
model.lm_head = new_lm_head
|
model.lm_head = new_lm_head
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -129,7 +129,9 @@ class QuantizedLinear(torch.nn.Module):
|
||||||
self,
|
self,
|
||||||
weight: torch.Tensor,
|
weight: torch.Tensor,
|
||||||
scale: torch.Tensor,
|
scale: torch.Tensor,
|
||||||
|
zero: Optional[torch.Tensor] = None,
|
||||||
bias: Optional[torch.Tensor] = None,
|
bias: Optional[torch.Tensor] = None,
|
||||||
|
qtype: Optional[str] = "sym_int4_rtn",
|
||||||
group_size: int = 0,
|
group_size: int = 0,
|
||||||
):
|
):
|
||||||
"""Initialize the QuantizedLinear class.
|
"""Initialize the QuantizedLinear class.
|
||||||
|
|
@ -137,8 +139,10 @@ class QuantizedLinear(torch.nn.Module):
|
||||||
Args:
|
Args:
|
||||||
weight (torch.Tensor): Linear operation weight
|
weight (torch.Tensor): Linear operation weight
|
||||||
scale (torch.Tensor): Quantization scale
|
scale (torch.Tensor): Quantization scale
|
||||||
|
zero (Optional[torch.Tensor], optional): Quantization zero for asym_int4_rtn
|
||||||
bias (Optional[torch.Tensor], optional): Linear operation optional bias.
|
bias (Optional[torch.Tensor], optional): Linear operation optional bias.
|
||||||
Defaults to None.
|
Defaults to None.
|
||||||
|
qtype (Optional[str], optional): qtype of this Linear
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
RuntimeError: Quantized weight must be in torch.int8 format
|
RuntimeError: Quantized weight must be in torch.int8 format
|
||||||
|
|
@ -155,14 +159,19 @@ class QuantizedLinear(torch.nn.Module):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
self.outC, self.inC = self.weight.shape
|
self.outC, self.inC = self.weight.shape
|
||||||
|
self.zero = None
|
||||||
if group_size != 0:
|
if group_size != 0:
|
||||||
self.scale = Parameter(scale, requires_grad=False)
|
self.scale = Parameter(scale, requires_grad=False)
|
||||||
|
self.zero = Parameter(zero, requires_grad=False)
|
||||||
else:
|
else:
|
||||||
if self.weight.dtype == torch.uint8:
|
if self.weight.dtype == torch.uint8:
|
||||||
# Int4 we need to double the input channels because weights are compressed
|
# Int4 we need to double the input channels because weights are compressed
|
||||||
self.inC *= 2
|
self.inC *= 2
|
||||||
self.scale = Parameter(scale * math.sqrt(self.inC), requires_grad=False)
|
self.scale = Parameter(scale * math.sqrt(self.inC), requires_grad=False)
|
||||||
|
if zero is not None:
|
||||||
|
self.zero = Parameter(zero * math.sqrt(self.inC), requires_grad=False)
|
||||||
self.bias = bias
|
self.bias = bias
|
||||||
|
self.qtype = qtype
|
||||||
self.op_id = str(uuid.uuid4())
|
self.op_id = str(uuid.uuid4())
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
|
@ -195,7 +204,8 @@ class QuantizedLinear(torch.nn.Module):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
out = run_matmul(x, self.weight.data, self.scale.data, self.op_id)
|
zero_data = self.zero.data if self.zero is not None else None
|
||||||
|
out = run_matmul(x, self.weight.data, self.scale.data, zero_data, self.op_id)
|
||||||
|
|
||||||
if self.bias is None:
|
if self.bias is None:
|
||||||
return out
|
return out
|
||||||
|
|
@ -209,14 +219,18 @@ class DequantizedLinear(torch.nn.Module):
|
||||||
self,
|
self,
|
||||||
weight: torch.Tensor,
|
weight: torch.Tensor,
|
||||||
scale: torch.Tensor,
|
scale: torch.Tensor,
|
||||||
|
zero: Optional[torch.Tensor] = None,
|
||||||
bias: Optional[torch.Tensor] = None,
|
bias: Optional[torch.Tensor] = None,
|
||||||
|
qtype: Optional[str] = "sym_int4_rtn",
|
||||||
):
|
):
|
||||||
"""Initialize the DequantizedLinear class.
|
"""Initialize the DequantizedLinear class.
|
||||||
Args:
|
Args:
|
||||||
weight (torch.Tensor): Linear operation quantized weight
|
weight (torch.Tensor): Linear operation quantized weight
|
||||||
scale (torch.Tensor): Quantization scale
|
scale (torch.Tensor): Quantization scale
|
||||||
|
zero (Optional[torch.Tensor], optional): Quantization zero for asym_int4_rtn
|
||||||
bias (Optional[torch.Tensor], optional): Linear operation optional bias.
|
bias (Optional[torch.Tensor], optional): Linear operation optional bias.
|
||||||
Defaults to None.
|
Defaults to None.
|
||||||
|
qtype (Optional[str], optional): qtype of this Linear
|
||||||
Raises:
|
Raises:
|
||||||
RuntimeError: Quantized weight must be in torch.int8 format
|
RuntimeError: Quantized weight must be in torch.int8 format
|
||||||
"""
|
"""
|
||||||
|
|
@ -240,6 +254,9 @@ class DequantizedLinear(torch.nn.Module):
|
||||||
decompressed_weight = combined_weight.view(combined_weight.size(0), -1)
|
decompressed_weight = combined_weight.view(combined_weight.size(0), -1)
|
||||||
dequantized_weight = decompressed_weight.to(torch.float32) * \
|
dequantized_weight = decompressed_weight.to(torch.float32) * \
|
||||||
torch.unsqueeze(scale.to(torch.float32), dim=1)
|
torch.unsqueeze(scale.to(torch.float32), dim=1)
|
||||||
|
if qtype == "asym_int4_rtn" and zero is not None:
|
||||||
|
dequantized_weight = dequantized_weight + torch.unsqueeze(zero.to(torch.float32),
|
||||||
|
dim=1)
|
||||||
self.weight = Parameter(dequantized_weight, requires_grad=False).contiguous()
|
self.weight = Parameter(dequantized_weight, requires_grad=False).contiguous()
|
||||||
else:
|
else:
|
||||||
dequantized_weight = weight.to(torch.float32) * \
|
dequantized_weight = weight.to(torch.float32) * \
|
||||||
|
|
|
||||||
|
|
@ -36,6 +36,7 @@ class LMHeadLinear(NNFactory):
|
||||||
dtype: np.dtype = np.int8,
|
dtype: np.dtype = np.int8,
|
||||||
use_split: bool = False,
|
use_split: bool = False,
|
||||||
group_size: int = 0,
|
group_size: int = 0,
|
||||||
|
asym: bool = False,
|
||||||
):
|
):
|
||||||
"""Initialize the LMHeadLinear class.
|
"""Initialize the LMHeadLinear class.
|
||||||
|
|
||||||
|
|
@ -54,11 +55,10 @@ class LMHeadLinear(NNFactory):
|
||||||
self.batch = batch
|
self.batch = batch
|
||||||
|
|
||||||
self.split_num = split_num
|
self.split_num = split_num
|
||||||
|
|
||||||
if use_split:
|
if use_split:
|
||||||
input = self.parameter((1, self.batch, self.inC))
|
input = self.parameter((1, self.batch, self.inC))
|
||||||
res = self.dq_split_linear(input, self.split_num, self.outC, self.inC, wt_dtype=dtype,
|
res = self.dq_split_linear(input, self.split_num, self.outC, self.inC, wt_dtype=dtype,
|
||||||
scale_factor=(group_size == 0))
|
scale_factor=(group_size == 0), asym=asym)
|
||||||
else:
|
else:
|
||||||
input = self.parameter((self.batch, self.inC))
|
input = self.parameter((self.batch, self.inC))
|
||||||
split_size = self.inC // split_num // 2 * 2
|
split_size = self.inC // split_num // 2 * 2
|
||||||
|
|
@ -69,7 +69,7 @@ class LMHeadLinear(NNFactory):
|
||||||
input_slice = self.slice(input, begin=[0, start_idx],
|
input_slice = self.slice(input, begin=[0, start_idx],
|
||||||
end=[self.batch, end_idx])
|
end=[self.batch, end_idx])
|
||||||
linear_slice = self.linear(input_slice, outC, split_size, bias=False,
|
linear_slice = self.linear(input_slice, outC, split_size, bias=False,
|
||||||
wt_dtype=dtype)
|
wt_dtype=dtype, asym=asym)
|
||||||
if i == 0:
|
if i == 0:
|
||||||
res = linear_slice
|
res = linear_slice
|
||||||
else:
|
else:
|
||||||
|
|
@ -109,7 +109,7 @@ class LMHeadLinear(NNFactory):
|
||||||
|
|
||||||
|
|
||||||
class SlicedLMHead(nn.Module):
|
class SlicedLMHead(nn.Module):
|
||||||
def __init__(self, weight, bias, split_num, use_split=False, group_size=0):
|
def __init__(self, weight, bias, split_num, use_split=False, group_size=0, asym=False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.split_num = split_num
|
self.split_num = split_num
|
||||||
self.outC, self.inC = weight.shape
|
self.outC, self.inC = weight.shape
|
||||||
|
|
@ -128,6 +128,7 @@ class SlicedLMHead(nn.Module):
|
||||||
self.lm_heads.append(new_linear)
|
self.lm_heads.append(new_linear)
|
||||||
self.bias = bias
|
self.bias = bias
|
||||||
self.use_split = use_split
|
self.use_split = use_split
|
||||||
|
self.asym = asym
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
def forward(self, hidden_states):
|
||||||
if hidden_states.size(0) * hidden_states.size(1) == 1:
|
if hidden_states.size(0) * hidden_states.size(1) == 1:
|
||||||
|
|
@ -162,15 +163,29 @@ class SlicedLMHead(nn.Module):
|
||||||
np_dtype = np.uint8 if self.get_weight_dtype() == torch.uint8 else np.int8
|
np_dtype = np.uint8 if self.get_weight_dtype() == torch.uint8 else np.int8
|
||||||
self.fused_lm_head = LMHeadLinear(self.inC, self.outC, 1, self.split_num,
|
self.fused_lm_head = LMHeadLinear(self.inC, self.outC, 1, self.split_num,
|
||||||
False, "NPU", dtype=np_dtype, use_split=self.use_split,
|
False, "NPU", dtype=np_dtype, use_split=self.use_split,
|
||||||
group_size=self.group_size)
|
group_size=self.group_size, asym=self.asym)
|
||||||
if self.use_split:
|
if self.use_split:
|
||||||
weights = []
|
weights = []
|
||||||
scales = []
|
scales = []
|
||||||
|
zeros = []
|
||||||
for i in range(self.split_num):
|
for i in range(self.split_num):
|
||||||
weights.append(self.lm_heads[i].weight)
|
weights.append(self.lm_heads[i].weight)
|
||||||
scales.append(self.lm_heads[i].scale)
|
scales.append(self.lm_heads[i].scale)
|
||||||
fused_lm_head_weights = (torch.stack(weights, axis=0).numpy(),
|
if self.lm_heads[i].zero is not None:
|
||||||
torch.stack(scales, axis=0).numpy())
|
zeros.append(self.lm_heads[i].zero)
|
||||||
|
if len(zeros):
|
||||||
|
fused_lm_head_weights = [(torch.stack(weights, axis=0).numpy(),
|
||||||
|
torch.stack(scales, axis=0).numpy(),
|
||||||
|
torch.stack(zeros, axis=0).numpy())]
|
||||||
|
else:
|
||||||
|
fused_lm_head_weights = [(torch.stack(weights, axis=0).numpy(),
|
||||||
|
torch.stack(scales, axis=0).numpy())]
|
||||||
|
else:
|
||||||
|
if self.asym:
|
||||||
|
fused_lm_head_weights = [(self.lm_heads[i].weight.data.numpy(),
|
||||||
|
self.lm_heads[i].scale.data.numpy(),
|
||||||
|
self.lm_heads[i].zero.data.numpy())
|
||||||
|
for i in range(self.split_num)]
|
||||||
else:
|
else:
|
||||||
fused_lm_head_weights = [(self.lm_heads[i].weight.data.numpy(),
|
fused_lm_head_weights = [(self.lm_heads[i].weight.data.numpy(),
|
||||||
self.lm_heads[i].scale.data.numpy())
|
self.lm_heads[i].scale.data.numpy())
|
||||||
|
|
|
||||||
|
|
@ -59,9 +59,16 @@ def run_model(
|
||||||
op_args_flatten = []
|
op_args_flatten = []
|
||||||
for w in weights:
|
for w in weights:
|
||||||
if isinstance(w, tuple): # from QuantizedLinear
|
if isinstance(w, tuple): # from QuantizedLinear
|
||||||
|
if len(w) == 2:
|
||||||
op_args.append((set_contiguous(w[0]).numpy(), set_contiguous(w[1]).numpy()))
|
op_args.append((set_contiguous(w[0]).numpy(), set_contiguous(w[1]).numpy()))
|
||||||
op_args_flatten.append(op_args[-1][0])
|
op_args_flatten.append(op_args[-1][0])
|
||||||
op_args_flatten.append(op_args[-1][1])
|
op_args_flatten.append(op_args[-1][1])
|
||||||
|
else:
|
||||||
|
op_args.append((set_contiguous(w[0]).numpy(), set_contiguous(w[1]).numpy(),
|
||||||
|
set_contiguous(w[2]).numpy()))
|
||||||
|
op_args_flatten.append(op_args[-1][0])
|
||||||
|
op_args_flatten.append(op_args[-1][1])
|
||||||
|
op_args_flatten.append(op_args[-1][2])
|
||||||
elif w.dtype in [torch.int8, torch.uint8]: # QuantizedLinear weight
|
elif w.dtype in [torch.int8, torch.uint8]: # QuantizedLinear weight
|
||||||
op_args.append(w.numpy())
|
op_args.append(w.numpy())
|
||||||
op_args_flatten.append(op_args[-1])
|
op_args_flatten.append(op_args[-1])
|
||||||
|
|
@ -104,7 +111,7 @@ def run_model(
|
||||||
class LLMBaseNNFactory(NNFactory):
|
class LLMBaseNNFactory(NNFactory):
|
||||||
|
|
||||||
def __init__(self, max_seq_len, transpose_value, dtype, profile=False, device="NPU",
|
def __init__(self, max_seq_len, transpose_value, dtype, profile=False, device="NPU",
|
||||||
n_splits_linear=1, n_splits_down_proj=1, group_size=0):
|
n_splits_linear=1, n_splits_down_proj=1, group_size=0, asym=False):
|
||||||
super().__init__(profile, device)
|
super().__init__(profile, device)
|
||||||
self.cache_parameter_ops = []
|
self.cache_parameter_ops = []
|
||||||
self.input_ops = []
|
self.input_ops = []
|
||||||
|
|
@ -117,6 +124,7 @@ class LLMBaseNNFactory(NNFactory):
|
||||||
self.n_splits_linear = n_splits_linear
|
self.n_splits_linear = n_splits_linear
|
||||||
self.n_splits_down_proj = n_splits_down_proj
|
self.n_splits_down_proj = n_splits_down_proj
|
||||||
self.group_size = group_size
|
self.group_size = group_size
|
||||||
|
self.asym = asym
|
||||||
|
|
||||||
def attention(self,
|
def attention(self,
|
||||||
*,
|
*,
|
||||||
|
|
@ -149,7 +157,8 @@ class LLMBaseNNFactory(NNFactory):
|
||||||
wt_dtype=self.dtype,
|
wt_dtype=self.dtype,
|
||||||
n_splits=self.n_splits_linear,
|
n_splits=self.n_splits_linear,
|
||||||
scale_factor=(self.group_size == 0),
|
scale_factor=(self.group_size == 0),
|
||||||
is_prefill=(mode == "prefill")
|
is_prefill=(mode == "prefill"),
|
||||||
|
asym=self.asym
|
||||||
)
|
)
|
||||||
|
|
||||||
key_states = self.linear(
|
key_states = self.linear(
|
||||||
|
|
@ -160,7 +169,8 @@ class LLMBaseNNFactory(NNFactory):
|
||||||
wt_dtype=self.dtype,
|
wt_dtype=self.dtype,
|
||||||
n_splits=self.n_splits_linear,
|
n_splits=self.n_splits_linear,
|
||||||
scale_factor=(self.group_size == 0),
|
scale_factor=(self.group_size == 0),
|
||||||
is_prefill=(mode == "prefill")
|
is_prefill=(mode == "prefill"),
|
||||||
|
asym=self.asym
|
||||||
)
|
)
|
||||||
|
|
||||||
value_states = self.linear(
|
value_states = self.linear(
|
||||||
|
|
@ -171,7 +181,8 @@ class LLMBaseNNFactory(NNFactory):
|
||||||
wt_dtype=self.dtype,
|
wt_dtype=self.dtype,
|
||||||
n_splits=self.n_splits_linear,
|
n_splits=self.n_splits_linear,
|
||||||
scale_factor=(self.group_size == 0),
|
scale_factor=(self.group_size == 0),
|
||||||
is_prefill=(mode == "prefill")
|
is_prefill=(mode == "prefill"),
|
||||||
|
asym=self.asym
|
||||||
)
|
)
|
||||||
|
|
||||||
if q_bias is not None:
|
if q_bias is not None:
|
||||||
|
|
@ -260,7 +271,8 @@ class LLMBaseNNFactory(NNFactory):
|
||||||
attn_output, hidden_size, hidden_size, bias=False, wt_dtype=self.dtype,
|
attn_output, hidden_size, hidden_size, bias=False, wt_dtype=self.dtype,
|
||||||
n_splits=self.n_splits_linear,
|
n_splits=self.n_splits_linear,
|
||||||
scale_factor=(self.group_size == 0),
|
scale_factor=(self.group_size == 0),
|
||||||
is_prefill=(mode == "prefill")
|
is_prefill=(mode == "prefill"),
|
||||||
|
asym=self.asym
|
||||||
)
|
)
|
||||||
return attn_output, new_key_states, new_value_states
|
return attn_output, new_key_states, new_value_states
|
||||||
|
|
||||||
|
|
@ -428,13 +440,15 @@ class LLMBaseNNFactory(NNFactory):
|
||||||
hidden_states, self.intermediate_size, self.hidden_size, bias=False,
|
hidden_states, self.intermediate_size, self.hidden_size, bias=False,
|
||||||
wt_dtype=self.dtype, n_splits=self.n_splits_linear,
|
wt_dtype=self.dtype, n_splits=self.n_splits_linear,
|
||||||
scale_factor=(self.group_size == 0),
|
scale_factor=(self.group_size == 0),
|
||||||
is_prefill=(mode == "prefill")
|
is_prefill=(mode == "prefill"),
|
||||||
|
asym=self.asym
|
||||||
)
|
)
|
||||||
mm2 = self.linear(
|
mm2 = self.linear(
|
||||||
hidden_states, self.intermediate_size, self.hidden_size, bias=False,
|
hidden_states, self.intermediate_size, self.hidden_size, bias=False,
|
||||||
wt_dtype=self.dtype, n_splits=self.n_splits_linear,
|
wt_dtype=self.dtype, n_splits=self.n_splits_linear,
|
||||||
scale_factor=(self.group_size == 0),
|
scale_factor=(self.group_size == 0),
|
||||||
is_prefill=(mode == "prefill")
|
is_prefill=(mode == "prefill"),
|
||||||
|
asym=self.asym
|
||||||
) # type: ignore[attr-defined]
|
) # type: ignore[attr-defined]
|
||||||
mm1 = self.eltwise_mul(self.swish(mm1), mm2) # type: ignore[attr-defined]
|
mm1 = self.eltwise_mul(self.swish(mm1), mm2) # type: ignore[attr-defined]
|
||||||
|
|
||||||
|
|
@ -442,7 +456,8 @@ class LLMBaseNNFactory(NNFactory):
|
||||||
mm1, self.hidden_size, self.intermediate_size, bias=False, wt_dtype=self.dtype,
|
mm1, self.hidden_size, self.intermediate_size, bias=False, wt_dtype=self.dtype,
|
||||||
n_splits=self.n_splits_down_proj,
|
n_splits=self.n_splits_down_proj,
|
||||||
scale_factor=(self.group_size == 0),
|
scale_factor=(self.group_size == 0),
|
||||||
is_prefill=(mode == "prefill")
|
is_prefill=(mode == "prefill"),
|
||||||
|
asym=self.asym
|
||||||
)
|
)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
@ -558,17 +573,20 @@ class LLMBaseNNFactory(NNFactory):
|
||||||
wt_dtype: npt.DTypeLike = np.float16,
|
wt_dtype: npt.DTypeLike = np.float16,
|
||||||
n_splits: int = 1,
|
n_splits: int = 1,
|
||||||
scale_factor: bool = True,
|
scale_factor: bool = True,
|
||||||
is_prefill: bool = False):
|
is_prefill: bool = False,
|
||||||
|
asym: bool = False):
|
||||||
if n_splits == 1:
|
if n_splits == 1:
|
||||||
op = super().linear(input_node, output_channels,
|
op = super().linear(input_node, output_channels,
|
||||||
input_channels, bias, act_dtype,
|
input_channels, bias, act_dtype,
|
||||||
wt_dtype, scale_factor=scale_factor)
|
wt_dtype, scale_factor=scale_factor,
|
||||||
|
asym=asym)
|
||||||
else:
|
else:
|
||||||
op = super().dq_split_linear(input_node, n_splits,
|
op = super().dq_split_linear(input_node, n_splits,
|
||||||
output_channels, input_channels,
|
output_channels, input_channels,
|
||||||
bias=bias, act_dtype=act_dtype,
|
bias=bias, act_dtype=act_dtype,
|
||||||
wt_dtype=wt_dtype, scale_factor=scale_factor,
|
wt_dtype=wt_dtype, scale_factor=scale_factor,
|
||||||
is_prefill=is_prefill)
|
is_prefill=is_prefill,
|
||||||
|
asym=asym)
|
||||||
self.linear_ops.append(op)
|
self.linear_ops.append(op)
|
||||||
return op
|
return op
|
||||||
|
|
||||||
|
|
@ -580,10 +598,11 @@ class LLMBaseNNFactory(NNFactory):
|
||||||
act_dtype: npt.DTypeLike = np.float16,
|
act_dtype: npt.DTypeLike = np.float16,
|
||||||
wt_dtype: npt.DTypeLike = np.float16,
|
wt_dtype: npt.DTypeLike = np.float16,
|
||||||
scale_factor: bool = False,
|
scale_factor: bool = False,
|
||||||
is_prefill: bool = False):
|
is_prefill: bool = False,
|
||||||
|
asym: bool = False):
|
||||||
op = super().dq_split_linear(input_node, n_splits, output_channels, input_channels,
|
op = super().dq_split_linear(input_node, n_splits, output_channels, input_channels,
|
||||||
False, act_dtype, wt_dtype, scale_factor,
|
False, act_dtype, wt_dtype, scale_factor,
|
||||||
is_prefill=is_prefill)
|
is_prefill=is_prefill, asym=asym)
|
||||||
self.linear_ops.append(op)
|
self.linear_ops.append(op)
|
||||||
return op
|
return op
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -97,7 +97,8 @@ class LowBitQwenMultiDecoderlayer(LLMBaseNNFactory):
|
||||||
intermediate_size,
|
intermediate_size,
|
||||||
n_splits_linear: int = 1,
|
n_splits_linear: int = 1,
|
||||||
n_splits_down_proj: int = 1,
|
n_splits_down_proj: int = 1,
|
||||||
group_size: int = 0
|
group_size: int = 0,
|
||||||
|
asym: bool = False,
|
||||||
):
|
):
|
||||||
super().__init__(max_seq_len=max_seq_len,
|
super().__init__(max_seq_len=max_seq_len,
|
||||||
transpose_value=transpose_value,
|
transpose_value=transpose_value,
|
||||||
|
|
@ -106,7 +107,8 @@ class LowBitQwenMultiDecoderlayer(LLMBaseNNFactory):
|
||||||
device=device,
|
device=device,
|
||||||
n_splits_linear=n_splits_linear,
|
n_splits_linear=n_splits_linear,
|
||||||
n_splits_down_proj=n_splits_down_proj,
|
n_splits_down_proj=n_splits_down_proj,
|
||||||
group_size=group_size)
|
group_size=group_size,
|
||||||
|
asym=asym)
|
||||||
self.max_seq_len = max_seq_len
|
self.max_seq_len = max_seq_len
|
||||||
self.intermediate_size = intermediate_size
|
self.intermediate_size = intermediate_size
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
|
|
@ -311,6 +313,7 @@ class FusedQwenLowBitMultiDecoderlayer(torch.nn.Module):
|
||||||
n_splits_linear: int = 1,
|
n_splits_linear: int = 1,
|
||||||
n_splits_down_proj: int = 1,
|
n_splits_down_proj: int = 1,
|
||||||
group_size: int = 0,
|
group_size: int = 0,
|
||||||
|
asym: bool = False,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
|
@ -318,8 +321,10 @@ class FusedQwenLowBitMultiDecoderlayer(torch.nn.Module):
|
||||||
|
|
||||||
op_parameters = []
|
op_parameters = []
|
||||||
for w in parameters:
|
for w in parameters:
|
||||||
if isinstance(w, tuple): # from QuantizedLinear
|
if isinstance(w, tuple) and not asym: # from QuantizedLinear
|
||||||
op_parameters.append((w[0].numpy(), w[1].numpy()))
|
op_parameters.append((w[0].numpy(), w[1].numpy()))
|
||||||
|
elif isinstance(w, tuple) and asym: # from QuantizedLinear
|
||||||
|
op_parameters.append((w[0].numpy(), w[1].numpy(), w[2].numpy()))
|
||||||
elif w.dtype in [torch.int8, torch.uint8]: # QuantizedLinear weight
|
elif w.dtype in [torch.int8, torch.uint8]: # QuantizedLinear weight
|
||||||
op_parameters.append(w.numpy())
|
op_parameters.append(w.numpy())
|
||||||
elif isinstance(w, np.ndarray): # scale
|
elif isinstance(w, np.ndarray): # scale
|
||||||
|
|
@ -375,7 +380,8 @@ class FusedQwenLowBitMultiDecoderlayer(torch.nn.Module):
|
||||||
dtype=np_dtype,
|
dtype=np_dtype,
|
||||||
n_splits_linear=n_splits_linear,
|
n_splits_linear=n_splits_linear,
|
||||||
n_splits_down_proj=n_splits_down_proj,
|
n_splits_down_proj=n_splits_down_proj,
|
||||||
group_size=group_size
|
group_size=group_size,
|
||||||
|
asym=asym,
|
||||||
)
|
)
|
||||||
self.backend_decoders.append(decoder)
|
self.backend_decoders.append(decoder)
|
||||||
|
|
||||||
|
|
@ -461,6 +467,7 @@ class FusedQwenLowBitDecoderlayer(torch.nn.Module):
|
||||||
n_splits_linear: int = 1,
|
n_splits_linear: int = 1,
|
||||||
n_splits_down_proj: int = 1,
|
n_splits_down_proj: int = 1,
|
||||||
group_size: int = 0,
|
group_size: int = 0,
|
||||||
|
asym: bool = False,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.op_parameters = parameters
|
self.op_parameters = parameters
|
||||||
|
|
@ -491,7 +498,8 @@ class FusedQwenLowBitDecoderlayer(torch.nn.Module):
|
||||||
dtype=np_dtype,
|
dtype=np_dtype,
|
||||||
n_splits_linear=n_splits_linear,
|
n_splits_linear=n_splits_linear,
|
||||||
n_splits_down_proj=n_splits_down_proj,
|
n_splits_down_proj=n_splits_down_proj,
|
||||||
group_size=group_size
|
group_size=group_size,
|
||||||
|
asym=asym
|
||||||
)
|
)
|
||||||
self.layer_norm_0 = layer_norm_0
|
self.layer_norm_0 = layer_norm_0
|
||||||
self.layer_norm_1 = layer_norm_1
|
self.layer_norm_1 = layer_norm_1
|
||||||
|
|
@ -580,6 +588,7 @@ def run_decode(
|
||||||
layer_indexs = range(layer_start, layer_end)
|
layer_indexs = range(layer_start, layer_end)
|
||||||
n_splits_linear = len(model.model.layers[0].mlp.gate_proj_dq_list)
|
n_splits_linear = len(model.model.layers[0].mlp.gate_proj_dq_list)
|
||||||
n_splits_down_proj = len(model.model.layers[0].mlp.down_proj_dq_list)
|
n_splits_down_proj = len(model.model.layers[0].mlp.down_proj_dq_list)
|
||||||
|
asym = getattr(model.config, "asym", False)
|
||||||
for layer_idx in layer_indexs:
|
for layer_idx in layer_indexs:
|
||||||
curr_layer = model.model.layers[layer_idx]
|
curr_layer = model.model.layers[layer_idx]
|
||||||
attn_layer = curr_layer.self_attn
|
attn_layer = curr_layer.self_attn
|
||||||
|
|
@ -592,9 +601,16 @@ def run_decode(
|
||||||
mlp_layer.down_proj_dq_list]:
|
mlp_layer.down_proj_dq_list]:
|
||||||
l_weights = []
|
l_weights = []
|
||||||
scales = []
|
scales = []
|
||||||
|
zeros = []
|
||||||
for l in layer_list:
|
for l in layer_list:
|
||||||
l_weights.append(l.weight)
|
l_weights.append(l.weight)
|
||||||
scales.append(l.scale)
|
scales.append(l.scale)
|
||||||
|
if l.zero is not None:
|
||||||
|
zeros.append(l.zero)
|
||||||
|
if len(zeros):
|
||||||
|
weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0),
|
||||||
|
torch.stack(zeros, axis=0)))
|
||||||
|
else:
|
||||||
weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
|
weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
|
||||||
|
|
||||||
cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
|
cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
|
||||||
|
|
@ -630,7 +646,8 @@ def run_decode(
|
||||||
do_print=False,
|
do_print=False,
|
||||||
n_splits_linear=n_splits_linear,
|
n_splits_linear=n_splits_linear,
|
||||||
n_splits_down_proj=n_splits_down_proj,
|
n_splits_down_proj=n_splits_down_proj,
|
||||||
group_size=group_size
|
group_size=group_size,
|
||||||
|
asym=asym
|
||||||
)
|
)
|
||||||
|
|
||||||
dist.barrier()
|
dist.barrier()
|
||||||
|
|
@ -809,6 +826,7 @@ def run_prefill(
|
||||||
layer_indexs = range(layer_start, layer_end)
|
layer_indexs = range(layer_start, layer_end)
|
||||||
n_splits_linear = len(model.model.layers[0].mlp.gate_proj_dq_list)
|
n_splits_linear = len(model.model.layers[0].mlp.gate_proj_dq_list)
|
||||||
n_splits_down_proj = len(model.model.layers[0].mlp.down_proj_dq_list)
|
n_splits_down_proj = len(model.model.layers[0].mlp.down_proj_dq_list)
|
||||||
|
asym = getattr(model.config, "asym", False)
|
||||||
for layer_idx in layer_indexs:
|
for layer_idx in layer_indexs:
|
||||||
curr_layer = model.model.layers[layer_idx]
|
curr_layer = model.model.layers[layer_idx]
|
||||||
attn_layer = curr_layer.self_attn
|
attn_layer = curr_layer.self_attn
|
||||||
|
|
@ -821,9 +839,16 @@ def run_prefill(
|
||||||
mlp_layer.down_proj_dq_list]:
|
mlp_layer.down_proj_dq_list]:
|
||||||
l_weights = []
|
l_weights = []
|
||||||
scales = []
|
scales = []
|
||||||
|
zeros = []
|
||||||
for l in layer_list:
|
for l in layer_list:
|
||||||
l_weights.append(l.weight)
|
l_weights.append(l.weight)
|
||||||
scales.append(l.scale)
|
scales.append(l.scale)
|
||||||
|
if l.zero is not None:
|
||||||
|
zeros.append(l.zero)
|
||||||
|
if len(zeros):
|
||||||
|
weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0),
|
||||||
|
torch.stack(zeros, axis=0)))
|
||||||
|
else:
|
||||||
weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
|
weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
|
||||||
|
|
||||||
cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
|
cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
|
||||||
|
|
@ -850,7 +875,8 @@ def run_prefill(
|
||||||
transpose_value=transpose_value_cache,
|
transpose_value=transpose_value_cache,
|
||||||
n_splits_linear=n_splits_linear,
|
n_splits_linear=n_splits_linear,
|
||||||
n_splits_down_proj=n_splits_down_proj,
|
n_splits_down_proj=n_splits_down_proj,
|
||||||
group_size=group_size
|
group_size=group_size,
|
||||||
|
asym=asym
|
||||||
)
|
)
|
||||||
|
|
||||||
layer_weights.extend(weights)
|
layer_weights.extend(weights)
|
||||||
|
|
|
||||||
|
|
@ -86,6 +86,7 @@ class LowBitLLMLMHead(LLMBaseNNFactory):
|
||||||
device: str = "NPU",
|
device: str = "NPU",
|
||||||
n_splits: int = 1,
|
n_splits: int = 1,
|
||||||
group_size: int = 0,
|
group_size: int = 0,
|
||||||
|
asym: bool = False
|
||||||
):
|
):
|
||||||
super().__init__(max_seq_len=max_seq_len,
|
super().__init__(max_seq_len=max_seq_len,
|
||||||
transpose_value=transpose_value,
|
transpose_value=transpose_value,
|
||||||
|
|
@ -119,6 +120,7 @@ class LowBitLLMLMHead(LLMBaseNNFactory):
|
||||||
hidden_states, self.vocab_size, self.hidden_size, bias=False, wt_dtype=self.dtype,
|
hidden_states, self.vocab_size, self.hidden_size, bias=False, wt_dtype=self.dtype,
|
||||||
n_splits=n_splits,
|
n_splits=n_splits,
|
||||||
scale_factor=(group_size == 0),
|
scale_factor=(group_size == 0),
|
||||||
|
asym=asym
|
||||||
)
|
)
|
||||||
|
|
||||||
# define outputs
|
# define outputs
|
||||||
|
|
|
||||||
|
|
@ -201,7 +201,7 @@ def convert_llm(model: torch.nn.Module,
|
||||||
layernorm_const = os.environ.get("IPEX_LLM_NPU_LAYERNORM_CONST", "1") == "1"
|
layernorm_const = os.environ.get("IPEX_LLM_NPU_LAYERNORM_CONST", "1") == "1"
|
||||||
if group_size == 0:
|
if group_size == 0:
|
||||||
n_splits_linear = 1
|
n_splits_linear = 1
|
||||||
if qtype == "sym_int8_rtn":
|
if qtype in ["sym_int8_rtn", "asym_int4_rtn"]:
|
||||||
# do not split mlp down_proj for Qwen2-7B & sym_int8
|
# do not split mlp down_proj for Qwen2-7B & sym_int8
|
||||||
n_splits_down_proj = 1
|
n_splits_down_proj = 1
|
||||||
else:
|
else:
|
||||||
|
|
@ -434,6 +434,12 @@ def convert_llm_for_deploy(model: torch.nn.Module,
|
||||||
os.mkdir(weight_dir)
|
os.mkdir(weight_dir)
|
||||||
layernorm_const = os.environ.get("IPEX_LLM_NPU_LAYERNORM_CONST", "1") == "1"
|
layernorm_const = os.environ.get("IPEX_LLM_NPU_LAYERNORM_CONST", "1") == "1"
|
||||||
|
|
||||||
|
lm_head_low_bit = getattr(model.config, "bigdl_transformers_low_bit", "sym_int4_rtn")
|
||||||
|
if not isinstance(model.lm_head, SlicedLMHead):
|
||||||
|
lm_head_low_bit = model.lm_head.qtype
|
||||||
|
else:
|
||||||
|
lm_head_low_bit = model.lm_head.lm_heads[0].qtype
|
||||||
|
|
||||||
if model.config.model_type == "qwen2":
|
if model.config.model_type == "qwen2":
|
||||||
if group_size == 0:
|
if group_size == 0:
|
||||||
if model.config.hidden_size == 1536:
|
if model.config.hidden_size == 1536:
|
||||||
|
|
@ -456,7 +462,8 @@ def convert_llm_for_deploy(model: torch.nn.Module,
|
||||||
"weight_num": 7,
|
"weight_num": 7,
|
||||||
"weight_idx": 8,
|
"weight_idx": 8,
|
||||||
"n_splits_linear": n_splits_linear,
|
"n_splits_linear": n_splits_linear,
|
||||||
"n_splits_down_proj": n_splits_down_proj}
|
"n_splits_down_proj": n_splits_down_proj,
|
||||||
|
"lm_head_low_bit": lm_head_low_bit}
|
||||||
model.config.update(update_dict)
|
model.config.update(update_dict)
|
||||||
model.config.save_pretrained(save_directory)
|
model.config.save_pretrained(save_directory)
|
||||||
|
|
||||||
|
|
@ -517,7 +524,8 @@ def convert_llm_for_deploy(model: torch.nn.Module,
|
||||||
"embedding_post": embedding_post,
|
"embedding_post": embedding_post,
|
||||||
"cos_sin_input": cos_sin_input,
|
"cos_sin_input": cos_sin_input,
|
||||||
"n_splits_linear": n_splits_linear,
|
"n_splits_linear": n_splits_linear,
|
||||||
"n_splits_down_proj": n_splits_down_proj}
|
"n_splits_down_proj": n_splits_down_proj,
|
||||||
|
"lm_head_low_bit": lm_head_low_bit}
|
||||||
model.config.update(update_dict)
|
model.config.update(update_dict)
|
||||||
model.config.save_pretrained(save_directory)
|
model.config.save_pretrained(save_directory)
|
||||||
|
|
||||||
|
|
@ -556,7 +564,8 @@ def convert_llm_for_deploy(model: torch.nn.Module,
|
||||||
"model_type": "minicpm",
|
"model_type": "minicpm",
|
||||||
"embedding_post": True,
|
"embedding_post": True,
|
||||||
"n_splits_linear": n_splits_linear,
|
"n_splits_linear": n_splits_linear,
|
||||||
"n_splits_down_proj": n_splits_down_proj}
|
"n_splits_down_proj": n_splits_down_proj,
|
||||||
|
"lm_head_low_bit": lm_head_low_bit}
|
||||||
model.config.update(update_dict)
|
model.config.update(update_dict)
|
||||||
model.config.save_pretrained(save_directory)
|
model.config.save_pretrained(save_directory)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -31,15 +31,30 @@ def convert_lm_head_and_embedding(model, temp_dir, weight_dir,
|
||||||
model_norm = model.model.norm
|
model_norm = model.model.norm
|
||||||
lm_head = model.lm_head
|
lm_head = model.lm_head
|
||||||
lm_head_n_splits = 1
|
lm_head_n_splits = 1
|
||||||
|
asym = getattr(model.config, "asym", False)
|
||||||
|
|
||||||
if not isinstance(lm_head, SlicedLMHead):
|
if not isinstance(lm_head, SlicedLMHead):
|
||||||
|
asym = lm_head.qtype == "asym_int4_rtn"
|
||||||
|
if asym:
|
||||||
|
weights = [(lm_head.weight, lm_head.scale, lm_head.zero)]
|
||||||
|
else:
|
||||||
weights = [(lm_head.weight, lm_head.scale)]
|
weights = [(lm_head.weight, lm_head.scale)]
|
||||||
else:
|
else:
|
||||||
lm_heads = lm_head.lm_heads
|
lm_heads = lm_head.lm_heads
|
||||||
|
asym = lm_heads[0].qtype == "asym_int4_rtn"
|
||||||
lm_head_weights = []
|
lm_head_weights = []
|
||||||
scales = []
|
scales = []
|
||||||
|
zeros = []
|
||||||
for l in lm_heads:
|
for l in lm_heads:
|
||||||
lm_head_weights.append(l.weight)
|
lm_head_weights.append(l.weight)
|
||||||
scales.append(l.scale)
|
scales.append(l.scale)
|
||||||
|
if l.zero is not None:
|
||||||
|
zeros.append(l.zero)
|
||||||
|
if len(zeros):
|
||||||
|
weights = [(torch.stack(lm_head_weights, axis=0),
|
||||||
|
torch.stack(scales, axis=0),
|
||||||
|
torch.stack(zeros, axis=0))]
|
||||||
|
else:
|
||||||
weights = [(torch.stack(lm_head_weights, axis=0),
|
weights = [(torch.stack(lm_head_weights, axis=0),
|
||||||
torch.stack(scales, axis=0))]
|
torch.stack(scales, axis=0))]
|
||||||
lm_head_n_splits = lm_head.split_num
|
lm_head_n_splits = lm_head.split_num
|
||||||
|
|
@ -60,6 +75,7 @@ def convert_lm_head_and_embedding(model, temp_dir, weight_dir,
|
||||||
vocab_size=vocab_size,
|
vocab_size=vocab_size,
|
||||||
n_splits=lm_head_n_splits,
|
n_splits=lm_head_n_splits,
|
||||||
group_size=group_size,
|
group_size=group_size,
|
||||||
|
asym=asym
|
||||||
)
|
)
|
||||||
|
|
||||||
last_blob_path = update_names_of_IR_and_export_blob(new_lm_head, f"lm_head",
|
last_blob_path = update_names_of_IR_and_export_blob(new_lm_head, f"lm_head",
|
||||||
|
|
@ -67,9 +83,15 @@ def convert_lm_head_and_embedding(model, temp_dir, weight_dir,
|
||||||
|
|
||||||
# save weights bins files
|
# save weights bins files
|
||||||
if not isinstance(lm_head, SlicedLMHead):
|
if not isinstance(lm_head, SlicedLMHead):
|
||||||
|
if not asym:
|
||||||
weight_numpy = [
|
weight_numpy = [
|
||||||
lm_head.weight.data.numpy(), lm_head.scale.data.numpy(),
|
lm_head.weight.data.numpy(), lm_head.scale.data.numpy(),
|
||||||
]
|
]
|
||||||
|
else:
|
||||||
|
weight_numpy = [
|
||||||
|
lm_head.weight.data.numpy(), lm_head.scale.data.numpy(),
|
||||||
|
lm_head.zero.data.numpy()
|
||||||
|
]
|
||||||
else:
|
else:
|
||||||
weight_numpy = [v.numpy() for v in weights[0]]
|
weight_numpy = [v.numpy() for v in weights[0]]
|
||||||
|
|
||||||
|
|
@ -104,6 +126,7 @@ def convert_qwen_layer(model, layer_idx, n_splits_linear, n_splits_down_proj,
|
||||||
head_dim = model.model.layers[0].self_attn.head_dim
|
head_dim = model.model.layers[0].self_attn.head_dim
|
||||||
intermediate_size = model.config.intermediate_size
|
intermediate_size = model.config.intermediate_size
|
||||||
rms_norm_eps = model.config.rms_norm_eps
|
rms_norm_eps = model.config.rms_norm_eps
|
||||||
|
asym = getattr(model.config, "asym", False)
|
||||||
|
|
||||||
from ipex_llm.transformers.npu_models.qwen2_mp import LowBitQwenMultiDecoderlayer
|
from ipex_llm.transformers.npu_models.qwen2_mp import LowBitQwenMultiDecoderlayer
|
||||||
curr_layer = model.model.layers[layer_idx]
|
curr_layer = model.model.layers[layer_idx]
|
||||||
|
|
@ -117,9 +140,16 @@ def convert_qwen_layer(model, layer_idx, n_splits_linear, n_splits_down_proj,
|
||||||
mlp_layer.down_proj_dq_list]:
|
mlp_layer.down_proj_dq_list]:
|
||||||
l_weights = []
|
l_weights = []
|
||||||
scales = []
|
scales = []
|
||||||
|
zeros = []
|
||||||
for l in layer_list:
|
for l in layer_list:
|
||||||
l_weights.append(l.weight)
|
l_weights.append(l.weight)
|
||||||
scales.append(l.scale)
|
scales.append(l.scale)
|
||||||
|
if l.zero is not None:
|
||||||
|
zeros.append(l.zero)
|
||||||
|
if len(zeros):
|
||||||
|
weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0),
|
||||||
|
torch.stack(zeros, axis=0)))
|
||||||
|
else:
|
||||||
weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
|
weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
|
||||||
|
|
||||||
q_bias = attn_layer.q_proj_dq_list.q_proj_dq_0.bias.to(torch.float16)
|
q_bias = attn_layer.q_proj_dq_list.q_proj_dq_0.bias.to(torch.float16)
|
||||||
|
|
@ -164,7 +194,8 @@ def convert_qwen_layer(model, layer_idx, n_splits_linear, n_splits_down_proj,
|
||||||
dtype=np_dtype,
|
dtype=np_dtype,
|
||||||
n_splits_linear=n_splits_linear,
|
n_splits_linear=n_splits_linear,
|
||||||
n_splits_down_proj=n_splits_down_proj,
|
n_splits_down_proj=n_splits_down_proj,
|
||||||
group_size=group_size
|
group_size=group_size,
|
||||||
|
asym=asym
|
||||||
)
|
)
|
||||||
rest_blob_path = update_names_of_IR_and_export_blob(single_decoder,
|
rest_blob_path = update_names_of_IR_and_export_blob(single_decoder,
|
||||||
decoder_name,
|
decoder_name,
|
||||||
|
|
@ -188,11 +219,23 @@ def convert_qwen_layer(model, layer_idx, n_splits_linear, n_splits_down_proj,
|
||||||
k_bias.data.numpy().tofile(k_bias_bin_file)
|
k_bias.data.numpy().tofile(k_bias_bin_file)
|
||||||
v_bias.data.numpy().tofile(v_bias_bin_file)
|
v_bias.data.numpy().tofile(v_bias_bin_file)
|
||||||
# 6, 7 are past k/v
|
# 6, 7 are past k/v
|
||||||
|
if not asym:
|
||||||
for idx, (weight, scale) in enumerate(weights):
|
for idx, (weight, scale) in enumerate(weights):
|
||||||
bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{st_idx+5+idx*2}.bin")
|
bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{st_idx+3+idx*2}.bin")
|
||||||
weight.numpy().tofile(bin_file)
|
weight.numpy().tofile(bin_file)
|
||||||
bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{st_idx+5+idx*2+1}.bin")
|
bin_file = os.path.join(weight_dir,
|
||||||
|
f"model_{layer_idx}_input_{st_idx+3+idx*2+1}.bin")
|
||||||
scale.numpy().tofile(bin_file)
|
scale.numpy().tofile(bin_file)
|
||||||
|
else:
|
||||||
|
for idx, (weight, scale, zero) in enumerate(weights):
|
||||||
|
bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{st_idx+3+idx*3}.bin")
|
||||||
|
weight.numpy().tofile(bin_file)
|
||||||
|
bin_file = os.path.join(weight_dir,
|
||||||
|
f"model_{layer_idx}_input_{st_idx+3+idx*3+1}.bin")
|
||||||
|
scale.numpy().tofile(bin_file)
|
||||||
|
bin_file = os.path.join(weight_dir,
|
||||||
|
f"model_{layer_idx}_input_{st_idx+3+idx*3+2}.bin")
|
||||||
|
zero.numpy().tofile(bin_file)
|
||||||
|
|
||||||
del single_decoder
|
del single_decoder
|
||||||
|
|
||||||
|
|
@ -207,6 +250,7 @@ def convert_fused_qwen_layer(model, fused_layers, n_splits_linear, n_splits_down
|
||||||
rms_norm_eps = model.config.rms_norm_eps
|
rms_norm_eps = model.config.rms_norm_eps
|
||||||
layer_num = len(model.model.layers)
|
layer_num = len(model.model.layers)
|
||||||
fused_layer_num = layer_num // fused_layers
|
fused_layer_num = layer_num // fused_layers
|
||||||
|
asym = getattr(model.config, "asym", False)
|
||||||
|
|
||||||
from ipex_llm.transformers.npu_models.qwen2_mp import LowBitQwenMultiDecoderlayer
|
from ipex_llm.transformers.npu_models.qwen2_mp import LowBitQwenMultiDecoderlayer
|
||||||
for i in range(fused_layers):
|
for i in range(fused_layers):
|
||||||
|
|
@ -233,9 +277,16 @@ def convert_fused_qwen_layer(model, fused_layers, n_splits_linear, n_splits_down
|
||||||
mlp_layer.down_proj_dq_list]:
|
mlp_layer.down_proj_dq_list]:
|
||||||
l_weights = []
|
l_weights = []
|
||||||
scales = []
|
scales = []
|
||||||
|
zeros = []
|
||||||
for l in layer_list:
|
for l in layer_list:
|
||||||
l_weights.append(l.weight)
|
l_weights.append(l.weight)
|
||||||
scales.append(l.scale)
|
scales.append(l.scale)
|
||||||
|
if l.zero is not None:
|
||||||
|
zeros.append(l.zero)
|
||||||
|
if len(zeros):
|
||||||
|
weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0),
|
||||||
|
torch.stack(zeros, axis=0)))
|
||||||
|
else:
|
||||||
weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
|
weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
|
||||||
|
|
||||||
cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
|
cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
|
||||||
|
|
@ -264,12 +315,25 @@ def convert_fused_qwen_layer(model, fused_layers, n_splits_linear, n_splits_down
|
||||||
k_biases[-1].data.numpy().tofile(k_bias_bin_file)
|
k_biases[-1].data.numpy().tofile(k_bias_bin_file)
|
||||||
v_biases[-1].data.numpy().tofile(v_bias_bin_file)
|
v_biases[-1].data.numpy().tofile(v_bias_bin_file)
|
||||||
# 6, 7 are past k/v
|
# 6, 7 are past k/v
|
||||||
|
if not asym:
|
||||||
for idx, (weight, scale) in enumerate(weights):
|
for idx, (weight, scale) in enumerate(weights):
|
||||||
bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{st_idx+3+idx*2}.bin")
|
bin_file = os.path.join(weight_dir,
|
||||||
|
f"model_{layer_idx}_input_{st_idx+3+idx*2}.bin")
|
||||||
weight.numpy().tofile(bin_file)
|
weight.numpy().tofile(bin_file)
|
||||||
bin_file = os.path.join(weight_dir,
|
bin_file = os.path.join(weight_dir,
|
||||||
f"model_{layer_idx}_input_{st_idx+3+idx*2+1}.bin")
|
f"model_{layer_idx}_input_{st_idx+3+idx*2+1}.bin")
|
||||||
scale.numpy().tofile(bin_file)
|
scale.numpy().tofile(bin_file)
|
||||||
|
else:
|
||||||
|
for idx, (weight, scale, zero) in enumerate(weights):
|
||||||
|
bin_file = os.path.join(weight_dir,
|
||||||
|
f"model_{layer_idx}_input_{st_idx+3+idx*3}.bin")
|
||||||
|
weight.numpy().tofile(bin_file)
|
||||||
|
bin_file = os.path.join(weight_dir,
|
||||||
|
f"model_{layer_idx}_input_{st_idx+3+idx*3+1}.bin")
|
||||||
|
scale.numpy().tofile(bin_file)
|
||||||
|
bin_file = os.path.join(weight_dir,
|
||||||
|
f"model_{layer_idx}_input_{st_idx+3+idx*3+2}.bin")
|
||||||
|
zero.numpy().tofile(bin_file)
|
||||||
|
|
||||||
if isinstance(weights[0], tuple):
|
if isinstance(weights[0], tuple):
|
||||||
np_dtype = np.int8 if weights[0][0].dtype == torch.int8 else np.uint8
|
np_dtype = np.int8 if weights[0][0].dtype == torch.int8 else np.uint8
|
||||||
|
|
@ -296,7 +360,8 @@ def convert_fused_qwen_layer(model, fused_layers, n_splits_linear, n_splits_down
|
||||||
dtype=np_dtype,
|
dtype=np_dtype,
|
||||||
n_splits_linear=n_splits_linear,
|
n_splits_linear=n_splits_linear,
|
||||||
n_splits_down_proj=n_splits_down_proj,
|
n_splits_down_proj=n_splits_down_proj,
|
||||||
group_size=group_size
|
group_size=group_size,
|
||||||
|
asym=asym
|
||||||
)
|
)
|
||||||
update_names_of_IR_and_export_blob(fused_decoder,
|
update_names_of_IR_and_export_blob(fused_decoder,
|
||||||
f"decoder_layer_{i}",
|
f"decoder_layer_{i}",
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue