Reduce Mistral softmax memory only in low memory mode (#11775)
* Reduce Mistral softmax memory only in low memory mode
This commit is contained in:
		
							parent
							
								
									aa861df066
								
							
						
					
					
						commit
						a88c132e54
					
				
					 1 changed files with 5 additions and 6 deletions
				
			
		| 
						 | 
				
			
			@ -131,11 +131,11 @@ def compute_attn_outputs_weights(query_states, key_states, value_states, bsz, q_
 | 
			
		|||
            )
 | 
			
		||||
 | 
			
		||||
        attn_weights = attn_weights + attention_mask
 | 
			
		||||
 | 
			
		||||
    if kv_seq_len >= 2048 or bsz >= 64:
 | 
			
		||||
        # for memory considerations, do not upcast attention to fp32
 | 
			
		||||
        # for long sequences or large batches
 | 
			
		||||
        attn_weights = nn.functional.softmax(attn_weights, dim=-1)
 | 
			
		||||
    if os.getenv("IPEX_LLM_LOW_MEM", '0').lower() in ('true', '1', 't'):
 | 
			
		||||
        if kv_seq_len >= 2048 or bsz >= 64:
 | 
			
		||||
            # for memory considerations, do not upcast attention to fp32
 | 
			
		||||
            # for long sequences or large batches
 | 
			
		||||
            attn_weights = nn.functional.softmax(attn_weights, dim=-1)
 | 
			
		||||
    else:
 | 
			
		||||
        # upcast attention to fp32
 | 
			
		||||
        attn_weights = nn.functional.softmax(attn_weights, dim=-1,
 | 
			
		||||
| 
						 | 
				
			
			@ -838,7 +838,6 @@ def mistral_attention_forward_4_36_quantized(
 | 
			
		|||
                        f" but is {attention_mask.size()}"
 | 
			
		||||
                    )
 | 
			
		||||
                attn_weights = attn_weights + attention_mask
 | 
			
		||||
 | 
			
		||||
            if kv_seq_len >= 2048 or bsz >= 64:
 | 
			
		||||
                # for memory considerations, do not upcast attention to fp32
 | 
			
		||||
                # for long sequences or large batches
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue