parent
24de13fc45
commit
a31f2cbe13
1 changed files with 6 additions and 13 deletions
|
|
@ -241,12 +241,9 @@ def minicpm_attention_forward_original(
|
||||||
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||||||
|
|
||||||
if use_fuse_rope:
|
if use_fuse_rope:
|
||||||
rope_theta = self.rotary_emb.base
|
import xe_addons
|
||||||
query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states,
|
xe_addons.rotary_half_inplaced(self.rotary_emb.inv_freq, position_ids,
|
||||||
key_states,
|
query_states, key_states)
|
||||||
position_ids,
|
|
||||||
"llama",
|
|
||||||
rope_theta=rope_theta)
|
|
||||||
else:
|
else:
|
||||||
if cache_position is not None:
|
if cache_position is not None:
|
||||||
# for transformers 4.38.0
|
# for transformers 4.38.0
|
||||||
|
|
@ -313,7 +310,6 @@ def minicpm_attention_forward_original(
|
||||||
is_causal=True)
|
is_causal=True)
|
||||||
attn_weights = None
|
attn_weights = None
|
||||||
elif not self.training and not hidden_states.requires_grad and \
|
elif not self.training and not hidden_states.requires_grad and \
|
||||||
self.layer_idx > 0 and \
|
|
||||||
use_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
|
use_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
|
||||||
import xe_addons
|
import xe_addons
|
||||||
attn_output = xe_addons.sdp(query_states, key_states, value_states,
|
attn_output = xe_addons.sdp(query_states, key_states, value_states,
|
||||||
|
|
@ -450,12 +446,9 @@ def minicpm_attention_forward_quantized(
|
||||||
)
|
)
|
||||||
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||||||
if use_fuse_rope:
|
if use_fuse_rope:
|
||||||
rope_theta = self.rotary_emb.base
|
import xe_addons
|
||||||
query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states,
|
xe_addons.rotary_half_inplaced(self.rotary_emb.inv_freq, position_ids,
|
||||||
key_states,
|
query_states, key_states)
|
||||||
position_ids,
|
|
||||||
"llama",
|
|
||||||
rope_theta=rope_theta)
|
|
||||||
else:
|
else:
|
||||||
if cache_position is not None:
|
if cache_position is not None:
|
||||||
# for transformers 4.38.0
|
# for transformers 4.38.0
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue