LLM: update qlora alpaca example to change lora usage (#9835)

* update example

* fix style
This commit is contained in:
Ruonan Wang 2024-01-04 15:22:20 +08:00 committed by GitHub
parent 05b681fa85
commit 8504a2bbca

View file

@ -196,20 +196,28 @@ def train(
else: else:
# According to the QLoRA paper, using "nf4" could yield better model quality than "int4" # According to the QLoRA paper, using "nf4" could yield better model quality than "int4"
# Default 4-bit format for qa-lora is sym_int4 # Default 4-bit format for qa-lora is sym_int4
if training_mode == "qalora": if training_mode == "lora":
low_bit_format = "int4" model = AutoModelForCausalLM.from_pretrained(
elif training_mode == "lora": base_model,
low_bit_format = "bf16" load_in_low_bit="bf16",
optimize_model=False,
torch_dtype=torch.bfloat16,
modules_to_not_convert=["lm_head"],
)
else: else:
low_bit_format = "nf4" # use bnb_config for qlora/qalora/relora, which use 4bit for base model
bnb_config = BitsAndBytesConfig( if training_mode == "qalora":
load_in_4bit=True, low_bit_format = "int4"
bnb_4bit_use_double_quant=False, else:
bnb_4bit_quant_type=low_bit_format, low_bit_format = "nf4"
bnb_4bit_compute_dtype=torch.bfloat16 bnb_config = BitsAndBytesConfig(
) load_in_4bit=True,
model = AutoModelForCausalLM.from_pretrained(base_model, bnb_4bit_use_double_quant=False,
quantization_config=bnb_config, ) bnb_4bit_quant_type=low_bit_format,
bnb_4bit_compute_dtype=torch.bfloat16
)
model = AutoModelForCausalLM.from_pretrained(base_model,
quantization_config=bnb_config, )
# below is also supported # below is also supported
# Load the base model from a directory or the HF Hub to 4-bit format # Load the base model from a directory or the HF Hub to 4-bit format