add save_low_bit support for DiskEmbedding (#11621)
This commit is contained in:
		
							parent
							
								
									380717f50d
								
							
						
					
					
						commit
						d020ad6397
					
				
					 3 changed files with 57 additions and 20 deletions
				
			
		| 
						 | 
				
			
			@ -310,7 +310,6 @@ def use_scale_search(model_config, qtype):
 | 
			
		|||
def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
 | 
			
		||||
                                 convert_shape_only=False,
 | 
			
		||||
                                 cpu_embedding=False,
 | 
			
		||||
                                 disk_embedding=False,
 | 
			
		||||
                                 prefix_name='',
 | 
			
		||||
                                 imatrix_data=None, embedding_qtype=None,
 | 
			
		||||
                                 model_config=None, torch_dtype=torch.float32,
 | 
			
		||||
| 
						 | 
				
			
			@ -472,8 +471,6 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
 | 
			
		|||
        # skip user-defined Embedding layer
 | 
			
		||||
        elif cpu_embedding and type(module) == nn.Embedding:
 | 
			
		||||
            model._modules[name] = CPUEmbedding.from_embedding(module)
 | 
			
		||||
        elif disk_embedding and type(module) == nn.Embedding:
 | 
			
		||||
            model._modules[name] = DiskEmbedding.from_embedding(module)
 | 
			
		||||
        elif embedding_qtype is not None and type(module) == nn.Embedding:
 | 
			
		||||
            model._modules[name] = LowBitEmbedding.from_embedding(module,
 | 
			
		||||
                                                                  convert_shape_only,
 | 
			
		||||
| 
						 | 
				
			
			@ -486,7 +483,6 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
 | 
			
		|||
                modules_to_not_convert,
 | 
			
		||||
                convert_shape_only,
 | 
			
		||||
                cpu_embedding,
 | 
			
		||||
                disk_embedding,
 | 
			
		||||
                prefix_name=prefix_name + '.' + name if prefix_name != '' else name,
 | 
			
		||||
                imatrix_data=imatrix_data,
 | 
			
		||||
                embedding_qtype=embedding_qtype,
 | 
			
		||||
| 
						 | 
				
			
			@ -746,7 +742,7 @@ def _optimize_pre(model, qtype=None):
 | 
			
		|||
def ggml_convert_low_bit(model, qtype, optimize_model=True,
 | 
			
		||||
                         convert_shape_only=False, device="cpu",
 | 
			
		||||
                         modules_to_not_convert=None,
 | 
			
		||||
                         cpu_embedding=False, disk_embedding=False,
 | 
			
		||||
                         cpu_embedding=False,
 | 
			
		||||
                         lightweight_bmm=False, torch_dtype="auto",
 | 
			
		||||
                         imatrix_data=None,
 | 
			
		||||
                         embedding_qtype=None,
 | 
			
		||||
| 
						 | 
				
			
			@ -788,7 +784,7 @@ def ggml_convert_low_bit(model, qtype, optimize_model=True,
 | 
			
		|||
    # mixed quantization needs model_config to choose custom quantization strategy
 | 
			
		||||
    model, has_been_replaced = _replace_with_low_bit_linear(
 | 
			
		||||
        model, qtype, modules_to_not_convert,
 | 
			
		||||
        convert_shape_only, cpu_embedding, disk_embedding,
 | 
			
		||||
        convert_shape_only, cpu_embedding,
 | 
			
		||||
        imatrix_data=imatrix_data,
 | 
			
		||||
        embedding_qtype=embedding_qtype,
 | 
			
		||||
        model_config=model_config,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -110,7 +110,7 @@ class DiskEmbedding(torch.nn.Embedding):
 | 
			
		|||
                         max_norm, norm_type, scale_grad_by_freq,
 | 
			
		||||
                         sparse, _weight, True, device, dtype)
 | 
			
		||||
        self.filename = "embeddings.bin"
 | 
			
		||||
        self.weight.data.flatten().half().numpy().tofile(self.filename)
 | 
			
		||||
        self.weight.data.flatten().to(device='cpu', dtype=torch.half).numpy().tofile(self.filename)
 | 
			
		||||
        dummy_weight = torch.empty(0, 0, dtype=self.weight.dtype, device=self.weight.device)
 | 
			
		||||
        self.weight = torch.nn.Parameter(dummy_weight, requires_grad=False)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -126,15 +126,6 @@ class DiskEmbedding(torch.nn.Embedding):
 | 
			
		|||
        embeds = torch.stack(embeds).to(device=input_ids.device, dtype=self.weight.dtype)
 | 
			
		||||
        return embeds.view(*input_ids.size(), self.embedding_dim)
 | 
			
		||||
 | 
			
		||||
    def restore(self):
 | 
			
		||||
        with open(self.filename, 'rb') as f:
 | 
			
		||||
            buffer = f.read()
 | 
			
		||||
        embeds = torch.frombuffer(buffer, dtype=torch.half).clone()
 | 
			
		||||
        embeds = embeds.view(self.num_embeddings, self.embedding_dim).to(
 | 
			
		||||
            device=self.weight.device, dtype=self.weight.dtype
 | 
			
		||||
        )
 | 
			
		||||
        self.weight = torch.nn.Parameter(embeds, requires_grad=False)
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def from_embedding(cls, embedding: torch.nn.Embedding):
 | 
			
		||||
        return cls(
 | 
			
		||||
| 
						 | 
				
			
			@ -151,6 +142,39 @@ class DiskEmbedding(torch.nn.Embedding):
 | 
			
		|||
            embedding.weight.dtype,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def to_embedding(self):
 | 
			
		||||
        with open(self.filename, 'rb') as f:
 | 
			
		||||
            buffer = f.read()
 | 
			
		||||
        embeds = torch.frombuffer(buffer, dtype=torch.half).clone()
 | 
			
		||||
        embeds = embeds.view(self.num_embeddings, self.embedding_dim).to(
 | 
			
		||||
            device=self.weight.device, dtype=self.weight.dtype
 | 
			
		||||
        )
 | 
			
		||||
        return torch.nn.Embedding(
 | 
			
		||||
            self.num_embeddings,
 | 
			
		||||
            self.embedding_dim,
 | 
			
		||||
            self.padding_idx,
 | 
			
		||||
            self.max_norm,
 | 
			
		||||
            self.norm_type,
 | 
			
		||||
            self.scale_grad_by_freq,
 | 
			
		||||
            self.sparse,
 | 
			
		||||
            embeds,
 | 
			
		||||
            True,
 | 
			
		||||
            embeds.device,
 | 
			
		||||
            embeds.dtype,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def replace_normal_embedding(m: torch.nn.Module):
 | 
			
		||||
        for name, module in m.named_children():
 | 
			
		||||
            if type(module) == torch.nn.Embedding:
 | 
			
		||||
                m._modules[name] = DiskEmbedding.from_embedding(module)
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def restore_normal_embedding(m: torch.nn.Module):
 | 
			
		||||
        for name, module in m.named_children():
 | 
			
		||||
            if type(module) == DiskEmbedding:
 | 
			
		||||
                m._modules[name] = module.to_embedding()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class LowBitEmbedding(torch.nn.Embedding):
 | 
			
		||||
    def __init__(self,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -71,6 +71,14 @@ def save_low_bit(self, *args, **kwargs):
 | 
			
		|||
 | 
			
		||||
    architectures = getattr(self.config, "architectures", None)
 | 
			
		||||
    model_type = getattr(self.config, "model_type", None)
 | 
			
		||||
    disk_embedding = getattr(self.config, "bigdl_disk_embedding", False)
 | 
			
		||||
 | 
			
		||||
    if disk_embedding:
 | 
			
		||||
        from ipex_llm.transformers.embedding import DiskEmbedding
 | 
			
		||||
        self.apply(DiskEmbedding.restore_normal_embedding)
 | 
			
		||||
        self.save_pretrained(*args, **kwargs)
 | 
			
		||||
        self.apply(DiskEmbedding.replace_normal_embedding)
 | 
			
		||||
    else:
 | 
			
		||||
        self.save_pretrained(*args, **kwargs)
 | 
			
		||||
 | 
			
		||||
    if architectures:
 | 
			
		||||
| 
						 | 
				
			
			@ -511,14 +519,19 @@ class _BaseAutoModelClass:
 | 
			
		|||
        model = ggml_convert_low_bit(model, qtype, optimize_model,
 | 
			
		||||
                                     modules_to_not_convert=modules_to_not_convert,
 | 
			
		||||
                                     cpu_embedding=cpu_embedding,
 | 
			
		||||
                                     disk_embedding=disk_embedding,
 | 
			
		||||
                                     lightweight_bmm=lightweight_bmm,
 | 
			
		||||
                                     torch_dtype=kwargs.get("torch_dtype", 'auto'),
 | 
			
		||||
                                     imatrix_data=imatrix_data,
 | 
			
		||||
                                     embedding_qtype=embedding_qtype,
 | 
			
		||||
                                     enable_xetla=enable_xetla,
 | 
			
		||||
                                     mixed_precision=mixed_precision)
 | 
			
		||||
        model.config.update({"bigdl_transformers_low_bit": q_k})
 | 
			
		||||
 | 
			
		||||
        if disk_embedding:
 | 
			
		||||
            from ipex_llm.transformers.embedding import DiskEmbedding
 | 
			
		||||
            model.apply(DiskEmbedding.replace_normal_embedding)
 | 
			
		||||
 | 
			
		||||
        model.config.update({"bigdl_transformers_low_bit": q_k,
 | 
			
		||||
                             "bigdl_disk_embedding": disk_embedding})
 | 
			
		||||
 | 
			
		||||
        # enable tie_word_embeddings for MPT
 | 
			
		||||
        # refer to https://huggingface.co/mosaicml/mpt-7b-chat/blob/main/modeling_mpt.py#L232
 | 
			
		||||
| 
						 | 
				
			
			@ -706,7 +719,6 @@ class _BaseAutoModelClass:
 | 
			
		|||
        model = ggml_convert_low_bit(model, qtype, optimize_model, device=quant_device,
 | 
			
		||||
                                     modules_to_not_convert=modules_to_not_convert,
 | 
			
		||||
                                     cpu_embedding=cpu_embedding,
 | 
			
		||||
                                     disk_embedding=disk_embedding,
 | 
			
		||||
                                     lightweight_bmm=lightweight_bmm,
 | 
			
		||||
                                     embedding_qtype=embedding_qtype, torch_dtype=torch_dtype)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -749,6 +761,11 @@ class _BaseAutoModelClass:
 | 
			
		|||
        # make sure token embedding weights are still tied if needed
 | 
			
		||||
        model.tie_weights()
 | 
			
		||||
 | 
			
		||||
        if disk_embedding:
 | 
			
		||||
            from ipex_llm.transformers.embedding import DiskEmbedding
 | 
			
		||||
            model.apply(DiskEmbedding.replace_normal_embedding)
 | 
			
		||||
            model.config.update({"bigdl_disk_embedding": disk_embedding})
 | 
			
		||||
 | 
			
		||||
        # Set model in evaluation mode to deactivate DropOut modules by default
 | 
			
		||||
        model.eval()
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue