Fix abnormal output for Qwen2-7B when sym_int8 (#12446)
This commit is contained in:
parent
71e1f11aa6
commit
303b104c10
1 changed files with 5 additions and 1 deletions
|
|
@ -128,6 +128,10 @@ def optimize_llm_pre(model: torch.nn.Module, qtype, mixed_precision,
|
||||||
from ipex_llm.transformers.npu_models.common import split_linears
|
from ipex_llm.transformers.npu_models.common import split_linears
|
||||||
if quantization_group_size == 0:
|
if quantization_group_size == 0:
|
||||||
n_splits_linear = 1
|
n_splits_linear = 1
|
||||||
|
if qtype == "sym_int8_rtn":
|
||||||
|
# do not split mlp down_proj for Qwen2-7B & sym_int8
|
||||||
|
n_splits_down_proj = 1
|
||||||
|
else:
|
||||||
n_splits_down_proj = 2 if model.config.intermediate_size == 18944 else 1
|
n_splits_down_proj = 2 if model.config.intermediate_size == 18944 else 1
|
||||||
else:
|
else:
|
||||||
invalidInputError(
|
invalidInputError(
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue