add save_low_bit support for DiskEmbedding (#11621)
This commit is contained in:
parent
380717f50d
commit
d020ad6397
3 changed files with 57 additions and 20 deletions
|
|
@ -310,7 +310,6 @@ def use_scale_search(model_config, qtype):
|
||||||
def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
|
def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
|
||||||
convert_shape_only=False,
|
convert_shape_only=False,
|
||||||
cpu_embedding=False,
|
cpu_embedding=False,
|
||||||
disk_embedding=False,
|
|
||||||
prefix_name='',
|
prefix_name='',
|
||||||
imatrix_data=None, embedding_qtype=None,
|
imatrix_data=None, embedding_qtype=None,
|
||||||
model_config=None, torch_dtype=torch.float32,
|
model_config=None, torch_dtype=torch.float32,
|
||||||
|
|
@ -472,8 +471,6 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
|
||||||
# skip user-defined Embedding layer
|
# skip user-defined Embedding layer
|
||||||
elif cpu_embedding and type(module) == nn.Embedding:
|
elif cpu_embedding and type(module) == nn.Embedding:
|
||||||
model._modules[name] = CPUEmbedding.from_embedding(module)
|
model._modules[name] = CPUEmbedding.from_embedding(module)
|
||||||
elif disk_embedding and type(module) == nn.Embedding:
|
|
||||||
model._modules[name] = DiskEmbedding.from_embedding(module)
|
|
||||||
elif embedding_qtype is not None and type(module) == nn.Embedding:
|
elif embedding_qtype is not None and type(module) == nn.Embedding:
|
||||||
model._modules[name] = LowBitEmbedding.from_embedding(module,
|
model._modules[name] = LowBitEmbedding.from_embedding(module,
|
||||||
convert_shape_only,
|
convert_shape_only,
|
||||||
|
|
@ -486,7 +483,6 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
|
||||||
modules_to_not_convert,
|
modules_to_not_convert,
|
||||||
convert_shape_only,
|
convert_shape_only,
|
||||||
cpu_embedding,
|
cpu_embedding,
|
||||||
disk_embedding,
|
|
||||||
prefix_name=prefix_name + '.' + name if prefix_name != '' else name,
|
prefix_name=prefix_name + '.' + name if prefix_name != '' else name,
|
||||||
imatrix_data=imatrix_data,
|
imatrix_data=imatrix_data,
|
||||||
embedding_qtype=embedding_qtype,
|
embedding_qtype=embedding_qtype,
|
||||||
|
|
@ -746,7 +742,7 @@ def _optimize_pre(model, qtype=None):
|
||||||
def ggml_convert_low_bit(model, qtype, optimize_model=True,
|
def ggml_convert_low_bit(model, qtype, optimize_model=True,
|
||||||
convert_shape_only=False, device="cpu",
|
convert_shape_only=False, device="cpu",
|
||||||
modules_to_not_convert=None,
|
modules_to_not_convert=None,
|
||||||
cpu_embedding=False, disk_embedding=False,
|
cpu_embedding=False,
|
||||||
lightweight_bmm=False, torch_dtype="auto",
|
lightweight_bmm=False, torch_dtype="auto",
|
||||||
imatrix_data=None,
|
imatrix_data=None,
|
||||||
embedding_qtype=None,
|
embedding_qtype=None,
|
||||||
|
|
@ -788,7 +784,7 @@ def ggml_convert_low_bit(model, qtype, optimize_model=True,
|
||||||
# mixed quantization needs model_config to choose custom quantization strategy
|
# mixed quantization needs model_config to choose custom quantization strategy
|
||||||
model, has_been_replaced = _replace_with_low_bit_linear(
|
model, has_been_replaced = _replace_with_low_bit_linear(
|
||||||
model, qtype, modules_to_not_convert,
|
model, qtype, modules_to_not_convert,
|
||||||
convert_shape_only, cpu_embedding, disk_embedding,
|
convert_shape_only, cpu_embedding,
|
||||||
imatrix_data=imatrix_data,
|
imatrix_data=imatrix_data,
|
||||||
embedding_qtype=embedding_qtype,
|
embedding_qtype=embedding_qtype,
|
||||||
model_config=model_config,
|
model_config=model_config,
|
||||||
|
|
|
||||||
|
|
@ -110,7 +110,7 @@ class DiskEmbedding(torch.nn.Embedding):
|
||||||
max_norm, norm_type, scale_grad_by_freq,
|
max_norm, norm_type, scale_grad_by_freq,
|
||||||
sparse, _weight, True, device, dtype)
|
sparse, _weight, True, device, dtype)
|
||||||
self.filename = "embeddings.bin"
|
self.filename = "embeddings.bin"
|
||||||
self.weight.data.flatten().half().numpy().tofile(self.filename)
|
self.weight.data.flatten().to(device='cpu', dtype=torch.half).numpy().tofile(self.filename)
|
||||||
dummy_weight = torch.empty(0, 0, dtype=self.weight.dtype, device=self.weight.device)
|
dummy_weight = torch.empty(0, 0, dtype=self.weight.dtype, device=self.weight.device)
|
||||||
self.weight = torch.nn.Parameter(dummy_weight, requires_grad=False)
|
self.weight = torch.nn.Parameter(dummy_weight, requires_grad=False)
|
||||||
|
|
||||||
|
|
@ -126,15 +126,6 @@ class DiskEmbedding(torch.nn.Embedding):
|
||||||
embeds = torch.stack(embeds).to(device=input_ids.device, dtype=self.weight.dtype)
|
embeds = torch.stack(embeds).to(device=input_ids.device, dtype=self.weight.dtype)
|
||||||
return embeds.view(*input_ids.size(), self.embedding_dim)
|
return embeds.view(*input_ids.size(), self.embedding_dim)
|
||||||
|
|
||||||
def restore(self):
|
|
||||||
with open(self.filename, 'rb') as f:
|
|
||||||
buffer = f.read()
|
|
||||||
embeds = torch.frombuffer(buffer, dtype=torch.half).clone()
|
|
||||||
embeds = embeds.view(self.num_embeddings, self.embedding_dim).to(
|
|
||||||
device=self.weight.device, dtype=self.weight.dtype
|
|
||||||
)
|
|
||||||
self.weight = torch.nn.Parameter(embeds, requires_grad=False)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_embedding(cls, embedding: torch.nn.Embedding):
|
def from_embedding(cls, embedding: torch.nn.Embedding):
|
||||||
return cls(
|
return cls(
|
||||||
|
|
@ -151,6 +142,39 @@ class DiskEmbedding(torch.nn.Embedding):
|
||||||
embedding.weight.dtype,
|
embedding.weight.dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def to_embedding(self):
|
||||||
|
with open(self.filename, 'rb') as f:
|
||||||
|
buffer = f.read()
|
||||||
|
embeds = torch.frombuffer(buffer, dtype=torch.half).clone()
|
||||||
|
embeds = embeds.view(self.num_embeddings, self.embedding_dim).to(
|
||||||
|
device=self.weight.device, dtype=self.weight.dtype
|
||||||
|
)
|
||||||
|
return torch.nn.Embedding(
|
||||||
|
self.num_embeddings,
|
||||||
|
self.embedding_dim,
|
||||||
|
self.padding_idx,
|
||||||
|
self.max_norm,
|
||||||
|
self.norm_type,
|
||||||
|
self.scale_grad_by_freq,
|
||||||
|
self.sparse,
|
||||||
|
embeds,
|
||||||
|
True,
|
||||||
|
embeds.device,
|
||||||
|
embeds.dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def replace_normal_embedding(m: torch.nn.Module):
|
||||||
|
for name, module in m.named_children():
|
||||||
|
if type(module) == torch.nn.Embedding:
|
||||||
|
m._modules[name] = DiskEmbedding.from_embedding(module)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def restore_normal_embedding(m: torch.nn.Module):
|
||||||
|
for name, module in m.named_children():
|
||||||
|
if type(module) == DiskEmbedding:
|
||||||
|
m._modules[name] = module.to_embedding()
|
||||||
|
|
||||||
|
|
||||||
class LowBitEmbedding(torch.nn.Embedding):
|
class LowBitEmbedding(torch.nn.Embedding):
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
|
|
|
||||||
|
|
@ -71,6 +71,14 @@ def save_low_bit(self, *args, **kwargs):
|
||||||
|
|
||||||
architectures = getattr(self.config, "architectures", None)
|
architectures = getattr(self.config, "architectures", None)
|
||||||
model_type = getattr(self.config, "model_type", None)
|
model_type = getattr(self.config, "model_type", None)
|
||||||
|
disk_embedding = getattr(self.config, "bigdl_disk_embedding", False)
|
||||||
|
|
||||||
|
if disk_embedding:
|
||||||
|
from ipex_llm.transformers.embedding import DiskEmbedding
|
||||||
|
self.apply(DiskEmbedding.restore_normal_embedding)
|
||||||
|
self.save_pretrained(*args, **kwargs)
|
||||||
|
self.apply(DiskEmbedding.replace_normal_embedding)
|
||||||
|
else:
|
||||||
self.save_pretrained(*args, **kwargs)
|
self.save_pretrained(*args, **kwargs)
|
||||||
|
|
||||||
if architectures:
|
if architectures:
|
||||||
|
|
@ -511,14 +519,19 @@ class _BaseAutoModelClass:
|
||||||
model = ggml_convert_low_bit(model, qtype, optimize_model,
|
model = ggml_convert_low_bit(model, qtype, optimize_model,
|
||||||
modules_to_not_convert=modules_to_not_convert,
|
modules_to_not_convert=modules_to_not_convert,
|
||||||
cpu_embedding=cpu_embedding,
|
cpu_embedding=cpu_embedding,
|
||||||
disk_embedding=disk_embedding,
|
|
||||||
lightweight_bmm=lightweight_bmm,
|
lightweight_bmm=lightweight_bmm,
|
||||||
torch_dtype=kwargs.get("torch_dtype", 'auto'),
|
torch_dtype=kwargs.get("torch_dtype", 'auto'),
|
||||||
imatrix_data=imatrix_data,
|
imatrix_data=imatrix_data,
|
||||||
embedding_qtype=embedding_qtype,
|
embedding_qtype=embedding_qtype,
|
||||||
enable_xetla=enable_xetla,
|
enable_xetla=enable_xetla,
|
||||||
mixed_precision=mixed_precision)
|
mixed_precision=mixed_precision)
|
||||||
model.config.update({"bigdl_transformers_low_bit": q_k})
|
|
||||||
|
if disk_embedding:
|
||||||
|
from ipex_llm.transformers.embedding import DiskEmbedding
|
||||||
|
model.apply(DiskEmbedding.replace_normal_embedding)
|
||||||
|
|
||||||
|
model.config.update({"bigdl_transformers_low_bit": q_k,
|
||||||
|
"bigdl_disk_embedding": disk_embedding})
|
||||||
|
|
||||||
# enable tie_word_embeddings for MPT
|
# enable tie_word_embeddings for MPT
|
||||||
# refer to https://huggingface.co/mosaicml/mpt-7b-chat/blob/main/modeling_mpt.py#L232
|
# refer to https://huggingface.co/mosaicml/mpt-7b-chat/blob/main/modeling_mpt.py#L232
|
||||||
|
|
@ -706,7 +719,6 @@ class _BaseAutoModelClass:
|
||||||
model = ggml_convert_low_bit(model, qtype, optimize_model, device=quant_device,
|
model = ggml_convert_low_bit(model, qtype, optimize_model, device=quant_device,
|
||||||
modules_to_not_convert=modules_to_not_convert,
|
modules_to_not_convert=modules_to_not_convert,
|
||||||
cpu_embedding=cpu_embedding,
|
cpu_embedding=cpu_embedding,
|
||||||
disk_embedding=disk_embedding,
|
|
||||||
lightweight_bmm=lightweight_bmm,
|
lightweight_bmm=lightweight_bmm,
|
||||||
embedding_qtype=embedding_qtype, torch_dtype=torch_dtype)
|
embedding_qtype=embedding_qtype, torch_dtype=torch_dtype)
|
||||||
|
|
||||||
|
|
@ -749,6 +761,11 @@ class _BaseAutoModelClass:
|
||||||
# make sure token embedding weights are still tied if needed
|
# make sure token embedding weights are still tied if needed
|
||||||
model.tie_weights()
|
model.tie_weights()
|
||||||
|
|
||||||
|
if disk_embedding:
|
||||||
|
from ipex_llm.transformers.embedding import DiskEmbedding
|
||||||
|
model.apply(DiskEmbedding.replace_normal_embedding)
|
||||||
|
model.config.update({"bigdl_disk_embedding": disk_embedding})
|
||||||
|
|
||||||
# Set model in evaluation mode to deactivate DropOut modules by default
|
# Set model in evaluation mode to deactivate DropOut modules by default
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue