LLM: update qwen attention forward. (#9695)
* feat: update qwen attention forward. * fix: style.
This commit is contained in:
		
							parent
							
								
									b8437a1c1e
								
							
						
					
					
						commit
						adbef56001
					
				
					 1 changed files with 36 additions and 14 deletions
				
			
		| 
						 | 
				
			
			@ -14,7 +14,7 @@
 | 
			
		|||
# limitations under the License.
 | 
			
		||||
#
 | 
			
		||||
# Some parts of this file is adapted from
 | 
			
		||||
# https://huggingface.co/Qwen/Qwen-7B-Chat/blob/faf3ff60438d724a7eb78ebed7e2f7c7330c6bd8/modeling_qwen.py
 | 
			
		||||
# https://huggingface.co/Qwen/Qwen-7B-Chat/blob/be72f02dd47087f9035ee9bb5dea571b84785d27/modeling_qwen.py
 | 
			
		||||
#
 | 
			
		||||
# Copyright (c) Alibaba Cloud.
 | 
			
		||||
#
 | 
			
		||||
| 
						 | 
				
			
			@ -38,7 +38,7 @@ except ImportError:
 | 
			
		|||
 | 
			
		||||
from bigdl.llm.transformers.models.utils import extend_kv_cache, init_kv_cache, append_kv_cache
 | 
			
		||||
from bigdl.llm.transformers.models.utils import rotate_half
 | 
			
		||||
from bigdl.llm.utils.common import invalidInputError
 | 
			
		||||
from bigdl.llm.utils.common import invalidInputError, invalidOperationError
 | 
			
		||||
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
 | 
			
		||||
 | 
			
		||||
apply_rotary_emb_func = None
 | 
			
		||||
| 
						 | 
				
			
			@ -48,6 +48,7 @@ flash_attn_unpadded_func = None
 | 
			
		|||
logger = logging.get_logger(__name__)
 | 
			
		||||
 | 
			
		||||
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
 | 
			
		||||
SUPPORT_TORCH2 = hasattr(torch, '__version__') and int(torch.__version__.split(".")[0]) >= 2
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def apply_rotary_pos_emb(t, freqs):
 | 
			
		||||
| 
						 | 
				
			
			@ -159,7 +160,7 @@ def qwen_attention_forward(
 | 
			
		|||
        else:
 | 
			
		||||
            seq_start = key.size(1) - query.size(1)
 | 
			
		||||
            seq_end = key.size(1)
 | 
			
		||||
        logn_tensor = self.logn_tensor[:, seq_start:seq_end, :, :]
 | 
			
		||||
        logn_tensor = self.logn_tensor[:, seq_start:seq_end, :, :].type_as(query)
 | 
			
		||||
        query = query * logn_tensor.expand_as(query)
 | 
			
		||||
 | 
			
		||||
    if (
 | 
			
		||||
| 
						 | 
				
			
			@ -169,12 +170,12 @@ def qwen_attention_forward(
 | 
			
		|||
        and query.is_cuda
 | 
			
		||||
    ):
 | 
			
		||||
        q, k, v = query, key, value
 | 
			
		||||
        context_layer = self.core_attention_flash(q, k, v, attention_mask=attention_mask)
 | 
			
		||||
 | 
			
		||||
        attn_output = self.core_attention_flash(q, k, v, attention_mask=attention_mask)
 | 
			
		||||
    else:
 | 
			
		||||
        key_size = key[0].size(2) if self.use_cache_quantization else key.size(1)
 | 
			
		||||
        if query.size(1) == key_size:
 | 
			
		||||
            causal_mask = torch.tril(
 | 
			
		||||
                torch.ones((key_size, key_size), dtype=torch.bool, device=key.device)
 | 
			
		||||
                torch.ones((key_size, key_size), dtype=torch.bool, device=query.device)
 | 
			
		||||
            ).view(1, 1, key_size, key_size)
 | 
			
		||||
        else:
 | 
			
		||||
            causal_mask = None
 | 
			
		||||
| 
						 | 
				
			
			@ -189,13 +190,30 @@ def qwen_attention_forward(
 | 
			
		|||
            and not self.is_fp32
 | 
			
		||||
            and not query.is_cuda
 | 
			
		||||
        ):
 | 
			
		||||
            invalidInputError(False, _ERROR_INPUT_CPU_QUERY_WITH_FLASH_ATTN_ACTIVATED)
 | 
			
		||||
        attn_output, attn_weight = self._attn(
 | 
			
		||||
            query, key, value, causal_mask, attention_mask, head_mask
 | 
			
		||||
        )
 | 
			
		||||
        context_layer = self._merge_heads(
 | 
			
		||||
            attn_output, self.num_heads, self.head_dim
 | 
			
		||||
        )
 | 
			
		||||
            invalidOperationError(False,
 | 
			
		||||
                                  None,
 | 
			
		||||
                                  None,
 | 
			
		||||
                                  Exception(_ERROR_INPUT_CPU_QUERY_WITH_FLASH_ATTN_ACTIVATED))
 | 
			
		||||
 | 
			
		||||
        if not self.use_cache_quantization and SUPPORT_TORCH2:
 | 
			
		||||
            if attention_mask is not None:
 | 
			
		||||
                attention_mask = attention_mask.expand(-1, -1, query.size(2), -1)
 | 
			
		||||
                if causal_mask is not None:
 | 
			
		||||
                    attention_mask = attention_mask.masked_fill(~causal_mask,
 | 
			
		||||
                                                                torch.finfo(query.dtype).min)
 | 
			
		||||
            else:
 | 
			
		||||
                attention_mask = causal_mask
 | 
			
		||||
            attn_output = F.scaled_dot_product_attention(
 | 
			
		||||
                query, key, value, attn_mask=attention_mask
 | 
			
		||||
            ).transpose(1, 2)
 | 
			
		||||
            attn_weight = None
 | 
			
		||||
        else:
 | 
			
		||||
            attn_output, attn_weight = self._attn(
 | 
			
		||||
                query, key, value, causal_mask, attention_mask, head_mask
 | 
			
		||||
            )
 | 
			
		||||
    context_layer = self._merge_heads(
 | 
			
		||||
        attn_output, self.num_heads, self.head_dim
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    attn_output = self.c_proj(context_layer)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -206,7 +224,11 @@ def qwen_attention_forward(
 | 
			
		|||
            and flash_attn_unpadded_func is not None
 | 
			
		||||
            and not self.is_fp32
 | 
			
		||||
        ):
 | 
			
		||||
            invalidInputError(False, "Cannot output attentions while using flash-attn")
 | 
			
		||||
            invalidInputError(False,
 | 
			
		||||
                              f"Cannot output attentions while using flash-attn")
 | 
			
		||||
        elif not self.use_cache_quantization and SUPPORT_TORCH2:
 | 
			
		||||
            invalidInputError(False,
 | 
			
		||||
                              f"Cannot output attentions while using scaled_dot_product_attention")
 | 
			
		||||
        else:
 | 
			
		||||
            outputs += (attn_weight,)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue