add disk embedding (#11543)

This commit is contained in:
Yishuo Wang 2024-07-09 17:38:40 +08:00 committed by GitHub
parent 76a5802acf
commit 7dc6756d86
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -15,6 +15,7 @@
#
import numpy
import torch
from torch import Tensor
from torch.nn import functional as F
@ -68,14 +69,56 @@ class LLMEmbedding(torch.nn.Embedding):
_freeze: bool = False,
device=None, dtype=None) -> None:
super().__init__(num_embeddings, embedding_dim, padding_idx,
max_norm, norm_type, scale_grad_by_freq, sparse,
_weight, device, dtype)
max_norm, norm_type, scale_grad_by_freq,
sparse, _weight, _freeze, device, dtype)
self.weight = CPUPinnedParam(self.weight.data, requires_grad=not _freeze)
def forward(self, x: Tensor):
return super().forward(x.to('cpu')).to(x.device)
class DiskEmbedding(torch.nn.Embedding):
def __init__(self,
num_embeddings: int,
embedding_dim: int,
padding_idx: Optional[int] = None,
max_norm: Optional[float] = None,
norm_type: float = 2.,
scale_grad_by_freq: bool = False,
sparse: bool = False,
_weight: Optional[Tensor] = None,
_freeze: bool = False,
device=None, dtype=None) -> None:
super().__init__(num_embeddings, embedding_dim, padding_idx,
max_norm, norm_type, scale_grad_by_freq,
sparse, _weight, _freeze, device, dtype)
self.filename = "embeddings.bin"
self.weight.data.flatten().half().numpy().tofile(self.filename)
dummy_weight = torch.empty(0, 0, dtype=self.weight.dtype, device=self.weight.device)
self.weight = torch.nn.Parameter(dummy_weight, requires_grad=False)
def forward(self, input_ids: Tensor):
ids = input_ids.cpu().flatten()
embeds = []
with open(self.filename, 'rb') as f:
for idx in ids:
f.seek(idx * self.embedding_dim * 2)
buffer = f.read(self.embedding_dim * 2)
embeds.append(torch.frombuffer(buffer, dtype=torch.half))
embeds = torch.stack(embeds).to(device=input_ids.device, dtype=self.weight.dtype)
return embeds.view(*input_ids.size(), self.embedding_dim)
def restore(self):
with open(self.filename, 'rb') as f:
buffer = f.read()
embeds = torch.frombuffer(buffer, dtype=torch.half).clone()
embeds = embeds.view(self.num_embeddings, self.embedding_dim).to(
device=self.weight.device, dtype=self.weight.dtype
)
self.weight = torch.nn.Parameter(embeds, requires_grad=False)
class LowBitEmbedding(torch.nn.Embedding):
def __init__(self,
num_embeddings: int,