add disk embedding (#11543)
This commit is contained in:
parent
76a5802acf
commit
7dc6756d86
1 changed files with 45 additions and 2 deletions
|
|
@ -15,6 +15,7 @@
|
||||||
#
|
#
|
||||||
|
|
||||||
|
|
||||||
|
import numpy
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
|
|
@ -68,14 +69,56 @@ class LLMEmbedding(torch.nn.Embedding):
|
||||||
_freeze: bool = False,
|
_freeze: bool = False,
|
||||||
device=None, dtype=None) -> None:
|
device=None, dtype=None) -> None:
|
||||||
super().__init__(num_embeddings, embedding_dim, padding_idx,
|
super().__init__(num_embeddings, embedding_dim, padding_idx,
|
||||||
max_norm, norm_type, scale_grad_by_freq, sparse,
|
max_norm, norm_type, scale_grad_by_freq,
|
||||||
_weight, device, dtype)
|
sparse, _weight, _freeze, device, dtype)
|
||||||
self.weight = CPUPinnedParam(self.weight.data, requires_grad=not _freeze)
|
self.weight = CPUPinnedParam(self.weight.data, requires_grad=not _freeze)
|
||||||
|
|
||||||
def forward(self, x: Tensor):
|
def forward(self, x: Tensor):
|
||||||
return super().forward(x.to('cpu')).to(x.device)
|
return super().forward(x.to('cpu')).to(x.device)
|
||||||
|
|
||||||
|
|
||||||
|
class DiskEmbedding(torch.nn.Embedding):
|
||||||
|
def __init__(self,
|
||||||
|
num_embeddings: int,
|
||||||
|
embedding_dim: int,
|
||||||
|
padding_idx: Optional[int] = None,
|
||||||
|
max_norm: Optional[float] = None,
|
||||||
|
norm_type: float = 2.,
|
||||||
|
scale_grad_by_freq: bool = False,
|
||||||
|
sparse: bool = False,
|
||||||
|
_weight: Optional[Tensor] = None,
|
||||||
|
_freeze: bool = False,
|
||||||
|
device=None, dtype=None) -> None:
|
||||||
|
super().__init__(num_embeddings, embedding_dim, padding_idx,
|
||||||
|
max_norm, norm_type, scale_grad_by_freq,
|
||||||
|
sparse, _weight, _freeze, device, dtype)
|
||||||
|
self.filename = "embeddings.bin"
|
||||||
|
self.weight.data.flatten().half().numpy().tofile(self.filename)
|
||||||
|
dummy_weight = torch.empty(0, 0, dtype=self.weight.dtype, device=self.weight.device)
|
||||||
|
self.weight = torch.nn.Parameter(dummy_weight, requires_grad=False)
|
||||||
|
|
||||||
|
def forward(self, input_ids: Tensor):
|
||||||
|
ids = input_ids.cpu().flatten()
|
||||||
|
|
||||||
|
embeds = []
|
||||||
|
with open(self.filename, 'rb') as f:
|
||||||
|
for idx in ids:
|
||||||
|
f.seek(idx * self.embedding_dim * 2)
|
||||||
|
buffer = f.read(self.embedding_dim * 2)
|
||||||
|
embeds.append(torch.frombuffer(buffer, dtype=torch.half))
|
||||||
|
embeds = torch.stack(embeds).to(device=input_ids.device, dtype=self.weight.dtype)
|
||||||
|
return embeds.view(*input_ids.size(), self.embedding_dim)
|
||||||
|
|
||||||
|
def restore(self):
|
||||||
|
with open(self.filename, 'rb') as f:
|
||||||
|
buffer = f.read()
|
||||||
|
embeds = torch.frombuffer(buffer, dtype=torch.half).clone()
|
||||||
|
embeds = embeds.view(self.num_embeddings, self.embedding_dim).to(
|
||||||
|
device=self.weight.device, dtype=self.weight.dtype
|
||||||
|
)
|
||||||
|
self.weight = torch.nn.Parameter(embeds, requires_grad=False)
|
||||||
|
|
||||||
|
|
||||||
class LowBitEmbedding(torch.nn.Embedding):
|
class LowBitEmbedding(torch.nn.Embedding):
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
num_embeddings: int,
|
num_embeddings: int,
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue