LLM: Fix no return_last_logit running bigdl_ipex chatglm3 (#10678)
* fix no return_last_logits * update only for chatglm
This commit is contained in:
		
							parent
							
								
									33f90beda0
								
							
						
					
					
						commit
						47cabe8fcc
					
				
					 2 changed files with 17 additions and 8 deletions
				
			
		| 
						 | 
				
			
			@ -163,7 +163,7 @@ def _ipex_jit(model):
 | 
			
		|||
        get_dummy_input(model, return_dict=True)
 | 
			
		||||
    )
 | 
			
		||||
    if "return_last_logit" in sample_inputs:
 | 
			
		||||
        del sample_inputs["return_last_logit"]
 | 
			
		||||
        sample_inputs["return_last_logit"] = torch.tensor(False)
 | 
			
		||||
    with torch.no_grad(), torch.cpu.amp.autocast(
 | 
			
		||||
        enabled=True
 | 
			
		||||
    ):
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -691,12 +691,21 @@ def speculative_generate(self,
 | 
			
		|||
                        past_key_value_len = draft_past_key_values[0][0].shape[2]
 | 
			
		||||
                        position_ids = torch.Tensor([[past_key_value_len + step_draft]]).long()
 | 
			
		||||
                        position_ids = position_ids[:, :-draft_current_input_ids.size(0)]
 | 
			
		||||
                        draft_output = draft_model.trace_graph(
 | 
			
		||||
                            input_ids=draft_current_input_ids,
 | 
			
		||||
                            attention_mask=draft_attention_mask,
 | 
			
		||||
                            position_ids=position_ids,
 | 
			
		||||
                            past_key_values=draft_past_key_values,
 | 
			
		||||
                        )
 | 
			
		||||
                        if self.config.model_type == "chatglm":
 | 
			
		||||
                            draft_output = draft_model.trace_graph(
 | 
			
		||||
                                input_ids=draft_current_input_ids,
 | 
			
		||||
                                attention_mask=draft_attention_mask,
 | 
			
		||||
                                position_ids=position_ids,
 | 
			
		||||
                                return_last_logit=torch.tensor(False),
 | 
			
		||||
                                past_key_values=draft_past_key_values,
 | 
			
		||||
                            )
 | 
			
		||||
                        else:
 | 
			
		||||
                            draft_output = draft_model.trace_graph(
 | 
			
		||||
                                input_ids=draft_current_input_ids,
 | 
			
		||||
                                attention_mask=draft_attention_mask,
 | 
			
		||||
                                position_ids=position_ids,
 | 
			
		||||
                                past_key_values=draft_past_key_values,
 | 
			
		||||
                            )
 | 
			
		||||
                    elif self.config.model_type == "baichuan":
 | 
			
		||||
                        if self.config.hidden_size == 4096:
 | 
			
		||||
                            past_key_value_len = draft_past_key_values[0][0].shape[2]
 | 
			
		||||
| 
						 | 
				
			
			@ -820,7 +829,7 @@ def speculative_generate(self,
 | 
			
		|||
                    output = self.trace_graph(input_ids=drafted_input_ids,
 | 
			
		||||
                                              attention_mask=cur_attention_mask,
 | 
			
		||||
                                              position_ids=position_ids,
 | 
			
		||||
                                              # return_last_logit=torch.tensor(False),
 | 
			
		||||
                                              return_last_logit=torch.tensor(False),
 | 
			
		||||
                                              past_key_values=past_key_values,)
 | 
			
		||||
                elif "qwen" in self.config.model_type:
 | 
			
		||||
                    output = self.trace_graph(input_ids=drafted_input_ids,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue