Fix Llama 3.2 & 3.1 on LNL (#12196)
This commit is contained in:
		
							parent
							
								
									516b578104
								
							
						
					
					
						commit
						f8d1adc573
					
				
					 2 changed files with 4 additions and 1 deletions
				
			
		| 
						 | 
				
			
			@ -1268,7 +1268,7 @@ def _optimize_post(model, lightweight_bmm=False):
 | 
			
		|||
    from ipex_llm.transformers.models.llama import llama_mlp_forward
 | 
			
		||||
 | 
			
		||||
    if model.config.model_type == "llama" and model.config.rope_scaling is not None:
 | 
			
		||||
        # llama 3.2
 | 
			
		||||
        # llama 3.2 & llama 3.1
 | 
			
		||||
        modeling_module_name = model.__class__.__module__
 | 
			
		||||
        module = importlib.import_module(modeling_module_name)
 | 
			
		||||
        from ipex_llm.transformers.models.common import rms_norm_forward
 | 
			
		||||
| 
						 | 
				
			
			@ -1279,6 +1279,7 @@ def _optimize_post(model, lightweight_bmm=False):
 | 
			
		|||
        convert_forward(model, module.LlamaMLP, mlp_silu_forward)
 | 
			
		||||
        convert_forward(model, module.LlamaModel, llama_model_forward)
 | 
			
		||||
        convert_forward(model, module.LlamaAttention, llama_attention_forward)
 | 
			
		||||
        convert_forward(model, module.LlamaSdpaAttention, llama_attention_forward)
 | 
			
		||||
    elif model.config.model_type == "mllama":
 | 
			
		||||
        # llama 3.2 vision
 | 
			
		||||
        modeling_module_name = model.__class__.__module__
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -204,6 +204,8 @@ def llama_attention_forward(
 | 
			
		|||
    kv_seq_len = key_states.size(2)
 | 
			
		||||
    if attention_mask is not None:  # no matter the length, we just slice it
 | 
			
		||||
        causal_mask = attention_mask[:, :, :, :kv_seq_len]
 | 
			
		||||
    else:
 | 
			
		||||
        causal_mask = None
 | 
			
		||||
 | 
			
		||||
    attn_weights = None
 | 
			
		||||
    if use_sdp(q_len, kv_seq_len, self.head_dim, query_states):
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue