fix chatglm3-6b-32k (#11303)
This commit is contained in:
parent
9760ffc256
commit
a24666b8f3
1 changed files with 4 additions and 3 deletions
|
|
@ -100,7 +100,8 @@ def chatglm2_model_forward(
|
|||
|
||||
if getattr(self.rotary_pos_emb, "cached_dtype", None) != inputs_embeds.dtype:
|
||||
rot_dim = self.rotary_pos_emb.dim
|
||||
inv_freq = 1.0 / (10000 ** (torch.arange(0, rot_dim, 2,
|
||||
base = 10000 * getattr(self.rotary_pos_emb, "rope_ratio", 1)
|
||||
inv_freq = 1.0 / (base ** (torch.arange(0, rot_dim, 2,
|
||||
device=inputs_embeds.device,
|
||||
dtype=inputs_embeds.dtype) / rot_dim))
|
||||
self.rotary_pos_emb.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
|
|
|
|||
Loading…
Reference in a new issue