LLM: Support bitsandbytes config on qlora finetune (#9715)

* test support bitsandbytesconfig

* update style

* update cpu example

* update example

* update readme

* update unit test

* use bfloat16

* update logic

* use int4

* set defalut bnb_4bit_use_double_quant

* update

* update example

* update model.py

* update

* support lora example
This commit is contained in:
Wang, Jian4 2024-01-04 11:23:16 +08:00 committed by GitHub
parent 9a14465560
commit 4ceefc9b18
10 changed files with 126 additions and 27 deletions

View file

@ -304,5 +304,6 @@ jobs:
shell: bash
run: |
python -m pip install transformers==4.34.0 peft==0.5.0 accelerate==0.23.0
python -m pip install bitsandbytes scipy
source /opt/intel/oneapi/setvars.sh
bash python/llm/test/run-llm-example-tests-gpu.sh

View file

@ -23,6 +23,7 @@ pip install transformers==4.34.0
pip install peft==0.5.0
pip install datasets
pip install accelerate==0.23.0
pip install bitsandbytes scipy
```
### 2. Finetune model

View file

@ -11,6 +11,7 @@ pip install --pre --upgrade bigdl-llm[all]
pip install datasets transformers==4.35.0
pip install fire peft==0.5.0
pip install accelerate==0.23.0
pip install bitsandbytes scipy
```
### 2. Configures environment variables

View file

@ -46,6 +46,7 @@ from peft import (
)
from utils.prompter import Prompter
from transformers import BitsAndBytesConfig
from bigdl.llm.transformers import AutoModelForCausalLM
# import them from bigdl.llm.transformers.qlora to get a BigDL-LLM compatible Peft model
@ -177,13 +178,22 @@ def train(
)
else:
# Load the base model from a directory or the HF Hub to 4-bit NormalFloat format
model = AutoModelForCausalLM.from_pretrained(
base_model,
load_in_low_bit="sym_int4", # not support "nf4"
optimize_model=False,
torch_dtype=torch.bfloat16,
modules_to_not_convert=["lm_head"],
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=False,
bnb_4bit_quant_type="int4", # nf4 not supported on cpu yet
bnb_4bit_compute_dtype=torch.bfloat16
)
model = AutoModelForCausalLM.from_pretrained(base_model,
quantization_config=bnb_config, )
# below is also supported
# model = AutoModelForCausalLM.from_pretrained(
# base_model,
# load_in_low_bit="sym_int4", # nf4 not supported on cpu yet
# optimize_model=False,
# torch_dtype=torch.bfloat16,
# modules_to_not_convert=["lm_head"],
# )
print(f"Model loaded on rank {os.environ.get('LOCAL_RANK')}")
model = model.to("cpu")
print(f"Model moved to rank {os.environ.get('LOCAL_RANK')}")

View file

@ -20,6 +20,7 @@ import os
import transformers
from transformers import LlamaTokenizer
from transformers import BitsAndBytesConfig
from bigdl.llm.transformers.qlora import get_peft_model, prepare_model_for_kbit_training, LoraConfig
from bigdl.llm.transformers import AutoModelForCausalLM
from datasets import load_dataset
@ -45,11 +46,24 @@ if __name__ == "__main__":
data['train'] = data['train'].map(merge)
# use the max_length to reduce memory usage, should be adjusted by different datasets
data = data.map(lambda samples: tokenizer(samples["prediction"], max_length=256), batched=True)
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=False,
bnb_4bit_quant_type="int4", # nf4 not supported on cpu yet
bnb_4bit_compute_dtype=torch.bfloat16
)
model = AutoModelForCausalLM.from_pretrained(model_path,
load_in_low_bit="sym_int4",
optimize_model=False,
torch_dtype=torch.float16,
modules_to_not_convert=["lm_head"], )
quantization_config=bnb_config, )
# below is also supported
# model = AutoModelForCausalLM.from_pretrained(model_path,
# # nf4 not supported on cpu yet
# load_in_low_bit="sym_int4",
# optimize_model=False,
# torch_dtype=torch.bfloat16,
# modules_to_not_convert=["lm_head"], )
model = model.to('cpu')
model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=False)
model.enable_input_require_grads()

View file

@ -20,6 +20,7 @@ pip install --pre --upgrade bigdl-llm[xpu] -f https://developer.intel.com/ipex-w
pip install datasets transformers==4.34.0
pip install peft==0.5.0
pip install accelerate==0.23.0
pip install bitsandbytes scipy
```
### 2. Configures OneAPI environment variables

View file

@ -17,6 +17,7 @@ pip install datasets transformers==4.34.0
pip install fire peft==0.5.0
pip install oneccl_bind_pt==2.0.100 -f https://developer.intel.com/ipex-whl-stable-xpu # necessary to run distributed finetuning
pip install accelerate==0.23.0
pip install bitsandbytes scipy
```
### 2. Configures OneAPI environment variables

View file

@ -47,6 +47,7 @@ from peft import (
from utils.prompter import Prompter
import intel_extension_for_pytorch as ipex
from transformers import BitsAndBytesConfig
from bigdl.llm.transformers import AutoModelForCausalLM
# import them from bigdl.llm.transformers.qlora to get a BigDL-LLM compatible Peft model
from bigdl.llm.transformers.qlora import get_peft_model, prepare_model_for_kbit_training,\
@ -197,20 +198,36 @@ def train(
# According to the QLoRA paper, using "nf4" could yield better model quality than "int4"
# Default 4-bit format for qa-lora is sym_int4
if training_mode == "qalora":
low_bit_format = "sym_int4"
low_bit_format = "int4"
elif training_mode == "lora":
low_bit_format = "bf16"
else:
low_bit_format = "nf4"
# Load the base model from a directory or the HF Hub to 4-bit format
model = AutoModelForCausalLM.from_pretrained(
base_model,
load_in_low_bit=low_bit_format,
optimize_model=False,
torch_dtype=torch.bfloat16,
# device_map=device_map,
modules_to_not_convert=["lm_head"],
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=False,
bnb_4bit_quant_type=low_bit_format,
bnb_4bit_compute_dtype=torch.bfloat16
)
model = AutoModelForCausalLM.from_pretrained(base_model,
quantization_config=bnb_config, )
# below is also supported
# Load the base model from a directory or the HF Hub to 4-bit format
# if training_mode == "qalora":
# low_bit_format = "sym_int4"
# elif training_mode == "lora":
# low_bit_format = "bf16"
# else:
# low_bit_format = "nf4"
# model = AutoModelForCausalLM.from_pretrained(
# base_model,
# load_in_low_bit=low_bit_format,
# optimize_model=False,
# torch_dtype=torch.bfloat16,
# # device_map=device_map,
# modules_to_not_convert=["lm_head"],
# )
print(f"Model loaded on rank {os.environ.get('LOCAL_RANK')}")
model = model.to(f'xpu:{os.environ.get("LOCAL_RANK", 0)}')
print(f"Model moved to rank {os.environ.get('LOCAL_RANK')}")

View file

@ -21,6 +21,7 @@ import transformers
from transformers import LlamaTokenizer
import intel_extension_for_pytorch as ipex
from peft import LoraConfig
from transformers import BitsAndBytesConfig
from bigdl.llm.transformers.qlora import get_peft_model, prepare_model_for_kbit_training
from bigdl.llm.transformers import AutoModelForCausalLM
from datasets import load_dataset
@ -41,11 +42,22 @@ if __name__ == "__main__":
data = load_dataset(dataset_path)
data = data.map(lambda samples: tokenizer(samples["quote"]), batched=True)
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=False,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16
)
model = AutoModelForCausalLM.from_pretrained(model_path,
load_in_low_bit="nf4",
optimize_model=False,
torch_dtype=torch.float16,
modules_to_not_convert=["lm_head"],)
quantization_config=bnb_config, )
# below is also supported
# model = AutoModelForCausalLM.from_pretrained(model_path,
# load_in_low_bit="nf4",
# optimize_model=False,
# torch_dtype=torch.bfloat16,
# modules_to_not_convert=["lm_head"],)
model = model.to('xpu')
# Enable gradient_checkpointing if your memory is not enough,
# it will slowdown the training speed

View file

@ -88,7 +88,6 @@ def save_low_bit(self, *args, **kwargs):
class _BaseAutoModelClass:
HF_MODEL = None
@classmethod
@ -136,6 +135,48 @@ class _BaseAutoModelClass:
optimize_model = kwargs.pop("optimize_model", True)
user_quantization_config = kwargs.pop("quantization_config", None)
if user_quantization_config is not None and \
"BitsAndBytesConfig" in str(user_quantization_config.__class__):
if user_quantization_config.bnb_4bit_quant_type is not None:
bnb_4bit_type = user_quantization_config.bnb_4bit_quant_type
if bnb_4bit_type == "nf4":
load_in_low_bit = "nf4"
elif bnb_4bit_type == "fp4":
warnings.warn(
"BigDL LLM QLoRA does not support fp4 now, use default nf4", FutureWarning)
load_in_low_bit = "nf4"
elif bnb_4bit_type == "int4":
load_in_low_bit = "sym_int4"
elif bnb_4bit_type == "bf16":
load_in_low_bit = "bf16"
else:
invalidInputError(False,
"Only nf4 or int4 is supported for bnb_4bit_quant_type")
else:
warnings.warn(
"bnb_4bit_quant_type is None, use default int4", FutureWarning)
load_in_low_bit = "sym_int4"
if user_quantization_config.bnb_4bit_use_double_quant is True:
warnings.warn(
"BigDL LLM QLoRA does not support double quant now, set to False",
FutureWarning)
if user_quantization_config.bnb_4bit_compute_dtype is not None:
bnb_dtype = user_quantization_config.bnb_4bit_compute_dtype
if bnb_dtype == torch.float32:
kwargs["torch_dtype"] = bnb_dtype
elif bnb_dtype == torch.bfloat16:
kwargs["torch_dtype"] = bnb_dtype
else:
invalidInputError(False,
"Only float32 or bfloat16"
" is supported for bnb_4bit_compute_dtype")
else:
warnings.warn(
"torch_dtype is None, use default float32", FutureWarning)
kwargs["torch_dtype"] = torch.float32
optimize_model = False
kwargs["modules_to_not_convert"] = ["lm_head"]
if load_in_4bit or load_in_low_bit:
if config_dict.get("quantization_config", None) is not None:
@ -253,9 +294,9 @@ class _BaseAutoModelClass:
# The latest transformers only support cuda version
# This load awq ckpt logic is copied from
# https://github.com/casper-hansen/AutoAWQ/blob/main/awq/models/base.py#L147
from accelerate import init_empty_weights, infer_auto_device_map,\
from accelerate import init_empty_weights, infer_auto_device_map, \
load_checkpoint_in_model
from bigdl.llm.transformers.awq.awq import _replace_with_awq_layers,\
from bigdl.llm.transformers.awq.awq import _replace_with_awq_layers, \
get_layer_type, _load_config
awq_config = quant_config
model_weights_path, config = _load_config(args[0], '', max_new_tokens=None,