empty cache only for 1st token but rest token to speed up (#11665)
This commit is contained in:
		
							parent
							
								
									fc7f8feb83
								
							
						
					
					
						commit
						ba01b85c13
					
				
					 1 changed files with 6 additions and 2 deletions
				
			
		| 
						 | 
					@ -959,8 +959,12 @@ def llama_causallm_forward_4_37_lowmem(
 | 
				
			||||||
        logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]  # noqa
 | 
					        logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]  # noqa
 | 
				
			||||||
        logits = torch.cat(logits, dim=-1)
 | 
					        logits = torch.cat(logits, dim=-1)
 | 
				
			||||||
    else:
 | 
					    else:
 | 
				
			||||||
 | 
					        # Only empty cache for first token
 | 
				
			||||||
 | 
					        if hidden_states.shape[1] > 1:
 | 
				
			||||||
            torch.xpu.empty_cache()
 | 
					            torch.xpu.empty_cache()
 | 
				
			||||||
        logits = self.lm_head(hidden_states)
 | 
					        logits = self.lm_head(hidden_states)
 | 
				
			||||||
 | 
					        # Only empty cache for first token
 | 
				
			||||||
 | 
					        if hidden_states.shape[1] > 1:
 | 
				
			||||||
            torch.xpu.empty_cache()
 | 
					            torch.xpu.empty_cache()
 | 
				
			||||||
    # logits = logits.float()
 | 
					    # logits = logits.float()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue