empty cache only for 1st token but rest token to speed up (#11665)
This commit is contained in:
parent
fc7f8feb83
commit
ba01b85c13
1 changed files with 6 additions and 2 deletions
|
|
@ -959,9 +959,13 @@ def llama_causallm_forward_4_37_lowmem(
|
||||||
logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)] # noqa
|
logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)] # noqa
|
||||||
logits = torch.cat(logits, dim=-1)
|
logits = torch.cat(logits, dim=-1)
|
||||||
else:
|
else:
|
||||||
torch.xpu.empty_cache()
|
# Only empty cache for first token
|
||||||
|
if hidden_states.shape[1] > 1:
|
||||||
|
torch.xpu.empty_cache()
|
||||||
logits = self.lm_head(hidden_states)
|
logits = self.lm_head(hidden_states)
|
||||||
torch.xpu.empty_cache()
|
# Only empty cache for first token
|
||||||
|
if hidden_states.shape[1] > 1:
|
||||||
|
torch.xpu.empty_cache()
|
||||||
# logits = logits.float()
|
# logits = logits.float()
|
||||||
|
|
||||||
# ipex-llm change ends
|
# ipex-llm change ends
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue