diff --git a/python/llm/src/bigdl/llm/transformers/models/qwen.py b/python/llm/src/bigdl/llm/transformers/models/qwen.py index e044cd72..43dca325 100644 --- a/python/llm/src/bigdl/llm/transformers/models/qwen.py +++ b/python/llm/src/bigdl/llm/transformers/models/qwen.py @@ -195,22 +195,26 @@ def qwen_attention_forward( None, Exception(_ERROR_INPUT_CPU_QUERY_WITH_FLASH_ATTN_ACTIVATED)) - if not self.use_cache_quantization and SUPPORT_TORCH2: - if attention_mask is not None: - attention_mask = attention_mask.expand(-1, -1, query.size(2), -1) - if causal_mask is not None: - attention_mask = attention_mask.masked_fill(~causal_mask, - torch.finfo(query.dtype).min) - else: - attention_mask = causal_mask - attn_output = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask - ).transpose(1, 2) - attn_weight = None - else: - attn_output, attn_weight = self._attn( - query, key, value, causal_mask, attention_mask, head_mask - ) + # Remove for efficiency issue on Arc, maybe add later. + # if not self.use_cache_quantization and SUPPORT_TORCH2: + # if attention_mask is not None: + # attention_mask = attention_mask.expand(-1, -1, query.size(2), -1) + # if causal_mask is not None: + # attention_mask = attention_mask.masked_fill(~causal_mask, + # torch.finfo(query.dtype).min) + # else: + # attention_mask = causal_mask + # attn_output = F.scaled_dot_product_attention( + # query, key, value, attn_mask=attention_mask + # ).transpose(1, 2) + # attn_weight = None + # else: + # attn_output, attn_weight = self._attn( + # query, key, value, causal_mask, attention_mask, head_mask + # ) + attn_output, attn_weight = self._attn( + query, key, value, causal_mask, attention_mask, head_mask + ) context_layer = self._merge_heads( attn_output, self.num_heads, self.head_dim )