fix typos
This commit is contained in:
parent
4adadddbbc
commit
2a0aa9271b
1 changed files with 10 additions and 7 deletions
|
|
@ -89,9 +89,11 @@ def baichuan_model_7b_forward(
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
output_attentions = output_attentions if output_attentions is not None \
|
||||||
|
else self.config.output_attentions
|
||||||
output_hidden_states = (
|
output_hidden_states = (
|
||||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
output_hidden_states if output_hidden_states is not None else
|
||||||
|
self.config.output_hidden_states
|
||||||
)
|
)
|
||||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||||
|
|
||||||
|
|
@ -110,7 +112,8 @@ def baichuan_model_7b_forward(
|
||||||
|
|
||||||
# retrieve input_ids and inputs_embeds
|
# retrieve input_ids and inputs_embeds
|
||||||
if input_ids is not None and inputs_embeds is not None:
|
if input_ids is not None and inputs_embeds is not None:
|
||||||
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
|
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at \
|
||||||
|
the same time")
|
||||||
elif input_ids is not None:
|
elif input_ids is not None:
|
||||||
batch_size, seq_length = input_ids.shape
|
batch_size, seq_length = input_ids.shape
|
||||||
elif inputs_embeds is not None:
|
elif inputs_embeds is not None:
|
||||||
|
|
@ -130,9 +133,8 @@ def baichuan_model_7b_forward(
|
||||||
|
|
||||||
if position_ids is None:
|
if position_ids is None:
|
||||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||||
position_ids = torch.arange(
|
position_ids = torch.arange(past_key_values_length, seq_length + past_key_values_length,
|
||||||
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
|
dtype=torch.long, device=device)
|
||||||
)
|
|
||||||
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
||||||
else:
|
else:
|
||||||
position_ids = position_ids.view(-1, seq_length).long()
|
position_ids = position_ids.view(-1, seq_length).long()
|
||||||
|
|
@ -216,7 +218,8 @@ def baichuan_model_7b_forward(
|
||||||
|
|
||||||
next_cache = next_decoder_cache if use_cache else None
|
next_cache = next_decoder_cache if use_cache else None
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
|
||||||
|
if v is not None)
|
||||||
return BaseModelOutputWithPast(
|
return BaseModelOutputWithPast(
|
||||||
last_hidden_state=hidden_states,
|
last_hidden_state=hidden_states,
|
||||||
past_key_values=next_cache,
|
past_key_values=next_cache,
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue