empty cache only for 1st token but rest token to speed up (#11665)
This commit is contained in:
		
							parent
							
								
									fc7f8feb83
								
							
						
					
					
						commit
						ba01b85c13
					
				
					 1 changed files with 6 additions and 2 deletions
				
			
		| 
						 | 
				
			
			@ -959,9 +959,13 @@ def llama_causallm_forward_4_37_lowmem(
 | 
			
		|||
        logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]  # noqa
 | 
			
		||||
        logits = torch.cat(logits, dim=-1)
 | 
			
		||||
    else:
 | 
			
		||||
        torch.xpu.empty_cache()
 | 
			
		||||
        # Only empty cache for first token
 | 
			
		||||
        if hidden_states.shape[1] > 1:
 | 
			
		||||
            torch.xpu.empty_cache()
 | 
			
		||||
        logits = self.lm_head(hidden_states)
 | 
			
		||||
        torch.xpu.empty_cache()
 | 
			
		||||
        # Only empty cache for first token
 | 
			
		||||
        if hidden_states.shape[1] > 1:
 | 
			
		||||
            torch.xpu.empty_cache()
 | 
			
		||||
    # logits = logits.float()
 | 
			
		||||
 | 
			
		||||
    # ipex-llm change ends
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue