[LLM] Fix the model.device problem when cpu_embedding=True (#9971)
				
					
				
			* Overwrite the device attribute for CPUPinnedParam * Expose cpu_embedding=True for Linux users * Fix python style
This commit is contained in:
		
							parent
							
								
									f82782cd3b
								
							
						
					
					
						commit
						8d28aa8e2b
					
				
					 2 changed files with 26 additions and 11 deletions
				
			
		| 
						 | 
				
			
			@ -303,17 +303,16 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
 | 
			
		|||
                        module.weight = None
 | 
			
		||||
        elif cpu_embedding and type(module) == nn.Embedding:
 | 
			
		||||
            # skip user-defined Embedding layer
 | 
			
		||||
            if platform.system().lower() == 'windows':
 | 
			
		||||
                model._modules[name] = LLMEmbedding(
 | 
			
		||||
                    num_embeddings=module.num_embeddings,
 | 
			
		||||
                    embedding_dim=module.embedding_dim,
 | 
			
		||||
                    padding_idx=module.padding_idx,
 | 
			
		||||
                    max_norm=module.max_norm,
 | 
			
		||||
                    norm_type=module.norm_type,
 | 
			
		||||
                    scale_grad_by_freq=module.scale_grad_by_freq,
 | 
			
		||||
                    sparse=module.sparse,
 | 
			
		||||
                    _weight=module.weight.data,
 | 
			
		||||
                )
 | 
			
		||||
            model._modules[name] = LLMEmbedding(
 | 
			
		||||
                num_embeddings=module.num_embeddings,
 | 
			
		||||
                embedding_dim=module.embedding_dim,
 | 
			
		||||
                padding_idx=module.padding_idx,
 | 
			
		||||
                max_norm=module.max_norm,
 | 
			
		||||
                norm_type=module.norm_type,
 | 
			
		||||
                scale_grad_by_freq=module.scale_grad_by_freq,
 | 
			
		||||
                sparse=module.sparse,
 | 
			
		||||
                _weight=module.weight.data,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        # Remove the last key for recursion
 | 
			
		||||
        if len(list(module.children())) > 0:
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -25,11 +25,27 @@ from typing import Optional
 | 
			
		|||
# To prevent insufficient available memory when moving embedding from XPU back to CPU,
 | 
			
		||||
# we can pin the embedding to CPU if `cpu_embedding==True`.
 | 
			
		||||
class CPUPinnedParam(Parameter):
 | 
			
		||||
    # Overwrite the device attribute for CPUPinnedParam so that its device will be same as
 | 
			
		||||
    # the device for model.to(device);
 | 
			
		||||
    # With this device attribute, model.device will be same as the
 | 
			
		||||
    # the device for model.to(device) even with cpu_embedding==True
 | 
			
		||||
    @property
 | 
			
		||||
    def device(self):
 | 
			
		||||
        try:
 | 
			
		||||
            return self._device
 | 
			
		||||
        except AttributeError:
 | 
			
		||||
            return super().device
 | 
			
		||||
 | 
			
		||||
    @device.setter
 | 
			
		||||
    def device(self, to_device):
 | 
			
		||||
        self._device = to_device
 | 
			
		||||
 | 
			
		||||
    def to(self, *args, **kwargs):
 | 
			
		||||
        device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
 | 
			
		||||
        if device is None:
 | 
			
		||||
            return super().to(*args, **kwargs)
 | 
			
		||||
        elif device.type == 'xpu':
 | 
			
		||||
            self.device = device
 | 
			
		||||
            if convert_to_format is not None and self.dim() in (4, 5):
 | 
			
		||||
                return super().to('cpu', dtype,
 | 
			
		||||
                                  non_blocking, memory_format=convert_to_format)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue