From ba42a6da63143221433ce76824c1ec13f0c9c137 Mon Sep 17 00:00:00 2001 From: Yuwen Hu <54161268+Oscilloscope98@users.noreply.github.com> Date: Fri, 21 Jul 2023 17:55:00 +0800 Subject: [PATCH] [LLM] Set torch_dtype default value to 'auto' for transformers low bit from_pretrained API --- python/llm/src/bigdl/llm/transformers/model.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/llm/src/bigdl/llm/transformers/model.py b/python/llm/src/bigdl/llm/transformers/model.py index 2c86c83b..70a353a6 100644 --- a/python/llm/src/bigdl/llm/transformers/model.py +++ b/python/llm/src/bigdl/llm/transformers/model.py @@ -65,6 +65,8 @@ class _BaseAutoModelClass: if load_in_4bit or load_in_low_bit: # load int x-bit kwargs["low_cpu_mem_usage"] = True + # set default torch_dtype='auto' + kwargs["torch_dtype"] = kwargs.get("torch_dtype", 'auto') q_k = load_in_low_bit if load_in_low_bit else "sym_int4" model = cls.load_convert(q_k, *args, **kwargs) else: