Fix cpu pinned embedding (#9556)
This commit is contained in:
		
							parent
							
								
									557bb6bbdb
								
							
						
					
					
						commit
						34503efa6a
					
				
					 1 changed files with 3 additions and 1 deletions
				
			
		| 
						 | 
					@ -27,7 +27,9 @@ from typing import Optional
 | 
				
			||||||
class CPUPinnedParam(Parameter):
 | 
					class CPUPinnedParam(Parameter):
 | 
				
			||||||
    def to(self, *args, **kwargs):
 | 
					    def to(self, *args, **kwargs):
 | 
				
			||||||
        device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
 | 
					        device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
 | 
				
			||||||
        if device.type == 'xpu':
 | 
					        if device is None:
 | 
				
			||||||
 | 
					            return super().to(*args, **kwargs)
 | 
				
			||||||
 | 
					        elif device.type == 'xpu':
 | 
				
			||||||
            if convert_to_format is not None and self.dim() in (4, 5):
 | 
					            if convert_to_format is not None and self.dim() in (4, 5):
 | 
				
			||||||
                return super().to('cpu', dtype,
 | 
					                return super().to('cpu', dtype,
 | 
				
			||||||
                                  non_blocking, memory_format=convert_to_format)
 | 
					                                  non_blocking, memory_format=convert_to_format)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue