Fix cpu pinned embedding (#9556)
This commit is contained in:
parent
557bb6bbdb
commit
34503efa6a
1 changed files with 3 additions and 1 deletions
|
|
@ -27,7 +27,9 @@ from typing import Optional
|
||||||
class CPUPinnedParam(Parameter):
|
class CPUPinnedParam(Parameter):
|
||||||
def to(self, *args, **kwargs):
|
def to(self, *args, **kwargs):
|
||||||
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
|
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
|
||||||
if device.type == 'xpu':
|
if device is None:
|
||||||
|
return super().to(*args, **kwargs)
|
||||||
|
elif device.type == 'xpu':
|
||||||
if convert_to_format is not None and self.dim() in (4, 5):
|
if convert_to_format is not None and self.dim() in (4, 5):
|
||||||
return super().to('cpu', dtype,
|
return super().to('cpu', dtype,
|
||||||
non_blocking, memory_format=convert_to_format)
|
non_blocking, memory_format=convert_to_format)
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue