diff --git a/python/llm/src/ipex_llm/transformers/models/chatglm2.py b/python/llm/src/ipex_llm/transformers/models/chatglm2.py index 5abff616..747b7ddd 100644 --- a/python/llm/src/ipex_llm/transformers/models/chatglm2.py +++ b/python/llm/src/ipex_llm/transformers/models/chatglm2.py @@ -100,9 +100,10 @@ def chatglm2_model_forward( if getattr(self.rotary_pos_emb, "cached_dtype", None) != inputs_embeds.dtype: rot_dim = self.rotary_pos_emb.dim - inv_freq = 1.0 / (10000 ** (torch.arange(0, rot_dim, 2, - device=inputs_embeds.device, - dtype=inputs_embeds.dtype) / rot_dim)) + base = 10000 * getattr(self.rotary_pos_emb, "rope_ratio", 1) + inv_freq = 1.0 / (base ** (torch.arange(0, rot_dim, 2, + device=inputs_embeds.device, + dtype=inputs_embeds.dtype) / rot_dim)) self.rotary_pos_emb.register_buffer("inv_freq", inv_freq, persistent=False) self.rotary_pos_emb.cached_dtype = inputs_embeds.dtype