From a425eaabfcdd896700f4537a695421c1d46095af Mon Sep 17 00:00:00 2001 From: "Chen, Zhentao" Date: Mon, 11 Mar 2024 16:06:12 +0800 Subject: [PATCH] fix from_pretrained when device_map=None (#10361) * pr trigger * fix error when device_map=None * fix device_map=None --- .github/workflows/llm-harness-evaluation.yml | 2 +- python/llm/src/bigdl/llm/transformers/model.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/llm-harness-evaluation.yml b/.github/workflows/llm-harness-evaluation.yml index d3f04228..cca5ffd0 100644 --- a/.github/workflows/llm-harness-evaluation.yml +++ b/.github/workflows/llm-harness-evaluation.yml @@ -185,7 +185,7 @@ jobs: # set --limit if it's pr-triggered to accelerate pr action if ${{github.event_name == 'pull_request'}}; then - export LIMIT="--limit 4" + export LIMIT="--limit 6" fi python run_llb.py \ diff --git a/python/llm/src/bigdl/llm/transformers/model.py b/python/llm/src/bigdl/llm/transformers/model.py index 248d61a9..b72b3d34 100644 --- a/python/llm/src/bigdl/llm/transformers/model.py +++ b/python/llm/src/bigdl/llm/transformers/model.py @@ -145,7 +145,7 @@ class _BaseAutoModelClass: invalidInputError(model_hub in ["huggingface", "modelscope"], "The parameter `model_hub` is supposed to be `huggingface` or " f"`modelscope`, but got {model_hub}.") - invalidInputError(not ('device_map' in kwargs and 'xpu' in kwargs['device_map']), + invalidInputError(not (kwargs.get('device_map') and 'xpu' in kwargs['device_map']), "Please do not use `device_map` " "with `xpu` value as an argument. " "Use model.to('xpu') instead.")