optimize phi3 memory usage again (#11848)

This commit is contained in:
Yishuo Wang 2024-08-19 17:26:59 +08:00 committed by GitHub
parent 3cd4e87168
commit 9490781aec
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -177,12 +177,12 @@ def attention_forward(
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
attn_weights = torch.matmul(query_states,
key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3))
# use inplaced div, add and softmax to avoid double attn_weights memory usage
attn_weights.div_(math.sqrt(self.head_dim))
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
attn_weights.add_(attention_mask)
attn_weights = attention_softmax(attn_weights, self.training)
attn_weights = torch.nn.functional.dropout(attn_weights, p=self.attention_dropout,