Update pp llama.py to save memory (#11233)

This commit is contained in:
Wang, Jian4 2024-06-07 13:18:16 +08:00 committed by GitHub
parent ef8e9b2ecd
commit 6f2684e5c9
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -50,7 +50,9 @@ class LlamaModel(LlamaPreTrainedModel):
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
if self.pp_config.is_head:
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
if self.pp_config.is_tail:
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@ -259,7 +261,6 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
if self.pp_config.is_tail:
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
logits = logits.float()
loss = None
if labels is not None: