diff --git a/python/llm/src/ipex_llm/transformers/models/phi3.py b/python/llm/src/ipex_llm/transformers/models/phi3.py index e2ab2539..6378b6fe 100644 --- a/python/llm/src/ipex_llm/transformers/models/phi3.py +++ b/python/llm/src/ipex_llm/transformers/models/phi3.py @@ -61,7 +61,8 @@ def pre_compute_inv_freq(module: torch.nn.Module): module.base ** (torch.arange(0, module.dim, 2, dtype=torch.int64).float() / module.dim) ) - else: + elif module.__class__.__name__ in ["Phi3SuScaledRotaryEmbedding", + "Phi3LongRoPEScaledRotaryEmbedding"]: inv_freq_shape = torch.arange(0, module.dim, 2, dtype=torch.int64).float() / module.dim short_ext_factors = torch.tensor(module.short_factor, dtype=torch.float32) module.inv_freq = 1.0 / (short_ext_factors * module.base ** inv_freq_shape)