From 30f111cd32e30a0e1dd263edfbacdf1afa940a24 Mon Sep 17 00:00:00 2001 From: Kai Huang Date: Thu, 21 Mar 2024 17:11:43 +0800 Subject: [PATCH] lm_head empty_cache for more models (#10490) * modify constraint * fix style --- .../src/bigdl/llm/transformers/low_bit_linear.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/python/llm/src/bigdl/llm/transformers/low_bit_linear.py b/python/llm/src/bigdl/llm/transformers/low_bit_linear.py index 0d39290e..702e7da5 100644 --- a/python/llm/src/bigdl/llm/transformers/low_bit_linear.py +++ b/python/llm/src/bigdl/llm/transformers/low_bit_linear.py @@ -533,19 +533,22 @@ class LowBitLinear(nn.Linear): self.enable_xetla = enable_xetla self.optimize_lm_head = optimize_lm_head self.device = None # detected only once in the first forward - # empty cache before and after lm_head at first token (by default on arc) for models - # with large vocabulary (e.g. baichuan/qwen) when given long input at inference time. + # empty cache before and after lm_head at first token (by default on arc) + # especially for baichuan/qwen when given long input at inference time. # The condition makes sure that empty cache only takes effect if this layer is lm_head. - # TODO: may modify the value constraints for other models. - self.low_memory_mode = self.in_len * self.out_len >= 70000*4096 + # For other models like llama, lm_cache will be applied as well + # since performance isn't impacted. + self.is_lm_head = self.in_len * self.out_len >= 30000 * 4096 + self.low_memory_mode = self.is_lm_head def forward(self, x: torch.Tensor): # empty cache before and after lm_head at first token when input > 1024 # on arc or BIGDL_LOW_MEMORY_MODE is set to 1 at inference time. if self.device is None: self.device = get_xpu_device_type(self.weight.data) + # TODO: may remove BIGDL_LOW_MEMORY_MODE here, probably not necessary self.low_memory_mode = \ - self.low_memory_mode and\ + self.low_memory_mode and \ (self.device == "arc" or os.environ.get("BIGDL_LOW_MEMORY_MODE", None) == "1") # Due to inconsistent training status in some models like Baichuan-7b-Chat, # we should check both self.training and torch.is_inference_mode_enabled().