diff --git a/python/llm/src/bigdl/llm/transformers/model.py b/python/llm/src/bigdl/llm/transformers/model.py index 2c86c83b..70a353a6 100644 --- a/python/llm/src/bigdl/llm/transformers/model.py +++ b/python/llm/src/bigdl/llm/transformers/model.py @@ -65,6 +65,8 @@ class _BaseAutoModelClass: if load_in_4bit or load_in_low_bit: # load int x-bit kwargs["low_cpu_mem_usage"] = True + # set default torch_dtype='auto' + kwargs["torch_dtype"] = kwargs.get("torch_dtype", 'auto') q_k = load_in_low_bit if load_in_low_bit else "sym_int4" model = cls.load_convert(q_k, *args, **kwargs) else: