Update pp llama.py to save memory (#11233)
This commit is contained in:
		
							parent
							
								
									ef8e9b2ecd
								
							
						
					
					
						commit
						6f2684e5c9
					
				
					 1 changed files with 4 additions and 3 deletions
				
			
		| 
						 | 
					@ -50,7 +50,9 @@ class LlamaModel(LlamaPreTrainedModel):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self.padding_idx = config.pad_token_id
 | 
					        self.padding_idx = config.pad_token_id
 | 
				
			||||||
        self.vocab_size = config.vocab_size
 | 
					        self.vocab_size = config.vocab_size
 | 
				
			||||||
 | 
					        if self.pp_config.is_head:
 | 
				
			||||||
            self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
 | 
					            self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
 | 
				
			||||||
 | 
					        if self.pp_config.is_tail:
 | 
				
			||||||
            self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
 | 
					            self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -259,7 +261,6 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
 | 
				
			||||||
        if self.pp_config.is_tail:
 | 
					        if self.pp_config.is_tail:
 | 
				
			||||||
            hidden_states = outputs[0]
 | 
					            hidden_states = outputs[0]
 | 
				
			||||||
            logits = self.lm_head(hidden_states)
 | 
					            logits = self.lm_head(hidden_states)
 | 
				
			||||||
            logits = logits.float()
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
            loss = None
 | 
					            loss = None
 | 
				
			||||||
            if labels is not None:
 | 
					            if labels is not None:
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue