small fix (#12616)
This commit is contained in:
		
							parent
							
								
									d841e1dc0d
								
							
						
					
					
						commit
						1604b4ead8
					
				
					 2 changed files with 3 additions and 6 deletions
				
			
		| 
						 | 
				
			
			@ -1784,9 +1784,6 @@ def _optimize_post(model, lightweight_bmm=False):
 | 
			
		|||
        convert_forward(model,
 | 
			
		||||
                        module.CohereAttention,
 | 
			
		||||
                        cohere_attention_forward)
 | 
			
		||||
        convert_forward(model,
 | 
			
		||||
                        module.CohereLayerNorm,
 | 
			
		||||
                        rms_norm_forward)
 | 
			
		||||
        convert_forward(model,
 | 
			
		||||
                        module.CohereMLP,
 | 
			
		||||
                        mlp_silu_forward)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -144,12 +144,12 @@ def llama_attention_forward(
 | 
			
		|||
 | 
			
		||||
    if query_states.device.type == "xpu":
 | 
			
		||||
        import xe_addons
 | 
			
		||||
        if position_embeddings is None:
 | 
			
		||||
            # transformers < 4.43
 | 
			
		||||
        if hasattr(self, "rotary_emb"):
 | 
			
		||||
            # transformers < 4.46
 | 
			
		||||
            xe_addons.rotary_half_inplaced(self.rotary_emb.inv_freq, position_ids,
 | 
			
		||||
                                           query_states, key_states)
 | 
			
		||||
        else:
 | 
			
		||||
            # transformers >= 4.43
 | 
			
		||||
            # transformers >= 4.46
 | 
			
		||||
            cos, sin = position_embeddings
 | 
			
		||||
            make_cache_contiguous_inplaced(cos, sin)
 | 
			
		||||
            xe_addons.rotary_half_with_cache_inplaced(query_states, key_states, cos, sin)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue