Fix deepseek coder with linear rope type support on GPU (#12709)
* Fix deepseek coder with linear rope type * Style fix * Move to optimize_pre * Small fix * Small fix * Small fix to not affect other cases * Style fixes * Update function name * Small fix * Small fix * Small fix * Fix for low transformers version first * Style fix * Small fix
This commit is contained in:
		
							parent
							
								
									36bf3d8e29
								
							
						
					
					
						commit
						9d65dcd7ef
					
				
					 2 changed files with 15 additions and 3 deletions
				
			
		| 
						 | 
				
			
			@ -995,8 +995,9 @@ def _optimize_pre(model, qtype=None):
 | 
			
		|||
        from ipex_llm.transformers.models.gemma2 import merge_qkv
 | 
			
		||||
        model.apply(merge_qkv)
 | 
			
		||||
    elif model.config.model_type == "llama":
 | 
			
		||||
        from ipex_llm.transformers.models.llama import merge_qkv
 | 
			
		||||
        from ipex_llm.transformers.models.llama import merge_qkv, pre_compute_inv_freq
 | 
			
		||||
        model.apply(merge_qkv)
 | 
			
		||||
        model.apply(pre_compute_inv_freq)
 | 
			
		||||
    elif model.config.model_type == "mllama":
 | 
			
		||||
        from ipex_llm.transformers.models.mllama import merge_qkv
 | 
			
		||||
        model.apply(merge_qkv)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -119,6 +119,13 @@ def merge_qkv(module: torch.nn.Module):
 | 
			
		|||
    merge_qkv_base(module, LlamaAttention)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def pre_compute_inv_freq(module: torch.nn.Module):
 | 
			
		||||
    if module.__class__.__name__ == "LlamaLinearScalingRotaryEmbedding":
 | 
			
		||||
        if hasattr(module, "scaling_factor"):
 | 
			
		||||
            module.register_buffer("inv_freq_scaled", None, persistent=False)
 | 
			
		||||
            module.inv_freq_scaled = module.inv_freq / module.scaling_factor
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def llama_attention_forward(
 | 
			
		||||
    self,
 | 
			
		||||
    hidden_states: torch.Tensor,
 | 
			
		||||
| 
						 | 
				
			
			@ -147,8 +154,12 @@ def llama_attention_forward(
 | 
			
		|||
        import xe_addons
 | 
			
		||||
        if hasattr(self, "rotary_emb"):
 | 
			
		||||
            # transformers < 4.46
 | 
			
		||||
            xe_addons.rotary_half_inplaced(self.rotary_emb.inv_freq, position_ids,
 | 
			
		||||
                                           query_states, key_states)
 | 
			
		||||
            if hasattr(self.rotary_emb, "inv_freq_scaled"):
 | 
			
		||||
                xe_addons.rotary_half_inplaced(self.rotary_emb.inv_freq_scaled, position_ids,
 | 
			
		||||
                                               query_states, key_states)
 | 
			
		||||
            else:
 | 
			
		||||
                xe_addons.rotary_half_inplaced(self.rotary_emb.inv_freq, position_ids,
 | 
			
		||||
                                               query_states, key_states)
 | 
			
		||||
        else:
 | 
			
		||||
            # transformers >= 4.46
 | 
			
		||||
            cos, sin = position_embeddings
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue