small fix (#12616)
This commit is contained in:
parent
d841e1dc0d
commit
1604b4ead8
2 changed files with 3 additions and 6 deletions
|
|
@ -1784,9 +1784,6 @@ def _optimize_post(model, lightweight_bmm=False):
|
||||||
convert_forward(model,
|
convert_forward(model,
|
||||||
module.CohereAttention,
|
module.CohereAttention,
|
||||||
cohere_attention_forward)
|
cohere_attention_forward)
|
||||||
convert_forward(model,
|
|
||||||
module.CohereLayerNorm,
|
|
||||||
rms_norm_forward)
|
|
||||||
convert_forward(model,
|
convert_forward(model,
|
||||||
module.CohereMLP,
|
module.CohereMLP,
|
||||||
mlp_silu_forward)
|
mlp_silu_forward)
|
||||||
|
|
|
||||||
|
|
@ -144,12 +144,12 @@ def llama_attention_forward(
|
||||||
|
|
||||||
if query_states.device.type == "xpu":
|
if query_states.device.type == "xpu":
|
||||||
import xe_addons
|
import xe_addons
|
||||||
if position_embeddings is None:
|
if hasattr(self, "rotary_emb"):
|
||||||
# transformers < 4.43
|
# transformers < 4.46
|
||||||
xe_addons.rotary_half_inplaced(self.rotary_emb.inv_freq, position_ids,
|
xe_addons.rotary_half_inplaced(self.rotary_emb.inv_freq, position_ids,
|
||||||
query_states, key_states)
|
query_states, key_states)
|
||||||
else:
|
else:
|
||||||
# transformers >= 4.43
|
# transformers >= 4.46
|
||||||
cos, sin = position_embeddings
|
cos, sin = position_embeddings
|
||||||
make_cache_contiguous_inplaced(cos, sin)
|
make_cache_contiguous_inplaced(cos, sin)
|
||||||
xe_addons.rotary_half_with_cache_inplaced(query_states, key_states, cos, sin)
|
xe_addons.rotary_half_with_cache_inplaced(query_states, key_states, cos, sin)
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue