empty cache only for 1st token but rest token to speed up (#11665)
This commit is contained in:
parent
fc7f8feb83
commit
ba01b85c13
1 changed files with 6 additions and 2 deletions
|
|
@ -959,9 +959,13 @@ def llama_causallm_forward_4_37_lowmem(
|
|||
logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)] # noqa
|
||||
logits = torch.cat(logits, dim=-1)
|
||||
else:
|
||||
torch.xpu.empty_cache()
|
||||
# Only empty cache for first token
|
||||
if hidden_states.shape[1] > 1:
|
||||
torch.xpu.empty_cache()
|
||||
logits = self.lm_head(hidden_states)
|
||||
torch.xpu.empty_cache()
|
||||
# Only empty cache for first token
|
||||
if hidden_states.shape[1] > 1:
|
||||
torch.xpu.empty_cache()
|
||||
# logits = logits.float()
|
||||
|
||||
# ipex-llm change ends
|
||||
|
|
|
|||
Loading…
Reference in a new issue