[LLM] Fix the model.device problem when cpu_embedding=True (#9971)

* Overwrite the device attribute for CPUPinnedParam

* Expose cpu_embedding=True for Linux users

* Fix python style
This commit is contained in:
Yuwen Hu 2024-01-23 18:51:11 +08:00 committed by GitHub
parent f82782cd3b
commit 8d28aa8e2b
2 changed files with 26 additions and 11 deletions

View file

@ -303,7 +303,6 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
module.weight = None
elif cpu_embedding and type(module) == nn.Embedding:
# skip user-defined Embedding layer
if platform.system().lower() == 'windows':
model._modules[name] = LLMEmbedding(
num_embeddings=module.num_embeddings,
embedding_dim=module.embedding_dim,

View file

@ -25,11 +25,27 @@ from typing import Optional
# To prevent insufficient available memory when moving embedding from XPU back to CPU,
# we can pin the embedding to CPU if `cpu_embedding==True`.
class CPUPinnedParam(Parameter):
# Overwrite the device attribute for CPUPinnedParam so that its device will be same as
# the device for model.to(device);
# With this device attribute, model.device will be same as the
# the device for model.to(device) even with cpu_embedding==True
@property
def device(self):
try:
return self._device
except AttributeError:
return super().device
@device.setter
def device(self, to_device):
self._device = to_device
def to(self, *args, **kwargs):
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
if device is None:
return super().to(*args, **kwargs)
elif device.type == 'xpu':
self.device = device
if convert_to_format is not None and self.dim() in (4, 5):
return super().to('cpu', dtype,
non_blocking, memory_format=convert_to_format)