optimize phi3 memory usage again (#11848)
This commit is contained in:
parent
3cd4e87168
commit
9490781aec
1 changed files with 4 additions and 4 deletions
|
|
@ -177,12 +177,12 @@ def attention_forward(
|
||||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||||
|
|
||||||
attn_weights = torch.matmul(query_states,
|
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3))
|
||||||
key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
|
||||||
|
|
||||||
|
# use inplaced div, add and softmax to avoid double attn_weights memory usage
|
||||||
|
attn_weights.div_(math.sqrt(self.head_dim))
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
attn_weights = attn_weights + attention_mask
|
attn_weights.add_(attention_mask)
|
||||||
|
|
||||||
attn_weights = attention_softmax(attn_weights, self.training)
|
attn_weights = attention_softmax(attn_weights, self.training)
|
||||||
|
|
||||||
attn_weights = torch.nn.functional.dropout(attn_weights, p=self.attention_dropout,
|
attn_weights = torch.nn.functional.dropout(attn_weights, p=self.attention_dropout,
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue