diff --git a/python/llm/src/bigdl/llm/transformers/speculative.py b/python/llm/src/bigdl/llm/transformers/speculative.py index 9dc028e9..aa9324f1 100644 --- a/python/llm/src/bigdl/llm/transformers/speculative.py +++ b/python/llm/src/bigdl/llm/transformers/speculative.py @@ -441,11 +441,12 @@ def speculative_generate(self, _enable_ipex = os.getenv("BIGDL_OPT_IPEX") _enable_ipex = (_enable_ipex is not None) and (_enable_ipex.lower() == "true") if _enable_ipex: - if not ((self.config.model_type == 'baichuan' and self.config.hidden_size == 5120) or - ('llama' in self.config.model_type) or ("chatglm" in self.config.model_type) or - ("mistral" in self.config.model_type)): + if not ((self.config.model_type == 'baichuan') or + ('llama' in self.config.model_type) or + ("mistral" in self.config.model_type) or + ("chatglm" in self.config.model_type)): invalidInputError(False, "BigDL Speculative Decoding with IPEX BF16 only supports \ - Llama, Baichuan2-13b and Mistral models currently.") + Llama, Baichuan2, Mistral and ChatGLM models currently.") if "chatglm" in self.config.model_type: global query_group_size query_group_size = draft_model.config.num_attention_heads // \ @@ -597,10 +598,25 @@ def speculative_generate(self, cur_attention_mask = torch.cat((attention_mask, ones_to_append), dim=1) if _enable_ipex and hasattr(self, "trace_graph"): if self.config.model_type == "baichuan": - output = self.trace_graph(input_ids=drafted_input_ids, - attention_mask=cur_attention_mask, - past_key_values=past_key_values, - ) + if self.config.hidden_size == 4096: + past_key_value_len = past_key_values[0][0].shape[2] + seq_len = drafted_input_ids.shape[1] + seq_len_with_past = seq_len + past_key_value_len + position_ids = torch.arange(past_key_value_len, + seq_len_with_past, + dtype=torch.long, + device=drafted_input_ids.device) + position_ids = position_ids.unsqueeze(0).view(-1, seq_len) + output = self.trace_graph(input_ids=drafted_input_ids, + attention_mask=cur_attention_mask, + past_key_values=past_key_values, + position_ids=position_ids, + ) + elif self.config.hidden_size == 5120: + output = self.trace_graph(input_ids=drafted_input_ids, + attention_mask=cur_attention_mask, + past_key_values=past_key_values, + ) elif "llama" in self.config.model_type: past_key_value_len = past_key_values[0][0].shape[2] position_ids = torch.arange(drafted_input_ids.shape[1], dtype=torch.long,