From aa861df066dce8510a3204c0b8347692c7f92ec8 Mon Sep 17 00:00:00 2001 From: Yishuo Wang Date: Tue, 13 Aug 2024 14:48:11 +0800 Subject: [PATCH] use new fp32 softmax kernel (#11776) --- python/llm/src/ipex_llm/transformers/models/minicpmv.py | 5 +++-- python/llm/src/ipex_llm/transformers/models/phi3.py | 6 +++--- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/models/minicpmv.py b/python/llm/src/ipex_llm/transformers/models/minicpmv.py index c92226a0..159e61b6 100644 --- a/python/llm/src/ipex_llm/transformers/models/minicpmv.py +++ b/python/llm/src/ipex_llm/transformers/models/minicpmv.py @@ -42,8 +42,9 @@ def siglip_attention_forward( if attention_mask is not None: attn_weights = attn_weights + attention_mask - # upcast attention to fp32 - attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1) + import xe_addons + xe_addons.attn_softmax_inplaced(attn_weights) + attn_weights = torch.nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) attn_output = torch.matmul(attn_weights, value_states) diff --git a/python/llm/src/ipex_llm/transformers/models/phi3.py b/python/llm/src/ipex_llm/transformers/models/phi3.py index 443a9921..64d998cd 100644 --- a/python/llm/src/ipex_llm/transformers/models/phi3.py +++ b/python/llm/src/ipex_llm/transformers/models/phi3.py @@ -184,9 +184,9 @@ def attention_forward( if attention_mask is not None: attn_weights = attn_weights + attention_mask - # upcast attention to fp32 - attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, - dtype=torch.float32).to(value_states.dtype) + import xe_addons + xe_addons.attn_softmax_inplaced(attn_weights) + attn_weights = torch.nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) attn_output = torch.matmul(attn_weights, value_states)