From 47cabe8fcc42639e570d1c4e97d188ba2793a77b Mon Sep 17 00:00:00 2001 From: "Wang, Jian4" <61138589+hzjane@users.noreply.github.com> Date: Sun, 7 Apr 2024 15:27:58 +0800 Subject: [PATCH] LLM: Fix no return_last_logit running bigdl_ipex chatglm3 (#10678) * fix no return_last_logits * update only for chatglm --- .../src/ipex_llm/transformers/convert_ipex.py | 2 +- .../src/ipex_llm/transformers/speculative.py | 23 +++++++++++++------ 2 files changed, 17 insertions(+), 8 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/convert_ipex.py b/python/llm/src/ipex_llm/transformers/convert_ipex.py index 4d6764bf..6368d13a 100644 --- a/python/llm/src/ipex_llm/transformers/convert_ipex.py +++ b/python/llm/src/ipex_llm/transformers/convert_ipex.py @@ -163,7 +163,7 @@ def _ipex_jit(model): get_dummy_input(model, return_dict=True) ) if "return_last_logit" in sample_inputs: - del sample_inputs["return_last_logit"] + sample_inputs["return_last_logit"] = torch.tensor(False) with torch.no_grad(), torch.cpu.amp.autocast( enabled=True ): diff --git a/python/llm/src/ipex_llm/transformers/speculative.py b/python/llm/src/ipex_llm/transformers/speculative.py index 2585890d..cc9f627e 100644 --- a/python/llm/src/ipex_llm/transformers/speculative.py +++ b/python/llm/src/ipex_llm/transformers/speculative.py @@ -691,12 +691,21 @@ def speculative_generate(self, past_key_value_len = draft_past_key_values[0][0].shape[2] position_ids = torch.Tensor([[past_key_value_len + step_draft]]).long() position_ids = position_ids[:, :-draft_current_input_ids.size(0)] - draft_output = draft_model.trace_graph( - input_ids=draft_current_input_ids, - attention_mask=draft_attention_mask, - position_ids=position_ids, - past_key_values=draft_past_key_values, - ) + if self.config.model_type == "chatglm": + draft_output = draft_model.trace_graph( + input_ids=draft_current_input_ids, + attention_mask=draft_attention_mask, + position_ids=position_ids, + return_last_logit=torch.tensor(False), + past_key_values=draft_past_key_values, + ) + else: + draft_output = draft_model.trace_graph( + input_ids=draft_current_input_ids, + attention_mask=draft_attention_mask, + position_ids=position_ids, + past_key_values=draft_past_key_values, + ) elif self.config.model_type == "baichuan": if self.config.hidden_size == 4096: past_key_value_len = draft_past_key_values[0][0].shape[2] @@ -820,7 +829,7 @@ def speculative_generate(self, output = self.trace_graph(input_ids=drafted_input_ids, attention_mask=cur_attention_mask, position_ids=position_ids, - # return_last_logit=torch.tensor(False), + return_last_logit=torch.tensor(False), past_key_values=past_key_values,) elif "qwen" in self.config.model_type: output = self.trace_graph(input_ids=drafted_input_ids,