From e1bc18f8eba52b87c1bcf7be749a14bc7c324fe9 Mon Sep 17 00:00:00 2001 From: Yang Wang Date: Wed, 1 Nov 2023 11:31:34 +0800 Subject: [PATCH] fix import ipex problem (#9323) * fix import ipex problem * fix style --- .../src/bigdl/llm/transformers/models/llama.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/python/llm/src/bigdl/llm/transformers/models/llama.py b/python/llm/src/bigdl/llm/transformers/models/llama.py index 94515ea0..798facec 100644 --- a/python/llm/src/bigdl/llm/transformers/models/llama.py +++ b/python/llm/src/bigdl/llm/transformers/models/llama.py @@ -59,21 +59,23 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: KV_CACHE_ALLOC_BLOCK_LENGTH = 256 +_ipex_version = None + + def get_ipex_version(): - if importlib.util.find_spec("intel_extension_for_pytorch") is not None: - import intel_extension_for_pytorch as ipex - return ipex.__version__ - else: - return None + global _ipex_version + if _ipex_version is not None: + return _ipex_version - -ipex_version = get_ipex_version() + import intel_extension_for_pytorch as ipex + _ipex_version = ipex.__version__ + return _ipex_version def llama_rms_norm_forward(self, hidden_states): if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad): - if ipex_version == "2.0.110+xpu": + if get_ipex_version() == "2.0.110+xpu": hidden_states, _ = torch.ops.torch_ipex.rms_norm(hidden_states, [self.weight.size(0)], self.weight) else: