add save_low_bit support for DiskEmbedding (#11621)
This commit is contained in:
parent
380717f50d
commit
d020ad6397
3 changed files with 57 additions and 20 deletions
|
|
@ -310,7 +310,6 @@ def use_scale_search(model_config, qtype):
|
|||
def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
|
||||
convert_shape_only=False,
|
||||
cpu_embedding=False,
|
||||
disk_embedding=False,
|
||||
prefix_name='',
|
||||
imatrix_data=None, embedding_qtype=None,
|
||||
model_config=None, torch_dtype=torch.float32,
|
||||
|
|
@ -472,8 +471,6 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
|
|||
# skip user-defined Embedding layer
|
||||
elif cpu_embedding and type(module) == nn.Embedding:
|
||||
model._modules[name] = CPUEmbedding.from_embedding(module)
|
||||
elif disk_embedding and type(module) == nn.Embedding:
|
||||
model._modules[name] = DiskEmbedding.from_embedding(module)
|
||||
elif embedding_qtype is not None and type(module) == nn.Embedding:
|
||||
model._modules[name] = LowBitEmbedding.from_embedding(module,
|
||||
convert_shape_only,
|
||||
|
|
@ -486,7 +483,6 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
|
|||
modules_to_not_convert,
|
||||
convert_shape_only,
|
||||
cpu_embedding,
|
||||
disk_embedding,
|
||||
prefix_name=prefix_name + '.' + name if prefix_name != '' else name,
|
||||
imatrix_data=imatrix_data,
|
||||
embedding_qtype=embedding_qtype,
|
||||
|
|
@ -746,7 +742,7 @@ def _optimize_pre(model, qtype=None):
|
|||
def ggml_convert_low_bit(model, qtype, optimize_model=True,
|
||||
convert_shape_only=False, device="cpu",
|
||||
modules_to_not_convert=None,
|
||||
cpu_embedding=False, disk_embedding=False,
|
||||
cpu_embedding=False,
|
||||
lightweight_bmm=False, torch_dtype="auto",
|
||||
imatrix_data=None,
|
||||
embedding_qtype=None,
|
||||
|
|
@ -788,7 +784,7 @@ def ggml_convert_low_bit(model, qtype, optimize_model=True,
|
|||
# mixed quantization needs model_config to choose custom quantization strategy
|
||||
model, has_been_replaced = _replace_with_low_bit_linear(
|
||||
model, qtype, modules_to_not_convert,
|
||||
convert_shape_only, cpu_embedding, disk_embedding,
|
||||
convert_shape_only, cpu_embedding,
|
||||
imatrix_data=imatrix_data,
|
||||
embedding_qtype=embedding_qtype,
|
||||
model_config=model_config,
|
||||
|
|
|
|||
|
|
@ -110,7 +110,7 @@ class DiskEmbedding(torch.nn.Embedding):
|
|||
max_norm, norm_type, scale_grad_by_freq,
|
||||
sparse, _weight, True, device, dtype)
|
||||
self.filename = "embeddings.bin"
|
||||
self.weight.data.flatten().half().numpy().tofile(self.filename)
|
||||
self.weight.data.flatten().to(device='cpu', dtype=torch.half).numpy().tofile(self.filename)
|
||||
dummy_weight = torch.empty(0, 0, dtype=self.weight.dtype, device=self.weight.device)
|
||||
self.weight = torch.nn.Parameter(dummy_weight, requires_grad=False)
|
||||
|
||||
|
|
@ -126,15 +126,6 @@ class DiskEmbedding(torch.nn.Embedding):
|
|||
embeds = torch.stack(embeds).to(device=input_ids.device, dtype=self.weight.dtype)
|
||||
return embeds.view(*input_ids.size(), self.embedding_dim)
|
||||
|
||||
def restore(self):
|
||||
with open(self.filename, 'rb') as f:
|
||||
buffer = f.read()
|
||||
embeds = torch.frombuffer(buffer, dtype=torch.half).clone()
|
||||
embeds = embeds.view(self.num_embeddings, self.embedding_dim).to(
|
||||
device=self.weight.device, dtype=self.weight.dtype
|
||||
)
|
||||
self.weight = torch.nn.Parameter(embeds, requires_grad=False)
|
||||
|
||||
@classmethod
|
||||
def from_embedding(cls, embedding: torch.nn.Embedding):
|
||||
return cls(
|
||||
|
|
@ -151,6 +142,39 @@ class DiskEmbedding(torch.nn.Embedding):
|
|||
embedding.weight.dtype,
|
||||
)
|
||||
|
||||
def to_embedding(self):
|
||||
with open(self.filename, 'rb') as f:
|
||||
buffer = f.read()
|
||||
embeds = torch.frombuffer(buffer, dtype=torch.half).clone()
|
||||
embeds = embeds.view(self.num_embeddings, self.embedding_dim).to(
|
||||
device=self.weight.device, dtype=self.weight.dtype
|
||||
)
|
||||
return torch.nn.Embedding(
|
||||
self.num_embeddings,
|
||||
self.embedding_dim,
|
||||
self.padding_idx,
|
||||
self.max_norm,
|
||||
self.norm_type,
|
||||
self.scale_grad_by_freq,
|
||||
self.sparse,
|
||||
embeds,
|
||||
True,
|
||||
embeds.device,
|
||||
embeds.dtype,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def replace_normal_embedding(m: torch.nn.Module):
|
||||
for name, module in m.named_children():
|
||||
if type(module) == torch.nn.Embedding:
|
||||
m._modules[name] = DiskEmbedding.from_embedding(module)
|
||||
|
||||
@staticmethod
|
||||
def restore_normal_embedding(m: torch.nn.Module):
|
||||
for name, module in m.named_children():
|
||||
if type(module) == DiskEmbedding:
|
||||
m._modules[name] = module.to_embedding()
|
||||
|
||||
|
||||
class LowBitEmbedding(torch.nn.Embedding):
|
||||
def __init__(self,
|
||||
|
|
|
|||
|
|
@ -71,7 +71,15 @@ def save_low_bit(self, *args, **kwargs):
|
|||
|
||||
architectures = getattr(self.config, "architectures", None)
|
||||
model_type = getattr(self.config, "model_type", None)
|
||||
self.save_pretrained(*args, **kwargs)
|
||||
disk_embedding = getattr(self.config, "bigdl_disk_embedding", False)
|
||||
|
||||
if disk_embedding:
|
||||
from ipex_llm.transformers.embedding import DiskEmbedding
|
||||
self.apply(DiskEmbedding.restore_normal_embedding)
|
||||
self.save_pretrained(*args, **kwargs)
|
||||
self.apply(DiskEmbedding.replace_normal_embedding)
|
||||
else:
|
||||
self.save_pretrained(*args, **kwargs)
|
||||
|
||||
if architectures:
|
||||
self.config.update({"architectures": architectures})
|
||||
|
|
@ -511,14 +519,19 @@ class _BaseAutoModelClass:
|
|||
model = ggml_convert_low_bit(model, qtype, optimize_model,
|
||||
modules_to_not_convert=modules_to_not_convert,
|
||||
cpu_embedding=cpu_embedding,
|
||||
disk_embedding=disk_embedding,
|
||||
lightweight_bmm=lightweight_bmm,
|
||||
torch_dtype=kwargs.get("torch_dtype", 'auto'),
|
||||
imatrix_data=imatrix_data,
|
||||
embedding_qtype=embedding_qtype,
|
||||
enable_xetla=enable_xetla,
|
||||
mixed_precision=mixed_precision)
|
||||
model.config.update({"bigdl_transformers_low_bit": q_k})
|
||||
|
||||
if disk_embedding:
|
||||
from ipex_llm.transformers.embedding import DiskEmbedding
|
||||
model.apply(DiskEmbedding.replace_normal_embedding)
|
||||
|
||||
model.config.update({"bigdl_transformers_low_bit": q_k,
|
||||
"bigdl_disk_embedding": disk_embedding})
|
||||
|
||||
# enable tie_word_embeddings for MPT
|
||||
# refer to https://huggingface.co/mosaicml/mpt-7b-chat/blob/main/modeling_mpt.py#L232
|
||||
|
|
@ -706,7 +719,6 @@ class _BaseAutoModelClass:
|
|||
model = ggml_convert_low_bit(model, qtype, optimize_model, device=quant_device,
|
||||
modules_to_not_convert=modules_to_not_convert,
|
||||
cpu_embedding=cpu_embedding,
|
||||
disk_embedding=disk_embedding,
|
||||
lightweight_bmm=lightweight_bmm,
|
||||
embedding_qtype=embedding_qtype, torch_dtype=torch_dtype)
|
||||
|
||||
|
|
@ -749,6 +761,11 @@ class _BaseAutoModelClass:
|
|||
# make sure token embedding weights are still tied if needed
|
||||
model.tie_weights()
|
||||
|
||||
if disk_embedding:
|
||||
from ipex_llm.transformers.embedding import DiskEmbedding
|
||||
model.apply(DiskEmbedding.replace_normal_embedding)
|
||||
model.config.update({"bigdl_disk_embedding": disk_embedding})
|
||||
|
||||
# Set model in evaluation mode to deactivate DropOut modules by default
|
||||
model.eval()
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue