Fix IPEX Baichuan Speculative (#10162)
* Fix IPEX Baichuan Speculative * compatible with 13B * Update speculative.py
This commit is contained in:
		
							parent
							
								
									6952847f68
								
							
						
					
					
						commit
						3e2af5ec0a
					
				
					 1 changed files with 24 additions and 8 deletions
				
			
		| 
						 | 
				
			
			@ -441,11 +441,12 @@ def speculative_generate(self,
 | 
			
		|||
    _enable_ipex = os.getenv("BIGDL_OPT_IPEX")
 | 
			
		||||
    _enable_ipex = (_enable_ipex is not None) and (_enable_ipex.lower() == "true")
 | 
			
		||||
    if _enable_ipex:
 | 
			
		||||
        if not ((self.config.model_type == 'baichuan' and self.config.hidden_size == 5120) or
 | 
			
		||||
                ('llama' in self.config.model_type) or ("chatglm" in self.config.model_type) or
 | 
			
		||||
                ("mistral" in self.config.model_type)):
 | 
			
		||||
        if not ((self.config.model_type == 'baichuan') or
 | 
			
		||||
                ('llama' in self.config.model_type) or
 | 
			
		||||
                ("mistral" in self.config.model_type) or
 | 
			
		||||
                ("chatglm" in self.config.model_type)):
 | 
			
		||||
            invalidInputError(False, "BigDL Speculative Decoding with IPEX BF16 only supports \
 | 
			
		||||
                                      Llama, Baichuan2-13b and Mistral models currently.")
 | 
			
		||||
                                      Llama, Baichuan2, Mistral and ChatGLM models currently.")
 | 
			
		||||
        if "chatglm" in self.config.model_type:
 | 
			
		||||
            global query_group_size
 | 
			
		||||
            query_group_size = draft_model.config.num_attention_heads // \
 | 
			
		||||
| 
						 | 
				
			
			@ -597,6 +598,21 @@ def speculative_generate(self,
 | 
			
		|||
                cur_attention_mask = torch.cat((attention_mask, ones_to_append), dim=1)
 | 
			
		||||
            if _enable_ipex and hasattr(self, "trace_graph"):
 | 
			
		||||
                if self.config.model_type == "baichuan":
 | 
			
		||||
                    if self.config.hidden_size == 4096:
 | 
			
		||||
                        past_key_value_len = past_key_values[0][0].shape[2]
 | 
			
		||||
                        seq_len = drafted_input_ids.shape[1]
 | 
			
		||||
                        seq_len_with_past = seq_len + past_key_value_len
 | 
			
		||||
                        position_ids = torch.arange(past_key_value_len,
 | 
			
		||||
                                                    seq_len_with_past,
 | 
			
		||||
                                                    dtype=torch.long,
 | 
			
		||||
                                                    device=drafted_input_ids.device)
 | 
			
		||||
                        position_ids = position_ids.unsqueeze(0).view(-1, seq_len)
 | 
			
		||||
                        output = self.trace_graph(input_ids=drafted_input_ids,
 | 
			
		||||
                                                  attention_mask=cur_attention_mask,
 | 
			
		||||
                                                  past_key_values=past_key_values,
 | 
			
		||||
                                                  position_ids=position_ids,
 | 
			
		||||
                                                  )
 | 
			
		||||
                    elif self.config.hidden_size == 5120:
 | 
			
		||||
                        output = self.trace_graph(input_ids=drafted_input_ids,
 | 
			
		||||
                                                  attention_mask=cur_attention_mask,
 | 
			
		||||
                                                  past_key_values=past_key_values,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue