diff --git a/python/llm/example/GPU/Pipeline-Parallel-FastAPI/llama_models.py b/python/llm/example/GPU/Pipeline-Parallel-FastAPI/llama_models.py index 4244a735..29bac7a7 100644 --- a/python/llm/example/GPU/Pipeline-Parallel-FastAPI/llama_models.py +++ b/python/llm/example/GPU/Pipeline-Parallel-FastAPI/llama_models.py @@ -50,8 +50,10 @@ class LlamaModel(LlamaPreTrainedModel): self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size - self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) - self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + if self.pp_config.is_head: + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + if self.pp_config.is_tail: + self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) def get_input_embeddings(self): @@ -259,7 +261,6 @@ class LlamaForCausalLM(LlamaPreTrainedModel): if self.pp_config.is_tail: hidden_states = outputs[0] logits = self.lm_head(hidden_states) - logits = logits.float() loss = None if labels is not None: