diff --git a/python/llm/src/bigdl/llm/transformers/speculative.py b/python/llm/src/bigdl/llm/transformers/speculative.py index c167ed3c..60501e0b 100644 --- a/python/llm/src/bigdl/llm/transformers/speculative.py +++ b/python/llm/src/bigdl/llm/transformers/speculative.py @@ -468,7 +468,7 @@ def speculative_generate(self, position_ids = torch.Tensor([[past_key_value_len + step_draft]]).long() forward_args["position_ids"] = position_ids elif self.config.model_type == "gptj": - past_length = draft_past_key_values[0][0].size(1) + past_length = draft_past_key_values[0][0].size(2) position_ids = torch.Tensor([[past_length]]).long().to(self.device) forward_args["position_ids"] = position_ids draft_output = draft_model(**forward_args) @@ -563,7 +563,7 @@ def speculative_generate(self, position_ids = position_ids.unsqueeze(0).repeat(1, 1) + past_key_value_len forward_args["position_ids"] = position_ids elif self.config.model_type == "gptj": - past_length = past_key_values[0][0].size(1) + past_length = past_key_values[0][0].size(2) input_len = drafted_input_ids.shape[1] position_ids = torch.arange(past_length, input_len + past_length, dtype=torch.long, device=drafted_input_ids.device) @@ -644,7 +644,7 @@ def speculative_generate(self, past_key_values = [[tmp, key_cache, value_cache, beam_idx] for _, key_cache, value_cache, beam_idx in past_key_values] else: - if self.config.model_type in ["qwen", "gptj"]: + if self.config.model_type in ["qwen"]: past_key_values = [ (k[:, :-(max_of_max_matched - max_matched), :], v[:, :-(max_of_max_matched - max_matched), :]) @@ -657,7 +657,7 @@ def speculative_generate(self, v[:-(max_of_max_matched - max_matched), :, :, :]) for k, v in past_key_values ] - elif self.config.model_type == "baichuan": + elif self.config.model_type in ["baichuan", "gptj"]: past_key_values = [ (k[:, :, :-(max_of_max_matched - max_matched), :], v[:, :, :-(max_of_max_matched - max_matched), :])