LLM: Fix no return_last_logit running bigdl_ipex chatglm3 (#10678)
* fix no return_last_logits * update only for chatglm
This commit is contained in:
		
							parent
							
								
									33f90beda0
								
							
						
					
					
						commit
						47cabe8fcc
					
				
					 2 changed files with 17 additions and 8 deletions
				
			
		| 
						 | 
					@ -163,7 +163,7 @@ def _ipex_jit(model):
 | 
				
			||||||
        get_dummy_input(model, return_dict=True)
 | 
					        get_dummy_input(model, return_dict=True)
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    if "return_last_logit" in sample_inputs:
 | 
					    if "return_last_logit" in sample_inputs:
 | 
				
			||||||
        del sample_inputs["return_last_logit"]
 | 
					        sample_inputs["return_last_logit"] = torch.tensor(False)
 | 
				
			||||||
    with torch.no_grad(), torch.cpu.amp.autocast(
 | 
					    with torch.no_grad(), torch.cpu.amp.autocast(
 | 
				
			||||||
        enabled=True
 | 
					        enabled=True
 | 
				
			||||||
    ):
 | 
					    ):
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -691,12 +691,21 @@ def speculative_generate(self,
 | 
				
			||||||
                        past_key_value_len = draft_past_key_values[0][0].shape[2]
 | 
					                        past_key_value_len = draft_past_key_values[0][0].shape[2]
 | 
				
			||||||
                        position_ids = torch.Tensor([[past_key_value_len + step_draft]]).long()
 | 
					                        position_ids = torch.Tensor([[past_key_value_len + step_draft]]).long()
 | 
				
			||||||
                        position_ids = position_ids[:, :-draft_current_input_ids.size(0)]
 | 
					                        position_ids = position_ids[:, :-draft_current_input_ids.size(0)]
 | 
				
			||||||
                        draft_output = draft_model.trace_graph(
 | 
					                        if self.config.model_type == "chatglm":
 | 
				
			||||||
                            input_ids=draft_current_input_ids,
 | 
					                            draft_output = draft_model.trace_graph(
 | 
				
			||||||
                            attention_mask=draft_attention_mask,
 | 
					                                input_ids=draft_current_input_ids,
 | 
				
			||||||
                            position_ids=position_ids,
 | 
					                                attention_mask=draft_attention_mask,
 | 
				
			||||||
                            past_key_values=draft_past_key_values,
 | 
					                                position_ids=position_ids,
 | 
				
			||||||
                        )
 | 
					                                return_last_logit=torch.tensor(False),
 | 
				
			||||||
 | 
					                                past_key_values=draft_past_key_values,
 | 
				
			||||||
 | 
					                            )
 | 
				
			||||||
 | 
					                        else:
 | 
				
			||||||
 | 
					                            draft_output = draft_model.trace_graph(
 | 
				
			||||||
 | 
					                                input_ids=draft_current_input_ids,
 | 
				
			||||||
 | 
					                                attention_mask=draft_attention_mask,
 | 
				
			||||||
 | 
					                                position_ids=position_ids,
 | 
				
			||||||
 | 
					                                past_key_values=draft_past_key_values,
 | 
				
			||||||
 | 
					                            )
 | 
				
			||||||
                    elif self.config.model_type == "baichuan":
 | 
					                    elif self.config.model_type == "baichuan":
 | 
				
			||||||
                        if self.config.hidden_size == 4096:
 | 
					                        if self.config.hidden_size == 4096:
 | 
				
			||||||
                            past_key_value_len = draft_past_key_values[0][0].shape[2]
 | 
					                            past_key_value_len = draft_past_key_values[0][0].shape[2]
 | 
				
			||||||
| 
						 | 
					@ -820,7 +829,7 @@ def speculative_generate(self,
 | 
				
			||||||
                    output = self.trace_graph(input_ids=drafted_input_ids,
 | 
					                    output = self.trace_graph(input_ids=drafted_input_ids,
 | 
				
			||||||
                                              attention_mask=cur_attention_mask,
 | 
					                                              attention_mask=cur_attention_mask,
 | 
				
			||||||
                                              position_ids=position_ids,
 | 
					                                              position_ids=position_ids,
 | 
				
			||||||
                                              # return_last_logit=torch.tensor(False),
 | 
					                                              return_last_logit=torch.tensor(False),
 | 
				
			||||||
                                              past_key_values=past_key_values,)
 | 
					                                              past_key_values=past_key_values,)
 | 
				
			||||||
                elif "qwen" in self.config.model_type:
 | 
					                elif "qwen" in self.config.model_type:
 | 
				
			||||||
                    output = self.trace_graph(input_ids=drafted_input_ids,
 | 
					                    output = self.trace_graph(input_ids=drafted_input_ids,
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue