From d020ad63974e9d6b34a480a9106fa17b970c1e6e Mon Sep 17 00:00:00 2001 From: Yishuo Wang Date: Fri, 19 Jul 2024 10:34:53 +0800 Subject: [PATCH] add save_low_bit support for DiskEmbedding (#11621) --- .../llm/src/ipex_llm/transformers/convert.py | 8 +--- .../src/ipex_llm/transformers/embedding.py | 44 ++++++++++++++----- python/llm/src/ipex_llm/transformers/model.py | 25 +++++++++-- 3 files changed, 57 insertions(+), 20 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/convert.py b/python/llm/src/ipex_llm/transformers/convert.py index 9fcb49db..cc820fa1 100644 --- a/python/llm/src/ipex_llm/transformers/convert.py +++ b/python/llm/src/ipex_llm/transformers/convert.py @@ -310,7 +310,6 @@ def use_scale_search(model_config, qtype): def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None, convert_shape_only=False, cpu_embedding=False, - disk_embedding=False, prefix_name='', imatrix_data=None, embedding_qtype=None, model_config=None, torch_dtype=torch.float32, @@ -472,8 +471,6 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None, # skip user-defined Embedding layer elif cpu_embedding and type(module) == nn.Embedding: model._modules[name] = CPUEmbedding.from_embedding(module) - elif disk_embedding and type(module) == nn.Embedding: - model._modules[name] = DiskEmbedding.from_embedding(module) elif embedding_qtype is not None and type(module) == nn.Embedding: model._modules[name] = LowBitEmbedding.from_embedding(module, convert_shape_only, @@ -486,7 +483,6 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None, modules_to_not_convert, convert_shape_only, cpu_embedding, - disk_embedding, prefix_name=prefix_name + '.' + name if prefix_name != '' else name, imatrix_data=imatrix_data, embedding_qtype=embedding_qtype, @@ -746,7 +742,7 @@ def _optimize_pre(model, qtype=None): def ggml_convert_low_bit(model, qtype, optimize_model=True, convert_shape_only=False, device="cpu", modules_to_not_convert=None, - cpu_embedding=False, disk_embedding=False, + cpu_embedding=False, lightweight_bmm=False, torch_dtype="auto", imatrix_data=None, embedding_qtype=None, @@ -788,7 +784,7 @@ def ggml_convert_low_bit(model, qtype, optimize_model=True, # mixed quantization needs model_config to choose custom quantization strategy model, has_been_replaced = _replace_with_low_bit_linear( model, qtype, modules_to_not_convert, - convert_shape_only, cpu_embedding, disk_embedding, + convert_shape_only, cpu_embedding, imatrix_data=imatrix_data, embedding_qtype=embedding_qtype, model_config=model_config, diff --git a/python/llm/src/ipex_llm/transformers/embedding.py b/python/llm/src/ipex_llm/transformers/embedding.py index 773afb69..20eecb76 100644 --- a/python/llm/src/ipex_llm/transformers/embedding.py +++ b/python/llm/src/ipex_llm/transformers/embedding.py @@ -110,7 +110,7 @@ class DiskEmbedding(torch.nn.Embedding): max_norm, norm_type, scale_grad_by_freq, sparse, _weight, True, device, dtype) self.filename = "embeddings.bin" - self.weight.data.flatten().half().numpy().tofile(self.filename) + self.weight.data.flatten().to(device='cpu', dtype=torch.half).numpy().tofile(self.filename) dummy_weight = torch.empty(0, 0, dtype=self.weight.dtype, device=self.weight.device) self.weight = torch.nn.Parameter(dummy_weight, requires_grad=False) @@ -126,15 +126,6 @@ class DiskEmbedding(torch.nn.Embedding): embeds = torch.stack(embeds).to(device=input_ids.device, dtype=self.weight.dtype) return embeds.view(*input_ids.size(), self.embedding_dim) - def restore(self): - with open(self.filename, 'rb') as f: - buffer = f.read() - embeds = torch.frombuffer(buffer, dtype=torch.half).clone() - embeds = embeds.view(self.num_embeddings, self.embedding_dim).to( - device=self.weight.device, dtype=self.weight.dtype - ) - self.weight = torch.nn.Parameter(embeds, requires_grad=False) - @classmethod def from_embedding(cls, embedding: torch.nn.Embedding): return cls( @@ -151,6 +142,39 @@ class DiskEmbedding(torch.nn.Embedding): embedding.weight.dtype, ) + def to_embedding(self): + with open(self.filename, 'rb') as f: + buffer = f.read() + embeds = torch.frombuffer(buffer, dtype=torch.half).clone() + embeds = embeds.view(self.num_embeddings, self.embedding_dim).to( + device=self.weight.device, dtype=self.weight.dtype + ) + return torch.nn.Embedding( + self.num_embeddings, + self.embedding_dim, + self.padding_idx, + self.max_norm, + self.norm_type, + self.scale_grad_by_freq, + self.sparse, + embeds, + True, + embeds.device, + embeds.dtype, + ) + + @staticmethod + def replace_normal_embedding(m: torch.nn.Module): + for name, module in m.named_children(): + if type(module) == torch.nn.Embedding: + m._modules[name] = DiskEmbedding.from_embedding(module) + + @staticmethod + def restore_normal_embedding(m: torch.nn.Module): + for name, module in m.named_children(): + if type(module) == DiskEmbedding: + m._modules[name] = module.to_embedding() + class LowBitEmbedding(torch.nn.Embedding): def __init__(self, diff --git a/python/llm/src/ipex_llm/transformers/model.py b/python/llm/src/ipex_llm/transformers/model.py index 1d6af54d..e40c56d6 100644 --- a/python/llm/src/ipex_llm/transformers/model.py +++ b/python/llm/src/ipex_llm/transformers/model.py @@ -71,7 +71,15 @@ def save_low_bit(self, *args, **kwargs): architectures = getattr(self.config, "architectures", None) model_type = getattr(self.config, "model_type", None) - self.save_pretrained(*args, **kwargs) + disk_embedding = getattr(self.config, "bigdl_disk_embedding", False) + + if disk_embedding: + from ipex_llm.transformers.embedding import DiskEmbedding + self.apply(DiskEmbedding.restore_normal_embedding) + self.save_pretrained(*args, **kwargs) + self.apply(DiskEmbedding.replace_normal_embedding) + else: + self.save_pretrained(*args, **kwargs) if architectures: self.config.update({"architectures": architectures}) @@ -511,14 +519,19 @@ class _BaseAutoModelClass: model = ggml_convert_low_bit(model, qtype, optimize_model, modules_to_not_convert=modules_to_not_convert, cpu_embedding=cpu_embedding, - disk_embedding=disk_embedding, lightweight_bmm=lightweight_bmm, torch_dtype=kwargs.get("torch_dtype", 'auto'), imatrix_data=imatrix_data, embedding_qtype=embedding_qtype, enable_xetla=enable_xetla, mixed_precision=mixed_precision) - model.config.update({"bigdl_transformers_low_bit": q_k}) + + if disk_embedding: + from ipex_llm.transformers.embedding import DiskEmbedding + model.apply(DiskEmbedding.replace_normal_embedding) + + model.config.update({"bigdl_transformers_low_bit": q_k, + "bigdl_disk_embedding": disk_embedding}) # enable tie_word_embeddings for MPT # refer to https://huggingface.co/mosaicml/mpt-7b-chat/blob/main/modeling_mpt.py#L232 @@ -706,7 +719,6 @@ class _BaseAutoModelClass: model = ggml_convert_low_bit(model, qtype, optimize_model, device=quant_device, modules_to_not_convert=modules_to_not_convert, cpu_embedding=cpu_embedding, - disk_embedding=disk_embedding, lightweight_bmm=lightweight_bmm, embedding_qtype=embedding_qtype, torch_dtype=torch_dtype) @@ -749,6 +761,11 @@ class _BaseAutoModelClass: # make sure token embedding weights are still tied if needed model.tie_weights() + if disk_embedding: + from ipex_llm.transformers.embedding import DiskEmbedding + model.apply(DiskEmbedding.replace_normal_embedding) + model.config.update({"bigdl_disk_embedding": disk_embedding}) + # Set model in evaluation mode to deactivate DropOut modules by default model.eval()