add save_low_bit support for DiskEmbedding (#11621)

This commit is contained in:
Yishuo Wang 2024-07-19 10:34:53 +08:00 committed by GitHub
parent 380717f50d
commit d020ad6397
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 57 additions and 20 deletions

View file

@ -310,7 +310,6 @@ def use_scale_search(model_config, qtype):
def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
convert_shape_only=False,
cpu_embedding=False,
disk_embedding=False,
prefix_name='',
imatrix_data=None, embedding_qtype=None,
model_config=None, torch_dtype=torch.float32,
@ -472,8 +471,6 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
# skip user-defined Embedding layer
elif cpu_embedding and type(module) == nn.Embedding:
model._modules[name] = CPUEmbedding.from_embedding(module)
elif disk_embedding and type(module) == nn.Embedding:
model._modules[name] = DiskEmbedding.from_embedding(module)
elif embedding_qtype is not None and type(module) == nn.Embedding:
model._modules[name] = LowBitEmbedding.from_embedding(module,
convert_shape_only,
@ -486,7 +483,6 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
modules_to_not_convert,
convert_shape_only,
cpu_embedding,
disk_embedding,
prefix_name=prefix_name + '.' + name if prefix_name != '' else name,
imatrix_data=imatrix_data,
embedding_qtype=embedding_qtype,
@ -746,7 +742,7 @@ def _optimize_pre(model, qtype=None):
def ggml_convert_low_bit(model, qtype, optimize_model=True,
convert_shape_only=False, device="cpu",
modules_to_not_convert=None,
cpu_embedding=False, disk_embedding=False,
cpu_embedding=False,
lightweight_bmm=False, torch_dtype="auto",
imatrix_data=None,
embedding_qtype=None,
@ -788,7 +784,7 @@ def ggml_convert_low_bit(model, qtype, optimize_model=True,
# mixed quantization needs model_config to choose custom quantization strategy
model, has_been_replaced = _replace_with_low_bit_linear(
model, qtype, modules_to_not_convert,
convert_shape_only, cpu_embedding, disk_embedding,
convert_shape_only, cpu_embedding,
imatrix_data=imatrix_data,
embedding_qtype=embedding_qtype,
model_config=model_config,

View file

@ -110,7 +110,7 @@ class DiskEmbedding(torch.nn.Embedding):
max_norm, norm_type, scale_grad_by_freq,
sparse, _weight, True, device, dtype)
self.filename = "embeddings.bin"
self.weight.data.flatten().half().numpy().tofile(self.filename)
self.weight.data.flatten().to(device='cpu', dtype=torch.half).numpy().tofile(self.filename)
dummy_weight = torch.empty(0, 0, dtype=self.weight.dtype, device=self.weight.device)
self.weight = torch.nn.Parameter(dummy_weight, requires_grad=False)
@ -126,15 +126,6 @@ class DiskEmbedding(torch.nn.Embedding):
embeds = torch.stack(embeds).to(device=input_ids.device, dtype=self.weight.dtype)
return embeds.view(*input_ids.size(), self.embedding_dim)
def restore(self):
with open(self.filename, 'rb') as f:
buffer = f.read()
embeds = torch.frombuffer(buffer, dtype=torch.half).clone()
embeds = embeds.view(self.num_embeddings, self.embedding_dim).to(
device=self.weight.device, dtype=self.weight.dtype
)
self.weight = torch.nn.Parameter(embeds, requires_grad=False)
@classmethod
def from_embedding(cls, embedding: torch.nn.Embedding):
return cls(
@ -151,6 +142,39 @@ class DiskEmbedding(torch.nn.Embedding):
embedding.weight.dtype,
)
def to_embedding(self):
with open(self.filename, 'rb') as f:
buffer = f.read()
embeds = torch.frombuffer(buffer, dtype=torch.half).clone()
embeds = embeds.view(self.num_embeddings, self.embedding_dim).to(
device=self.weight.device, dtype=self.weight.dtype
)
return torch.nn.Embedding(
self.num_embeddings,
self.embedding_dim,
self.padding_idx,
self.max_norm,
self.norm_type,
self.scale_grad_by_freq,
self.sparse,
embeds,
True,
embeds.device,
embeds.dtype,
)
@staticmethod
def replace_normal_embedding(m: torch.nn.Module):
for name, module in m.named_children():
if type(module) == torch.nn.Embedding:
m._modules[name] = DiskEmbedding.from_embedding(module)
@staticmethod
def restore_normal_embedding(m: torch.nn.Module):
for name, module in m.named_children():
if type(module) == DiskEmbedding:
m._modules[name] = module.to_embedding()
class LowBitEmbedding(torch.nn.Embedding):
def __init__(self,

View file

@ -71,7 +71,15 @@ def save_low_bit(self, *args, **kwargs):
architectures = getattr(self.config, "architectures", None)
model_type = getattr(self.config, "model_type", None)
self.save_pretrained(*args, **kwargs)
disk_embedding = getattr(self.config, "bigdl_disk_embedding", False)
if disk_embedding:
from ipex_llm.transformers.embedding import DiskEmbedding
self.apply(DiskEmbedding.restore_normal_embedding)
self.save_pretrained(*args, **kwargs)
self.apply(DiskEmbedding.replace_normal_embedding)
else:
self.save_pretrained(*args, **kwargs)
if architectures:
self.config.update({"architectures": architectures})
@ -511,14 +519,19 @@ class _BaseAutoModelClass:
model = ggml_convert_low_bit(model, qtype, optimize_model,
modules_to_not_convert=modules_to_not_convert,
cpu_embedding=cpu_embedding,
disk_embedding=disk_embedding,
lightweight_bmm=lightweight_bmm,
torch_dtype=kwargs.get("torch_dtype", 'auto'),
imatrix_data=imatrix_data,
embedding_qtype=embedding_qtype,
enable_xetla=enable_xetla,
mixed_precision=mixed_precision)
model.config.update({"bigdl_transformers_low_bit": q_k})
if disk_embedding:
from ipex_llm.transformers.embedding import DiskEmbedding
model.apply(DiskEmbedding.replace_normal_embedding)
model.config.update({"bigdl_transformers_low_bit": q_k,
"bigdl_disk_embedding": disk_embedding})
# enable tie_word_embeddings for MPT
# refer to https://huggingface.co/mosaicml/mpt-7b-chat/blob/main/modeling_mpt.py#L232
@ -706,7 +719,6 @@ class _BaseAutoModelClass:
model = ggml_convert_low_bit(model, qtype, optimize_model, device=quant_device,
modules_to_not_convert=modules_to_not_convert,
cpu_embedding=cpu_embedding,
disk_embedding=disk_embedding,
lightweight_bmm=lightweight_bmm,
embedding_qtype=embedding_qtype, torch_dtype=torch_dtype)
@ -749,6 +761,11 @@ class _BaseAutoModelClass:
# make sure token embedding weights are still tied if needed
model.tie_weights()
if disk_embedding:
from ipex_llm.transformers.embedding import DiskEmbedding
model.apply(DiskEmbedding.replace_normal_embedding)
model.config.update({"bigdl_disk_embedding": disk_embedding})
# Set model in evaluation mode to deactivate DropOut modules by default
model.eval()