Add disk_embedding parameter to support put Embedding layer on CPU (#11617)

This commit is contained in:
Yishuo Wang 2024-07-18 17:06:06 +08:00 committed by GitHub
parent 2478e2c14b
commit 0209427cf4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 86 additions and 66 deletions

View file

@ -309,7 +309,9 @@ def use_scale_search(model_config, qtype):
def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
convert_shape_only=False,
cpu_embedding=False, prefix_name='',
cpu_embedding=False,
disk_embedding=False,
prefix_name='',
imatrix_data=None, embedding_qtype=None,
model_config=None, torch_dtype=torch.float32,
enable_xetla=False,
@ -319,7 +321,7 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
):
from ipex_llm.transformers.low_bit_linear import LowBitLinear, FP4Params, \
FP16Linear, BF16Linear
from ipex_llm.transformers.embedding import LLMEmbedding, LowBitEmbedding
from ipex_llm.transformers.embedding import CPUEmbedding, DiskEmbedding, LowBitEmbedding
has_been_replaced = False
for name, module in model.named_children():
@ -467,48 +469,15 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
model._modules[name].requires_grad_(False)
module.weight = None
# skip user-defined Embedding layer
elif cpu_embedding and type(module) == nn.Embedding:
# skip user-defined Embedding layer
model._modules[name] = LLMEmbedding(
num_embeddings=module.num_embeddings,
embedding_dim=module.embedding_dim,
padding_idx=module.padding_idx,
max_norm=module.max_norm,
norm_type=module.norm_type,
scale_grad_by_freq=module.scale_grad_by_freq,
sparse=module.sparse,
_weight=module.weight.data,
)
elif type(module) == nn.Embedding and embedding_qtype is not None:
if torch_dtype == "auto":
torch_dtype = torch.float32
q_embedding = LowBitEmbedding(
num_embeddings=module.num_embeddings,
embedding_dim=module.embedding_dim,
padding_idx=module.padding_idx,
max_norm=module.max_norm,
norm_type=module.norm_type,
scale_grad_by_freq=module.scale_grad_by_freq,
sparse=module.sparse,
_weight=module.weight.data,
qtype=embedding_qtype,
torch_dtype=torch_dtype
)
device = module.weight.data.device
# Copy the weights
paramsLowBit = FP4Params(data=module.weight.data,
requires_grad=False,
quantized=False,
_shape=None,
convert_shape_only=convert_shape_only,
qtype=embedding_qtype,
in_features=module.embedding_dim).to(device)
q_embedding._parameters['weight'] = paramsLowBit
model._modules[name] = q_embedding
# Force requires grad to False to avoid unexpected errors
model._modules[name].requires_grad_(False)
module.weight = None
model._modules[name] = CPUEmbedding.from_embedding(module)
elif disk_embedding and type(module) == nn.Embedding:
model._modules[name] = DiskEmbedding.from_embedding(module)
elif embedding_qtype is not None and type(module) == nn.Embedding:
model._modules[name] = LowBitEmbedding.from_embedding(module,
convert_shape_only,
embedding_qtype)
# Remove the last key for recursion
if len(list(module.children())) > 0:
_, _flag = _replace_with_low_bit_linear(
@ -517,6 +486,7 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
modules_to_not_convert,
convert_shape_only,
cpu_embedding,
disk_embedding,
prefix_name=prefix_name + '.' + name if prefix_name != '' else name,
imatrix_data=imatrix_data,
embedding_qtype=embedding_qtype,
@ -775,7 +745,8 @@ def _optimize_pre(model, qtype=None):
def ggml_convert_low_bit(model, qtype, optimize_model=True,
convert_shape_only=False, device="cpu",
modules_to_not_convert=None, cpu_embedding=False,
modules_to_not_convert=None,
cpu_embedding=False, disk_embedding=False,
lightweight_bmm=False, torch_dtype="auto",
imatrix_data=None,
embedding_qtype=None,
@ -817,7 +788,7 @@ def ggml_convert_low_bit(model, qtype, optimize_model=True,
# mixed quantization needs model_config to choose custom quantization strategy
model, has_been_replaced = _replace_with_low_bit_linear(
model, qtype, modules_to_not_convert,
convert_shape_only, cpu_embedding,
convert_shape_only, cpu_embedding, disk_embedding,
imatrix_data=imatrix_data,
embedding_qtype=embedding_qtype,
model_config=model_config,

View file

@ -18,7 +18,6 @@
import numpy
import torch
from torch import Tensor
from torch.nn import functional as F
from torch.nn import Parameter
from typing import Optional
from ipex_llm.transformers.low_bit_linear import FP4Params
@ -56,7 +55,7 @@ class CPUPinnedParam(Parameter):
return super().to(*args, **kwargs)
class LLMEmbedding(torch.nn.Embedding):
class CPUEmbedding(torch.nn.Embedding):
def __init__(self,
num_embeddings: int,
embedding_dim: int,
@ -67,15 +66,32 @@ class LLMEmbedding(torch.nn.Embedding):
sparse: bool = False,
_weight: Optional[Tensor] = None,
_freeze: bool = False,
device=None, dtype=None) -> None:
device=None,
dtype=None) -> None:
super().__init__(num_embeddings, embedding_dim, padding_idx,
max_norm, norm_type, scale_grad_by_freq,
sparse, _weight, _freeze, device, dtype)
self.weight = CPUPinnedParam(self.weight.data, requires_grad=not _freeze)
sparse, _weight, True, device, dtype)
self.weight = CPUPinnedParam(self.weight.data, requires_grad=False)
def forward(self, x: Tensor):
return super().forward(x.to('cpu')).to(x.device)
@classmethod
def from_embedding(cls, embedding: torch.nn.Embedding):
return cls(
embedding.num_embeddings,
embedding.embedding_dim,
embedding.padding_idx,
embedding.max_norm,
embedding.norm_type,
embedding.scale_grad_by_freq,
embedding.sparse,
embedding.weight.data,
True,
embedding.weight.device,
embedding.weight.dtype,
)
class DiskEmbedding(torch.nn.Embedding):
def __init__(self,
@ -89,7 +105,7 @@ class DiskEmbedding(torch.nn.Embedding):
_weight: Optional[Tensor] = None,
_freeze: bool = False,
device=None,
dtype=None):
dtype=None) -> None:
super().__init__(num_embeddings, embedding_dim, padding_idx,
max_norm, norm_type, scale_grad_by_freq,
sparse, _weight, True, device, dtype)
@ -147,30 +163,55 @@ class LowBitEmbedding(torch.nn.Embedding):
sparse: bool = False,
_weight: Optional[Tensor] = None,
_freeze: bool = False,
device=None, dtype=None,
qtype=None,
torch_dtype=torch.float32) -> None:
device=None,
dtype=None,
convert_shape_only=None,
qtype=None) -> None:
super().__init__(num_embeddings, embedding_dim, padding_idx,
max_norm, norm_type, scale_grad_by_freq, sparse,
_weight, device, dtype)
self.weight = FP4Params(self.weight.data,
requires_grad=False,
quantized=False, _shape=None, qtype=qtype)
self.qweight = FP4Params(self.weight.data,
requires_grad=False,
quantized=False,
_shape=None,
convert_shape_only=convert_shape_only,
qtype=qtype,
in_features=embedding_dim)
# this dummy_weight is used to record model's dtype and device
dummy_weight = torch.empty(0, 0, dtype=self.weight.dtype, device=self.weight.device)
self.weight = torch.nn.Parameter(dummy_weight, requires_grad=False)
self.embedding_dim = embedding_dim
self.num_embeddings = num_embeddings
self.torch_dtype = torch_dtype
def forward(self, x: Tensor):
invalidInputError(x.device.type == "xpu",
"`LowBitEmbedding` only supports GPU now.")
try:
import intel_extension_for_pytorch
import xe_linear
except ModuleNotFoundError:
invalidInputError(False,
"Please `pip install bigdl_core_xe` first.")
"Please `pip install bigdl_core_xe_21` first.")
result = xe_linear.dequantize_rows(x.contiguous(), self.weight.data,
self.weight.qtype, self.embedding_dim,
result = xe_linear.dequantize_rows(x.contiguous(), self.qweight.data,
self.qweight.qtype, self.embedding_dim,
self.num_embeddings)
return result.to(self.torch_dtype)
return result.to(self.weight.dtype)
@classmethod
def from_embedding(cls, embedding: torch.nn.Embedding, convert_shape_only, qtype):
return cls(
embedding.num_embeddings,
embedding.embedding_dim,
embedding.padding_idx,
embedding.max_norm,
embedding.norm_type,
embedding.scale_grad_by_freq,
embedding.sparse,
embedding.weight.data,
True,
embedding.weight.device,
embedding.weight.dtype,
convert_shape_only,
qtype,
)

View file

@ -483,7 +483,7 @@ class FP4Params(torch.nn.Parameter):
return self.quantize(device.type)
elif (device is not None and device.type == "xpu" and self.data.device.type == "cpu"):
# enter xpu logic, compile linear_int4 extension at first time
self.quantize(device) # tensor is cpu now
self.quantize("cpu") # tensor is cpu now
self.data = ggml_q_format_convet_cpu2xpu(self.data,
reduce(mul, self._shape, 1),
self.qtype)

View file

@ -144,6 +144,8 @@ class _BaseAutoModelClass:
Default to be ``False``.
:param cpu_embedding: Whether to replace the Embedding layer, may need to set it
to ``True`` when running BigDL-LLM on GPU on Windows. Default to be ``False``.
:param disk_embedding: Whether to put the Embedding layer on disk to save memory.
Default to be ``False``.
:param lightweight_bmm: Whether to replace the torch.bmm ops, may need to set it
to ``True`` when running BigDL-LLM on GPU on Windows. Default to be ``False``.
:param imatrix: str value, represent filename of importance matrix pretrained on
@ -435,6 +437,7 @@ class _BaseAutoModelClass:
warnings.warn("replace_embedding is deprecated and will be removed in a future version,"
" please use cpu_embedding instead.", FutureWarning)
cpu_embedding = True
disk_embedding = kwargs.pop("disk_embedding", False)
lightweight_bmm = kwargs.pop("lightweight_bmm", False)
quant_config = kwargs.pop("quantization_config", None)
imatrix_data = kwargs.pop("imatrix_data", None)
@ -507,7 +510,9 @@ class _BaseAutoModelClass:
model = model.to("cpu")
model = ggml_convert_low_bit(model, qtype, optimize_model,
modules_to_not_convert=modules_to_not_convert,
cpu_embedding=cpu_embedding, lightweight_bmm=lightweight_bmm,
cpu_embedding=cpu_embedding,
disk_embedding=disk_embedding,
lightweight_bmm=lightweight_bmm,
torch_dtype=kwargs.get("torch_dtype", 'auto'),
imatrix_data=imatrix_data,
embedding_qtype=embedding_qtype,
@ -563,6 +568,7 @@ class _BaseAutoModelClass:
warnings.warn("replace_embedding is deprecated and will be removed in a future version,"
" please use cpu_embedding instead.", FutureWarning)
cpu_embedding = True
disk_embedding = kwargs.pop("disk_embedding", False)
lightweight_bmm = kwargs.pop("lightweight_bmm", False)
# Autofactory
trust_remote_code = kwargs.pop("trust_remote_code", None)
@ -699,7 +705,9 @@ class _BaseAutoModelClass:
quant_device = "meta" if bigdl_lcmu_enabled else "cpu"
model = ggml_convert_low_bit(model, qtype, optimize_model, device=quant_device,
modules_to_not_convert=modules_to_not_convert,
cpu_embedding=cpu_embedding, lightweight_bmm=lightweight_bmm,
cpu_embedding=cpu_embedding,
disk_embedding=disk_embedding,
lightweight_bmm=lightweight_bmm,
embedding_qtype=embedding_qtype, torch_dtype=torch_dtype)
if is_sharded: