Enable IPEX Mistral in Speculative (#10059)
This commit is contained in:
parent
3ca03d4e97
commit
968e70544d
1 changed files with 16 additions and 2 deletions
|
|
@ -372,9 +372,10 @@ def speculative_generate(self,
|
||||||
_enable_ipex = (_enable_ipex is not None) and (_enable_ipex.lower() == "true")
|
_enable_ipex = (_enable_ipex is not None) and (_enable_ipex.lower() == "true")
|
||||||
if _enable_ipex:
|
if _enable_ipex:
|
||||||
if not ((self.config.model_type == 'baichuan' and self.config.hidden_size == 5120) or
|
if not ((self.config.model_type == 'baichuan' and self.config.hidden_size == 5120) or
|
||||||
('llama' in self.config.model_type)):
|
('llama' in self.config.model_type) or
|
||||||
|
("mistral" in self.config.model_type)):
|
||||||
invalidInputError(False, "BigDL Speculative Decoding with IPEX BF16 only supports \
|
invalidInputError(False, "BigDL Speculative Decoding with IPEX BF16 only supports \
|
||||||
Llama and Baichuan2-13b models currently.")
|
Llama, Baichuan2-13b and Mistral models currently.")
|
||||||
|
|
||||||
tmp_matchness = 0
|
tmp_matchness = 0
|
||||||
e2e_tic = 0.0
|
e2e_tic = 0.0
|
||||||
|
|
@ -531,6 +532,19 @@ def speculative_generate(self,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
past_key_values=past_key_values,
|
past_key_values=past_key_values,
|
||||||
)
|
)
|
||||||
|
elif "mistral" in self.config.model_type:
|
||||||
|
past_key_value_len = past_key_values[0][0].shape[2]
|
||||||
|
seq_len = drafted_input_ids.shape[1]
|
||||||
|
position_ids = torch.arange(past_key_value_len,
|
||||||
|
seq_len + past_key_value_len,
|
||||||
|
dtype=torch.long,
|
||||||
|
device=drafted_input_ids.device)
|
||||||
|
position_ids = position_ids.unsqueeze(0).view(-1, seq_len)
|
||||||
|
output = self.trace_graph(input_ids=drafted_input_ids,
|
||||||
|
attention_mask=cur_attention_mask,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
position_ids=position_ids,
|
||||||
|
)
|
||||||
logits = output[0]
|
logits = output[0]
|
||||||
past_key_values = output[1]
|
past_key_values = output[1]
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue