diff --git a/python/llm/src/bigdl/llm/transformers/models/llama.py b/python/llm/src/bigdl/llm/transformers/models/llama.py index 8cbe1b0e..791f5acd 100644 --- a/python/llm/src/bigdl/llm/transformers/models/llama.py +++ b/python/llm/src/bigdl/llm/transformers/models/llama.py @@ -39,6 +39,7 @@ import torch.nn.functional as F from bigdl.llm.utils.common import invalidInputError from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache from bigdl.llm.transformers.models.utils import rotate_half, apply_rotary_pos_emb +from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: @@ -58,7 +59,7 @@ KV_CACHE_ALLOC_BLOCK_LENGTH = 256 def llama_rms_norm_forward(self, hidden_states): - if hidden_states.device.type == "xpu": + if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad): hidden_states, _ = torch.ops.torch_ipex.rms_norm(hidden_states, [self.weight.size(0)], self.weight) else: @@ -116,9 +117,16 @@ def llama_attention_forward_4_31( kv_seq_len = key_states.shape[-2] if past_key_value is not None: kv_seq_len += past_key_value[0].shape[-2] - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, - cos, sin, position_ids, "llama") + + if query_states.device.type == "xpu" and not (self.training and query_states.requires_grad): + query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states, + key_states, + position_ids, + "llama") + else: + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, + cos, sin, position_ids, "llama") if past_key_value is not None: # reuse k, v, self_attention diff --git a/python/llm/src/bigdl/llm/transformers/models/utils.py b/python/llm/src/bigdl/llm/transformers/models/utils.py index f4f1372b..1aed301f 100644 --- a/python/llm/src/bigdl/llm/transformers/models/utils.py +++ b/python/llm/src/bigdl/llm/transformers/models/utils.py @@ -97,3 +97,18 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, model_family): else: invalidInputError(False, f"{model_family} is not supported.") + + +def apply_rotary_pos_emb_no_cache_xpu(q, k, position_ids, model_family): + if q.device.type != "xpu": + invalidInputError(False, + f"only xpu is supported in this function") + import linear_q4_0 + q_embed = torch.empty(q.shape, dtype=q.dtype, device=q.device) + k_embed = torch.empty(k.shape, dtype=k.dtype, device=k.device) + if model_family in ["llama", "baichuan", "internlm", "aquila", "gpt_neox"]: + linear_q4_0.apply_rotary_embedding_half_qk(q, k, position_ids, q_embed, k_embed) + return q_embed, k_embed + else: + invalidInputError(False, + f"{model_family} is not supported.")