diff --git a/python/llm/src/bigdl/llm/transformers/embedding.py b/python/llm/src/bigdl/llm/transformers/embedding.py index 38aa4db7..aa99e2d7 100644 --- a/python/llm/src/bigdl/llm/transformers/embedding.py +++ b/python/llm/src/bigdl/llm/transformers/embedding.py @@ -23,4 +23,5 @@ class LLMEmbedding(torch.nn.Embedding): def forward(self, x: Tensor): if self.weight.device != 'cpu': self.to('cpu') + torch.xpu.empty_cache() return super().forward(x.to('cpu')).to(x.device)