From 76e30d8ec8cda18a6d464321e40ccb1c8c3c91d1 Mon Sep 17 00:00:00 2001 From: Kai Huang Date: Wed, 13 Mar 2024 20:31:53 +0800 Subject: [PATCH] Empty cache for lm_head (#10317) * empty cache * add comments --- .../bigdl/llm/transformers/low_bit_linear.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/python/llm/src/bigdl/llm/transformers/low_bit_linear.py b/python/llm/src/bigdl/llm/transformers/low_bit_linear.py index c73eec79..0d39290e 100644 --- a/python/llm/src/bigdl/llm/transformers/low_bit_linear.py +++ b/python/llm/src/bigdl/llm/transformers/low_bit_linear.py @@ -532,8 +532,21 @@ class LowBitLinear(nn.Linear): self.compute_dtype = None # only for training self.enable_xetla = enable_xetla self.optimize_lm_head = optimize_lm_head + self.device = None # detected only once in the first forward + # empty cache before and after lm_head at first token (by default on arc) for models + # with large vocabulary (e.g. baichuan/qwen) when given long input at inference time. + # The condition makes sure that empty cache only takes effect if this layer is lm_head. + # TODO: may modify the value constraints for other models. + self.low_memory_mode = self.in_len * self.out_len >= 70000*4096 def forward(self, x: torch.Tensor): + # empty cache before and after lm_head at first token when input > 1024 + # on arc or BIGDL_LOW_MEMORY_MODE is set to 1 at inference time. + if self.device is None: + self.device = get_xpu_device_type(self.weight.data) + self.low_memory_mode = \ + self.low_memory_mode and\ + (self.device == "arc" or os.environ.get("BIGDL_LOW_MEMORY_MODE", None) == "1") # Due to inconsistent training status in some models like Baichuan-7b-Chat, # we should check both self.training and torch.is_inference_mode_enabled(). is_training = self.training and not torch.is_inference_mode_enabled() @@ -599,6 +612,10 @@ class LowBitLinear(nn.Linear): # current workaround to reduce first token latency of fp32 input # sometimes fp16 cause nan and training instability # disable the conversion when training + # TODO: may modify the input length condition for empty cache. + do_empty_cache = self.low_memory_mode and x_2d.shape[0] >= 1024 + if do_empty_cache: + torch.xpu.empty_cache() if self.conver_to_half and x_2d.shape[0] > 1 and x_2d.dtype == torch.float32 and \ not use_xmx(x_2d, self.weight.qtype): x_2d = x_2d.half() @@ -608,6 +625,8 @@ class LowBitLinear(nn.Linear): else: result = linear_q4_0.forward_new(x_2d, self.weight.data, self.weight.qtype, input_seq_size) + if do_empty_cache: + torch.xpu.empty_cache() result = result.view(new_shape) if self.mp_group is not None: from deepspeed import comm as dist