From 929675aa6bfb5aadeee48afa2bbc2a4e77c5ac88 Mon Sep 17 00:00:00 2001 From: Yishuo Wang Date: Tue, 6 Aug 2024 15:52:55 +0800 Subject: [PATCH] support latest phi3 (#11721) --- python/llm/src/ipex_llm/transformers/models/phi3.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/models/phi3.py b/python/llm/src/ipex_llm/transformers/models/phi3.py index 9247ea94..e2ab2539 100644 --- a/python/llm/src/ipex_llm/transformers/models/phi3.py +++ b/python/llm/src/ipex_llm/transformers/models/phi3.py @@ -36,9 +36,7 @@ import torch import warnings from torch import nn -from ipex_llm.transformers.models.utils import ( - rotate_half, should_use_fuse_rope, -) +from ipex_llm.transformers.models.utils import should_use_fuse_rope, rotate_half from ipex_llm.transformers.models.utils import mlp_fusion_check, SILU from ipex_llm.transformers.models.utils import use_sdp, use_sdp_causal from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache @@ -63,7 +61,7 @@ def pre_compute_inv_freq(module: torch.nn.Module): module.base ** (torch.arange(0, module.dim, 2, dtype=torch.int64).float() / module.dim) ) - elif module.__class__.__name__ == "Phi3SuScaledRotaryEmbedding": + else: inv_freq_shape = torch.arange(0, module.dim, 2, dtype=torch.int64).float() / module.dim short_ext_factors = torch.tensor(module.short_factor, dtype=torch.float32) module.inv_freq = 1.0 / (short_ext_factors * module.base ** inv_freq_shape)