diff --git a/python/llm/src/ipex_llm/transformers/pipeline_parallel.py b/python/llm/src/ipex_llm/transformers/pipeline_parallel.py index 2e6e99b5..de51cbca 100644 --- a/python/llm/src/ipex_llm/transformers/pipeline_parallel.py +++ b/python/llm/src/ipex_llm/transformers/pipeline_parallel.py @@ -959,9 +959,13 @@ def llama_causallm_forward_4_37_lowmem( logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)] # noqa logits = torch.cat(logits, dim=-1) else: - torch.xpu.empty_cache() + # Only empty cache for first token + if hidden_states.shape[1] > 1: + torch.xpu.empty_cache() logits = self.lm_head(hidden_states) - torch.xpu.empty_cache() + # Only empty cache for first token + if hidden_states.shape[1] > 1: + torch.xpu.empty_cache() # logits = logits.float() # ipex-llm change ends