diff --git a/python/llm/src/ipex_llm/transformers/convert.py b/python/llm/src/ipex_llm/transformers/convert.py index fbd66151..e7f0d4db 100644 --- a/python/llm/src/ipex_llm/transformers/convert.py +++ b/python/llm/src/ipex_llm/transformers/convert.py @@ -309,7 +309,9 @@ def use_scale_search(model_config, qtype): def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None, convert_shape_only=False, - cpu_embedding=False, prefix_name='', + cpu_embedding=False, + disk_embedding=False, + prefix_name='', imatrix_data=None, embedding_qtype=None, model_config=None, torch_dtype=torch.float32, enable_xetla=False, @@ -319,7 +321,7 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None, ): from ipex_llm.transformers.low_bit_linear import LowBitLinear, FP4Params, \ FP16Linear, BF16Linear - from ipex_llm.transformers.embedding import LLMEmbedding, LowBitEmbedding + from ipex_llm.transformers.embedding import CPUEmbedding, DiskEmbedding, LowBitEmbedding has_been_replaced = False for name, module in model.named_children(): @@ -467,48 +469,15 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None, model._modules[name].requires_grad_(False) module.weight = None + # skip user-defined Embedding layer elif cpu_embedding and type(module) == nn.Embedding: - # skip user-defined Embedding layer - model._modules[name] = LLMEmbedding( - num_embeddings=module.num_embeddings, - embedding_dim=module.embedding_dim, - padding_idx=module.padding_idx, - max_norm=module.max_norm, - norm_type=module.norm_type, - scale_grad_by_freq=module.scale_grad_by_freq, - sparse=module.sparse, - _weight=module.weight.data, - ) - elif type(module) == nn.Embedding and embedding_qtype is not None: - if torch_dtype == "auto": - torch_dtype = torch.float32 - q_embedding = LowBitEmbedding( - num_embeddings=module.num_embeddings, - embedding_dim=module.embedding_dim, - padding_idx=module.padding_idx, - max_norm=module.max_norm, - norm_type=module.norm_type, - scale_grad_by_freq=module.scale_grad_by_freq, - sparse=module.sparse, - _weight=module.weight.data, - qtype=embedding_qtype, - torch_dtype=torch_dtype - ) - device = module.weight.data.device - # Copy the weights - paramsLowBit = FP4Params(data=module.weight.data, - requires_grad=False, - quantized=False, - _shape=None, - convert_shape_only=convert_shape_only, - qtype=embedding_qtype, - in_features=module.embedding_dim).to(device) - q_embedding._parameters['weight'] = paramsLowBit - model._modules[name] = q_embedding - # Force requires grad to False to avoid unexpected errors - model._modules[name].requires_grad_(False) - module.weight = None - + model._modules[name] = CPUEmbedding.from_embedding(module) + elif disk_embedding and type(module) == nn.Embedding: + model._modules[name] = DiskEmbedding.from_embedding(module) + elif embedding_qtype is not None and type(module) == nn.Embedding: + model._modules[name] = LowBitEmbedding.from_embedding(module, + convert_shape_only, + embedding_qtype) # Remove the last key for recursion if len(list(module.children())) > 0: _, _flag = _replace_with_low_bit_linear( @@ -517,6 +486,7 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None, modules_to_not_convert, convert_shape_only, cpu_embedding, + disk_embedding, prefix_name=prefix_name + '.' + name if prefix_name != '' else name, imatrix_data=imatrix_data, embedding_qtype=embedding_qtype, @@ -775,7 +745,8 @@ def _optimize_pre(model, qtype=None): def ggml_convert_low_bit(model, qtype, optimize_model=True, convert_shape_only=False, device="cpu", - modules_to_not_convert=None, cpu_embedding=False, + modules_to_not_convert=None, + cpu_embedding=False, disk_embedding=False, lightweight_bmm=False, torch_dtype="auto", imatrix_data=None, embedding_qtype=None, @@ -817,7 +788,7 @@ def ggml_convert_low_bit(model, qtype, optimize_model=True, # mixed quantization needs model_config to choose custom quantization strategy model, has_been_replaced = _replace_with_low_bit_linear( model, qtype, modules_to_not_convert, - convert_shape_only, cpu_embedding, + convert_shape_only, cpu_embedding, disk_embedding, imatrix_data=imatrix_data, embedding_qtype=embedding_qtype, model_config=model_config, diff --git a/python/llm/src/ipex_llm/transformers/embedding.py b/python/llm/src/ipex_llm/transformers/embedding.py index 836c1eb7..773afb69 100644 --- a/python/llm/src/ipex_llm/transformers/embedding.py +++ b/python/llm/src/ipex_llm/transformers/embedding.py @@ -18,7 +18,6 @@ import numpy import torch from torch import Tensor -from torch.nn import functional as F from torch.nn import Parameter from typing import Optional from ipex_llm.transformers.low_bit_linear import FP4Params @@ -56,7 +55,7 @@ class CPUPinnedParam(Parameter): return super().to(*args, **kwargs) -class LLMEmbedding(torch.nn.Embedding): +class CPUEmbedding(torch.nn.Embedding): def __init__(self, num_embeddings: int, embedding_dim: int, @@ -67,15 +66,32 @@ class LLMEmbedding(torch.nn.Embedding): sparse: bool = False, _weight: Optional[Tensor] = None, _freeze: bool = False, - device=None, dtype=None) -> None: + device=None, + dtype=None) -> None: super().__init__(num_embeddings, embedding_dim, padding_idx, max_norm, norm_type, scale_grad_by_freq, - sparse, _weight, _freeze, device, dtype) - self.weight = CPUPinnedParam(self.weight.data, requires_grad=not _freeze) + sparse, _weight, True, device, dtype) + self.weight = CPUPinnedParam(self.weight.data, requires_grad=False) def forward(self, x: Tensor): return super().forward(x.to('cpu')).to(x.device) + @classmethod + def from_embedding(cls, embedding: torch.nn.Embedding): + return cls( + embedding.num_embeddings, + embedding.embedding_dim, + embedding.padding_idx, + embedding.max_norm, + embedding.norm_type, + embedding.scale_grad_by_freq, + embedding.sparse, + embedding.weight.data, + True, + embedding.weight.device, + embedding.weight.dtype, + ) + class DiskEmbedding(torch.nn.Embedding): def __init__(self, @@ -89,7 +105,7 @@ class DiskEmbedding(torch.nn.Embedding): _weight: Optional[Tensor] = None, _freeze: bool = False, device=None, - dtype=None): + dtype=None) -> None: super().__init__(num_embeddings, embedding_dim, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse, _weight, True, device, dtype) @@ -147,30 +163,55 @@ class LowBitEmbedding(torch.nn.Embedding): sparse: bool = False, _weight: Optional[Tensor] = None, _freeze: bool = False, - device=None, dtype=None, - qtype=None, - torch_dtype=torch.float32) -> None: + device=None, + dtype=None, + convert_shape_only=None, + qtype=None) -> None: super().__init__(num_embeddings, embedding_dim, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse, _weight, device, dtype) - self.weight = FP4Params(self.weight.data, - requires_grad=False, - quantized=False, _shape=None, qtype=qtype) + self.qweight = FP4Params(self.weight.data, + requires_grad=False, + quantized=False, + _shape=None, + convert_shape_only=convert_shape_only, + qtype=qtype, + in_features=embedding_dim) + # this dummy_weight is used to record model's dtype and device + dummy_weight = torch.empty(0, 0, dtype=self.weight.dtype, device=self.weight.device) + self.weight = torch.nn.Parameter(dummy_weight, requires_grad=False) + self.embedding_dim = embedding_dim self.num_embeddings = num_embeddings - self.torch_dtype = torch_dtype def forward(self, x: Tensor): invalidInputError(x.device.type == "xpu", "`LowBitEmbedding` only supports GPU now.") try: - import intel_extension_for_pytorch import xe_linear except ModuleNotFoundError: invalidInputError(False, - "Please `pip install bigdl_core_xe` first.") + "Please `pip install bigdl_core_xe_21` first.") - result = xe_linear.dequantize_rows(x.contiguous(), self.weight.data, - self.weight.qtype, self.embedding_dim, + result = xe_linear.dequantize_rows(x.contiguous(), self.qweight.data, + self.qweight.qtype, self.embedding_dim, self.num_embeddings) - return result.to(self.torch_dtype) + return result.to(self.weight.dtype) + + @classmethod + def from_embedding(cls, embedding: torch.nn.Embedding, convert_shape_only, qtype): + return cls( + embedding.num_embeddings, + embedding.embedding_dim, + embedding.padding_idx, + embedding.max_norm, + embedding.norm_type, + embedding.scale_grad_by_freq, + embedding.sparse, + embedding.weight.data, + True, + embedding.weight.device, + embedding.weight.dtype, + convert_shape_only, + qtype, + ) diff --git a/python/llm/src/ipex_llm/transformers/low_bit_linear.py b/python/llm/src/ipex_llm/transformers/low_bit_linear.py index e5d862f5..c30ca4a2 100644 --- a/python/llm/src/ipex_llm/transformers/low_bit_linear.py +++ b/python/llm/src/ipex_llm/transformers/low_bit_linear.py @@ -483,7 +483,7 @@ class FP4Params(torch.nn.Parameter): return self.quantize(device.type) elif (device is not None and device.type == "xpu" and self.data.device.type == "cpu"): # enter xpu logic, compile linear_int4 extension at first time - self.quantize(device) # tensor is cpu now + self.quantize("cpu") # tensor is cpu now self.data = ggml_q_format_convet_cpu2xpu(self.data, reduce(mul, self._shape, 1), self.qtype) diff --git a/python/llm/src/ipex_llm/transformers/model.py b/python/llm/src/ipex_llm/transformers/model.py index 8b943ca4..1d6af54d 100644 --- a/python/llm/src/ipex_llm/transformers/model.py +++ b/python/llm/src/ipex_llm/transformers/model.py @@ -144,6 +144,8 @@ class _BaseAutoModelClass: Default to be ``False``. :param cpu_embedding: Whether to replace the Embedding layer, may need to set it to ``True`` when running BigDL-LLM on GPU on Windows. Default to be ``False``. + :param disk_embedding: Whether to put the Embedding layer on disk to save memory. + Default to be ``False``. :param lightweight_bmm: Whether to replace the torch.bmm ops, may need to set it to ``True`` when running BigDL-LLM on GPU on Windows. Default to be ``False``. :param imatrix: str value, represent filename of importance matrix pretrained on @@ -435,6 +437,7 @@ class _BaseAutoModelClass: warnings.warn("replace_embedding is deprecated and will be removed in a future version," " please use cpu_embedding instead.", FutureWarning) cpu_embedding = True + disk_embedding = kwargs.pop("disk_embedding", False) lightweight_bmm = kwargs.pop("lightweight_bmm", False) quant_config = kwargs.pop("quantization_config", None) imatrix_data = kwargs.pop("imatrix_data", None) @@ -507,7 +510,9 @@ class _BaseAutoModelClass: model = model.to("cpu") model = ggml_convert_low_bit(model, qtype, optimize_model, modules_to_not_convert=modules_to_not_convert, - cpu_embedding=cpu_embedding, lightweight_bmm=lightweight_bmm, + cpu_embedding=cpu_embedding, + disk_embedding=disk_embedding, + lightweight_bmm=lightweight_bmm, torch_dtype=kwargs.get("torch_dtype", 'auto'), imatrix_data=imatrix_data, embedding_qtype=embedding_qtype, @@ -563,6 +568,7 @@ class _BaseAutoModelClass: warnings.warn("replace_embedding is deprecated and will be removed in a future version," " please use cpu_embedding instead.", FutureWarning) cpu_embedding = True + disk_embedding = kwargs.pop("disk_embedding", False) lightweight_bmm = kwargs.pop("lightweight_bmm", False) # Autofactory trust_remote_code = kwargs.pop("trust_remote_code", None) @@ -699,7 +705,9 @@ class _BaseAutoModelClass: quant_device = "meta" if bigdl_lcmu_enabled else "cpu" model = ggml_convert_low_bit(model, qtype, optimize_model, device=quant_device, modules_to_not_convert=modules_to_not_convert, - cpu_embedding=cpu_embedding, lightweight_bmm=lightweight_bmm, + cpu_embedding=cpu_embedding, + disk_embedding=disk_embedding, + lightweight_bmm=lightweight_bmm, embedding_qtype=embedding_qtype, torch_dtype=torch_dtype) if is_sharded: