LLM: Fix no return_last_logit running bigdl_ipex chatglm3 (#10678)
* fix no return_last_logits * update only for chatglm
This commit is contained in:
parent
33f90beda0
commit
47cabe8fcc
2 changed files with 17 additions and 8 deletions
|
|
@ -163,7 +163,7 @@ def _ipex_jit(model):
|
|||
get_dummy_input(model, return_dict=True)
|
||||
)
|
||||
if "return_last_logit" in sample_inputs:
|
||||
del sample_inputs["return_last_logit"]
|
||||
sample_inputs["return_last_logit"] = torch.tensor(False)
|
||||
with torch.no_grad(), torch.cpu.amp.autocast(
|
||||
enabled=True
|
||||
):
|
||||
|
|
|
|||
|
|
@ -691,12 +691,21 @@ def speculative_generate(self,
|
|||
past_key_value_len = draft_past_key_values[0][0].shape[2]
|
||||
position_ids = torch.Tensor([[past_key_value_len + step_draft]]).long()
|
||||
position_ids = position_ids[:, :-draft_current_input_ids.size(0)]
|
||||
draft_output = draft_model.trace_graph(
|
||||
input_ids=draft_current_input_ids,
|
||||
attention_mask=draft_attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=draft_past_key_values,
|
||||
)
|
||||
if self.config.model_type == "chatglm":
|
||||
draft_output = draft_model.trace_graph(
|
||||
input_ids=draft_current_input_ids,
|
||||
attention_mask=draft_attention_mask,
|
||||
position_ids=position_ids,
|
||||
return_last_logit=torch.tensor(False),
|
||||
past_key_values=draft_past_key_values,
|
||||
)
|
||||
else:
|
||||
draft_output = draft_model.trace_graph(
|
||||
input_ids=draft_current_input_ids,
|
||||
attention_mask=draft_attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=draft_past_key_values,
|
||||
)
|
||||
elif self.config.model_type == "baichuan":
|
||||
if self.config.hidden_size == 4096:
|
||||
past_key_value_len = draft_past_key_values[0][0].shape[2]
|
||||
|
|
@ -820,7 +829,7 @@ def speculative_generate(self,
|
|||
output = self.trace_graph(input_ids=drafted_input_ids,
|
||||
attention_mask=cur_attention_mask,
|
||||
position_ids=position_ids,
|
||||
# return_last_logit=torch.tensor(False),
|
||||
return_last_logit=torch.tensor(False),
|
||||
past_key_values=past_key_values,)
|
||||
elif "qwen" in self.config.model_type:
|
||||
output = self.trace_graph(input_ids=drafted_input_ids,
|
||||
|
|
|
|||
Loading…
Reference in a new issue