diff --git a/python/llm/src/ipex_llm/transformers/speculative.py b/python/llm/src/ipex_llm/transformers/speculative.py index 8667da8d..4600e99f 100644 --- a/python/llm/src/ipex_llm/transformers/speculative.py +++ b/python/llm/src/ipex_llm/transformers/speculative.py @@ -493,12 +493,19 @@ def _crop_past_key_values(self, past_key_values, new_cache_size, _enable_ipex=Fa for k, v in past_key_values ] elif self.config.model_type == "chatglm": - # for chatglm, cache shape is [sl, bs, nh, hn] - past_key_values = [ - (k[:-(new_cache_size), :, :, :], - v[:-(new_cache_size), :, :, :]) - for k, v in past_key_values - ] + if self.config.num_layers == 40 and hasattr(self.config, 'rope_ratio'): + past_key_values = [ + (k[:, :, :-(new_cache_size), :], + v[:, :, :-(new_cache_size), :]) + for k, v in past_key_values + ] + else: + # for chatglm, cache shape is [sl, bs, nh, hn] + past_key_values = [ + (k[:-(new_cache_size), :, :, :], + v[:-(new_cache_size), :, :, :]) + for k, v in past_key_values + ] elif self.config.model_type in ["baichuan", "gptj"]: past_key_values = [ (k[:, :, :-(new_cache_size), :],