deepspeed zero3 QLoRA finetuning (#11625)
* deepspeed zero3 QLoRA finetuning * Update convert.py * Update low_bit_linear.py * Update utils.py * Update qlora_finetune_llama2_13b_arch_2_card.sh * Update low_bit_linear.py * Update alpaca_qlora_finetuning.py * Update low_bit_linear.py * Update utils.py * Update convert.py * Update alpaca_qlora_finetuning.py * Update alpaca_qlora_finetuning.py * Update low_bit_linear.py * Update deepspeed_zero3.json * Update qlora_finetune_llama2_13b_arch_2_card.sh * Update low_bit_linear.py * Update low_bit_linear.py * Update utils.py * fix style * fix style * Update alpaca_qlora_finetuning.py * Update qlora_finetune_llama2_13b_arch_2_card.sh * Update convert.py * Update low_bit_linear.py * Update model.py * Update alpaca_qlora_finetuning.py * Update low_bit_linear.py * Update low_bit_linear.py * Update low_bit_linear.py
This commit is contained in:
		
							parent
							
								
									a184b120c9
								
							
						
					
					
						commit
						70c828b87c
					
				
					 5 changed files with 119 additions and 14 deletions
				
			
		| 
						 | 
				
			
			@ -144,6 +144,14 @@ def train(
 | 
			
		|||
 | 
			
		||||
    prompter = Prompter(prompt_template_name)
 | 
			
		||||
 | 
			
		||||
    if deepspeed is not None and "zero3" in deepspeed:
 | 
			
		||||
        from ipex_llm.transformers.utils \
 | 
			
		||||
            import _constant_buffered_norm2
 | 
			
		||||
        from ipex_llm.llm_patching import replace_attr
 | 
			
		||||
        import deepspeed as ds
 | 
			
		||||
        replace_attr(ds.runtime.zero.stage3.DeepSpeedZeroOptimizer_Stage3,
 | 
			
		||||
                     "_constant_buffered_norm2", _constant_buffered_norm2)
 | 
			
		||||
 | 
			
		||||
    device_map = "auto"
 | 
			
		||||
    world_size = int(os.environ.get("WORLD_SIZE", 1))
 | 
			
		||||
    ddp = world_size != 1
 | 
			
		||||
| 
						 | 
				
			
			@ -161,7 +169,7 @@ def train(
 | 
			
		|||
            optimize_model=False,
 | 
			
		||||
            torch_dtype=torch.bfloat16,
 | 
			
		||||
            modules_to_not_convert=["lm_head"],
 | 
			
		||||
            trust_remote_code=True,
 | 
			
		||||
            trust_remote_code=True
 | 
			
		||||
        )
 | 
			
		||||
    else:
 | 
			
		||||
        # According to the QLoRA paper, using "nf4" could yield better model quality than "int4"
 | 
			
		||||
| 
						 | 
				
			
			@ -186,9 +194,10 @@ def train(
 | 
			
		|||
        #     # device_map=device_map,
 | 
			
		||||
        #     modules_to_not_convert=["lm_head"],
 | 
			
		||||
        # )
 | 
			
		||||
    print(f"Model loaded on rank {os.environ.get('LOCAL_RANK')}")
 | 
			
		||||
    model = model.to(f'xpu:{os.environ.get("LOCAL_RANK", 0)}')
 | 
			
		||||
    print(f"Model moved to rank {os.environ.get('LOCAL_RANK')}")
 | 
			
		||||
    if deepspeed is not None and not "zero3" in deepspeed:
 | 
			
		||||
        print(f"Model loaded on rank {os.environ.get('LOCAL_RANK')}")
 | 
			
		||||
        model = model.to(f'xpu:{os.environ.get("LOCAL_RANK", 0)}')
 | 
			
		||||
        print(f"Model moved to rank {os.environ.get('LOCAL_RANK')}")
 | 
			
		||||
 | 
			
		||||
    tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)
 | 
			
		||||
    print(f"Tokenizer loaded on rank {os.environ.get('LOCAL_RANK')}")
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -0,0 +1,15 @@
 | 
			
		|||
{
 | 
			
		||||
    "zero_optimization": {
 | 
			
		||||
      "stage": 3,
 | 
			
		||||
      "contiguous_gradients": true,
 | 
			
		||||
      "overlap_comm": true,
 | 
			
		||||
      "offload_optimizer": {"device": "cpu"}
 | 
			
		||||
    },
 | 
			
		||||
    "bf16": {
 | 
			
		||||
      "enabled": true
 | 
			
		||||
    },
 | 
			
		||||
    "world_size": 2,
 | 
			
		||||
    "train_batch_size": 32,
 | 
			
		||||
    "train_micro_batch_size_per_gpu": 2,
 | 
			
		||||
    "gradient_accumulation_steps": 8
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,41 @@
 | 
			
		|||
#
 | 
			
		||||
# Copyright 2016 The BigDL Authors.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
#
 | 
			
		||||
 | 
			
		||||
export MASTER_ADDR=127.0.0.1
 | 
			
		||||
export MASTER_PORT=29503
 | 
			
		||||
export FI_PROVIDER=tcp
 | 
			
		||||
export CCL_ATL_TRANSPORT=ofi
 | 
			
		||||
export CCL_ZE_IPC_EXCHANGE=sockets
 | 
			
		||||
export UR_L0_IN_ORDER_BARRIER_BY_SIGNAL=0
 | 
			
		||||
basekit_root=/opt/intel/oneapi
 | 
			
		||||
source $basekit_root/setvars.sh --force
 | 
			
		||||
source $basekit_root/ccl/latest/env/vars.sh --force
 | 
			
		||||
 | 
			
		||||
NUM_GPUS=2 # number of used GPU
 | 
			
		||||
export USE_XETLA=OFF
 | 
			
		||||
export SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=2
 | 
			
		||||
export TORCH_LLM_ALLREDUCE=0 # Different from PVC
 | 
			
		||||
export DS_SKIP_CUDA_CHECK=1
 | 
			
		||||
 | 
			
		||||
mpirun -n $NUM_GPUS \
 | 
			
		||||
          python -u ./alpaca_qlora_finetuning.py \
 | 
			
		||||
          --base_model "meta-llama/Llama-2-13b-hf" \
 | 
			
		||||
          --data_path "yahma/alpaca-cleaned" \
 | 
			
		||||
          --output_dir "./ipex-llm-qlora-alpaca" \
 | 
			
		||||
          --gradient_checkpointing True \
 | 
			
		||||
          --micro_batch_size 2 \
 | 
			
		||||
          --batch_size 32 \
 | 
			
		||||
          --deepspeed ./deepspeed_zero3.json
 | 
			
		||||
| 
						 | 
				
			
			@ -229,6 +229,13 @@ def ggml_convert_qtype(tensor: torch.Tensor, qtype: int,
 | 
			
		|||
        dst_tensor = dst_tensor.reshape(tensor.shape[0], tensor.shape[-1] // QK)
 | 
			
		||||
        scale = torch.empty(n // k, dtype=torch.float32,
 | 
			
		||||
                            device=device)
 | 
			
		||||
    elif qtype == NF4:
 | 
			
		||||
        # Deepspeed zero3 requires unified dtype,
 | 
			
		||||
        # thus here uses bfloat16 consistent to other layers
 | 
			
		||||
        # dst_size above is computed based on uint8, and for bfloat16,
 | 
			
		||||
        # buffer size should be half
 | 
			
		||||
        dst_tensor = torch.empty(dst_size // 2, dtype=torch.bfloat16,
 | 
			
		||||
                                 device=device)
 | 
			
		||||
    else:
 | 
			
		||||
        dst_tensor = torch.empty(dst_size, dtype=torch.uint8,
 | 
			
		||||
                                 device=device)
 | 
			
		||||
| 
						 | 
				
			
			@ -260,12 +267,15 @@ def ggml_convert_qtype(tensor: torch.Tensor, qtype: int,
 | 
			
		|||
 | 
			
		||||
 | 
			
		||||
def ggml_q_format_convet_cpu2xpu(tensor: torch.Tensor, num_elem: int, qtype: int):
 | 
			
		||||
 | 
			
		||||
    invalidInputError(tensor.dtype == torch.uint8,
 | 
			
		||||
                      "Input tensor must be uint8")
 | 
			
		||||
    if qtype == NF4:
 | 
			
		||||
        invalidInputError(tensor.dtype == torch.bfloat16,
 | 
			
		||||
                          "NF4 Input tensor must be bfloat16")
 | 
			
		||||
    else:
 | 
			
		||||
        invalidInputError(tensor.dtype == torch.uint8,
 | 
			
		||||
                          "Input tensor except NF4 must be uint8")
 | 
			
		||||
 | 
			
		||||
    invalidInputError(tensor.device == torch.device('cpu'),
 | 
			
		||||
                      "Input tensor must be uint8")
 | 
			
		||||
                      "Input tensor must be on cpu")
 | 
			
		||||
 | 
			
		||||
    src = ctypes.c_void_p(tensor.data.data_ptr())
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -370,7 +380,6 @@ def use_batch_forward(x: torch.Tensor, qtype: int, output_len: int):
 | 
			
		|||
 | 
			
		||||
# Rename to FP4Params to trigger initializing
 | 
			
		||||
# the params layer with all parameters on the CPU
 | 
			
		||||
# https://github.com/huggingface/accelerate/blob/main/src/accelerate/utils/modeling.py#L333
 | 
			
		||||
class FP4Params(torch.nn.Parameter):
 | 
			
		||||
    def __new__(cls,
 | 
			
		||||
                data=None,
 | 
			
		||||
| 
						 | 
				
			
			@ -582,7 +591,13 @@ class MatMulLowBit(torch.autograd.Function):
 | 
			
		|||
    def forward(ctx, A, weight, input_seq_size):
 | 
			
		||||
        ctx.is_empty = False
 | 
			
		||||
        import xe_linear
 | 
			
		||||
        result = xe_linear.forward_new(A, weight.data, weight.qtype, input_seq_size)
 | 
			
		||||
        if weight.qtype == NF4:
 | 
			
		||||
            result = xe_linear.forward_new(A,
 | 
			
		||||
                                           weight.data.view(torch.uint8),
 | 
			
		||||
                                           weight.qtype,
 | 
			
		||||
                                           input_seq_size)
 | 
			
		||||
        else:
 | 
			
		||||
            result = xe_linear.forward_new(A, weight.data, weight.qtype, input_seq_size)
 | 
			
		||||
        if any(ctx.needs_input_grad[:2]):
 | 
			
		||||
            ctx.tensors = (A, weight)
 | 
			
		||||
        else:
 | 
			
		||||
| 
						 | 
				
			
			@ -602,7 +617,12 @@ class MatMulLowBit(torch.autograd.Function):
 | 
			
		|||
        if req_gradA:
 | 
			
		||||
            if torch.xpu.is_autocast_xpu_enabled():
 | 
			
		||||
                grad_output = grad_output.to(torch.xpu.get_autocast_xpu_dtype())
 | 
			
		||||
            dequant_weight = xe_linear.dequant(A, weight.data, weight.qtype)
 | 
			
		||||
            if weight.qtype == NF4:
 | 
			
		||||
                dequant_weight = xe_linear.dequant(A,
 | 
			
		||||
                                                   weight.data.view(torch.uint8),
 | 
			
		||||
                                                   weight.qtype)
 | 
			
		||||
            else:
 | 
			
		||||
                dequant_weight = xe_linear.dequant(A, weight.data, weight.qtype)
 | 
			
		||||
            grad_A = torch.matmul(grad_output, dequant_weight.reshape(weight._shape))
 | 
			
		||||
 | 
			
		||||
        return grad_A, grad_weight, None
 | 
			
		||||
| 
						 | 
				
			
			@ -737,9 +757,16 @@ class LowBitLinear(nn.Linear):
 | 
			
		|||
                if x_2d.requires_grad:
 | 
			
		||||
                    result = MatMulLowBit.apply(x_2d, self.weight, input_seq_size)
 | 
			
		||||
                else:
 | 
			
		||||
                    result = xe_linear.forward_new(x_2d, self.weight.data,
 | 
			
		||||
                                                   self.weight.qtype,
 | 
			
		||||
                                                   input_seq_size)
 | 
			
		||||
                    if self.weight.qtype == NF4:
 | 
			
		||||
                        result = xe_linear.forward_new(x_2d,
 | 
			
		||||
                                                       self.weight.data.view(torch.uint8),
 | 
			
		||||
                                                       self.weight.qtype,
 | 
			
		||||
                                                       input_seq_size)
 | 
			
		||||
                    else:
 | 
			
		||||
                        result = xe_linear.forward_new(x_2d,
 | 
			
		||||
                                                       self.weight.data,
 | 
			
		||||
                                                       self.weight.qtype,
 | 
			
		||||
                                                       input_seq_size)
 | 
			
		||||
            elif self.enable_xetla:
 | 
			
		||||
                x_2d = x_2d.half()
 | 
			
		||||
                result = xe_linear.mm_xetla(x_2d, self.weight.data, self.qtype)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -382,3 +382,16 @@ def check_hidden_size(qtype, hidden_size):
 | 
			
		|||
                        "required for fq6_k - using fallback quantization fp6.")
 | 
			
		||||
            return ggml_tensor_qtype["fp6"]
 | 
			
		||||
    return qtype
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Arc platfrom does not support FP64,
 | 
			
		||||
# Disable FP64 in DeepSpeedZeroOptimizer_Stage3's _constant_buffered_norm2  method
 | 
			
		||||
# https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/runtime/zero/stage3.py#L1365
 | 
			
		||||
def _constant_buffered_norm2(self, input, buffer_size=250000000):
 | 
			
		||||
    norm = None
 | 
			
		||||
    for part in input.view(-1).split(buffer_size):
 | 
			
		||||
        if norm is None:
 | 
			
		||||
            norm = part.data.norm(2)**2.0
 | 
			
		||||
        else:
 | 
			
		||||
            norm += part.data.norm(2)**2.0
 | 
			
		||||
    return norm**0.5
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue