diff --git a/python/llm/src/ipex_llm/transformers/models/phi3.py b/python/llm/src/ipex_llm/transformers/models/phi3.py index 6c8e6dfa..5c630681 100644 --- a/python/llm/src/ipex_llm/transformers/models/phi3.py +++ b/python/llm/src/ipex_llm/transformers/models/phi3.py @@ -177,12 +177,12 @@ def attention_forward( key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) - attn_weights = torch.matmul(query_states, - key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) + # use inplaced div, add and softmax to avoid double attn_weights memory usage + attn_weights.div_(math.sqrt(self.head_dim)) if attention_mask is not None: - attn_weights = attn_weights + attention_mask - + attn_weights.add_(attention_mask) attn_weights = attention_softmax(attn_weights, self.training) attn_weights = torch.nn.functional.dropout(attn_weights, p=self.attention_dropout,