fix typos

This commit is contained in:
Huang, Xinshengzi 2024-08-22 11:23:22 +08:00
parent 4adadddbbc
commit 2a0aa9271b

View file

@ -89,9 +89,11 @@ def baichuan_model_7b_forward(
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPast]: ) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = output_attentions if output_attentions is not None \
else self.config.output_attentions
output_hidden_states = ( output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states output_hidden_states if output_hidden_states is not None else
self.config.output_hidden_states
) )
use_cache = use_cache if use_cache is not None else self.config.use_cache use_cache = use_cache if use_cache is not None else self.config.use_cache
@ -110,7 +112,8 @@ def baichuan_model_7b_forward(
# retrieve input_ids and inputs_embeds # retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None: if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at \
the same time")
elif input_ids is not None: elif input_ids is not None:
batch_size, seq_length = input_ids.shape batch_size, seq_length = input_ids.shape
elif inputs_embeds is not None: elif inputs_embeds is not None:
@ -130,9 +133,8 @@ def baichuan_model_7b_forward(
if position_ids is None: if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange( position_ids = torch.arange(past_key_values_length, seq_length + past_key_values_length,
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device dtype=torch.long, device=device)
)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length) position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else: else:
position_ids = position_ids.view(-1, seq_length).long() position_ids = position_ids.view(-1, seq_length).long()
@ -216,7 +218,8 @@ def baichuan_model_7b_forward(
next_cache = next_decoder_cache if use_cache else None next_cache = next_decoder_cache if use_cache else None
if not return_dict: if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
if v is not None)
return BaseModelOutputWithPast( return BaseModelOutputWithPast(
last_hidden_state=hidden_states, last_hidden_state=hidden_states,
past_key_values=next_cache, past_key_values=next_cache,