parent
963a5c8d79
commit
e7e0cd3b5e
1 changed files with 32 additions and 3 deletions
|
|
@ -17,11 +17,40 @@
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
from torch.nn import functional as F
|
||||||
|
from torch.nn import Parameter
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
|
# To prevent insufficient available memory when moving embedding from XPU back to CPU,
|
||||||
|
# we can pin the embedding to CPU if `cpu_embedding==True`.
|
||||||
|
class CPUPinnedParam(Parameter):
|
||||||
|
def to(self, *args, **kwargs):
|
||||||
|
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
|
||||||
|
if device.type == 'xpu':
|
||||||
|
if convert_to_format is not None and self.dim() in (4, 5):
|
||||||
|
return super().to('cpu', dtype,
|
||||||
|
non_blocking, memory_format=convert_to_format)
|
||||||
|
return super().to('cpu', dtype, non_blocking)
|
||||||
|
return super().to(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
class LLMEmbedding(torch.nn.Embedding):
|
class LLMEmbedding(torch.nn.Embedding):
|
||||||
|
def __init__(self,
|
||||||
|
num_embeddings: int,
|
||||||
|
embedding_dim: int,
|
||||||
|
padding_idx: Optional[int] = None,
|
||||||
|
max_norm: Optional[float] = None,
|
||||||
|
norm_type: float = 2.,
|
||||||
|
scale_grad_by_freq: bool = False,
|
||||||
|
sparse: bool = False,
|
||||||
|
_weight: Optional[Tensor] = None,
|
||||||
|
_freeze: bool = False,
|
||||||
|
device=None, dtype=None) -> None:
|
||||||
|
super().__init__(num_embeddings, embedding_dim, padding_idx,
|
||||||
|
max_norm, norm_type, scale_grad_by_freq, sparse,
|
||||||
|
_weight, device, dtype)
|
||||||
|
self.weight = CPUPinnedParam(self.weight.data, requires_grad=not _freeze)
|
||||||
|
|
||||||
def forward(self, x: Tensor):
|
def forward(self, x: Tensor):
|
||||||
if self.weight.device != 'cpu':
|
|
||||||
self.to('cpu')
|
|
||||||
torch.xpu.empty_cache()
|
|
||||||
return super().forward(x.to('cpu')).to(x.device)
|
return super().forward(x.to('cpu')).to(x.device)
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue