Fix abnormal output for Qwen2-7B when sym_int8 (#12446)

This commit is contained in:
Yuwen Hu 2024-11-26 15:53:04 +08:00 committed by GitHub
parent 71e1f11aa6
commit 303b104c10
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -128,7 +128,11 @@ def optimize_llm_pre(model: torch.nn.Module, qtype, mixed_precision,
from ipex_llm.transformers.npu_models.common import split_linears
if quantization_group_size == 0:
n_splits_linear = 1
n_splits_down_proj = 2 if model.config.intermediate_size == 18944 else 1
if qtype == "sym_int8_rtn":
# do not split mlp down_proj for Qwen2-7B & sym_int8
n_splits_down_proj = 1
else:
n_splits_down_proj = 2 if model.config.intermediate_size == 18944 else 1
else:
invalidInputError(
model.config.hidden_size % quantization_group_size == 0 and