add disk embedding api (#11585)
This commit is contained in:
parent
79c742dfd5
commit
c279849d27
1 changed files with 19 additions and 2 deletions
|
|
@ -88,10 +88,11 @@ class DiskEmbedding(torch.nn.Embedding):
|
||||||
sparse: bool = False,
|
sparse: bool = False,
|
||||||
_weight: Optional[Tensor] = None,
|
_weight: Optional[Tensor] = None,
|
||||||
_freeze: bool = False,
|
_freeze: bool = False,
|
||||||
device=None, dtype=None) -> None:
|
device=None,
|
||||||
|
dtype=None):
|
||||||
super().__init__(num_embeddings, embedding_dim, padding_idx,
|
super().__init__(num_embeddings, embedding_dim, padding_idx,
|
||||||
max_norm, norm_type, scale_grad_by_freq,
|
max_norm, norm_type, scale_grad_by_freq,
|
||||||
sparse, _weight, _freeze, device, dtype)
|
sparse, _weight, True, device, dtype)
|
||||||
self.filename = "embeddings.bin"
|
self.filename = "embeddings.bin"
|
||||||
self.weight.data.flatten().half().numpy().tofile(self.filename)
|
self.weight.data.flatten().half().numpy().tofile(self.filename)
|
||||||
dummy_weight = torch.empty(0, 0, dtype=self.weight.dtype, device=self.weight.device)
|
dummy_weight = torch.empty(0, 0, dtype=self.weight.dtype, device=self.weight.device)
|
||||||
|
|
@ -118,6 +119,22 @@ class DiskEmbedding(torch.nn.Embedding):
|
||||||
)
|
)
|
||||||
self.weight = torch.nn.Parameter(embeds, requires_grad=False)
|
self.weight = torch.nn.Parameter(embeds, requires_grad=False)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_embedding(cls, embedding: torch.nn.Embedding):
|
||||||
|
return cls(
|
||||||
|
embedding.num_embeddings,
|
||||||
|
embedding.embedding_dim,
|
||||||
|
embedding.padding_idx,
|
||||||
|
embedding.max_norm,
|
||||||
|
embedding.norm_type,
|
||||||
|
embedding.scale_grad_by_freq,
|
||||||
|
embedding.sparse,
|
||||||
|
embedding.weight.data,
|
||||||
|
True,
|
||||||
|
embedding.weight.device,
|
||||||
|
embedding.weight.dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class LowBitEmbedding(torch.nn.Embedding):
|
class LowBitEmbedding(torch.nn.Embedding):
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue