From 6f2684e5c900eedfab7a7a3fcb0b1c705b9050cb Mon Sep 17 00:00:00 2001 From: "Wang, Jian4" <61138589+hzjane@users.noreply.github.com> Date: Fri, 7 Jun 2024 13:18:16 +0800 Subject: [PATCH] Update pp llama.py to save memory (#11233) --- .../example/GPU/Pipeline-Parallel-FastAPI/llama_models.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/python/llm/example/GPU/Pipeline-Parallel-FastAPI/llama_models.py b/python/llm/example/GPU/Pipeline-Parallel-FastAPI/llama_models.py index 4244a735..29bac7a7 100644 --- a/python/llm/example/GPU/Pipeline-Parallel-FastAPI/llama_models.py +++ b/python/llm/example/GPU/Pipeline-Parallel-FastAPI/llama_models.py @@ -50,8 +50,10 @@ class LlamaModel(LlamaPreTrainedModel): self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size - self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) - self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + if self.pp_config.is_head: + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + if self.pp_config.is_tail: + self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) def get_input_embeddings(self): @@ -259,7 +261,6 @@ class LlamaForCausalLM(LlamaPreTrainedModel): if self.pp_config.is_tail: hidden_states = outputs[0] logits = self.lm_head(hidden_states) - logits = logits.float() loss = None if labels is not None: