From 4ceefc9b18576f582099d61b22c6369815bab045 Mon Sep 17 00:00:00 2001 From: "Wang, Jian4" <61138589+hzjane@users.noreply.github.com> Date: Thu, 4 Jan 2024 11:23:16 +0800 Subject: [PATCH] LLM: Support bitsandbytes config on qlora finetune (#9715) * test support bitsandbytesconfig * update style * update cpu example * update example * update readme * update unit test * use bfloat16 * update logic * use int4 * set defalut bnb_4bit_use_double_quant * update * update example * update model.py * update * support lora example --- .github/workflows/llm_unit_tests.yml | 1 + .../example/CPU/QLoRA-FineTuning/README.md | 1 + .../QLoRA-FineTuning/alpaca-qlora/README.md | 1 + .../alpaca_qlora_finetuning_cpu.py | 22 ++++++--- .../QLoRA-FineTuning/qlora_finetuning_cpu.py | 22 +++++++-- .../example/GPU/QLoRA-FineTuning/README.md | 1 + .../QLoRA-FineTuning/alpaca-qlora/README.md | 1 + .../alpaca-qlora/alpaca_qlora_finetuning.py | 35 +++++++++---- .../GPU/QLoRA-FineTuning/qlora_finetuning.py | 20 ++++++-- .../llm/src/bigdl/llm/transformers/model.py | 49 +++++++++++++++++-- 10 files changed, 126 insertions(+), 27 deletions(-) diff --git a/.github/workflows/llm_unit_tests.yml b/.github/workflows/llm_unit_tests.yml index 8a648822..b0905442 100644 --- a/.github/workflows/llm_unit_tests.yml +++ b/.github/workflows/llm_unit_tests.yml @@ -304,5 +304,6 @@ jobs: shell: bash run: | python -m pip install transformers==4.34.0 peft==0.5.0 accelerate==0.23.0 + python -m pip install bitsandbytes scipy source /opt/intel/oneapi/setvars.sh bash python/llm/test/run-llm-example-tests-gpu.sh \ No newline at end of file diff --git a/python/llm/example/CPU/QLoRA-FineTuning/README.md b/python/llm/example/CPU/QLoRA-FineTuning/README.md index 37c34b14..01de8ba1 100644 --- a/python/llm/example/CPU/QLoRA-FineTuning/README.md +++ b/python/llm/example/CPU/QLoRA-FineTuning/README.md @@ -23,6 +23,7 @@ pip install transformers==4.34.0 pip install peft==0.5.0 pip install datasets pip install accelerate==0.23.0 +pip install bitsandbytes scipy ``` ### 2. Finetune model diff --git a/python/llm/example/CPU/QLoRA-FineTuning/alpaca-qlora/README.md b/python/llm/example/CPU/QLoRA-FineTuning/alpaca-qlora/README.md index d604c660..3aa4f47a 100644 --- a/python/llm/example/CPU/QLoRA-FineTuning/alpaca-qlora/README.md +++ b/python/llm/example/CPU/QLoRA-FineTuning/alpaca-qlora/README.md @@ -11,6 +11,7 @@ pip install --pre --upgrade bigdl-llm[all] pip install datasets transformers==4.35.0 pip install fire peft==0.5.0 pip install accelerate==0.23.0 +pip install bitsandbytes scipy ``` ### 2. Configures environment variables diff --git a/python/llm/example/CPU/QLoRA-FineTuning/alpaca-qlora/alpaca_qlora_finetuning_cpu.py b/python/llm/example/CPU/QLoRA-FineTuning/alpaca-qlora/alpaca_qlora_finetuning_cpu.py index 93ef8196..3e3fd34d 100644 --- a/python/llm/example/CPU/QLoRA-FineTuning/alpaca-qlora/alpaca_qlora_finetuning_cpu.py +++ b/python/llm/example/CPU/QLoRA-FineTuning/alpaca-qlora/alpaca_qlora_finetuning_cpu.py @@ -46,6 +46,7 @@ from peft import ( ) from utils.prompter import Prompter +from transformers import BitsAndBytesConfig from bigdl.llm.transformers import AutoModelForCausalLM # import them from bigdl.llm.transformers.qlora to get a BigDL-LLM compatible Peft model @@ -177,13 +178,22 @@ def train( ) else: # Load the base model from a directory or the HF Hub to 4-bit NormalFloat format - model = AutoModelForCausalLM.from_pretrained( - base_model, - load_in_low_bit="sym_int4", # not support "nf4" - optimize_model=False, - torch_dtype=torch.bfloat16, - modules_to_not_convert=["lm_head"], + bnb_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_use_double_quant=False, + bnb_4bit_quant_type="int4", # nf4 not supported on cpu yet + bnb_4bit_compute_dtype=torch.bfloat16 ) + model = AutoModelForCausalLM.from_pretrained(base_model, + quantization_config=bnb_config, ) + # below is also supported + # model = AutoModelForCausalLM.from_pretrained( + # base_model, + # load_in_low_bit="sym_int4", # nf4 not supported on cpu yet + # optimize_model=False, + # torch_dtype=torch.bfloat16, + # modules_to_not_convert=["lm_head"], + # ) print(f"Model loaded on rank {os.environ.get('LOCAL_RANK')}") model = model.to("cpu") print(f"Model moved to rank {os.environ.get('LOCAL_RANK')}") diff --git a/python/llm/example/CPU/QLoRA-FineTuning/qlora_finetuning_cpu.py b/python/llm/example/CPU/QLoRA-FineTuning/qlora_finetuning_cpu.py index b6e1adc5..d0d47152 100644 --- a/python/llm/example/CPU/QLoRA-FineTuning/qlora_finetuning_cpu.py +++ b/python/llm/example/CPU/QLoRA-FineTuning/qlora_finetuning_cpu.py @@ -20,6 +20,7 @@ import os import transformers from transformers import LlamaTokenizer +from transformers import BitsAndBytesConfig from bigdl.llm.transformers.qlora import get_peft_model, prepare_model_for_kbit_training, LoraConfig from bigdl.llm.transformers import AutoModelForCausalLM from datasets import load_dataset @@ -45,11 +46,24 @@ if __name__ == "__main__": data['train'] = data['train'].map(merge) # use the max_length to reduce memory usage, should be adjusted by different datasets data = data.map(lambda samples: tokenizer(samples["prediction"], max_length=256), batched=True) + + bnb_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_use_double_quant=False, + bnb_4bit_quant_type="int4", # nf4 not supported on cpu yet + bnb_4bit_compute_dtype=torch.bfloat16 + ) model = AutoModelForCausalLM.from_pretrained(model_path, - load_in_low_bit="sym_int4", - optimize_model=False, - torch_dtype=torch.float16, - modules_to_not_convert=["lm_head"], ) + quantization_config=bnb_config, ) + + # below is also supported + # model = AutoModelForCausalLM.from_pretrained(model_path, + # # nf4 not supported on cpu yet + # load_in_low_bit="sym_int4", + # optimize_model=False, + # torch_dtype=torch.bfloat16, + # modules_to_not_convert=["lm_head"], ) + model = model.to('cpu') model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=False) model.enable_input_require_grads() diff --git a/python/llm/example/GPU/QLoRA-FineTuning/README.md b/python/llm/example/GPU/QLoRA-FineTuning/README.md index 14667b57..28777b52 100644 --- a/python/llm/example/GPU/QLoRA-FineTuning/README.md +++ b/python/llm/example/GPU/QLoRA-FineTuning/README.md @@ -20,6 +20,7 @@ pip install --pre --upgrade bigdl-llm[xpu] -f https://developer.intel.com/ipex-w pip install datasets transformers==4.34.0 pip install peft==0.5.0 pip install accelerate==0.23.0 +pip install bitsandbytes scipy ``` ### 2. Configures OneAPI environment variables diff --git a/python/llm/example/GPU/QLoRA-FineTuning/alpaca-qlora/README.md b/python/llm/example/GPU/QLoRA-FineTuning/alpaca-qlora/README.md index 29b81e82..8cb6bd63 100644 --- a/python/llm/example/GPU/QLoRA-FineTuning/alpaca-qlora/README.md +++ b/python/llm/example/GPU/QLoRA-FineTuning/alpaca-qlora/README.md @@ -17,6 +17,7 @@ pip install datasets transformers==4.34.0 pip install fire peft==0.5.0 pip install oneccl_bind_pt==2.0.100 -f https://developer.intel.com/ipex-whl-stable-xpu # necessary to run distributed finetuning pip install accelerate==0.23.0 +pip install bitsandbytes scipy ``` ### 2. Configures OneAPI environment variables diff --git a/python/llm/example/GPU/QLoRA-FineTuning/alpaca-qlora/alpaca_qlora_finetuning.py b/python/llm/example/GPU/QLoRA-FineTuning/alpaca-qlora/alpaca_qlora_finetuning.py index 36108cf0..cfed03fc 100644 --- a/python/llm/example/GPU/QLoRA-FineTuning/alpaca-qlora/alpaca_qlora_finetuning.py +++ b/python/llm/example/GPU/QLoRA-FineTuning/alpaca-qlora/alpaca_qlora_finetuning.py @@ -47,6 +47,7 @@ from peft import ( from utils.prompter import Prompter import intel_extension_for_pytorch as ipex +from transformers import BitsAndBytesConfig from bigdl.llm.transformers import AutoModelForCausalLM # import them from bigdl.llm.transformers.qlora to get a BigDL-LLM compatible Peft model from bigdl.llm.transformers.qlora import get_peft_model, prepare_model_for_kbit_training,\ @@ -197,20 +198,36 @@ def train( # According to the QLoRA paper, using "nf4" could yield better model quality than "int4" # Default 4-bit format for qa-lora is sym_int4 if training_mode == "qalora": - low_bit_format = "sym_int4" + low_bit_format = "int4" elif training_mode == "lora": low_bit_format = "bf16" else: low_bit_format = "nf4" - # Load the base model from a directory or the HF Hub to 4-bit format - model = AutoModelForCausalLM.from_pretrained( - base_model, - load_in_low_bit=low_bit_format, - optimize_model=False, - torch_dtype=torch.bfloat16, - # device_map=device_map, - modules_to_not_convert=["lm_head"], + bnb_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_use_double_quant=False, + bnb_4bit_quant_type=low_bit_format, + bnb_4bit_compute_dtype=torch.bfloat16 ) + model = AutoModelForCausalLM.from_pretrained(base_model, + quantization_config=bnb_config, ) + + # below is also supported + # Load the base model from a directory or the HF Hub to 4-bit format + # if training_mode == "qalora": + # low_bit_format = "sym_int4" + # elif training_mode == "lora": + # low_bit_format = "bf16" + # else: + # low_bit_format = "nf4" + # model = AutoModelForCausalLM.from_pretrained( + # base_model, + # load_in_low_bit=low_bit_format, + # optimize_model=False, + # torch_dtype=torch.bfloat16, + # # device_map=device_map, + # modules_to_not_convert=["lm_head"], + # ) print(f"Model loaded on rank {os.environ.get('LOCAL_RANK')}") model = model.to(f'xpu:{os.environ.get("LOCAL_RANK", 0)}') print(f"Model moved to rank {os.environ.get('LOCAL_RANK')}") diff --git a/python/llm/example/GPU/QLoRA-FineTuning/qlora_finetuning.py b/python/llm/example/GPU/QLoRA-FineTuning/qlora_finetuning.py index 21dbeaad..bab936c5 100644 --- a/python/llm/example/GPU/QLoRA-FineTuning/qlora_finetuning.py +++ b/python/llm/example/GPU/QLoRA-FineTuning/qlora_finetuning.py @@ -21,6 +21,7 @@ import transformers from transformers import LlamaTokenizer import intel_extension_for_pytorch as ipex from peft import LoraConfig +from transformers import BitsAndBytesConfig from bigdl.llm.transformers.qlora import get_peft_model, prepare_model_for_kbit_training from bigdl.llm.transformers import AutoModelForCausalLM from datasets import load_dataset @@ -41,11 +42,22 @@ if __name__ == "__main__": data = load_dataset(dataset_path) data = data.map(lambda samples: tokenizer(samples["quote"]), batched=True) + + bnb_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_use_double_quant=False, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype=torch.bfloat16 + ) model = AutoModelForCausalLM.from_pretrained(model_path, - load_in_low_bit="nf4", - optimize_model=False, - torch_dtype=torch.float16, - modules_to_not_convert=["lm_head"],) + quantization_config=bnb_config, ) + + # below is also supported + # model = AutoModelForCausalLM.from_pretrained(model_path, + # load_in_low_bit="nf4", + # optimize_model=False, + # torch_dtype=torch.bfloat16, + # modules_to_not_convert=["lm_head"],) model = model.to('xpu') # Enable gradient_checkpointing if your memory is not enough, # it will slowdown the training speed diff --git a/python/llm/src/bigdl/llm/transformers/model.py b/python/llm/src/bigdl/llm/transformers/model.py index bb5ce960..e84776ba 100644 --- a/python/llm/src/bigdl/llm/transformers/model.py +++ b/python/llm/src/bigdl/llm/transformers/model.py @@ -88,7 +88,6 @@ def save_low_bit(self, *args, **kwargs): class _BaseAutoModelClass: - HF_MODEL = None @classmethod @@ -136,6 +135,48 @@ class _BaseAutoModelClass: optimize_model = kwargs.pop("optimize_model", True) user_quantization_config = kwargs.pop("quantization_config", None) + if user_quantization_config is not None and \ + "BitsAndBytesConfig" in str(user_quantization_config.__class__): + if user_quantization_config.bnb_4bit_quant_type is not None: + bnb_4bit_type = user_quantization_config.bnb_4bit_quant_type + if bnb_4bit_type == "nf4": + load_in_low_bit = "nf4" + elif bnb_4bit_type == "fp4": + warnings.warn( + "BigDL LLM QLoRA does not support fp4 now, use default nf4", FutureWarning) + load_in_low_bit = "nf4" + elif bnb_4bit_type == "int4": + load_in_low_bit = "sym_int4" + elif bnb_4bit_type == "bf16": + load_in_low_bit = "bf16" + else: + invalidInputError(False, + "Only nf4 or int4 is supported for bnb_4bit_quant_type") + else: + warnings.warn( + "bnb_4bit_quant_type is None, use default int4", FutureWarning) + load_in_low_bit = "sym_int4" + if user_quantization_config.bnb_4bit_use_double_quant is True: + warnings.warn( + "BigDL LLM QLoRA does not support double quant now, set to False", + FutureWarning) + if user_quantization_config.bnb_4bit_compute_dtype is not None: + bnb_dtype = user_quantization_config.bnb_4bit_compute_dtype + if bnb_dtype == torch.float32: + kwargs["torch_dtype"] = bnb_dtype + elif bnb_dtype == torch.bfloat16: + kwargs["torch_dtype"] = bnb_dtype + else: + invalidInputError(False, + "Only float32 or bfloat16" + " is supported for bnb_4bit_compute_dtype") + else: + warnings.warn( + "torch_dtype is None, use default float32", FutureWarning) + kwargs["torch_dtype"] = torch.float32 + optimize_model = False + kwargs["modules_to_not_convert"] = ["lm_head"] + if load_in_4bit or load_in_low_bit: if config_dict.get("quantization_config", None) is not None: @@ -253,9 +294,9 @@ class _BaseAutoModelClass: # The latest transformers only support cuda version # This load awq ckpt logic is copied from # https://github.com/casper-hansen/AutoAWQ/blob/main/awq/models/base.py#L147 - from accelerate import init_empty_weights, infer_auto_device_map,\ + from accelerate import init_empty_weights, infer_auto_device_map, \ load_checkpoint_in_model - from bigdl.llm.transformers.awq.awq import _replace_with_awq_layers,\ + from bigdl.llm.transformers.awq.awq import _replace_with_awq_layers, \ get_layer_type, _load_config awq_config = quant_config model_weights_path, config = _load_config(args[0], '', max_new_tokens=None, @@ -397,7 +438,7 @@ class _BaseAutoModelClass: if has_remote_code and trust_remote_code: class_ref = config.auto_map[cls.HF_Model.__name__] model_class = get_class_from_dynamic_module( - class_ref, pretrained_model_name_or_path, **kwargs + class_ref, pretrained_model_name_or_path, **kwargs ) if os.path.isdir(pretrained_model_name_or_path): model_class.register_for_auto_class(cls.HF_Model.__name__)