fix phi3 (#11729)
This commit is contained in:
		
							parent
							
								
									e32d13d78c
								
							
						
					
					
						commit
						c093f7d980
					
				
					 1 changed files with 2 additions and 1 deletions
				
			
		| 
						 | 
					@ -61,7 +61,8 @@ def pre_compute_inv_freq(module: torch.nn.Module):
 | 
				
			||||||
            module.base **
 | 
					            module.base **
 | 
				
			||||||
            (torch.arange(0, module.dim, 2, dtype=torch.int64).float() / module.dim)
 | 
					            (torch.arange(0, module.dim, 2, dtype=torch.int64).float() / module.dim)
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
    else:
 | 
					    elif module.__class__.__name__ in ["Phi3SuScaledRotaryEmbedding",
 | 
				
			||||||
 | 
					                                       "Phi3LongRoPEScaledRotaryEmbedding"]:
 | 
				
			||||||
        inv_freq_shape = torch.arange(0, module.dim, 2, dtype=torch.int64).float() / module.dim
 | 
					        inv_freq_shape = torch.arange(0, module.dim, 2, dtype=torch.int64).float() / module.dim
 | 
				
			||||||
        short_ext_factors = torch.tensor(module.short_factor, dtype=torch.float32)
 | 
					        short_ext_factors = torch.tensor(module.short_factor, dtype=torch.float32)
 | 
				
			||||||
        module.inv_freq = 1.0 / (short_ext_factors * module.base ** inv_freq_shape)
 | 
					        module.inv_freq = 1.0 / (short_ext_factors * module.base ** inv_freq_shape)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue