small fix (#12616)
This commit is contained in:
parent
d841e1dc0d
commit
1604b4ead8
2 changed files with 3 additions and 6 deletions
|
|
@ -1784,9 +1784,6 @@ def _optimize_post(model, lightweight_bmm=False):
|
|||
convert_forward(model,
|
||||
module.CohereAttention,
|
||||
cohere_attention_forward)
|
||||
convert_forward(model,
|
||||
module.CohereLayerNorm,
|
||||
rms_norm_forward)
|
||||
convert_forward(model,
|
||||
module.CohereMLP,
|
||||
mlp_silu_forward)
|
||||
|
|
|
|||
|
|
@ -144,12 +144,12 @@ def llama_attention_forward(
|
|||
|
||||
if query_states.device.type == "xpu":
|
||||
import xe_addons
|
||||
if position_embeddings is None:
|
||||
# transformers < 4.43
|
||||
if hasattr(self, "rotary_emb"):
|
||||
# transformers < 4.46
|
||||
xe_addons.rotary_half_inplaced(self.rotary_emb.inv_freq, position_ids,
|
||||
query_states, key_states)
|
||||
else:
|
||||
# transformers >= 4.43
|
||||
# transformers >= 4.46
|
||||
cos, sin = position_embeddings
|
||||
make_cache_contiguous_inplaced(cos, sin)
|
||||
xe_addons.rotary_half_with_cache_inplaced(query_states, key_states, cos, sin)
|
||||
|
|
|
|||
Loading…
Reference in a new issue