diff --git a/python/llm/src/ipex_llm/transformers/embedding.py b/python/llm/src/ipex_llm/transformers/embedding.py index 2a8a23fb..836c1eb7 100644 --- a/python/llm/src/ipex_llm/transformers/embedding.py +++ b/python/llm/src/ipex_llm/transformers/embedding.py @@ -88,10 +88,11 @@ class DiskEmbedding(torch.nn.Embedding): sparse: bool = False, _weight: Optional[Tensor] = None, _freeze: bool = False, - device=None, dtype=None) -> None: + device=None, + dtype=None): super().__init__(num_embeddings, embedding_dim, padding_idx, max_norm, norm_type, scale_grad_by_freq, - sparse, _weight, _freeze, device, dtype) + sparse, _weight, True, device, dtype) self.filename = "embeddings.bin" self.weight.data.flatten().half().numpy().tofile(self.filename) dummy_weight = torch.empty(0, 0, dtype=self.weight.dtype, device=self.weight.device) @@ -118,6 +119,22 @@ class DiskEmbedding(torch.nn.Embedding): ) self.weight = torch.nn.Parameter(embeds, requires_grad=False) + @classmethod + def from_embedding(cls, embedding: torch.nn.Embedding): + return cls( + embedding.num_embeddings, + embedding.embedding_dim, + embedding.padding_idx, + embedding.max_norm, + embedding.norm_type, + embedding.scale_grad_by_freq, + embedding.sparse, + embedding.weight.data, + True, + embedding.weight.device, + embedding.weight.dtype, + ) + class LowBitEmbedding(torch.nn.Embedding): def __init__(self,