diff --git a/python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py b/python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py index 0080b40a..d918e4a7 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py @@ -103,7 +103,7 @@ def run_model( class LLMBaseNNFactory(NNFactory): def __init__(self, max_seq_len, transpose_value, dtype, profile=False, device="NPU", - n_splits_linear=1, n_splits_down_proj=1, group_size=False): + n_splits_linear=1, n_splits_down_proj=1, group_size=0): super().__init__(profile, device) self.cache_parameter_ops = [] self.input_ops = []