Fix cpu pinned embedding (#9556)

This commit is contained in:
Yuwen Hu 2023-11-29 18:27:56 +08:00 committed by GitHub
parent 557bb6bbdb
commit 34503efa6a

View file

@ -27,7 +27,9 @@ from typing import Optional
class CPUPinnedParam(Parameter):
def to(self, *args, **kwargs):
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
if device.type == 'xpu':
if device is None:
return super().to(*args, **kwargs)
elif device.type == 'xpu':
if convert_to_format is not None and self.dim() in (4, 5):
return super().to('cpu', dtype,
non_blocking, memory_format=convert_to_format)