Update pp llama.py to save memory (#11233)
This commit is contained in:
parent
ef8e9b2ecd
commit
6f2684e5c9
1 changed files with 4 additions and 3 deletions
|
|
@ -50,7 +50,9 @@ class LlamaModel(LlamaPreTrainedModel):
|
|||
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.vocab_size = config.vocab_size
|
||||
if self.pp_config.is_head:
|
||||
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
||||
if self.pp_config.is_tail:
|
||||
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
|
||||
|
|
@ -259,7 +261,6 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
|
|||
if self.pp_config.is_tail:
|
||||
hidden_states = outputs[0]
|
||||
logits = self.lm_head(hidden_states)
|
||||
logits = logits.float()
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
|
|
|
|||
Loading…
Reference in a new issue