fix chatglm3-6b-32k (#11303)

This commit is contained in:
Yishuo Wang 2024-06-13 16:01:34 +08:00 committed by GitHub
parent 9760ffc256
commit a24666b8f3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -100,7 +100,8 @@ def chatglm2_model_forward(
if getattr(self.rotary_pos_emb, "cached_dtype", None) != inputs_embeds.dtype: if getattr(self.rotary_pos_emb, "cached_dtype", None) != inputs_embeds.dtype:
rot_dim = self.rotary_pos_emb.dim rot_dim = self.rotary_pos_emb.dim
inv_freq = 1.0 / (10000 ** (torch.arange(0, rot_dim, 2, base = 10000 * getattr(self.rotary_pos_emb, "rope_ratio", 1)
inv_freq = 1.0 / (base ** (torch.arange(0, rot_dim, 2,
device=inputs_embeds.device, device=inputs_embeds.device,
dtype=inputs_embeds.dtype) / rot_dim)) dtype=inputs_embeds.dtype) / rot_dim))
self.rotary_pos_emb.register_buffer("inv_freq", inv_freq, persistent=False) self.rotary_pos_emb.register_buffer("inv_freq", inv_freq, persistent=False)