simplify qwen attention (#9747)
This commit is contained in:
		
							parent
							
								
									984697afe2
								
							
						
					
					
						commit
						426660b88e
					
				
					 1 changed files with 28 additions and 69 deletions
				
			
		| 
						 | 
				
			
			@ -74,6 +74,9 @@ def qwen_attention_forward(
 | 
			
		|||
    output_attentions: Optional[bool] = False,
 | 
			
		||||
    use_cache: Optional[bool] = False,
 | 
			
		||||
):
 | 
			
		||||
    invalidInputError(not self.use_flash_attn and not self.use_cache_quantization,
 | 
			
		||||
                      "flash attn and kv_cache quantization are not supported")
 | 
			
		||||
 | 
			
		||||
    mixed_x_layer = self.c_attn(hidden_states)
 | 
			
		||||
    query, key, value = mixed_x_layer.split(self.split_size, dim=2)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -119,12 +122,10 @@ def qwen_attention_forward(
 | 
			
		|||
    bsz, _, n_heads, head_dim = key.size()
 | 
			
		||||
 | 
			
		||||
    if layer_past is not None:
 | 
			
		||||
        kv_seq_len += layer_past[0].shape[1]
 | 
			
		||||
        # past_key, past_value = layer_past[0], layer_past[1]
 | 
			
		||||
        # key = torch.cat((past_key, key), dim=1)
 | 
			
		||||
        # value = torch.cat((past_value, value), dim=1)
 | 
			
		||||
        cache_k = layer_past[0].transpose(1, 2)
 | 
			
		||||
        cache_v = layer_past[1].transpose(1, 2)
 | 
			
		||||
        cache_k, cache_v = layer_past[0], layer_past[1]
 | 
			
		||||
        cache_k = cache_k.transpose(1, 2)
 | 
			
		||||
        cache_v = cache_v.transpose(1, 2)
 | 
			
		||||
        kv_seq_len += cache_k.shape[2]
 | 
			
		||||
        if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
 | 
			
		||||
            # allocate new
 | 
			
		||||
            new_cache_k, new_cache_v = extend_kv_cache(bsz,
 | 
			
		||||
| 
						 | 
				
			
			@ -141,8 +142,8 @@ def qwen_attention_forward(
 | 
			
		|||
 | 
			
		||||
        key_states, value_states = append_kv_cache(cache_k, cache_v,
 | 
			
		||||
                                                   key.transpose(1, 2), value.transpose(1, 2))
 | 
			
		||||
        key = key_states.transpose(1, 2)
 | 
			
		||||
        value = value_states.transpose(1, 2)
 | 
			
		||||
        key = key_states
 | 
			
		||||
        value = value_states
 | 
			
		||||
    elif use_cache:
 | 
			
		||||
        max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH
 | 
			
		||||
        new_key_states, new_value_states = init_kv_cache(bsz,
 | 
			
		||||
| 
						 | 
				
			
			@ -154,80 +155,38 @@ def qwen_attention_forward(
 | 
			
		|||
                                                         device=hidden_states.device)
 | 
			
		||||
        new_key_states[:] = key.transpose(1, 2)
 | 
			
		||||
        new_value_states[:] = value.transpose(1, 2)
 | 
			
		||||
        key = new_key_states.transpose(1, 2)
 | 
			
		||||
        value = new_value_states.transpose(1, 2)
 | 
			
		||||
        key = new_key_states
 | 
			
		||||
        value = new_value_states
 | 
			
		||||
 | 
			
		||||
    if use_cache:
 | 
			
		||||
        present = (key, value)
 | 
			
		||||
    else:
 | 
			
		||||
        present = None
 | 
			
		||||
 | 
			
		||||
    key_size = key[0].size(2) if self.use_cache_quantization else key.size(1)
 | 
			
		||||
    query_size, key_size = query.size(1), key.size(2)
 | 
			
		||||
    if key_size > self.seq_length and self.use_logn_attn and not self.training:
 | 
			
		||||
        if self.use_cache_quantization:
 | 
			
		||||
            seq_start = key[0].size(2) - query.size(1)
 | 
			
		||||
            seq_end = key[0].size(2)
 | 
			
		||||
        else:
 | 
			
		||||
            seq_start = key.size(1) - query.size(1)
 | 
			
		||||
            seq_end = key.size(1)
 | 
			
		||||
        seq_start = key_size - query_size
 | 
			
		||||
        seq_end = key_size
 | 
			
		||||
        logn_tensor = self.logn_tensor[:, seq_start:seq_end, :, :].type_as(query)
 | 
			
		||||
        query = query * logn_tensor.expand_as(query)
 | 
			
		||||
 | 
			
		||||
    if (
 | 
			
		||||
        self.use_flash_attn
 | 
			
		||||
        and flash_attn_unpadded_func is not None
 | 
			
		||||
        and not self.is_fp32
 | 
			
		||||
        and query.is_cuda
 | 
			
		||||
    ):
 | 
			
		||||
        q, k, v = query, key, value
 | 
			
		||||
        attn_output = self.core_attention_flash(q, k, v, attention_mask=attention_mask)
 | 
			
		||||
    if query_size == key_size:
 | 
			
		||||
        causal_mask = torch.tril(
 | 
			
		||||
            torch.ones((key_size, key_size), dtype=torch.bool, device=query.device)
 | 
			
		||||
        ).view(1, 1, key_size, key_size)
 | 
			
		||||
    else:
 | 
			
		||||
        key_size = key[0].size(2) if self.use_cache_quantization else key.size(1)
 | 
			
		||||
        if query.size(1) == key_size:
 | 
			
		||||
            causal_mask = torch.tril(
 | 
			
		||||
                torch.ones((key_size, key_size), dtype=torch.bool, device=query.device)
 | 
			
		||||
            ).view(1, 1, key_size, key_size)
 | 
			
		||||
        else:
 | 
			
		||||
            causal_mask = None
 | 
			
		||||
        query = query.permute(0, 2, 1, 3)
 | 
			
		||||
        if not self.use_cache_quantization:
 | 
			
		||||
            key = key.permute(0, 2, 1, 3)
 | 
			
		||||
            value = value.permute(0, 2, 1, 3)
 | 
			
		||||
        if (
 | 
			
		||||
            causal_mask is None
 | 
			
		||||
            and self.use_flash_attn
 | 
			
		||||
            and flash_attn_unpadded_func is not None
 | 
			
		||||
            and not self.is_fp32
 | 
			
		||||
            and not query.is_cuda
 | 
			
		||||
        ):
 | 
			
		||||
            invalidOperationError(False,
 | 
			
		||||
                                  None,
 | 
			
		||||
                                  None,
 | 
			
		||||
                                  Exception(_ERROR_INPUT_CPU_QUERY_WITH_FLASH_ATTN_ACTIVATED))
 | 
			
		||||
        causal_mask = None
 | 
			
		||||
    query = query.transpose(1, 2)
 | 
			
		||||
 | 
			
		||||
        attn_output, attn_weight = self._attn(
 | 
			
		||||
            query, key, value, causal_mask, attention_mask, head_mask
 | 
			
		||||
        )
 | 
			
		||||
    attn_output, attn_weight = self._attn(
 | 
			
		||||
        query, key, value, causal_mask, attention_mask, head_mask
 | 
			
		||||
    )
 | 
			
		||||
    context_layer = self._merge_heads(
 | 
			
		||||
        attn_output, self.num_heads, self.head_dim
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    attn_output = self.c_proj(context_layer)
 | 
			
		||||
 | 
			
		||||
    outputs = (attn_output, present)
 | 
			
		||||
    if use_cache:
 | 
			
		||||
        outputs = (attn_output, (key.transpose(1, 2), value.transpose(1, 2)))
 | 
			
		||||
    else:
 | 
			
		||||
        outputs = (attn_output, None)
 | 
			
		||||
    if output_attentions:
 | 
			
		||||
        if (
 | 
			
		||||
            self.use_flash_attn
 | 
			
		||||
            and flash_attn_unpadded_func is not None
 | 
			
		||||
            and not self.is_fp32
 | 
			
		||||
        ):
 | 
			
		||||
            invalidInputError(False,
 | 
			
		||||
                              f"Cannot output attentions while using flash-attn")
 | 
			
		||||
        elif not self.use_cache_quantization and SUPPORT_TORCH2:
 | 
			
		||||
            invalidInputError(False,
 | 
			
		||||
                              f"Cannot output attentions while using scaled_dot_product_attention")
 | 
			
		||||
        else:
 | 
			
		||||
            outputs += (attn_weight,)
 | 
			
		||||
        outputs += (attn_weight,)
 | 
			
		||||
 | 
			
		||||
    return outputs
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue