small improvement (#12359)
This commit is contained in:
		
							parent
							
								
									71ea539351
								
							
						
					
					
						commit
						ad68c56573
					
				
					 1 changed files with 2 additions and 6 deletions
				
			
		| 
						 | 
				
			
			@ -167,12 +167,8 @@ def qwen2_model_forward(
 | 
			
		|||
    from transformers.models.qwen2.modeling_qwen2 import _prepare_4d_causal_attention_mask_for_sdpa
 | 
			
		||||
    from transformers.models.qwen2.modeling_qwen2 import _prepare_4d_causal_attention_mask
 | 
			
		||||
 | 
			
		||||
    # ipex-llm changes start: don't generate `attention_mask` in specific cases
 | 
			
		||||
    if seq_length == 1 or batch_size == 1 and use_sdp_causal(
 | 
			
		||||
        seq_length, seq_length + past_key_values_length,
 | 
			
		||||
        self.config.hidden_size // self.config.num_attention_heads,
 | 
			
		||||
        inputs_embeds, self.training
 | 
			
		||||
    ):
 | 
			
		||||
    # ipex-llm changes start: don't generate `attention_mask` in decode phase
 | 
			
		||||
    if seq_length == 1:
 | 
			
		||||
        attention_mask = None
 | 
			
		||||
    # ipex-llm changes end
 | 
			
		||||
    elif self._attn_implementation == "flash_attention_2":
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue