From 968e70544d4cb048af428792c3f770317488a9dd Mon Sep 17 00:00:00 2001 From: Heyang Sun <60865256+Uxito-Ada@users.noreply.github.com> Date: Thu, 1 Feb 2024 10:48:16 +0800 Subject: [PATCH] Enable IPEX Mistral in Speculative (#10059) --- .../src/bigdl/llm/transformers/speculative.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/python/llm/src/bigdl/llm/transformers/speculative.py b/python/llm/src/bigdl/llm/transformers/speculative.py index 278830d9..bb0d1a1d 100644 --- a/python/llm/src/bigdl/llm/transformers/speculative.py +++ b/python/llm/src/bigdl/llm/transformers/speculative.py @@ -372,9 +372,10 @@ def speculative_generate(self, _enable_ipex = (_enable_ipex is not None) and (_enable_ipex.lower() == "true") if _enable_ipex: if not ((self.config.model_type == 'baichuan' and self.config.hidden_size == 5120) or - ('llama' in self.config.model_type)): + ('llama' in self.config.model_type) or + ("mistral" in self.config.model_type)): invalidInputError(False, "BigDL Speculative Decoding with IPEX BF16 only supports \ - Llama and Baichuan2-13b models currently.") + Llama, Baichuan2-13b and Mistral models currently.") tmp_matchness = 0 e2e_tic = 0.0 @@ -531,6 +532,19 @@ def speculative_generate(self, position_ids=position_ids, past_key_values=past_key_values, ) + elif "mistral" in self.config.model_type: + past_key_value_len = past_key_values[0][0].shape[2] + seq_len = drafted_input_ids.shape[1] + position_ids = torch.arange(past_key_value_len, + seq_len + past_key_value_len, + dtype=torch.long, + device=drafted_input_ids.device) + position_ids = position_ids.unsqueeze(0).view(-1, seq_len) + output = self.trace_graph(input_ids=drafted_input_ids, + attention_mask=cur_attention_mask, + past_key_values=past_key_values, + position_ids=position_ids, + ) logits = output[0] past_key_values = output[1] else: