parent
24de13fc45
commit
a31f2cbe13
1 changed files with 6 additions and 13 deletions
|
|
@ -241,12 +241,9 @@ def minicpm_attention_forward_original(
|
|||
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||||
|
||||
if use_fuse_rope:
|
||||
rope_theta = self.rotary_emb.base
|
||||
query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states,
|
||||
key_states,
|
||||
position_ids,
|
||||
"llama",
|
||||
rope_theta=rope_theta)
|
||||
import xe_addons
|
||||
xe_addons.rotary_half_inplaced(self.rotary_emb.inv_freq, position_ids,
|
||||
query_states, key_states)
|
||||
else:
|
||||
if cache_position is not None:
|
||||
# for transformers 4.38.0
|
||||
|
|
@ -313,7 +310,6 @@ def minicpm_attention_forward_original(
|
|||
is_causal=True)
|
||||
attn_weights = None
|
||||
elif not self.training and not hidden_states.requires_grad and \
|
||||
self.layer_idx > 0 and \
|
||||
use_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
|
||||
import xe_addons
|
||||
attn_output = xe_addons.sdp(query_states, key_states, value_states,
|
||||
|
|
@ -450,12 +446,9 @@ def minicpm_attention_forward_quantized(
|
|||
)
|
||||
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||||
if use_fuse_rope:
|
||||
rope_theta = self.rotary_emb.base
|
||||
query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states,
|
||||
key_states,
|
||||
position_ids,
|
||||
"llama",
|
||||
rope_theta=rope_theta)
|
||||
import xe_addons
|
||||
xe_addons.rotary_half_inplaced(self.rotary_emb.inv_freq, position_ids,
|
||||
query_states, key_states)
|
||||
else:
|
||||
if cache_position is not None:
|
||||
# for transformers 4.38.0
|
||||
|
|
|
|||
Loading…
Reference in a new issue