add disk embedding (#11543)
This commit is contained in:
		
							parent
							
								
									76a5802acf
								
							
						
					
					
						commit
						7dc6756d86
					
				
					 1 changed files with 45 additions and 2 deletions
				
			
		| 
						 | 
					@ -15,6 +15,7 @@
 | 
				
			||||||
#
 | 
					#
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import numpy
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
from torch import Tensor
 | 
					from torch import Tensor
 | 
				
			||||||
from torch.nn import functional as F
 | 
					from torch.nn import functional as F
 | 
				
			||||||
| 
						 | 
					@ -68,14 +69,56 @@ class LLMEmbedding(torch.nn.Embedding):
 | 
				
			||||||
                 _freeze: bool = False,
 | 
					                 _freeze: bool = False,
 | 
				
			||||||
                 device=None, dtype=None) -> None:
 | 
					                 device=None, dtype=None) -> None:
 | 
				
			||||||
        super().__init__(num_embeddings, embedding_dim, padding_idx,
 | 
					        super().__init__(num_embeddings, embedding_dim, padding_idx,
 | 
				
			||||||
                         max_norm, norm_type, scale_grad_by_freq, sparse,
 | 
					                         max_norm, norm_type, scale_grad_by_freq,
 | 
				
			||||||
                         _weight, device, dtype)
 | 
					                         sparse, _weight, _freeze, device, dtype)
 | 
				
			||||||
        self.weight = CPUPinnedParam(self.weight.data, requires_grad=not _freeze)
 | 
					        self.weight = CPUPinnedParam(self.weight.data, requires_grad=not _freeze)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def forward(self, x: Tensor):
 | 
					    def forward(self, x: Tensor):
 | 
				
			||||||
        return super().forward(x.to('cpu')).to(x.device)
 | 
					        return super().forward(x.to('cpu')).to(x.device)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class DiskEmbedding(torch.nn.Embedding):
 | 
				
			||||||
 | 
					    def __init__(self,
 | 
				
			||||||
 | 
					                 num_embeddings: int,
 | 
				
			||||||
 | 
					                 embedding_dim: int,
 | 
				
			||||||
 | 
					                 padding_idx: Optional[int] = None,
 | 
				
			||||||
 | 
					                 max_norm: Optional[float] = None,
 | 
				
			||||||
 | 
					                 norm_type: float = 2.,
 | 
				
			||||||
 | 
					                 scale_grad_by_freq: bool = False,
 | 
				
			||||||
 | 
					                 sparse: bool = False,
 | 
				
			||||||
 | 
					                 _weight: Optional[Tensor] = None,
 | 
				
			||||||
 | 
					                 _freeze: bool = False,
 | 
				
			||||||
 | 
					                 device=None, dtype=None) -> None:
 | 
				
			||||||
 | 
					        super().__init__(num_embeddings, embedding_dim, padding_idx,
 | 
				
			||||||
 | 
					                         max_norm, norm_type, scale_grad_by_freq,
 | 
				
			||||||
 | 
					                         sparse, _weight, _freeze, device, dtype)
 | 
				
			||||||
 | 
					        self.filename = "embeddings.bin"
 | 
				
			||||||
 | 
					        self.weight.data.flatten().half().numpy().tofile(self.filename)
 | 
				
			||||||
 | 
					        dummy_weight = torch.empty(0, 0, dtype=self.weight.dtype, device=self.weight.device)
 | 
				
			||||||
 | 
					        self.weight = torch.nn.Parameter(dummy_weight, requires_grad=False)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def forward(self, input_ids: Tensor):
 | 
				
			||||||
 | 
					        ids = input_ids.cpu().flatten()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        embeds = []
 | 
				
			||||||
 | 
					        with open(self.filename, 'rb') as f:
 | 
				
			||||||
 | 
					            for idx in ids:
 | 
				
			||||||
 | 
					                f.seek(idx * self.embedding_dim * 2)
 | 
				
			||||||
 | 
					                buffer = f.read(self.embedding_dim * 2)
 | 
				
			||||||
 | 
					                embeds.append(torch.frombuffer(buffer, dtype=torch.half))
 | 
				
			||||||
 | 
					        embeds = torch.stack(embeds).to(device=input_ids.device, dtype=self.weight.dtype)
 | 
				
			||||||
 | 
					        return embeds.view(*input_ids.size(), self.embedding_dim)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def restore(self):
 | 
				
			||||||
 | 
					        with open(self.filename, 'rb') as f:
 | 
				
			||||||
 | 
					            buffer = f.read()
 | 
				
			||||||
 | 
					        embeds = torch.frombuffer(buffer, dtype=torch.half).clone()
 | 
				
			||||||
 | 
					        embeds = embeds.view(self.num_embeddings, self.embedding_dim).to(
 | 
				
			||||||
 | 
					            device=self.weight.device, dtype=self.weight.dtype
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        self.weight = torch.nn.Parameter(embeds, requires_grad=False)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class LowBitEmbedding(torch.nn.Embedding):
 | 
					class LowBitEmbedding(torch.nn.Embedding):
 | 
				
			||||||
    def __init__(self,
 | 
					    def __init__(self,
 | 
				
			||||||
                 num_embeddings: int,
 | 
					                 num_embeddings: int,
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue