Reduce Mistral softmax memory only in low memory mode (#11775)
* Reduce Mistral softmax memory only in low memory mode
This commit is contained in:
		
							parent
							
								
									aa861df066
								
							
						
					
					
						commit
						a88c132e54
					
				
					 1 changed files with 5 additions and 6 deletions
				
			
		| 
						 | 
					@ -131,11 +131,11 @@ def compute_attn_outputs_weights(query_states, key_states, value_states, bsz, q_
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        attn_weights = attn_weights + attention_mask
 | 
					        attn_weights = attn_weights + attention_mask
 | 
				
			||||||
 | 
					    if os.getenv("IPEX_LLM_LOW_MEM", '0').lower() in ('true', '1', 't'):
 | 
				
			||||||
    if kv_seq_len >= 2048 or bsz >= 64:
 | 
					        if kv_seq_len >= 2048 or bsz >= 64:
 | 
				
			||||||
        # for memory considerations, do not upcast attention to fp32
 | 
					            # for memory considerations, do not upcast attention to fp32
 | 
				
			||||||
        # for long sequences or large batches
 | 
					            # for long sequences or large batches
 | 
				
			||||||
        attn_weights = nn.functional.softmax(attn_weights, dim=-1)
 | 
					            attn_weights = nn.functional.softmax(attn_weights, dim=-1)
 | 
				
			||||||
    else:
 | 
					    else:
 | 
				
			||||||
        # upcast attention to fp32
 | 
					        # upcast attention to fp32
 | 
				
			||||||
        attn_weights = nn.functional.softmax(attn_weights, dim=-1,
 | 
					        attn_weights = nn.functional.softmax(attn_weights, dim=-1,
 | 
				
			||||||
| 
						 | 
					@ -838,7 +838,6 @@ def mistral_attention_forward_4_36_quantized(
 | 
				
			||||||
                        f" but is {attention_mask.size()}"
 | 
					                        f" but is {attention_mask.size()}"
 | 
				
			||||||
                    )
 | 
					                    )
 | 
				
			||||||
                attn_weights = attn_weights + attention_mask
 | 
					                attn_weights = attn_weights + attention_mask
 | 
				
			||||||
 | 
					 | 
				
			||||||
            if kv_seq_len >= 2048 or bsz >= 64:
 | 
					            if kv_seq_len >= 2048 or bsz >= 64:
 | 
				
			||||||
                # for memory considerations, do not upcast attention to fp32
 | 
					                # for memory considerations, do not upcast attention to fp32
 | 
				
			||||||
                # for long sequences or large batches
 | 
					                # for long sequences or large batches
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue