fix typos

This commit is contained in:
Huang, Xinshengzi 2024-08-22 11:23:22 +08:00
parent 4adadddbbc
commit 2a0aa9271b

View file

@ -89,9 +89,11 @@ def baichuan_model_7b_forward(
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_attentions = output_attentions if output_attentions is not None \
else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
output_hidden_states if output_hidden_states is not None else
self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
@ -110,7 +112,8 @@ def baichuan_model_7b_forward(
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at \
the same time")
elif input_ids is not None:
batch_size, seq_length = input_ids.shape
elif inputs_embeds is not None:
@ -130,9 +133,8 @@ def baichuan_model_7b_forward(
if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
)
position_ids = torch.arange(past_key_values_length, seq_length + past_key_values_length,
dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else:
position_ids = position_ids.view(-1, seq_length).long()
@ -216,7 +218,8 @@ def baichuan_model_7b_forward(
next_cache = next_decoder_cache if use_cache else None
if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
if v is not None)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,