LLM: Fix no return_last_logit running bigdl_ipex chatglm3 (#10678)
* fix no return_last_logits * update only for chatglm
This commit is contained in:
parent
33f90beda0
commit
47cabe8fcc
2 changed files with 17 additions and 8 deletions
|
|
@ -163,7 +163,7 @@ def _ipex_jit(model):
|
||||||
get_dummy_input(model, return_dict=True)
|
get_dummy_input(model, return_dict=True)
|
||||||
)
|
)
|
||||||
if "return_last_logit" in sample_inputs:
|
if "return_last_logit" in sample_inputs:
|
||||||
del sample_inputs["return_last_logit"]
|
sample_inputs["return_last_logit"] = torch.tensor(False)
|
||||||
with torch.no_grad(), torch.cpu.amp.autocast(
|
with torch.no_grad(), torch.cpu.amp.autocast(
|
||||||
enabled=True
|
enabled=True
|
||||||
):
|
):
|
||||||
|
|
|
||||||
|
|
@ -691,6 +691,15 @@ def speculative_generate(self,
|
||||||
past_key_value_len = draft_past_key_values[0][0].shape[2]
|
past_key_value_len = draft_past_key_values[0][0].shape[2]
|
||||||
position_ids = torch.Tensor([[past_key_value_len + step_draft]]).long()
|
position_ids = torch.Tensor([[past_key_value_len + step_draft]]).long()
|
||||||
position_ids = position_ids[:, :-draft_current_input_ids.size(0)]
|
position_ids = position_ids[:, :-draft_current_input_ids.size(0)]
|
||||||
|
if self.config.model_type == "chatglm":
|
||||||
|
draft_output = draft_model.trace_graph(
|
||||||
|
input_ids=draft_current_input_ids,
|
||||||
|
attention_mask=draft_attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
return_last_logit=torch.tensor(False),
|
||||||
|
past_key_values=draft_past_key_values,
|
||||||
|
)
|
||||||
|
else:
|
||||||
draft_output = draft_model.trace_graph(
|
draft_output = draft_model.trace_graph(
|
||||||
input_ids=draft_current_input_ids,
|
input_ids=draft_current_input_ids,
|
||||||
attention_mask=draft_attention_mask,
|
attention_mask=draft_attention_mask,
|
||||||
|
|
@ -820,7 +829,7 @@ def speculative_generate(self,
|
||||||
output = self.trace_graph(input_ids=drafted_input_ids,
|
output = self.trace_graph(input_ids=drafted_input_ids,
|
||||||
attention_mask=cur_attention_mask,
|
attention_mask=cur_attention_mask,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
# return_last_logit=torch.tensor(False),
|
return_last_logit=torch.tensor(False),
|
||||||
past_key_values=past_key_values,)
|
past_key_values=past_key_values,)
|
||||||
elif "qwen" in self.config.model_type:
|
elif "qwen" in self.config.model_type:
|
||||||
output = self.trace_graph(input_ids=drafted_input_ids,
|
output = self.trace_graph(input_ids=drafted_input_ids,
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue