241 lines
9.1 KiB
Python
241 lines
9.1 KiB
Python
#
|
|
# Copyright 2016 The BigDL Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
|
|
|
|
import numpy
|
|
import torch
|
|
from torch import Tensor
|
|
from torch.nn import Parameter
|
|
from typing import Optional
|
|
from ipex_llm.transformers.low_bit_linear import FP4Params
|
|
from ipex_llm.utils.common import invalidInputError
|
|
|
|
|
|
# To prevent insufficient available memory when moving embedding from XPU back to CPU,
|
|
# we can pin the embedding to CPU if `cpu_embedding==True`.
|
|
class CPUPinnedParam(Parameter):
|
|
# Overwrite the device attribute for CPUPinnedParam so that its device will be same as
|
|
# the device for model.to(device);
|
|
# With this device attribute, model.device will be same as the
|
|
# the device for model.to(device) even with cpu_embedding==True
|
|
@property
|
|
def device(self):
|
|
try:
|
|
return self._device
|
|
except AttributeError:
|
|
return super().device
|
|
|
|
@device.setter
|
|
def device(self, to_device):
|
|
self._device = to_device
|
|
|
|
def to(self, *args, **kwargs):
|
|
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
|
|
if device is None:
|
|
return super().to(*args, **kwargs)
|
|
elif device.type == 'xpu':
|
|
self.device = device
|
|
if convert_to_format is not None and self.dim() in (4, 5):
|
|
return super().to('cpu', dtype,
|
|
non_blocking, memory_format=convert_to_format)
|
|
return super().to('cpu', dtype, non_blocking)
|
|
return super().to(*args, **kwargs)
|
|
|
|
|
|
class CPUEmbedding(torch.nn.Embedding):
|
|
def __init__(self,
|
|
num_embeddings: int,
|
|
embedding_dim: int,
|
|
padding_idx: Optional[int] = None,
|
|
max_norm: Optional[float] = None,
|
|
norm_type: float = 2.,
|
|
scale_grad_by_freq: bool = False,
|
|
sparse: bool = False,
|
|
_weight: Optional[Tensor] = None,
|
|
_freeze: bool = False,
|
|
device=None,
|
|
dtype=None) -> None:
|
|
super().__init__(num_embeddings, embedding_dim, padding_idx,
|
|
max_norm, norm_type, scale_grad_by_freq,
|
|
sparse, _weight, True, device, dtype)
|
|
self.weight = CPUPinnedParam(self.weight.data, requires_grad=False)
|
|
|
|
def forward(self, x: Tensor):
|
|
return super().forward(x.to('cpu')).to(x.device)
|
|
|
|
@classmethod
|
|
def from_embedding(cls, embedding: torch.nn.Embedding):
|
|
return cls(
|
|
embedding.num_embeddings,
|
|
embedding.embedding_dim,
|
|
embedding.padding_idx,
|
|
embedding.max_norm,
|
|
embedding.norm_type,
|
|
embedding.scale_grad_by_freq,
|
|
embedding.sparse,
|
|
embedding.weight.data,
|
|
True,
|
|
embedding.weight.device,
|
|
embedding.weight.dtype,
|
|
)
|
|
|
|
|
|
class DiskEmbedding(torch.nn.Embedding):
|
|
def __init__(self,
|
|
num_embeddings: int,
|
|
embedding_dim: int,
|
|
padding_idx: Optional[int] = None,
|
|
max_norm: Optional[float] = None,
|
|
norm_type: float = 2.,
|
|
scale_grad_by_freq: bool = False,
|
|
sparse: bool = False,
|
|
_weight: Optional[Tensor] = None,
|
|
_freeze: bool = False,
|
|
device=None,
|
|
dtype=None) -> None:
|
|
super().__init__(num_embeddings, embedding_dim, padding_idx,
|
|
max_norm, norm_type, scale_grad_by_freq,
|
|
sparse, _weight, True, device, dtype)
|
|
self.filename = "embeddings.bin"
|
|
self.weight.data.flatten().to(device='cpu', dtype=torch.half).numpy().tofile(self.filename)
|
|
dummy_weight = torch.empty(0, 0, dtype=self.weight.dtype, device=self.weight.device)
|
|
self.weight = torch.nn.Parameter(dummy_weight, requires_grad=False)
|
|
|
|
def forward(self, input_ids: Tensor):
|
|
ids = input_ids.cpu().flatten()
|
|
|
|
embeds = []
|
|
with open(self.filename, 'rb') as f:
|
|
for idx in ids:
|
|
f.seek(idx * self.embedding_dim * 2)
|
|
buffer = f.read(self.embedding_dim * 2)
|
|
embeds.append(torch.frombuffer(buffer, dtype=torch.half))
|
|
embeds = torch.stack(embeds).to(device=input_ids.device, dtype=self.weight.dtype)
|
|
return embeds.view(*input_ids.size(), self.embedding_dim)
|
|
|
|
@classmethod
|
|
def from_embedding(cls, embedding: torch.nn.Embedding):
|
|
return cls(
|
|
embedding.num_embeddings,
|
|
embedding.embedding_dim,
|
|
embedding.padding_idx,
|
|
embedding.max_norm,
|
|
embedding.norm_type,
|
|
embedding.scale_grad_by_freq,
|
|
embedding.sparse,
|
|
embedding.weight.data,
|
|
True,
|
|
embedding.weight.device,
|
|
embedding.weight.dtype,
|
|
)
|
|
|
|
def to_embedding(self):
|
|
with open(self.filename, 'rb') as f:
|
|
buffer = f.read()
|
|
embeds = torch.frombuffer(buffer, dtype=torch.half).clone()
|
|
embeds = embeds.view(self.num_embeddings, self.embedding_dim).to(
|
|
device=self.weight.device, dtype=self.weight.dtype
|
|
)
|
|
return torch.nn.Embedding(
|
|
self.num_embeddings,
|
|
self.embedding_dim,
|
|
self.padding_idx,
|
|
self.max_norm,
|
|
self.norm_type,
|
|
self.scale_grad_by_freq,
|
|
self.sparse,
|
|
embeds,
|
|
True,
|
|
embeds.device,
|
|
embeds.dtype,
|
|
)
|
|
|
|
@staticmethod
|
|
def replace_normal_embedding(m: torch.nn.Module):
|
|
for name, module in m.named_children():
|
|
if type(module) == torch.nn.Embedding:
|
|
m._modules[name] = DiskEmbedding.from_embedding(module)
|
|
|
|
@staticmethod
|
|
def restore_normal_embedding(m: torch.nn.Module):
|
|
for name, module in m.named_children():
|
|
if type(module) == DiskEmbedding:
|
|
m._modules[name] = module.to_embedding()
|
|
|
|
|
|
class LowBitEmbedding(torch.nn.Embedding):
|
|
def __init__(self,
|
|
num_embeddings: int,
|
|
embedding_dim: int,
|
|
padding_idx: Optional[int] = None,
|
|
max_norm: Optional[float] = None,
|
|
norm_type: float = 2.,
|
|
scale_grad_by_freq: bool = False,
|
|
sparse: bool = False,
|
|
_weight: Optional[Tensor] = None,
|
|
_freeze: bool = False,
|
|
device=None,
|
|
dtype=None,
|
|
convert_shape_only=None,
|
|
qtype=None) -> None:
|
|
super().__init__(num_embeddings, embedding_dim, padding_idx,
|
|
max_norm, norm_type, scale_grad_by_freq, sparse,
|
|
_weight, device, dtype)
|
|
self.qweight = FP4Params(self.weight.data,
|
|
requires_grad=False,
|
|
quantized=False,
|
|
_shape=None,
|
|
convert_shape_only=convert_shape_only,
|
|
qtype=qtype,
|
|
in_features=embedding_dim)
|
|
# this dummy_weight is used to record model's dtype and device
|
|
dummy_weight = torch.empty(0, 0, dtype=self.weight.dtype, device=self.weight.device)
|
|
self.weight = torch.nn.Parameter(dummy_weight, requires_grad=False)
|
|
|
|
self.embedding_dim = embedding_dim
|
|
self.num_embeddings = num_embeddings
|
|
|
|
def forward(self, x: Tensor):
|
|
invalidInputError(x.device.type == "xpu",
|
|
"`LowBitEmbedding` only supports GPU now.")
|
|
try:
|
|
import xe_linear
|
|
except ModuleNotFoundError:
|
|
invalidInputError(False,
|
|
"Please `pip install bigdl_core_xe_21` first.")
|
|
|
|
result = xe_linear.dequantize_rows(x.contiguous(), self.qweight.data,
|
|
self.qweight.qtype, self.embedding_dim,
|
|
self.num_embeddings)
|
|
return result.to(self.weight.dtype)
|
|
|
|
@classmethod
|
|
def from_embedding(cls, embedding: torch.nn.Embedding, convert_shape_only, qtype):
|
|
return cls(
|
|
embedding.num_embeddings,
|
|
embedding.embedding_dim,
|
|
embedding.padding_idx,
|
|
embedding.max_norm,
|
|
embedding.norm_type,
|
|
embedding.scale_grad_by_freq,
|
|
embedding.sparse,
|
|
embedding.weight.data,
|
|
True,
|
|
embedding.weight.device,
|
|
embedding.weight.dtype,
|
|
convert_shape_only,
|
|
qtype,
|
|
)
|