add disk embedding api (#11585)

This commit is contained in:
Yishuo Wang 2024-07-16 10:43:39 +08:00 committed by GitHub
parent 79c742dfd5
commit c279849d27
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -88,10 +88,11 @@ class DiskEmbedding(torch.nn.Embedding):
sparse: bool = False,
_weight: Optional[Tensor] = None,
_freeze: bool = False,
device=None, dtype=None) -> None:
device=None,
dtype=None):
super().__init__(num_embeddings, embedding_dim, padding_idx,
max_norm, norm_type, scale_grad_by_freq,
sparse, _weight, _freeze, device, dtype)
sparse, _weight, True, device, dtype)
self.filename = "embeddings.bin"
self.weight.data.flatten().half().numpy().tofile(self.filename)
dummy_weight = torch.empty(0, 0, dtype=self.weight.dtype, device=self.weight.device)
@ -118,6 +119,22 @@ class DiskEmbedding(torch.nn.Embedding):
)
self.weight = torch.nn.Parameter(embeds, requires_grad=False)
@classmethod
def from_embedding(cls, embedding: torch.nn.Embedding):
return cls(
embedding.num_embeddings,
embedding.embedding_dim,
embedding.padding_idx,
embedding.max_norm,
embedding.norm_type,
embedding.scale_grad_by_freq,
embedding.sparse,
embedding.weight.data,
True,
embedding.weight.device,
embedding.weight.dtype,
)
class LowBitEmbedding(torch.nn.Embedding):
def __init__(self,