diff --git a/python/llm/src/bigdl/llm/transformers/convert.py b/python/llm/src/bigdl/llm/transformers/convert.py index bf579e45..3d133762 100644 --- a/python/llm/src/bigdl/llm/transformers/convert.py +++ b/python/llm/src/bigdl/llm/transformers/convert.py @@ -303,17 +303,16 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None, module.weight = None elif cpu_embedding and type(module) == nn.Embedding: # skip user-defined Embedding layer - if platform.system().lower() == 'windows': - model._modules[name] = LLMEmbedding( - num_embeddings=module.num_embeddings, - embedding_dim=module.embedding_dim, - padding_idx=module.padding_idx, - max_norm=module.max_norm, - norm_type=module.norm_type, - scale_grad_by_freq=module.scale_grad_by_freq, - sparse=module.sparse, - _weight=module.weight.data, - ) + model._modules[name] = LLMEmbedding( + num_embeddings=module.num_embeddings, + embedding_dim=module.embedding_dim, + padding_idx=module.padding_idx, + max_norm=module.max_norm, + norm_type=module.norm_type, + scale_grad_by_freq=module.scale_grad_by_freq, + sparse=module.sparse, + _weight=module.weight.data, + ) # Remove the last key for recursion if len(list(module.children())) > 0: diff --git a/python/llm/src/bigdl/llm/transformers/embedding.py b/python/llm/src/bigdl/llm/transformers/embedding.py index f0bfc705..ba71b5c0 100644 --- a/python/llm/src/bigdl/llm/transformers/embedding.py +++ b/python/llm/src/bigdl/llm/transformers/embedding.py @@ -25,11 +25,27 @@ from typing import Optional # To prevent insufficient available memory when moving embedding from XPU back to CPU, # we can pin the embedding to CPU if `cpu_embedding==True`. class CPUPinnedParam(Parameter): + # Overwrite the device attribute for CPUPinnedParam so that its device will be same as + # the device for model.to(device); + # With this device attribute, model.device will be same as the + # the device for model.to(device) even with cpu_embedding==True + @property + def device(self): + try: + return self._device + except AttributeError: + return super().device + + @device.setter + def device(self, to_device): + self._device = to_device + def to(self, *args, **kwargs): device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs) if device is None: return super().to(*args, **kwargs) elif device.type == 'xpu': + self.device = device if convert_to_format is not None and self.dim() in (4, 5): return super().to('cpu', dtype, non_blocking, memory_format=convert_to_format)