support latest phi3 (#11721)
This commit is contained in:
		
							parent
							
								
									11650b6f81
								
							
						
					
					
						commit
						929675aa6b
					
				
					 1 changed files with 2 additions and 4 deletions
				
			
		| 
						 | 
				
			
			@ -36,9 +36,7 @@ import torch
 | 
			
		|||
import warnings
 | 
			
		||||
from torch import nn
 | 
			
		||||
 | 
			
		||||
from ipex_llm.transformers.models.utils import (
 | 
			
		||||
    rotate_half, should_use_fuse_rope,
 | 
			
		||||
)
 | 
			
		||||
from ipex_llm.transformers.models.utils import should_use_fuse_rope, rotate_half
 | 
			
		||||
from ipex_llm.transformers.models.utils import mlp_fusion_check, SILU
 | 
			
		||||
from ipex_llm.transformers.models.utils import use_sdp, use_sdp_causal
 | 
			
		||||
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache
 | 
			
		||||
| 
						 | 
				
			
			@ -63,7 +61,7 @@ def pre_compute_inv_freq(module: torch.nn.Module):
 | 
			
		|||
            module.base **
 | 
			
		||||
            (torch.arange(0, module.dim, 2, dtype=torch.int64).float() / module.dim)
 | 
			
		||||
        )
 | 
			
		||||
    elif module.__class__.__name__ == "Phi3SuScaledRotaryEmbedding":
 | 
			
		||||
    else:
 | 
			
		||||
        inv_freq_shape = torch.arange(0, module.dim, 2, dtype=torch.int64).float() / module.dim
 | 
			
		||||
        short_ext_factors = torch.tensor(module.short_factor, dtype=torch.float32)
 | 
			
		||||
        module.inv_freq = 1.0 / (short_ext_factors * module.base ** inv_freq_shape)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue