parent
							
								
									24de13fc45
								
							
						
					
					
						commit
						a31f2cbe13
					
				
					 1 changed files with 6 additions and 13 deletions
				
			
		| 
						 | 
					@ -241,12 +241,9 @@ def minicpm_attention_forward_original(
 | 
				
			||||||
            kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
 | 
					            kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if use_fuse_rope:
 | 
					        if use_fuse_rope:
 | 
				
			||||||
            rope_theta = self.rotary_emb.base
 | 
					            import xe_addons
 | 
				
			||||||
            query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states,
 | 
					            xe_addons.rotary_half_inplaced(self.rotary_emb.inv_freq, position_ids,
 | 
				
			||||||
                                                                         key_states,
 | 
					                                           query_states, key_states)
 | 
				
			||||||
                                                                         position_ids,
 | 
					 | 
				
			||||||
                                                                         "llama",
 | 
					 | 
				
			||||||
                                                                         rope_theta=rope_theta)
 | 
					 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            if cache_position is not None:
 | 
					            if cache_position is not None:
 | 
				
			||||||
                # for transformers 4.38.0
 | 
					                # for transformers 4.38.0
 | 
				
			||||||
| 
						 | 
					@ -313,7 +310,6 @@ def minicpm_attention_forward_original(
 | 
				
			||||||
                                                     is_causal=True)
 | 
					                                                     is_causal=True)
 | 
				
			||||||
        attn_weights = None
 | 
					        attn_weights = None
 | 
				
			||||||
    elif not self.training and not hidden_states.requires_grad and \
 | 
					    elif not self.training and not hidden_states.requires_grad and \
 | 
				
			||||||
            self.layer_idx > 0 and \
 | 
					 | 
				
			||||||
            use_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
 | 
					            use_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
 | 
				
			||||||
        import xe_addons
 | 
					        import xe_addons
 | 
				
			||||||
        attn_output = xe_addons.sdp(query_states, key_states, value_states,
 | 
					        attn_output = xe_addons.sdp(query_states, key_states, value_states,
 | 
				
			||||||
| 
						 | 
					@ -450,12 +446,9 @@ def minicpm_attention_forward_quantized(
 | 
				
			||||||
                )
 | 
					                )
 | 
				
			||||||
            kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
 | 
					            kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
 | 
				
			||||||
        if use_fuse_rope:
 | 
					        if use_fuse_rope:
 | 
				
			||||||
            rope_theta = self.rotary_emb.base
 | 
					            import xe_addons
 | 
				
			||||||
            query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states,
 | 
					            xe_addons.rotary_half_inplaced(self.rotary_emb.inv_freq, position_ids,
 | 
				
			||||||
                                                                         key_states,
 | 
					                                           query_states, key_states)
 | 
				
			||||||
                                                                         position_ids,
 | 
					 | 
				
			||||||
                                                                         "llama",
 | 
					 | 
				
			||||||
                                                                         rope_theta=rope_theta)
 | 
					 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            if cache_position is not None:
 | 
					            if cache_position is not None:
 | 
				
			||||||
                # for transformers 4.38.0
 | 
					                # for transformers 4.38.0
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue