add disk embedding api (#11585)
This commit is contained in:
		
							parent
							
								
									79c742dfd5
								
							
						
					
					
						commit
						c279849d27
					
				
					 1 changed files with 19 additions and 2 deletions
				
			
		| 
						 | 
					@ -88,10 +88,11 @@ class DiskEmbedding(torch.nn.Embedding):
 | 
				
			||||||
                 sparse: bool = False,
 | 
					                 sparse: bool = False,
 | 
				
			||||||
                 _weight: Optional[Tensor] = None,
 | 
					                 _weight: Optional[Tensor] = None,
 | 
				
			||||||
                 _freeze: bool = False,
 | 
					                 _freeze: bool = False,
 | 
				
			||||||
                 device=None, dtype=None) -> None:
 | 
					                 device=None,
 | 
				
			||||||
 | 
					                 dtype=None):
 | 
				
			||||||
        super().__init__(num_embeddings, embedding_dim, padding_idx,
 | 
					        super().__init__(num_embeddings, embedding_dim, padding_idx,
 | 
				
			||||||
                         max_norm, norm_type, scale_grad_by_freq,
 | 
					                         max_norm, norm_type, scale_grad_by_freq,
 | 
				
			||||||
                         sparse, _weight, _freeze, device, dtype)
 | 
					                         sparse, _weight, True, device, dtype)
 | 
				
			||||||
        self.filename = "embeddings.bin"
 | 
					        self.filename = "embeddings.bin"
 | 
				
			||||||
        self.weight.data.flatten().half().numpy().tofile(self.filename)
 | 
					        self.weight.data.flatten().half().numpy().tofile(self.filename)
 | 
				
			||||||
        dummy_weight = torch.empty(0, 0, dtype=self.weight.dtype, device=self.weight.device)
 | 
					        dummy_weight = torch.empty(0, 0, dtype=self.weight.dtype, device=self.weight.device)
 | 
				
			||||||
| 
						 | 
					@ -118,6 +119,22 @@ class DiskEmbedding(torch.nn.Embedding):
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        self.weight = torch.nn.Parameter(embeds, requires_grad=False)
 | 
					        self.weight = torch.nn.Parameter(embeds, requires_grad=False)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @classmethod
 | 
				
			||||||
 | 
					    def from_embedding(cls, embedding: torch.nn.Embedding):
 | 
				
			||||||
 | 
					        return cls(
 | 
				
			||||||
 | 
					            embedding.num_embeddings,
 | 
				
			||||||
 | 
					            embedding.embedding_dim,
 | 
				
			||||||
 | 
					            embedding.padding_idx,
 | 
				
			||||||
 | 
					            embedding.max_norm,
 | 
				
			||||||
 | 
					            embedding.norm_type,
 | 
				
			||||||
 | 
					            embedding.scale_grad_by_freq,
 | 
				
			||||||
 | 
					            embedding.sparse,
 | 
				
			||||||
 | 
					            embedding.weight.data,
 | 
				
			||||||
 | 
					            True,
 | 
				
			||||||
 | 
					            embedding.weight.device,
 | 
				
			||||||
 | 
					            embedding.weight.dtype,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class LowBitEmbedding(torch.nn.Embedding):
 | 
					class LowBitEmbedding(torch.nn.Embedding):
 | 
				
			||||||
    def __init__(self,
 | 
					    def __init__(self,
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue