diff --git a/python/llm/src/bigdl/llm/transformers/models/qwen.py b/python/llm/src/bigdl/llm/transformers/models/qwen.py index 2ef6062c..e044cd72 100644 --- a/python/llm/src/bigdl/llm/transformers/models/qwen.py +++ b/python/llm/src/bigdl/llm/transformers/models/qwen.py @@ -14,7 +14,7 @@ # limitations under the License. # # Some parts of this file is adapted from -# https://huggingface.co/Qwen/Qwen-7B-Chat/blob/faf3ff60438d724a7eb78ebed7e2f7c7330c6bd8/modeling_qwen.py +# https://huggingface.co/Qwen/Qwen-7B-Chat/blob/be72f02dd47087f9035ee9bb5dea571b84785d27/modeling_qwen.py # # Copyright (c) Alibaba Cloud. # @@ -38,7 +38,7 @@ except ImportError: from bigdl.llm.transformers.models.utils import extend_kv_cache, init_kv_cache, append_kv_cache from bigdl.llm.transformers.models.utils import rotate_half -from bigdl.llm.utils.common import invalidInputError +from bigdl.llm.utils.common import invalidInputError, invalidOperationError from bigdl.llm.ggml.quantize import ggml_tensor_qtype apply_rotary_emb_func = None @@ -48,6 +48,7 @@ flash_attn_unpadded_func = None logger = logging.get_logger(__name__) KV_CACHE_ALLOC_BLOCK_LENGTH = 256 +SUPPORT_TORCH2 = hasattr(torch, '__version__') and int(torch.__version__.split(".")[0]) >= 2 def apply_rotary_pos_emb(t, freqs): @@ -159,7 +160,7 @@ def qwen_attention_forward( else: seq_start = key.size(1) - query.size(1) seq_end = key.size(1) - logn_tensor = self.logn_tensor[:, seq_start:seq_end, :, :] + logn_tensor = self.logn_tensor[:, seq_start:seq_end, :, :].type_as(query) query = query * logn_tensor.expand_as(query) if ( @@ -169,12 +170,12 @@ def qwen_attention_forward( and query.is_cuda ): q, k, v = query, key, value - context_layer = self.core_attention_flash(q, k, v, attention_mask=attention_mask) - + attn_output = self.core_attention_flash(q, k, v, attention_mask=attention_mask) else: + key_size = key[0].size(2) if self.use_cache_quantization else key.size(1) if query.size(1) == key_size: causal_mask = torch.tril( - torch.ones((key_size, key_size), dtype=torch.bool, device=key.device) + torch.ones((key_size, key_size), dtype=torch.bool, device=query.device) ).view(1, 1, key_size, key_size) else: causal_mask = None @@ -189,13 +190,30 @@ def qwen_attention_forward( and not self.is_fp32 and not query.is_cuda ): - invalidInputError(False, _ERROR_INPUT_CPU_QUERY_WITH_FLASH_ATTN_ACTIVATED) - attn_output, attn_weight = self._attn( - query, key, value, causal_mask, attention_mask, head_mask - ) - context_layer = self._merge_heads( - attn_output, self.num_heads, self.head_dim - ) + invalidOperationError(False, + None, + None, + Exception(_ERROR_INPUT_CPU_QUERY_WITH_FLASH_ATTN_ACTIVATED)) + + if not self.use_cache_quantization and SUPPORT_TORCH2: + if attention_mask is not None: + attention_mask = attention_mask.expand(-1, -1, query.size(2), -1) + if causal_mask is not None: + attention_mask = attention_mask.masked_fill(~causal_mask, + torch.finfo(query.dtype).min) + else: + attention_mask = causal_mask + attn_output = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask + ).transpose(1, 2) + attn_weight = None + else: + attn_output, attn_weight = self._attn( + query, key, value, causal_mask, attention_mask, head_mask + ) + context_layer = self._merge_heads( + attn_output, self.num_heads, self.head_dim + ) attn_output = self.c_proj(context_layer) @@ -206,7 +224,11 @@ def qwen_attention_forward( and flash_attn_unpadded_func is not None and not self.is_fp32 ): - invalidInputError(False, "Cannot output attentions while using flash-attn") + invalidInputError(False, + f"Cannot output attentions while using flash-attn") + elif not self.use_cache_quantization and SUPPORT_TORCH2: + invalidInputError(False, + f"Cannot output attentions while using scaled_dot_product_attention") else: outputs += (attn_weight,)