diff --git a/python/llm/src/ipex_llm/transformers/models/baichuan.py b/python/llm/src/ipex_llm/transformers/models/baichuan.py index 9d4688df..21d86b06 100644 --- a/python/llm/src/ipex_llm/transformers/models/baichuan.py +++ b/python/llm/src/ipex_llm/transformers/models/baichuan.py @@ -89,9 +89,11 @@ def baichuan_model_7b_forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, BaseModelOutputWithPast]: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_attentions = output_attentions if output_attentions is not None \ + else self.config.output_attentions output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + output_hidden_states if output_hidden_states is not None else + self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache @@ -110,7 +112,8 @@ def baichuan_model_7b_forward( # retrieve input_ids and inputs_embeds if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at \ + the same time") elif input_ids is not None: batch_size, seq_length = input_ids.shape elif inputs_embeds is not None: @@ -130,9 +133,8 @@ def baichuan_model_7b_forward( if position_ids is None: device = input_ids.device if input_ids is not None else inputs_embeds.device - position_ids = torch.arange( - past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device - ) + position_ids = torch.arange(past_key_values_length, seq_length + past_key_values_length, + dtype=torch.long, device=device) position_ids = position_ids.unsqueeze(0).view(-1, seq_length) else: position_ids = position_ids.view(-1, seq_length).long() @@ -216,7 +218,8 @@ def baichuan_model_7b_forward( next_cache = next_decoder_cache if use_cache else None if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] + if v is not None) return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=next_cache,