Fix IPEX Baichuan Speculative (#10162)

* Fix IPEX Baichuan Speculative

* compatible with 13B

* Update speculative.py
This commit is contained in:
Heyang Sun 2024-02-19 15:27:34 +08:00 committed by GitHub
parent 6952847f68
commit 3e2af5ec0a

View file

@ -441,11 +441,12 @@ def speculative_generate(self,
_enable_ipex = os.getenv("BIGDL_OPT_IPEX")
_enable_ipex = (_enable_ipex is not None) and (_enable_ipex.lower() == "true")
if _enable_ipex:
if not ((self.config.model_type == 'baichuan' and self.config.hidden_size == 5120) or
('llama' in self.config.model_type) or ("chatglm" in self.config.model_type) or
("mistral" in self.config.model_type)):
if not ((self.config.model_type == 'baichuan') or
('llama' in self.config.model_type) or
("mistral" in self.config.model_type) or
("chatglm" in self.config.model_type)):
invalidInputError(False, "BigDL Speculative Decoding with IPEX BF16 only supports \
Llama, Baichuan2-13b and Mistral models currently.")
Llama, Baichuan2, Mistral and ChatGLM models currently.")
if "chatglm" in self.config.model_type:
global query_group_size
query_group_size = draft_model.config.num_attention_heads // \
@ -597,10 +598,25 @@ def speculative_generate(self,
cur_attention_mask = torch.cat((attention_mask, ones_to_append), dim=1)
if _enable_ipex and hasattr(self, "trace_graph"):
if self.config.model_type == "baichuan":
output = self.trace_graph(input_ids=drafted_input_ids,
attention_mask=cur_attention_mask,
past_key_values=past_key_values,
)
if self.config.hidden_size == 4096:
past_key_value_len = past_key_values[0][0].shape[2]
seq_len = drafted_input_ids.shape[1]
seq_len_with_past = seq_len + past_key_value_len
position_ids = torch.arange(past_key_value_len,
seq_len_with_past,
dtype=torch.long,
device=drafted_input_ids.device)
position_ids = position_ids.unsqueeze(0).view(-1, seq_len)
output = self.trace_graph(input_ids=drafted_input_ids,
attention_mask=cur_attention_mask,
past_key_values=past_key_values,
position_ids=position_ids,
)
elif self.config.hidden_size == 5120:
output = self.trace_graph(input_ids=drafted_input_ids,
attention_mask=cur_attention_mask,
past_key_values=past_key_values,
)
elif "llama" in self.config.model_type:
past_key_value_len = past_key_values[0][0].shape[2]
position_ids = torch.arange(drafted_input_ids.shape[1], dtype=torch.long,