diff --git a/python/llm/src/bigdl/llm/transformers/convert.py b/python/llm/src/bigdl/llm/transformers/convert.py index 37757fbd..794f910f 100644 --- a/python/llm/src/bigdl/llm/transformers/convert.py +++ b/python/llm/src/bigdl/llm/transformers/convert.py @@ -331,4 +331,7 @@ def optimize(model): module.AquilaAttention, aquila_attention_forward ) + convert_forward(model, + module.AquilaRMSNorm, + llama_rms_norm_forward) return model diff --git a/python/llm/src/bigdl/llm/transformers/models/aquila.py b/python/llm/src/bigdl/llm/transformers/models/aquila.py index 66c891a6..417b20e7 100644 --- a/python/llm/src/bigdl/llm/transformers/models/aquila.py +++ b/python/llm/src/bigdl/llm/transformers/models/aquila.py @@ -44,6 +44,7 @@ from torch import nn from bigdl.llm.transformers.models.utils import extend_kv_cache, init_kv_cache, append_kv_cache from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb +from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu from bigdl.llm.utils.common import log4Error KV_CACHE_ALLOC_BLOCK_LENGTH = 256 @@ -73,9 +74,15 @@ def aquila_attention_forward( kv_seq_len = key_states.shape[-2] if past_key_value is not None: kv_seq_len += past_key_value[0].shape[-2] - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, - cos, sin, position_ids, "aquila") + if query_states.device.type == "xpu" and not (self.training and query_states.requires_grad): + query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states, + key_states, + position_ids, + "aquila") + else: + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, + cos, sin, position_ids, "aquila") # [bsz, nh, t, hd] if past_key_value is not None: