support latest phi3 (#11721)

This commit is contained in:
Yishuo Wang 2024-08-06 15:52:55 +08:00 committed by GitHub
parent 11650b6f81
commit 929675aa6b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -36,9 +36,7 @@ import torch
import warnings
from torch import nn
from ipex_llm.transformers.models.utils import (
rotate_half, should_use_fuse_rope,
)
from ipex_llm.transformers.models.utils import should_use_fuse_rope, rotate_half
from ipex_llm.transformers.models.utils import mlp_fusion_check, SILU
from ipex_llm.transformers.models.utils import use_sdp, use_sdp_causal
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache
@ -63,7 +61,7 @@ def pre_compute_inv_freq(module: torch.nn.Module):
module.base **
(torch.arange(0, module.dim, 2, dtype=torch.int64).float() / module.dim)
)
elif module.__class__.__name__ == "Phi3SuScaledRotaryEmbedding":
else:
inv_freq_shape = torch.arange(0, module.dim, 2, dtype=torch.int64).float() / module.dim
short_ext_factors = torch.tensor(module.short_factor, dtype=torch.float32)
module.inv_freq = 1.0 / (short_ext_factors * module.base ** inv_freq_shape)