[LLM] Fix the model.device problem when cpu_embedding=True (#9971)
* Overwrite the device attribute for CPUPinnedParam * Expose cpu_embedding=True for Linux users * Fix python style
This commit is contained in:
parent
f82782cd3b
commit
8d28aa8e2b
2 changed files with 26 additions and 11 deletions
|
|
@ -303,17 +303,16 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
|
||||||
module.weight = None
|
module.weight = None
|
||||||
elif cpu_embedding and type(module) == nn.Embedding:
|
elif cpu_embedding and type(module) == nn.Embedding:
|
||||||
# skip user-defined Embedding layer
|
# skip user-defined Embedding layer
|
||||||
if platform.system().lower() == 'windows':
|
model._modules[name] = LLMEmbedding(
|
||||||
model._modules[name] = LLMEmbedding(
|
num_embeddings=module.num_embeddings,
|
||||||
num_embeddings=module.num_embeddings,
|
embedding_dim=module.embedding_dim,
|
||||||
embedding_dim=module.embedding_dim,
|
padding_idx=module.padding_idx,
|
||||||
padding_idx=module.padding_idx,
|
max_norm=module.max_norm,
|
||||||
max_norm=module.max_norm,
|
norm_type=module.norm_type,
|
||||||
norm_type=module.norm_type,
|
scale_grad_by_freq=module.scale_grad_by_freq,
|
||||||
scale_grad_by_freq=module.scale_grad_by_freq,
|
sparse=module.sparse,
|
||||||
sparse=module.sparse,
|
_weight=module.weight.data,
|
||||||
_weight=module.weight.data,
|
)
|
||||||
)
|
|
||||||
|
|
||||||
# Remove the last key for recursion
|
# Remove the last key for recursion
|
||||||
if len(list(module.children())) > 0:
|
if len(list(module.children())) > 0:
|
||||||
|
|
|
||||||
|
|
@ -25,11 +25,27 @@ from typing import Optional
|
||||||
# To prevent insufficient available memory when moving embedding from XPU back to CPU,
|
# To prevent insufficient available memory when moving embedding from XPU back to CPU,
|
||||||
# we can pin the embedding to CPU if `cpu_embedding==True`.
|
# we can pin the embedding to CPU if `cpu_embedding==True`.
|
||||||
class CPUPinnedParam(Parameter):
|
class CPUPinnedParam(Parameter):
|
||||||
|
# Overwrite the device attribute for CPUPinnedParam so that its device will be same as
|
||||||
|
# the device for model.to(device);
|
||||||
|
# With this device attribute, model.device will be same as the
|
||||||
|
# the device for model.to(device) even with cpu_embedding==True
|
||||||
|
@property
|
||||||
|
def device(self):
|
||||||
|
try:
|
||||||
|
return self._device
|
||||||
|
except AttributeError:
|
||||||
|
return super().device
|
||||||
|
|
||||||
|
@device.setter
|
||||||
|
def device(self, to_device):
|
||||||
|
self._device = to_device
|
||||||
|
|
||||||
def to(self, *args, **kwargs):
|
def to(self, *args, **kwargs):
|
||||||
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
|
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
|
||||||
if device is None:
|
if device is None:
|
||||||
return super().to(*args, **kwargs)
|
return super().to(*args, **kwargs)
|
||||||
elif device.type == 'xpu':
|
elif device.type == 'xpu':
|
||||||
|
self.device = device
|
||||||
if convert_to_format is not None and self.dim() in (4, 5):
|
if convert_to_format is not None and self.dim() in (4, 5):
|
||||||
return super().to('cpu', dtype,
|
return super().to('cpu', dtype,
|
||||||
non_blocking, memory_format=convert_to_format)
|
non_blocking, memory_format=convert_to_format)
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue