[LLM] Fix transformer qwen size mismatch and rename causal_mask (#9655)
* Fix size mismatching caused by context_layer * Change registered_causal_mask to causal_mask
This commit is contained in:
		
							parent
							
								
									2fe38b4b9b
								
							
						
					
					
						commit
						8931f2eb62
					
				
					 1 changed files with 5 additions and 7 deletions
				
			
		| 
						 | 
				
			
			@ -64,7 +64,6 @@ def qwen_attention_forward(
 | 
			
		|||
    self,
 | 
			
		||||
    hidden_states: Optional[Tuple[torch.FloatTensor]],
 | 
			
		||||
    rotary_pos_emb_list: Optional[List[torch.Tensor]] = None,
 | 
			
		||||
    registered_causal_mask: Optional[torch.Tensor] = None,
 | 
			
		||||
    layer_past: Optional[Tuple[torch.Tensor]] = None,
 | 
			
		||||
    attention_mask: Optional[torch.FloatTensor] = None,
 | 
			
		||||
    head_mask: Optional[torch.FloatTensor] = None,
 | 
			
		||||
| 
						 | 
				
			
			@ -171,20 +170,19 @@ def qwen_attention_forward(
 | 
			
		|||
        q, k, v = query, key, value
 | 
			
		||||
        context_layer = self.core_attention_flash(q, k, v, attention_mask=attention_mask)
 | 
			
		||||
 | 
			
		||||
        # b s h d -> b s (h d)
 | 
			
		||||
        context_layer = context_layer.flatten(2, 3).contiguous()
 | 
			
		||||
 | 
			
		||||
    else:
 | 
			
		||||
        if query.size(1) == key_size:
 | 
			
		||||
            registered_causal_mask = torch.tril(
 | 
			
		||||
            causal_mask = torch.tril(
 | 
			
		||||
                torch.ones((key_size, key_size), dtype=torch.bool, device=key.device)
 | 
			
		||||
            ).view(1, 1, key_size, key_size)
 | 
			
		||||
        else:
 | 
			
		||||
            causal_mask = None
 | 
			
		||||
        query = query.permute(0, 2, 1, 3)
 | 
			
		||||
        if not self.use_cache_quantization:
 | 
			
		||||
            key = key.permute(0, 2, 1, 3)
 | 
			
		||||
            value = value.permute(0, 2, 1, 3)
 | 
			
		||||
        if (
 | 
			
		||||
            registered_causal_mask is None
 | 
			
		||||
            causal_mask is None
 | 
			
		||||
            and self.use_flash_attn
 | 
			
		||||
            and flash_attn_unpadded_func is not None
 | 
			
		||||
            and not self.is_fp32
 | 
			
		||||
| 
						 | 
				
			
			@ -192,7 +190,7 @@ def qwen_attention_forward(
 | 
			
		|||
        ):
 | 
			
		||||
            invalidInputError(False, _ERROR_INPUT_CPU_QUERY_WITH_FLASH_ATTN_ACTIVATED)
 | 
			
		||||
        attn_output, attn_weight = self._attn(
 | 
			
		||||
            query, key, value, registered_causal_mask, attention_mask, head_mask
 | 
			
		||||
            query, key, value, causal_mask, attention_mask, head_mask
 | 
			
		||||
        )
 | 
			
		||||
        context_layer = self._merge_heads(
 | 
			
		||||
            attn_output, self.num_heads, self.head_dim
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue