[LLM] Fix the model.device problem when cpu_embedding=True (#9971)
				
					
				
			* Overwrite the device attribute for CPUPinnedParam * Expose cpu_embedding=True for Linux users * Fix python style
This commit is contained in:
		
							parent
							
								
									f82782cd3b
								
							
						
					
					
						commit
						8d28aa8e2b
					
				
					 2 changed files with 26 additions and 11 deletions
				
			
		| 
						 | 
					@ -303,7 +303,6 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
 | 
				
			||||||
                        module.weight = None
 | 
					                        module.weight = None
 | 
				
			||||||
        elif cpu_embedding and type(module) == nn.Embedding:
 | 
					        elif cpu_embedding and type(module) == nn.Embedding:
 | 
				
			||||||
            # skip user-defined Embedding layer
 | 
					            # skip user-defined Embedding layer
 | 
				
			||||||
            if platform.system().lower() == 'windows':
 | 
					 | 
				
			||||||
            model._modules[name] = LLMEmbedding(
 | 
					            model._modules[name] = LLMEmbedding(
 | 
				
			||||||
                num_embeddings=module.num_embeddings,
 | 
					                num_embeddings=module.num_embeddings,
 | 
				
			||||||
                embedding_dim=module.embedding_dim,
 | 
					                embedding_dim=module.embedding_dim,
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -25,11 +25,27 @@ from typing import Optional
 | 
				
			||||||
# To prevent insufficient available memory when moving embedding from XPU back to CPU,
 | 
					# To prevent insufficient available memory when moving embedding from XPU back to CPU,
 | 
				
			||||||
# we can pin the embedding to CPU if `cpu_embedding==True`.
 | 
					# we can pin the embedding to CPU if `cpu_embedding==True`.
 | 
				
			||||||
class CPUPinnedParam(Parameter):
 | 
					class CPUPinnedParam(Parameter):
 | 
				
			||||||
 | 
					    # Overwrite the device attribute for CPUPinnedParam so that its device will be same as
 | 
				
			||||||
 | 
					    # the device for model.to(device);
 | 
				
			||||||
 | 
					    # With this device attribute, model.device will be same as the
 | 
				
			||||||
 | 
					    # the device for model.to(device) even with cpu_embedding==True
 | 
				
			||||||
 | 
					    @property
 | 
				
			||||||
 | 
					    def device(self):
 | 
				
			||||||
 | 
					        try:
 | 
				
			||||||
 | 
					            return self._device
 | 
				
			||||||
 | 
					        except AttributeError:
 | 
				
			||||||
 | 
					            return super().device
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @device.setter
 | 
				
			||||||
 | 
					    def device(self, to_device):
 | 
				
			||||||
 | 
					        self._device = to_device
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def to(self, *args, **kwargs):
 | 
					    def to(self, *args, **kwargs):
 | 
				
			||||||
        device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
 | 
					        device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
 | 
				
			||||||
        if device is None:
 | 
					        if device is None:
 | 
				
			||||||
            return super().to(*args, **kwargs)
 | 
					            return super().to(*args, **kwargs)
 | 
				
			||||||
        elif device.type == 'xpu':
 | 
					        elif device.type == 'xpu':
 | 
				
			||||||
 | 
					            self.device = device
 | 
				
			||||||
            if convert_to_format is not None and self.dim() in (4, 5):
 | 
					            if convert_to_format is not None and self.dim() in (4, 5):
 | 
				
			||||||
                return super().to('cpu', dtype,
 | 
					                return super().to('cpu', dtype,
 | 
				
			||||||
                                  non_blocking, memory_format=convert_to_format)
 | 
					                                  non_blocking, memory_format=convert_to_format)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue