Empty cache after phi first attention to support 4k input (#10972)

* empty cache

* fix style
This commit is contained in:
Kai Huang 2024-05-09 19:50:04 +08:00 committed by GitHub
parent e753125880
commit a6342cc068
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -188,6 +188,13 @@ def is_linear_module(module):
return result, (in_features, out_features, mp_group)
def empty_cache_post(module, input, output):
try:
torch.xpu.empty_cache()
except: # cpu
pass
def convert_gptq(module, awq=False, llm_awq=False, act_order=False):
from ipex_llm.transformers.low_bit_linear import get_block_size
Q4_1 = get_block_size("asym_int4")
@ -1517,6 +1524,8 @@ def _optimize_post(model, lightweight_bmm=False):
from ipex_llm.transformers.models.phi3 import model_forward_wrapper
model_forward = model_forward_wrapper(module.Phi3Model.forward)
convert_forward(model, module.Phi3Model, model_forward)
# Empty cache after the first attention to run long context.
model.model.layers[0].self_attn.register_forward_hook(empty_cache_post)
elif model.config.model_type == 'yuan':
modeling_module_name = model.__class__.__module__
module = importlib.import_module(modeling_module_name)