This commit is contained in:
Yishuo Wang 2024-12-26 11:35:12 +08:00 committed by GitHub
parent d841e1dc0d
commit 1604b4ead8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 3 additions and 6 deletions

View file

@ -1784,9 +1784,6 @@ def _optimize_post(model, lightweight_bmm=False):
convert_forward(model,
module.CohereAttention,
cohere_attention_forward)
convert_forward(model,
module.CohereLayerNorm,
rms_norm_forward)
convert_forward(model,
module.CohereMLP,
mlp_silu_forward)

View file

@ -144,12 +144,12 @@ def llama_attention_forward(
if query_states.device.type == "xpu":
import xe_addons
if position_embeddings is None:
# transformers < 4.43
if hasattr(self, "rotary_emb"):
# transformers < 4.46
xe_addons.rotary_half_inplaced(self.rotary_emb.inv_freq, position_ids,
query_states, key_states)
else:
# transformers >= 4.43
# transformers >= 4.46
cos, sin = position_embeddings
make_cache_contiguous_inplaced(cos, sin)
xe_addons.rotary_half_with_cache_inplaced(query_states, key_states, cos, sin)