LLM: Fix DummyLayer.weight device in Pipeline Parallel (#11612)

This commit is contained in:
Xiangyu Tian 2024-07-18 13:39:34 +08:00 committed by GitHub
parent 4da93709b1
commit 4594a3dd6c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -53,7 +53,7 @@ class DummyLayer(nn.Module):
super().__init__() super().__init__()
# to avoid AttributeError in https://github.com/intel-analytics/ipex-llm/blob/main/ # to avoid AttributeError in https://github.com/intel-analytics/ipex-llm/blob/main/
# python/llm/src/ipex_llm/transformers/models/llama.py#L2076 # python/llm/src/ipex_llm/transformers/models/llama.py#L2076
self.weight = torch.randn(1,) self.weight = nn.Parameter(torch.empty(0,), requires_grad=False)
def forward(self, x): def forward(self, x):
return x return x