From 98b86f83d469a50ed5bce088c953d6cd8367e08a Mon Sep 17 00:00:00 2001 From: Yina Chen <33650826+cyita@users.noreply.github.com> Date: Wed, 17 Jan 2024 15:51:38 +0800 Subject: [PATCH] Support fast rope for training (#9745) * init * init * fix style * add test and fix * address comment * update * merge upstream main --- .../llm/transformers/layers/rope_embedding.py | 67 ++++++++++ .../bigdl/llm/transformers/models/llama.py | 121 +++++++++++++++++ .../bigdl/llm/transformers/models/utils.py | 2 +- .../llm/src/bigdl/llm/transformers/qlora.py | 19 +++ .../inference_gpu/test_layer_fast_rope.py | 124 ++++++++++++++++++ .../llm/test/run-llm-inference-tests-gpu.sh | 13 +- 6 files changed, 344 insertions(+), 2 deletions(-) create mode 100644 python/llm/src/bigdl/llm/transformers/layers/rope_embedding.py create mode 100644 python/llm/test/inference_gpu/test_layer_fast_rope.py diff --git a/python/llm/src/bigdl/llm/transformers/layers/rope_embedding.py b/python/llm/src/bigdl/llm/transformers/layers/rope_embedding.py new file mode 100644 index 00000000..b3af61dd --- /dev/null +++ b/python/llm/src/bigdl/llm/transformers/layers/rope_embedding.py @@ -0,0 +1,67 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import torch +import logging +from bigdl.llm.transformers.xpu_customize_fwd import custom_fwd, custom_bwd +from bigdl.llm.utils.common import invalidInputError + +LOG = logging.getLogger("bigdl.llm.rope_embedding") + + +# Fast RoPE for finetuning, split the q and k +def apply_fast_rope_embedding(q, k, position_ids, model_family): + if q.device.type != "xpu": + invalidInputError(False, + f"only xpu is supported in this function") + if model_family in ["llama", "baichuan", "internlm", "aquila", "gpt_neox", "mistral", + "mixtral"]: + q_embed = FastRopeEmbedding.apply(q, position_ids) + k_embed = FastRopeEmbedding.apply(k, position_ids) + return q_embed, k_embed + else: + invalidInputError(False, + f"{model_family} is not supported.") + + +# Fast RoPE for finetuning, split the q and k +class FastRopeEmbedding(torch.autograd.Function): + @staticmethod + @custom_fwd + def forward(ctx, x, position_ids): + import linear_q4_0 + x_embed = torch.empty(x.shape, dtype=x.dtype, device=x.device) + linear_q4_0.apply_rotary_embedding_half_q_or_k(x, position_ids, + x_embed, False) + ctx.save_for_backward(position_ids) + return x_embed + + @staticmethod + @custom_bwd + def backward(ctx, grad_output): + import linear_q4_0 + # LOG.info(f"backward, grad_output: {grad_output}") + position_ids, = ctx.saved_tensors + dx = torch.empty(grad_output.shape, + dtype=grad_output.dtype, + device=grad_output.device) + linear_q4_0.apply_rotary_embedding_half_q_or_k(grad_output, + position_ids, + dx, + True) + # LOG.info(f"backward, dx: {dx}, position_ids: {position_ids}, + # requires_grad: {ctx.needs_input_grad}") + return dx, None diff --git a/python/llm/src/bigdl/llm/transformers/models/llama.py b/python/llm/src/bigdl/llm/transformers/models/llama.py index 2d51af8c..179ec18b 100644 --- a/python/llm/src/bigdl/llm/transformers/models/llama.py +++ b/python/llm/src/bigdl/llm/transformers/models/llama.py @@ -127,6 +127,15 @@ def should_use_fuse_rope(self, query_states, position_ids): return use_fuse_rope +# Only for xpu and training +def should_use_fast_rope(self, query_states, position_ids): + use_fuse_rope = query_states.device.type == "xpu" + use_fuse_rope = use_fuse_rope and (self.training or query_states.requires_grad) + use_fuse_rope = use_fuse_rope and self.config.rope_scaling is None + use_fuse_rope = use_fuse_rope and position_ids is not None + return use_fuse_rope + + def llama_attention_forward_4_31( self, hidden_states: torch.Tensor, @@ -911,3 +920,115 @@ def llama_model_selective_batching_forward_4_31( hidden_states=all_hidden_states, attentions=all_self_attns, ) + + +# For training +def llama_attention_fast_forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + padding_mask: Optional[torch.LongTensor] = None, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + device = hidden_states.device + use_fast_rope = should_use_fast_rope(self, hidden_states, position_ids) + + # Check for inference + if use_cache and past_key_value is not None and q_len == 1: + A, past_key_value = llama_attention_forward_4_31( + self, + hidden_states, + past_key_value, + position_ids, + ) + return A, None, past_key_value + + if self.config.pretraining_tp > 1: + key_value_slicing = ((self.num_key_value_heads * self.head_dim) // + self.config.pretraining_tp) + query_slices = self.q_proj.weight.split( + (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0 + ) + key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) + value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) + + query_states = [F.linear(hidden_states, query_slices[i]) + for i in range(self.config.pretraining_tp)] + query_states = torch.cat(query_states, dim=-1) + + key_states = [F.linear(hidden_states, key_slices[i]) + for i in range(self.config.pretraining_tp)] + key_states = torch.cat(key_states, dim=-1) + + value_states = [F.linear(hidden_states, value_slices[i]) + for i in range(self.config.pretraining_tp)] + value_states = torch.cat(value_states, dim=-1) + + else: + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, + self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, + self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + + if use_fast_rope: + from bigdl.llm.transformers.layers.rope_embedding import apply_fast_rope_embedding + query_states, key_states = apply_fast_rope_embedding(query_states, + key_states, + position_ids, + "llama") + else: + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, + cos, sin, position_ids, "llama") + + if past_key_value is not None: + # reuse k, v, self_attention + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + + past_key_value = (key_states, value_states) if use_cache else None + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_output, attn_weights = native_sdp(query_states, key_states, value_states, + attention_mask, + bsz, q_len, kv_seq_len, + self.head_dim, self.num_heads) + + attn_output_size = (bsz, self.num_heads, q_len, self.head_dim) + if attn_output.size() != attn_output_size: + invalidInputError(False, + f"`attn_output` should be of size {attn_output_size}," + f" but is {attn_output.size()}") + + attn_output = attn_output.transpose(1, 2).contiguous() + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + if self.config.pretraining_tp > 1: + attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2) + o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, + dim=1) + attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) + for i in range(self.config.pretraining_tp)]) + else: + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value diff --git a/python/llm/src/bigdl/llm/transformers/models/utils.py b/python/llm/src/bigdl/llm/transformers/models/utils.py index 0191724a..10c2ffca 100644 --- a/python/llm/src/bigdl/llm/transformers/models/utils.py +++ b/python/llm/src/bigdl/llm/transformers/models/utils.py @@ -170,7 +170,7 @@ def apply_rotary_pos_emb_no_cache_xpu(q, k, position_ids, model_family): k_embed = torch.empty(k.shape, dtype=k.dtype, device=k.device) if model_family in ["llama", "baichuan", "internlm", "aquila", "gpt_neox", "mistral", "mixtral"]: - linear_q4_0.apply_rotary_embedding_half_qk(q, k, position_ids, q_embed, k_embed) + linear_q4_0.apply_rotary_embedding_half_q_and_k(q, k, position_ids, q_embed, k_embed) return q_embed, k_embed else: invalidInputError(False, diff --git a/python/llm/src/bigdl/llm/transformers/qlora.py b/python/llm/src/bigdl/llm/transformers/qlora.py index 739529ff..f41d1afb 100644 --- a/python/llm/src/bigdl/llm/transformers/qlora.py +++ b/python/llm/src/bigdl/llm/transformers/qlora.py @@ -49,6 +49,7 @@ # limitations under the License. import torch +import logging from torch.nn import Linear, Embedding from bigdl.llm.transformers.low_bit_linear import LowBitLinear, BF16Linear, get_qk_size from peft.tuners.lora import LoraLayer @@ -58,6 +59,8 @@ from bigdl.llm.ggml.quantize import ggml_tensor_qtype import functools from bigdl.llm.transformers import training_patch +LOG = logging.getLogger("bigdl.llm.qlora") + class LoraLowBitLinear(LowBitLinear, LoraLayer): # Lora implemented in a dense layer @@ -252,6 +255,7 @@ def get_peft_model(*args, **kwargs): if model.device.type == "xpu": cast_lora_weight(model, torch.bfloat16) + _optimize_post(model) torch.xpu.synchronize() return model @@ -345,3 +349,18 @@ def cast_lora_weight(model, dtype=torch.bfloat16): if hasattr(module, 'weight'): if module.weight.dtype == torch.float32: module = module.to(dtype) + + +def _optimize_post(model): + import transformers + from packaging import version + from bigdl.llm.transformers.convert import convert_forward + from bigdl.llm.transformers.models.llama import llama_attention_fast_forward + + trans_version = transformers.__version__ + if version.parse(trans_version) >= version.parse("4.31.0"): + LOG.info("Optimizing Llama finetuning....") + convert_forward( + model, + transformers.models.llama.modeling_llama.LlamaAttention, + llama_attention_fast_forward,) diff --git a/python/llm/test/inference_gpu/test_layer_fast_rope.py b/python/llm/test/inference_gpu/test_layer_fast_rope.py new file mode 100644 index 00000000..9861c913 --- /dev/null +++ b/python/llm/test/inference_gpu/test_layer_fast_rope.py @@ -0,0 +1,124 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# +# This file is adapted from +# https://github.com/Dao-AILab/flash-attention/blob/main/tests/layers/test_rotary.py +# +# Copyright (c) 2023, Tri Dao. +# + +import os +import pytest +import torch +import intel_extension_for_pytorch as ipex +import torch.nn.functional as F +from einops import rearrange +from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding +from transformers.models.llama.modeling_llama import ( + apply_rotary_pos_emb as apply_rotary_pos_emb_llama, +) +from bigdl.llm.transformers.layers.rope_embedding import apply_fast_rope_embedding + +device = os.environ['DEVICE'] +print(f'Running on {device}') +if 'xpu' not in device: + print(f"The layer.fast_rope test should running on xpu") + +# llama-style rotary embedding +@pytest.mark.parametrize("seqlen_offset", [0, 711]) +@pytest.mark.parametrize("rotary_emb_fraction", [0.5, 1.0]) +def test_rotary(rotary_emb_fraction, seqlen_offset): + device = "xpu" + dtype = torch.float16 + rtol, atol = (1e-3, 5e-3) + # set seed + torch.random.manual_seed(0) + batch_size = 8 + seqlen_total = 2048 + seqlen = seqlen_total - seqlen_offset + seqlen_offset = torch.tensor([[seqlen_offset]], device=device) + nheads = 32 + headdim = 128 + rotary_dim = int(headdim * rotary_emb_fraction) + qkv = torch.randn( + batch_size, seqlen, 3, nheads, headdim, device=device, dtype=dtype, + requires_grad=True + ) + rotary_llama = LlamaRotaryEmbedding(rotary_dim, seqlen_total, device=device) + # Doesn't matter what tensor we pass in, rotary_llama only uses the device + # of the tensor + cos_llama, sin_llama = rotary_llama(qkv, seq_len=seqlen_total) + cos_llama, sin_llama = cos_llama.to(dtype=dtype), sin_llama.to(dtype=dtype) + q_pt = ( + rearrange(qkv[:, :, 0, :, :rotary_dim], "b s h d -> b h s d") + .detach() + .clone() + .requires_grad_(True) + ) + k_pt = ( + rearrange(qkv[:, :, 1, :, :rotary_dim], "b s h d -> b h s d") + .detach() + .clone() + .requires_grad_(True) + ) + q_pt_fast = ( + rearrange(qkv[:, :, 0, :, :rotary_dim], "b s h d -> b h s d") + .detach() + .clone() + .requires_grad_(True) + ) + k_pt_fast = ( + rearrange(qkv[:, :, 1, :, :rotary_dim], "b s h d -> b h s d") + .detach() + .clone() + .requires_grad_(True) + ) + q_llama, k_llama = apply_rotary_pos_emb_llama(q_pt, k_pt, cos_llama, + sin_llama, position_ids=seqlen_offset) + q_fast, k_fast = apply_fast_rope_embedding(q_pt_fast, k_pt_fast, + position_ids=seqlen_offset, + model_family="llama") + assert torch.allclose( + rearrange(q_llama, "b h s d -> b s h d"), + rearrange(q_fast, "b h s d -> b s h d"), rtol=rtol, atol=atol + ) + assert torch.allclose( + rearrange(k_llama, "b h s d -> b s h d"), + rearrange(k_fast, "b h s d -> b s h d"), rtol=rtol, atol=atol + ) + + g = torch.randn_like(q_fast) + q_fast.backward(g) + k_fast.backward(g) + q_llama.backward(g) + k_llama.backward(g) + + assert torch.allclose( + q_pt.grad, + q_pt_fast.grad, + rtol=rtol, + atol=atol, + ) + + assert torch.allclose( + k_pt.grad, + k_pt_fast.grad, + rtol=rtol, + atol=atol, + ) + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/python/llm/test/run-llm-inference-tests-gpu.sh b/python/llm/test/run-llm-inference-tests-gpu.sh index 2430535b..03f583b4 100644 --- a/python/llm/test/run-llm-inference-tests-gpu.sh +++ b/python/llm/test/run-llm-inference-tests-gpu.sh @@ -22,5 +22,16 @@ pytest ${LLM_INFERENCE_TEST_DIR}/test_transformers_api.py -v -s now=$(date "+%s") time=$((now-start)) -echo "Bigdl-llm gpu tests finished" +echo "Bigdl-llm gpu inference tests finished" +echo "Time used:$time seconds" + +echo "# Start testing layers.fast_rope_embedding" +start=$(date "+%s") + +pytest ${LLM_INFERENCE_TEST_DIR}/test_layer_fast_rope.py -v -s + +now=$(date "+%s") +time=$((now-start)) + +echo "Bigdl-llm gpu layers.fast_rope_embedding tests finished" echo "Time used:$time seconds"