update minicpm.py (#11517)

* update minicpm

* meet code review
This commit is contained in:
Xin Qiu 2024-07-05 15:25:44 +08:00 committed by GitHub
parent 24de13fc45
commit a31f2cbe13
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -241,12 +241,9 @@ def minicpm_attention_forward_original(
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
if use_fuse_rope:
rope_theta = self.rotary_emb.base
query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states,
key_states,
position_ids,
"llama",
rope_theta=rope_theta)
import xe_addons
xe_addons.rotary_half_inplaced(self.rotary_emb.inv_freq, position_ids,
query_states, key_states)
else:
if cache_position is not None:
# for transformers 4.38.0
@ -313,7 +310,6 @@ def minicpm_attention_forward_original(
is_causal=True)
attn_weights = None
elif not self.training and not hidden_states.requires_grad and \
self.layer_idx > 0 and \
use_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
import xe_addons
attn_output = xe_addons.sdp(query_states, key_states, value_states,
@ -450,12 +446,9 @@ def minicpm_attention_forward_quantized(
)
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
if use_fuse_rope:
rope_theta = self.rotary_emb.base
query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states,
key_states,
position_ids,
"llama",
rope_theta=rope_theta)
import xe_addons
xe_addons.rotary_half_inplaced(self.rotary_emb.inv_freq, position_ids,
query_states, key_states)
else:
if cache_position is not None:
# for transformers 4.38.0