diff --git a/python/llm/src/bigdl/llm/transformers/embedding.py b/python/llm/src/bigdl/llm/transformers/embedding.py index a6fc5589..f0bfc705 100644 --- a/python/llm/src/bigdl/llm/transformers/embedding.py +++ b/python/llm/src/bigdl/llm/transformers/embedding.py @@ -27,7 +27,9 @@ from typing import Optional class CPUPinnedParam(Parameter): def to(self, *args, **kwargs): device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs) - if device.type == 'xpu': + if device is None: + return super().to(*args, **kwargs) + elif device.type == 'xpu': if convert_to_format is not None and self.dim() in (4, 5): return super().to('cpu', dtype, non_blocking, memory_format=convert_to_format)