LLM: Fix no return_last_logit running bigdl_ipex chatglm3 (#10678)

* fix no return_last_logits

* update only for chatglm
This commit is contained in:
Wang, Jian4 2024-04-07 15:27:58 +08:00 committed by GitHub
parent 33f90beda0
commit 47cabe8fcc
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 17 additions and 8 deletions

View file

@ -163,7 +163,7 @@ def _ipex_jit(model):
get_dummy_input(model, return_dict=True)
)
if "return_last_logit" in sample_inputs:
del sample_inputs["return_last_logit"]
sample_inputs["return_last_logit"] = torch.tensor(False)
with torch.no_grad(), torch.cpu.amp.autocast(
enabled=True
):

View file

@ -691,12 +691,21 @@ def speculative_generate(self,
past_key_value_len = draft_past_key_values[0][0].shape[2]
position_ids = torch.Tensor([[past_key_value_len + step_draft]]).long()
position_ids = position_ids[:, :-draft_current_input_ids.size(0)]
draft_output = draft_model.trace_graph(
input_ids=draft_current_input_ids,
attention_mask=draft_attention_mask,
position_ids=position_ids,
past_key_values=draft_past_key_values,
)
if self.config.model_type == "chatglm":
draft_output = draft_model.trace_graph(
input_ids=draft_current_input_ids,
attention_mask=draft_attention_mask,
position_ids=position_ids,
return_last_logit=torch.tensor(False),
past_key_values=draft_past_key_values,
)
else:
draft_output = draft_model.trace_graph(
input_ids=draft_current_input_ids,
attention_mask=draft_attention_mask,
position_ids=position_ids,
past_key_values=draft_past_key_values,
)
elif self.config.model_type == "baichuan":
if self.config.hidden_size == 4096:
past_key_value_len = draft_past_key_values[0][0].shape[2]
@ -820,7 +829,7 @@ def speculative_generate(self,
output = self.trace_graph(input_ids=drafted_input_ids,
attention_mask=cur_attention_mask,
position_ids=position_ids,
# return_last_logit=torch.tensor(False),
return_last_logit=torch.tensor(False),
past_key_values=past_key_values,)
elif "qwen" in self.config.model_type:
output = self.trace_graph(input_ids=drafted_input_ids,