diff --git a/python/llm/src/bigdl/llm/transformers/speculative.py b/python/llm/src/bigdl/llm/transformers/speculative.py index 278830d9..bb0d1a1d 100644 --- a/python/llm/src/bigdl/llm/transformers/speculative.py +++ b/python/llm/src/bigdl/llm/transformers/speculative.py @@ -372,9 +372,10 @@ def speculative_generate(self, _enable_ipex = (_enable_ipex is not None) and (_enable_ipex.lower() == "true") if _enable_ipex: if not ((self.config.model_type == 'baichuan' and self.config.hidden_size == 5120) or - ('llama' in self.config.model_type)): + ('llama' in self.config.model_type) or + ("mistral" in self.config.model_type)): invalidInputError(False, "BigDL Speculative Decoding with IPEX BF16 only supports \ - Llama and Baichuan2-13b models currently.") + Llama, Baichuan2-13b and Mistral models currently.") tmp_matchness = 0 e2e_tic = 0.0 @@ -531,6 +532,19 @@ def speculative_generate(self, position_ids=position_ids, past_key_values=past_key_values, ) + elif "mistral" in self.config.model_type: + past_key_value_len = past_key_values[0][0].shape[2] + seq_len = drafted_input_ids.shape[1] + position_ids = torch.arange(past_key_value_len, + seq_len + past_key_value_len, + dtype=torch.long, + device=drafted_input_ids.device) + position_ids = position_ids.unsqueeze(0).view(-1, seq_len) + output = self.trace_graph(input_ids=drafted_input_ids, + attention_mask=cur_attention_mask, + past_key_values=past_key_values, + position_ids=position_ids, + ) logits = output[0] past_key_values = output[1] else: