From a6342cc068a331ce84ed793686ce74c362d74acc Mon Sep 17 00:00:00 2001 From: Kai Huang Date: Thu, 9 May 2024 19:50:04 +0800 Subject: [PATCH] Empty cache after phi first attention to support 4k input (#10972) * empty cache * fix style --- python/llm/src/ipex_llm/transformers/convert.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/python/llm/src/ipex_llm/transformers/convert.py b/python/llm/src/ipex_llm/transformers/convert.py index fc0766ae..e6e0e009 100644 --- a/python/llm/src/ipex_llm/transformers/convert.py +++ b/python/llm/src/ipex_llm/transformers/convert.py @@ -188,6 +188,13 @@ def is_linear_module(module): return result, (in_features, out_features, mp_group) +def empty_cache_post(module, input, output): + try: + torch.xpu.empty_cache() + except: # cpu + pass + + def convert_gptq(module, awq=False, llm_awq=False, act_order=False): from ipex_llm.transformers.low_bit_linear import get_block_size Q4_1 = get_block_size("asym_int4") @@ -1517,6 +1524,8 @@ def _optimize_post(model, lightweight_bmm=False): from ipex_llm.transformers.models.phi3 import model_forward_wrapper model_forward = model_forward_wrapper(module.Phi3Model.forward) convert_forward(model, module.Phi3Model, model_forward) + # Empty cache after the first attention to run long context. + model.model.layers[0].self_attn.register_forward_hook(empty_cache_post) elif model.config.model_type == 'yuan': modeling_module_name = model.__class__.__module__ module = importlib.import_module(modeling_module_name)