support chatglm4 in lookup (#11855)

This commit is contained in:
Yina Chen 2024-08-21 10:53:17 +03:00 committed by GitHub
parent 0236de3ac2
commit cc27321441
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -493,12 +493,19 @@ def _crop_past_key_values(self, past_key_values, new_cache_size, _enable_ipex=Fa
for k, v in past_key_values for k, v in past_key_values
] ]
elif self.config.model_type == "chatglm": elif self.config.model_type == "chatglm":
# for chatglm, cache shape is [sl, bs, nh, hn] if self.config.num_layers == 40 and hasattr(self.config, 'rope_ratio'):
past_key_values = [ past_key_values = [
(k[:-(new_cache_size), :, :, :], (k[:, :, :-(new_cache_size), :],
v[:-(new_cache_size), :, :, :]) v[:, :, :-(new_cache_size), :])
for k, v in past_key_values for k, v in past_key_values
] ]
else:
# for chatglm, cache shape is [sl, bs, nh, hn]
past_key_values = [
(k[:-(new_cache_size), :, :, :],
v[:-(new_cache_size), :, :, :])
for k, v in past_key_values
]
elif self.config.model_type in ["baichuan", "gptj"]: elif self.config.model_type in ["baichuan", "gptj"]:
past_key_values = [ past_key_values = [
(k[:, :, :-(new_cache_size), :], (k[:, :, :-(new_cache_size), :],