small fix (#12397)
This commit is contained in:
parent
00fce5c940
commit
59b01fa7d2
1 changed files with 2 additions and 2 deletions
|
|
@ -934,9 +934,9 @@ class PrefillRunner:
|
||||||
hidden_states = F.pad(hidden_states.to(torch.float16), (0, 0, 0, pad_len), value=0.0)
|
hidden_states = F.pad(hidden_states.to(torch.float16), (0, 0, 0, pad_len), value=0.0)
|
||||||
position_ids = F.pad(position_ids, (0, pad_len), value=0)
|
position_ids = F.pad(position_ids, (0, pad_len), value=0)
|
||||||
attention_mask = F.pad(
|
attention_mask = F.pad(
|
||||||
attention_mask.to(torch.int64),
|
attention_mask.to(torch.float16),
|
||||||
(0, pad_len, 0, pad_len),
|
(0, pad_len, 0, pad_len),
|
||||||
value=torch.iinfo(torch.int64).min,
|
value=torch.finfo(torch.float16).min,
|
||||||
)
|
)
|
||||||
|
|
||||||
args = (hidden_states, position_ids, attention_mask, past_key_value)
|
args = (hidden_states, position_ids, attention_mask, past_key_value)
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue