Fix cpu pinned embedding (#9556)

This commit is contained in:
Yuwen Hu 2023-11-29 18:27:56 +08:00 committed by GitHub
parent 557bb6bbdb
commit 34503efa6a

View file

@ -27,7 +27,9 @@ from typing import Optional
class CPUPinnedParam(Parameter): class CPUPinnedParam(Parameter):
def to(self, *args, **kwargs): def to(self, *args, **kwargs):
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs) device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
if device.type == 'xpu': if device is None:
return super().to(*args, **kwargs)
elif device.type == 'xpu':
if convert_to_format is not None and self.dim() in (4, 5): if convert_to_format is not None and self.dim() in (4, 5):
return super().to('cpu', dtype, return super().to('cpu', dtype,
non_blocking, memory_format=convert_to_format) non_blocking, memory_format=convert_to_format)