lm_head empty_cache for more models (#10490)

* modify constraint

* fix style
This commit is contained in:
Kai Huang 2024-03-21 17:11:43 +08:00 committed by GitHub
parent 1579ee4421
commit 30f111cd32

View file

@ -533,19 +533,22 @@ class LowBitLinear(nn.Linear):
self.enable_xetla = enable_xetla self.enable_xetla = enable_xetla
self.optimize_lm_head = optimize_lm_head self.optimize_lm_head = optimize_lm_head
self.device = None # detected only once in the first forward self.device = None # detected only once in the first forward
# empty cache before and after lm_head at first token (by default on arc) for models # empty cache before and after lm_head at first token (by default on arc)
# with large vocabulary (e.g. baichuan/qwen) when given long input at inference time. # especially for baichuan/qwen when given long input at inference time.
# The condition makes sure that empty cache only takes effect if this layer is lm_head. # The condition makes sure that empty cache only takes effect if this layer is lm_head.
# TODO: may modify the value constraints for other models. # For other models like llama, lm_cache will be applied as well
self.low_memory_mode = self.in_len * self.out_len >= 70000*4096 # since performance isn't impacted.
self.is_lm_head = self.in_len * self.out_len >= 30000 * 4096
self.low_memory_mode = self.is_lm_head
def forward(self, x: torch.Tensor): def forward(self, x: torch.Tensor):
# empty cache before and after lm_head at first token when input > 1024 # empty cache before and after lm_head at first token when input > 1024
# on arc or BIGDL_LOW_MEMORY_MODE is set to 1 at inference time. # on arc or BIGDL_LOW_MEMORY_MODE is set to 1 at inference time.
if self.device is None: if self.device is None:
self.device = get_xpu_device_type(self.weight.data) self.device = get_xpu_device_type(self.weight.data)
# TODO: may remove BIGDL_LOW_MEMORY_MODE here, probably not necessary
self.low_memory_mode = \ self.low_memory_mode = \
self.low_memory_mode and\ self.low_memory_mode and \
(self.device == "arc" or os.environ.get("BIGDL_LOW_MEMORY_MODE", None) == "1") (self.device == "arc" or os.environ.get("BIGDL_LOW_MEMORY_MODE", None) == "1")
# Due to inconsistent training status in some models like Baichuan-7b-Chat, # Due to inconsistent training status in some models like Baichuan-7b-Chat,
# we should check both self.training and torch.is_inference_mode_enabled(). # we should check both self.training and torch.is_inference_mode_enabled().