From 34503efa6a84497fcfe4f1386fe5e26354a14fc8 Mon Sep 17 00:00:00 2001 From: Yuwen Hu <54161268+Oscilloscope98@users.noreply.github.com> Date: Wed, 29 Nov 2023 18:27:56 +0800 Subject: [PATCH] Fix cpu pinned embedding (#9556) --- python/llm/src/bigdl/llm/transformers/embedding.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/llm/src/bigdl/llm/transformers/embedding.py b/python/llm/src/bigdl/llm/transformers/embedding.py index a6fc5589..f0bfc705 100644 --- a/python/llm/src/bigdl/llm/transformers/embedding.py +++ b/python/llm/src/bigdl/llm/transformers/embedding.py @@ -27,7 +27,9 @@ from typing import Optional class CPUPinnedParam(Parameter): def to(self, *args, **kwargs): device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs) - if device.type == 'xpu': + if device is None: + return super().to(*args, **kwargs) + elif device.type == 'xpu': if convert_to_format is not None and self.dim() in (4, 5): return super().to('cpu', dtype, non_blocking, memory_format=convert_to_format)