From cc27321441cafde7a8fca5f7c45aa5c8c2cfce66 Mon Sep 17 00:00:00 2001 From: Yina Chen <33650826+cyita@users.noreply.github.com> Date: Wed, 21 Aug 2024 10:53:17 +0300 Subject: [PATCH] support chatglm4 in lookup (#11855) --- .../src/ipex_llm/transformers/speculative.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/speculative.py b/python/llm/src/ipex_llm/transformers/speculative.py index 8667da8d..4600e99f 100644 --- a/python/llm/src/ipex_llm/transformers/speculative.py +++ b/python/llm/src/ipex_llm/transformers/speculative.py @@ -493,12 +493,19 @@ def _crop_past_key_values(self, past_key_values, new_cache_size, _enable_ipex=Fa for k, v in past_key_values ] elif self.config.model_type == "chatglm": - # for chatglm, cache shape is [sl, bs, nh, hn] - past_key_values = [ - (k[:-(new_cache_size), :, :, :], - v[:-(new_cache_size), :, :, :]) - for k, v in past_key_values - ] + if self.config.num_layers == 40 and hasattr(self.config, 'rope_ratio'): + past_key_values = [ + (k[:, :, :-(new_cache_size), :], + v[:, :, :-(new_cache_size), :]) + for k, v in past_key_values + ] + else: + # for chatglm, cache shape is [sl, bs, nh, hn] + past_key_values = [ + (k[:-(new_cache_size), :, :, :], + v[:-(new_cache_size), :, :, :]) + for k, v in past_key_values + ] elif self.config.model_type in ["baichuan", "gptj"]: past_key_values = [ (k[:, :, :-(new_cache_size), :],