add rotary_half_with_cache_inplaced to ipex_llm.transformers.models.common (#13143)
				
					
				
			* update * small fix
This commit is contained in:
		
							parent
							
								
									f2598b119e
								
							
						
					
					
						commit
						f5d9c49a2a
					
				
					 4 changed files with 14 additions and 8 deletions
				
			
		| 
						 | 
				
			
			@ -357,3 +357,11 @@ def rotary_two_with_cache_inplaced(query_states: torch.Tensor, key_states: torch
 | 
			
		|||
    import xe_addons
 | 
			
		||||
    xe_addons.rotary_two_with_cache_inplaced(query_states, key_states,
 | 
			
		||||
                                             cos, sin, half_layout)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def rotary_half_with_cache_inplaced(query_states: torch.Tensor, key_states: torch.Tensor,
 | 
			
		||||
                                    cos: torch.Tensor, sin: torch.Tensor):
 | 
			
		||||
    import xe_addons
 | 
			
		||||
    from ipex_llm.transformers.models.utils import make_cache_contiguous_inplaced
 | 
			
		||||
    make_cache_contiguous_inplaced(cos, sin)
 | 
			
		||||
    xe_addons.rotary_half_with_cache_inplaced(query_states, key_states, cos, sin)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -162,9 +162,8 @@ def llama_attention_forward(
 | 
			
		|||
                                               query_states, key_states)
 | 
			
		||||
        else:
 | 
			
		||||
            # transformers >= 4.46
 | 
			
		||||
            cos, sin = position_embeddings
 | 
			
		||||
            make_cache_contiguous_inplaced(cos, sin)
 | 
			
		||||
            xe_addons.rotary_half_with_cache_inplaced(query_states, key_states, cos, sin)
 | 
			
		||||
            from ipex_llm.transformers.models.common import rotary_half_with_cache_inplaced
 | 
			
		||||
            rotary_half_with_cache_inplaced(query_states, key_states, cos, sin)
 | 
			
		||||
    else:
 | 
			
		||||
        if position_embeddings is None:
 | 
			
		||||
            if isinstance(getattr(self.rotary_emb, "cos_cached", None), torch.Tensor):
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -62,8 +62,8 @@ def qwen2_5_omni_attention_forward(
 | 
			
		|||
 | 
			
		||||
    cos, sin = position_embeddings
 | 
			
		||||
    if query_states.device.type == "xpu":
 | 
			
		||||
        import xe_addons
 | 
			
		||||
        xe_addons.rotary_half_with_cache_inplaced(query_states, key_states, cos, sin)
 | 
			
		||||
        from ipex_llm.transformers.models.common import rotary_half_with_cache_inplaced
 | 
			
		||||
        rotary_half_with_cache_inplaced(query_states, key_states, cos, sin)
 | 
			
		||||
    else:
 | 
			
		||||
        query_states, key_states = apply_multimodal_rotary_pos_emb(
 | 
			
		||||
            query_states, key_states, cos, sin, self.rope_scaling["mrope_section"]
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -93,9 +93,8 @@ def qwen3_attention_forward(
 | 
			
		|||
 | 
			
		||||
    cos, sin = position_embeddings
 | 
			
		||||
    if device.type == "xpu":
 | 
			
		||||
        import xe_addons
 | 
			
		||||
        make_cache_contiguous_inplaced(cos, sin)
 | 
			
		||||
        xe_addons.rotary_half_with_cache_inplaced(query_states, key_states, cos, sin)
 | 
			
		||||
        from ipex_llm.transformers.models.common import rotary_half_with_cache_inplaced
 | 
			
		||||
        rotary_half_with_cache_inplaced(query_states, key_states, cos, sin)
 | 
			
		||||
    else:
 | 
			
		||||
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue