Add disk_embedding parameter to support put Embedding layer on CPU (#11617)
This commit is contained in:
		
							parent
							
								
									2478e2c14b
								
							
						
					
					
						commit
						0209427cf4
					
				
					 4 changed files with 86 additions and 66 deletions
				
			
		| 
						 | 
				
			
			@ -309,7 +309,9 @@ def use_scale_search(model_config, qtype):
 | 
			
		|||
 | 
			
		||||
def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
 | 
			
		||||
                                 convert_shape_only=False,
 | 
			
		||||
                                 cpu_embedding=False, prefix_name='',
 | 
			
		||||
                                 cpu_embedding=False,
 | 
			
		||||
                                 disk_embedding=False,
 | 
			
		||||
                                 prefix_name='',
 | 
			
		||||
                                 imatrix_data=None, embedding_qtype=None,
 | 
			
		||||
                                 model_config=None, torch_dtype=torch.float32,
 | 
			
		||||
                                 enable_xetla=False,
 | 
			
		||||
| 
						 | 
				
			
			@ -319,7 +321,7 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
 | 
			
		|||
                                 ):
 | 
			
		||||
    from ipex_llm.transformers.low_bit_linear import LowBitLinear, FP4Params, \
 | 
			
		||||
        FP16Linear, BF16Linear
 | 
			
		||||
    from ipex_llm.transformers.embedding import LLMEmbedding, LowBitEmbedding
 | 
			
		||||
    from ipex_llm.transformers.embedding import CPUEmbedding, DiskEmbedding, LowBitEmbedding
 | 
			
		||||
    has_been_replaced = False
 | 
			
		||||
 | 
			
		||||
    for name, module in model.named_children():
 | 
			
		||||
| 
						 | 
				
			
			@ -467,48 +469,15 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
 | 
			
		|||
                    model._modules[name].requires_grad_(False)
 | 
			
		||||
 | 
			
		||||
                    module.weight = None
 | 
			
		||||
        # skip user-defined Embedding layer
 | 
			
		||||
        elif cpu_embedding and type(module) == nn.Embedding:
 | 
			
		||||
            # skip user-defined Embedding layer
 | 
			
		||||
            model._modules[name] = LLMEmbedding(
 | 
			
		||||
                num_embeddings=module.num_embeddings,
 | 
			
		||||
                embedding_dim=module.embedding_dim,
 | 
			
		||||
                padding_idx=module.padding_idx,
 | 
			
		||||
                max_norm=module.max_norm,
 | 
			
		||||
                norm_type=module.norm_type,
 | 
			
		||||
                scale_grad_by_freq=module.scale_grad_by_freq,
 | 
			
		||||
                sparse=module.sparse,
 | 
			
		||||
                _weight=module.weight.data,
 | 
			
		||||
            )
 | 
			
		||||
        elif type(module) == nn.Embedding and embedding_qtype is not None:
 | 
			
		||||
            if torch_dtype == "auto":
 | 
			
		||||
                torch_dtype = torch.float32
 | 
			
		||||
            q_embedding = LowBitEmbedding(
 | 
			
		||||
                num_embeddings=module.num_embeddings,
 | 
			
		||||
                embedding_dim=module.embedding_dim,
 | 
			
		||||
                padding_idx=module.padding_idx,
 | 
			
		||||
                max_norm=module.max_norm,
 | 
			
		||||
                norm_type=module.norm_type,
 | 
			
		||||
                scale_grad_by_freq=module.scale_grad_by_freq,
 | 
			
		||||
                sparse=module.sparse,
 | 
			
		||||
                _weight=module.weight.data,
 | 
			
		||||
                qtype=embedding_qtype,
 | 
			
		||||
                torch_dtype=torch_dtype
 | 
			
		||||
            )
 | 
			
		||||
            device = module.weight.data.device
 | 
			
		||||
            # Copy the weights
 | 
			
		||||
            paramsLowBit = FP4Params(data=module.weight.data,
 | 
			
		||||
                                     requires_grad=False,
 | 
			
		||||
                                     quantized=False,
 | 
			
		||||
                                     _shape=None,
 | 
			
		||||
                                     convert_shape_only=convert_shape_only,
 | 
			
		||||
                                     qtype=embedding_qtype,
 | 
			
		||||
                                     in_features=module.embedding_dim).to(device)
 | 
			
		||||
            q_embedding._parameters['weight'] = paramsLowBit
 | 
			
		||||
            model._modules[name] = q_embedding
 | 
			
		||||
            # Force requires grad to False to avoid unexpected errors
 | 
			
		||||
            model._modules[name].requires_grad_(False)
 | 
			
		||||
            module.weight = None
 | 
			
		||||
 | 
			
		||||
            model._modules[name] = CPUEmbedding.from_embedding(module)
 | 
			
		||||
        elif disk_embedding and type(module) == nn.Embedding:
 | 
			
		||||
            model._modules[name] = DiskEmbedding.from_embedding(module)
 | 
			
		||||
        elif embedding_qtype is not None and type(module) == nn.Embedding:
 | 
			
		||||
            model._modules[name] = LowBitEmbedding.from_embedding(module,
 | 
			
		||||
                                                                  convert_shape_only,
 | 
			
		||||
                                                                  embedding_qtype)
 | 
			
		||||
        # Remove the last key for recursion
 | 
			
		||||
        if len(list(module.children())) > 0:
 | 
			
		||||
            _, _flag = _replace_with_low_bit_linear(
 | 
			
		||||
| 
						 | 
				
			
			@ -517,6 +486,7 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
 | 
			
		|||
                modules_to_not_convert,
 | 
			
		||||
                convert_shape_only,
 | 
			
		||||
                cpu_embedding,
 | 
			
		||||
                disk_embedding,
 | 
			
		||||
                prefix_name=prefix_name + '.' + name if prefix_name != '' else name,
 | 
			
		||||
                imatrix_data=imatrix_data,
 | 
			
		||||
                embedding_qtype=embedding_qtype,
 | 
			
		||||
| 
						 | 
				
			
			@ -775,7 +745,8 @@ def _optimize_pre(model, qtype=None):
 | 
			
		|||
 | 
			
		||||
def ggml_convert_low_bit(model, qtype, optimize_model=True,
 | 
			
		||||
                         convert_shape_only=False, device="cpu",
 | 
			
		||||
                         modules_to_not_convert=None, cpu_embedding=False,
 | 
			
		||||
                         modules_to_not_convert=None,
 | 
			
		||||
                         cpu_embedding=False, disk_embedding=False,
 | 
			
		||||
                         lightweight_bmm=False, torch_dtype="auto",
 | 
			
		||||
                         imatrix_data=None,
 | 
			
		||||
                         embedding_qtype=None,
 | 
			
		||||
| 
						 | 
				
			
			@ -817,7 +788,7 @@ def ggml_convert_low_bit(model, qtype, optimize_model=True,
 | 
			
		|||
    # mixed quantization needs model_config to choose custom quantization strategy
 | 
			
		||||
    model, has_been_replaced = _replace_with_low_bit_linear(
 | 
			
		||||
        model, qtype, modules_to_not_convert,
 | 
			
		||||
        convert_shape_only, cpu_embedding,
 | 
			
		||||
        convert_shape_only, cpu_embedding, disk_embedding,
 | 
			
		||||
        imatrix_data=imatrix_data,
 | 
			
		||||
        embedding_qtype=embedding_qtype,
 | 
			
		||||
        model_config=model_config,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -18,7 +18,6 @@
 | 
			
		|||
import numpy
 | 
			
		||||
import torch
 | 
			
		||||
from torch import Tensor
 | 
			
		||||
from torch.nn import functional as F
 | 
			
		||||
from torch.nn import Parameter
 | 
			
		||||
from typing import Optional
 | 
			
		||||
from ipex_llm.transformers.low_bit_linear import FP4Params
 | 
			
		||||
| 
						 | 
				
			
			@ -56,7 +55,7 @@ class CPUPinnedParam(Parameter):
 | 
			
		|||
        return super().to(*args, **kwargs)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class LLMEmbedding(torch.nn.Embedding):
 | 
			
		||||
class CPUEmbedding(torch.nn.Embedding):
 | 
			
		||||
    def __init__(self,
 | 
			
		||||
                 num_embeddings: int,
 | 
			
		||||
                 embedding_dim: int,
 | 
			
		||||
| 
						 | 
				
			
			@ -67,15 +66,32 @@ class LLMEmbedding(torch.nn.Embedding):
 | 
			
		|||
                 sparse: bool = False,
 | 
			
		||||
                 _weight: Optional[Tensor] = None,
 | 
			
		||||
                 _freeze: bool = False,
 | 
			
		||||
                 device=None, dtype=None) -> None:
 | 
			
		||||
                 device=None,
 | 
			
		||||
                 dtype=None) -> None:
 | 
			
		||||
        super().__init__(num_embeddings, embedding_dim, padding_idx,
 | 
			
		||||
                         max_norm, norm_type, scale_grad_by_freq,
 | 
			
		||||
                         sparse, _weight, _freeze, device, dtype)
 | 
			
		||||
        self.weight = CPUPinnedParam(self.weight.data, requires_grad=not _freeze)
 | 
			
		||||
                         sparse, _weight, True, device, dtype)
 | 
			
		||||
        self.weight = CPUPinnedParam(self.weight.data, requires_grad=False)
 | 
			
		||||
 | 
			
		||||
    def forward(self, x: Tensor):
 | 
			
		||||
        return super().forward(x.to('cpu')).to(x.device)
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def from_embedding(cls, embedding: torch.nn.Embedding):
 | 
			
		||||
        return cls(
 | 
			
		||||
            embedding.num_embeddings,
 | 
			
		||||
            embedding.embedding_dim,
 | 
			
		||||
            embedding.padding_idx,
 | 
			
		||||
            embedding.max_norm,
 | 
			
		||||
            embedding.norm_type,
 | 
			
		||||
            embedding.scale_grad_by_freq,
 | 
			
		||||
            embedding.sparse,
 | 
			
		||||
            embedding.weight.data,
 | 
			
		||||
            True,
 | 
			
		||||
            embedding.weight.device,
 | 
			
		||||
            embedding.weight.dtype,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class DiskEmbedding(torch.nn.Embedding):
 | 
			
		||||
    def __init__(self,
 | 
			
		||||
| 
						 | 
				
			
			@ -89,7 +105,7 @@ class DiskEmbedding(torch.nn.Embedding):
 | 
			
		|||
                 _weight: Optional[Tensor] = None,
 | 
			
		||||
                 _freeze: bool = False,
 | 
			
		||||
                 device=None,
 | 
			
		||||
                 dtype=None):
 | 
			
		||||
                 dtype=None) -> None:
 | 
			
		||||
        super().__init__(num_embeddings, embedding_dim, padding_idx,
 | 
			
		||||
                         max_norm, norm_type, scale_grad_by_freq,
 | 
			
		||||
                         sparse, _weight, True, device, dtype)
 | 
			
		||||
| 
						 | 
				
			
			@ -147,30 +163,55 @@ class LowBitEmbedding(torch.nn.Embedding):
 | 
			
		|||
                 sparse: bool = False,
 | 
			
		||||
                 _weight: Optional[Tensor] = None,
 | 
			
		||||
                 _freeze: bool = False,
 | 
			
		||||
                 device=None, dtype=None,
 | 
			
		||||
                 qtype=None,
 | 
			
		||||
                 torch_dtype=torch.float32) -> None:
 | 
			
		||||
                 device=None,
 | 
			
		||||
                 dtype=None,
 | 
			
		||||
                 convert_shape_only=None,
 | 
			
		||||
                 qtype=None) -> None:
 | 
			
		||||
        super().__init__(num_embeddings, embedding_dim, padding_idx,
 | 
			
		||||
                         max_norm, norm_type, scale_grad_by_freq, sparse,
 | 
			
		||||
                         _weight, device, dtype)
 | 
			
		||||
        self.weight = FP4Params(self.weight.data,
 | 
			
		||||
                                requires_grad=False,
 | 
			
		||||
                                quantized=False, _shape=None, qtype=qtype)
 | 
			
		||||
        self.qweight = FP4Params(self.weight.data,
 | 
			
		||||
                                 requires_grad=False,
 | 
			
		||||
                                 quantized=False,
 | 
			
		||||
                                 _shape=None,
 | 
			
		||||
                                 convert_shape_only=convert_shape_only,
 | 
			
		||||
                                 qtype=qtype,
 | 
			
		||||
                                 in_features=embedding_dim)
 | 
			
		||||
        # this dummy_weight is used to record model's dtype and device
 | 
			
		||||
        dummy_weight = torch.empty(0, 0, dtype=self.weight.dtype, device=self.weight.device)
 | 
			
		||||
        self.weight = torch.nn.Parameter(dummy_weight, requires_grad=False)
 | 
			
		||||
 | 
			
		||||
        self.embedding_dim = embedding_dim
 | 
			
		||||
        self.num_embeddings = num_embeddings
 | 
			
		||||
        self.torch_dtype = torch_dtype
 | 
			
		||||
 | 
			
		||||
    def forward(self, x: Tensor):
 | 
			
		||||
        invalidInputError(x.device.type == "xpu",
 | 
			
		||||
                          "`LowBitEmbedding` only supports GPU now.")
 | 
			
		||||
        try:
 | 
			
		||||
            import intel_extension_for_pytorch
 | 
			
		||||
            import xe_linear
 | 
			
		||||
        except ModuleNotFoundError:
 | 
			
		||||
            invalidInputError(False,
 | 
			
		||||
                              "Please `pip install bigdl_core_xe` first.")
 | 
			
		||||
                              "Please `pip install bigdl_core_xe_21` first.")
 | 
			
		||||
 | 
			
		||||
        result = xe_linear.dequantize_rows(x.contiguous(), self.weight.data,
 | 
			
		||||
                                           self.weight.qtype, self.embedding_dim,
 | 
			
		||||
        result = xe_linear.dequantize_rows(x.contiguous(), self.qweight.data,
 | 
			
		||||
                                           self.qweight.qtype, self.embedding_dim,
 | 
			
		||||
                                           self.num_embeddings)
 | 
			
		||||
        return result.to(self.torch_dtype)
 | 
			
		||||
        return result.to(self.weight.dtype)
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def from_embedding(cls, embedding: torch.nn.Embedding, convert_shape_only, qtype):
 | 
			
		||||
        return cls(
 | 
			
		||||
            embedding.num_embeddings,
 | 
			
		||||
            embedding.embedding_dim,
 | 
			
		||||
            embedding.padding_idx,
 | 
			
		||||
            embedding.max_norm,
 | 
			
		||||
            embedding.norm_type,
 | 
			
		||||
            embedding.scale_grad_by_freq,
 | 
			
		||||
            embedding.sparse,
 | 
			
		||||
            embedding.weight.data,
 | 
			
		||||
            True,
 | 
			
		||||
            embedding.weight.device,
 | 
			
		||||
            embedding.weight.dtype,
 | 
			
		||||
            convert_shape_only,
 | 
			
		||||
            qtype,
 | 
			
		||||
        )
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -483,7 +483,7 @@ class FP4Params(torch.nn.Parameter):
 | 
			
		|||
            return self.quantize(device.type)
 | 
			
		||||
        elif (device is not None and device.type == "xpu" and self.data.device.type == "cpu"):
 | 
			
		||||
            # enter xpu logic, compile linear_int4 extension at first time
 | 
			
		||||
            self.quantize(device)  # tensor is cpu now
 | 
			
		||||
            self.quantize("cpu")  # tensor is cpu now
 | 
			
		||||
            self.data = ggml_q_format_convet_cpu2xpu(self.data,
 | 
			
		||||
                                                     reduce(mul, self._shape, 1),
 | 
			
		||||
                                                     self.qtype)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -144,6 +144,8 @@ class _BaseAutoModelClass:
 | 
			
		|||
                            Default to be ``False``.
 | 
			
		||||
        :param cpu_embedding: Whether to replace the Embedding layer, may need to set it
 | 
			
		||||
            to ``True`` when running BigDL-LLM on GPU on Windows. Default to be ``False``.
 | 
			
		||||
        :param disk_embedding: Whether to put the Embedding layer on disk to save memory.
 | 
			
		||||
            Default to be ``False``.
 | 
			
		||||
        :param lightweight_bmm: Whether to replace the torch.bmm ops, may need to set it
 | 
			
		||||
            to ``True`` when running BigDL-LLM on GPU on Windows. Default to be ``False``.
 | 
			
		||||
        :param imatrix: str value, represent filename of importance matrix pretrained on
 | 
			
		||||
| 
						 | 
				
			
			@ -435,6 +437,7 @@ class _BaseAutoModelClass:
 | 
			
		|||
            warnings.warn("replace_embedding is deprecated and will be removed in a future version,"
 | 
			
		||||
                          " please use cpu_embedding instead.", FutureWarning)
 | 
			
		||||
            cpu_embedding = True
 | 
			
		||||
        disk_embedding = kwargs.pop("disk_embedding", False)
 | 
			
		||||
        lightweight_bmm = kwargs.pop("lightweight_bmm", False)
 | 
			
		||||
        quant_config = kwargs.pop("quantization_config", None)
 | 
			
		||||
        imatrix_data = kwargs.pop("imatrix_data", None)
 | 
			
		||||
| 
						 | 
				
			
			@ -507,7 +510,9 @@ class _BaseAutoModelClass:
 | 
			
		|||
        model = model.to("cpu")
 | 
			
		||||
        model = ggml_convert_low_bit(model, qtype, optimize_model,
 | 
			
		||||
                                     modules_to_not_convert=modules_to_not_convert,
 | 
			
		||||
                                     cpu_embedding=cpu_embedding, lightweight_bmm=lightweight_bmm,
 | 
			
		||||
                                     cpu_embedding=cpu_embedding,
 | 
			
		||||
                                     disk_embedding=disk_embedding,
 | 
			
		||||
                                     lightweight_bmm=lightweight_bmm,
 | 
			
		||||
                                     torch_dtype=kwargs.get("torch_dtype", 'auto'),
 | 
			
		||||
                                     imatrix_data=imatrix_data,
 | 
			
		||||
                                     embedding_qtype=embedding_qtype,
 | 
			
		||||
| 
						 | 
				
			
			@ -563,6 +568,7 @@ class _BaseAutoModelClass:
 | 
			
		|||
            warnings.warn("replace_embedding is deprecated and will be removed in a future version,"
 | 
			
		||||
                          " please use cpu_embedding instead.", FutureWarning)
 | 
			
		||||
            cpu_embedding = True
 | 
			
		||||
        disk_embedding = kwargs.pop("disk_embedding", False)
 | 
			
		||||
        lightweight_bmm = kwargs.pop("lightweight_bmm", False)
 | 
			
		||||
        # Autofactory
 | 
			
		||||
        trust_remote_code = kwargs.pop("trust_remote_code", None)
 | 
			
		||||
| 
						 | 
				
			
			@ -699,7 +705,9 @@ class _BaseAutoModelClass:
 | 
			
		|||
        quant_device = "meta" if bigdl_lcmu_enabled else "cpu"
 | 
			
		||||
        model = ggml_convert_low_bit(model, qtype, optimize_model, device=quant_device,
 | 
			
		||||
                                     modules_to_not_convert=modules_to_not_convert,
 | 
			
		||||
                                     cpu_embedding=cpu_embedding, lightweight_bmm=lightweight_bmm,
 | 
			
		||||
                                     cpu_embedding=cpu_embedding,
 | 
			
		||||
                                     disk_embedding=disk_embedding,
 | 
			
		||||
                                     lightweight_bmm=lightweight_bmm,
 | 
			
		||||
                                     embedding_qtype=embedding_qtype, torch_dtype=torch_dtype)
 | 
			
		||||
 | 
			
		||||
        if is_sharded:
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue