diff --git a/python/llm/src/bigdl/llm/transformers/embedding.py b/python/llm/src/bigdl/llm/transformers/embedding.py index 2764d01e..38aa4db7 100644 --- a/python/llm/src/bigdl/llm/transformers/embedding.py +++ b/python/llm/src/bigdl/llm/transformers/embedding.py @@ -21,5 +21,6 @@ from torch import Tensor class LLMEmbedding(torch.nn.Embedding): def forward(self, x: Tensor): - x_shape = x.shape - return self.weight[x.reshape(-1)].reshape(*x_shape, -1) + if self.weight.device != 'cpu': + self.to('cpu') + return super().forward(x.to('cpu')).to(x.device)