From c093f7d9807f26403c5385b85c5ae9d712e5a670 Mon Sep 17 00:00:00 2001 From: Yishuo Wang Date: Wed, 7 Aug 2024 09:39:46 +0800 Subject: [PATCH] fix phi3 (#11729) --- python/llm/src/ipex_llm/transformers/models/phi3.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/llm/src/ipex_llm/transformers/models/phi3.py b/python/llm/src/ipex_llm/transformers/models/phi3.py index e2ab2539..6378b6fe 100644 --- a/python/llm/src/ipex_llm/transformers/models/phi3.py +++ b/python/llm/src/ipex_llm/transformers/models/phi3.py @@ -61,7 +61,8 @@ def pre_compute_inv_freq(module: torch.nn.Module): module.base ** (torch.arange(0, module.dim, 2, dtype=torch.int64).float() / module.dim) ) - else: + elif module.__class__.__name__ in ["Phi3SuScaledRotaryEmbedding", + "Phi3LongRoPEScaledRotaryEmbedding"]: inv_freq_shape = torch.arange(0, module.dim, 2, dtype=torch.int64).float() / module.dim short_ext_factors = torch.tensor(module.short_factor, dtype=torch.float32) module.inv_freq = 1.0 / (short_ext_factors * module.base ** inv_freq_shape)