Empty cache after embedding to cpu (#9477)

This commit is contained in:
Yuwen Hu 2023-11-16 10:52:30 +08:00 committed by GitHub
parent c487b53f21
commit 731b0aaade

View file

@ -23,4 +23,5 @@ class LLMEmbedding(torch.nn.Embedding):
def forward(self, x: Tensor): def forward(self, x: Tensor):
if self.weight.device != 'cpu': if self.weight.device != 'cpu':
self.to('cpu') self.to('cpu')
torch.xpu.empty_cache()
return super().forward(x.to('cpu')).to(x.device) return super().forward(x.to('cpu')).to(x.device)