optimize glm edge again (#12539)
This commit is contained in:
		
							parent
							
								
									6596c18489
								
							
						
					
					
						commit
						15219944b8
					
				
					 2 changed files with 15 additions and 1 deletions
				
			
		| 
						 | 
				
			
			@ -41,6 +41,7 @@ from transformers.models.glm.modeling_glm import repeat_kv, apply_rotary_pos_emb
 | 
			
		|||
from ipex_llm.transformers.kv import DynamicNormalCache, DynamicFp8Cache
 | 
			
		||||
from ipex_llm.transformers.models.common import merge_qkv_base
 | 
			
		||||
from ipex_llm.transformers.models.utils import use_sdp, use_sdp_causal
 | 
			
		||||
from ipex_llm.transformers.models.utils import make_cache_contiguous_inplaced
 | 
			
		||||
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -94,6 +95,11 @@ def glm_attention_forward(
 | 
			
		|||
                                                        self.num_key_value_heads], dim=1)
 | 
			
		||||
 | 
			
		||||
    cos, sin = position_embeddings
 | 
			
		||||
    if query_states.device.type == "xpu":
 | 
			
		||||
        import xe_addons
 | 
			
		||||
        make_cache_contiguous_inplaced(cos, sin)
 | 
			
		||||
        xe_addons.rotary_two_with_cache_inplaced(query_states, key_states, cos, sin, True)
 | 
			
		||||
    else:
 | 
			
		||||
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
 | 
			
		||||
 | 
			
		||||
    use_quantizekv = isinstance(past_key_value, DynamicFp8Cache)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -493,3 +493,11 @@ def get_q_proj_or_qkv_proj(self):
 | 
			
		|||
    elif hasattr(self, "qkv_proj"):
 | 
			
		||||
        proj = self.qkv_proj
 | 
			
		||||
    return proj
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def make_cache_contiguous_inplaced(cos: torch.Tensor, sin: torch.Tensor):
 | 
			
		||||
    if not cos.is_contiguous():
 | 
			
		||||
        new_cos = cos.contiguous()
 | 
			
		||||
        new_sin = sin.contiguous()
 | 
			
		||||
        cos.set_(new_cos)
 | 
			
		||||
        sin.set_(new_sin)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue