Update pp llama.py to save memory (#11233)
This commit is contained in:
parent
ef8e9b2ecd
commit
6f2684e5c9
1 changed files with 4 additions and 3 deletions
|
|
@ -50,8 +50,10 @@ class LlamaModel(LlamaPreTrainedModel):
|
||||||
|
|
||||||
self.padding_idx = config.pad_token_id
|
self.padding_idx = config.pad_token_id
|
||||||
self.vocab_size = config.vocab_size
|
self.vocab_size = config.vocab_size
|
||||||
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
if self.pp_config.is_head:
|
||||||
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
||||||
|
if self.pp_config.is_tail:
|
||||||
|
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
|
||||||
|
|
||||||
def get_input_embeddings(self):
|
def get_input_embeddings(self):
|
||||||
|
|
@ -259,7 +261,6 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
|
||||||
if self.pp_config.is_tail:
|
if self.pp_config.is_tail:
|
||||||
hidden_states = outputs[0]
|
hidden_states = outputs[0]
|
||||||
logits = self.lm_head(hidden_states)
|
logits = self.lm_head(hidden_states)
|
||||||
logits = logits.float()
|
|
||||||
|
|
||||||
loss = None
|
loss = None
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue