From a24666b8f3931045515b9c7ca75c1d544a7a9b88 Mon Sep 17 00:00:00 2001 From: Yishuo Wang Date: Thu, 13 Jun 2024 16:01:34 +0800 Subject: [PATCH] fix chatglm3-6b-32k (#11303) --- python/llm/src/ipex_llm/transformers/models/chatglm2.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/models/chatglm2.py b/python/llm/src/ipex_llm/transformers/models/chatglm2.py index 5abff616..747b7ddd 100644 --- a/python/llm/src/ipex_llm/transformers/models/chatglm2.py +++ b/python/llm/src/ipex_llm/transformers/models/chatglm2.py @@ -100,9 +100,10 @@ def chatglm2_model_forward( if getattr(self.rotary_pos_emb, "cached_dtype", None) != inputs_embeds.dtype: rot_dim = self.rotary_pos_emb.dim - inv_freq = 1.0 / (10000 ** (torch.arange(0, rot_dim, 2, - device=inputs_embeds.device, - dtype=inputs_embeds.dtype) / rot_dim)) + base = 10000 * getattr(self.rotary_pos_emb, "rope_ratio", 1) + inv_freq = 1.0 / (base ** (torch.arange(0, rot_dim, 2, + device=inputs_embeds.device, + dtype=inputs_embeds.dtype) / rot_dim)) self.rotary_pos_emb.register_buffer("inv_freq", inv_freq, persistent=False) self.rotary_pos_emb.cached_dtype = inputs_embeds.dtype