support chatglm4 in lookup (#11855)
This commit is contained in:
parent
0236de3ac2
commit
cc27321441
1 changed files with 13 additions and 6 deletions
|
|
@ -493,12 +493,19 @@ def _crop_past_key_values(self, past_key_values, new_cache_size, _enable_ipex=Fa
|
||||||
for k, v in past_key_values
|
for k, v in past_key_values
|
||||||
]
|
]
|
||||||
elif self.config.model_type == "chatglm":
|
elif self.config.model_type == "chatglm":
|
||||||
# for chatglm, cache shape is [sl, bs, nh, hn]
|
if self.config.num_layers == 40 and hasattr(self.config, 'rope_ratio'):
|
||||||
past_key_values = [
|
past_key_values = [
|
||||||
(k[:-(new_cache_size), :, :, :],
|
(k[:, :, :-(new_cache_size), :],
|
||||||
v[:-(new_cache_size), :, :, :])
|
v[:, :, :-(new_cache_size), :])
|
||||||
for k, v in past_key_values
|
for k, v in past_key_values
|
||||||
]
|
]
|
||||||
|
else:
|
||||||
|
# for chatglm, cache shape is [sl, bs, nh, hn]
|
||||||
|
past_key_values = [
|
||||||
|
(k[:-(new_cache_size), :, :, :],
|
||||||
|
v[:-(new_cache_size), :, :, :])
|
||||||
|
for k, v in past_key_values
|
||||||
|
]
|
||||||
elif self.config.model_type in ["baichuan", "gptj"]:
|
elif self.config.model_type in ["baichuan", "gptj"]:
|
||||||
past_key_values = [
|
past_key_values = [
|
||||||
(k[:, :, :-(new_cache_size), :],
|
(k[:, :, :-(new_cache_size), :],
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue