LLM: update qlora alpaca example to change lora usage (#9835)
* update example * fix style
This commit is contained in:
parent
05b681fa85
commit
8504a2bbca
1 changed files with 21 additions and 13 deletions
|
|
@ -196,10 +196,18 @@ def train(
|
||||||
else:
|
else:
|
||||||
# According to the QLoRA paper, using "nf4" could yield better model quality than "int4"
|
# According to the QLoRA paper, using "nf4" could yield better model quality than "int4"
|
||||||
# Default 4-bit format for qa-lora is sym_int4
|
# Default 4-bit format for qa-lora is sym_int4
|
||||||
|
if training_mode == "lora":
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
base_model,
|
||||||
|
load_in_low_bit="bf16",
|
||||||
|
optimize_model=False,
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
modules_to_not_convert=["lm_head"],
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# use bnb_config for qlora/qalora/relora, which use 4bit for base model
|
||||||
if training_mode == "qalora":
|
if training_mode == "qalora":
|
||||||
low_bit_format = "int4"
|
low_bit_format = "int4"
|
||||||
elif training_mode == "lora":
|
|
||||||
low_bit_format = "bf16"
|
|
||||||
else:
|
else:
|
||||||
low_bit_format = "nf4"
|
low_bit_format = "nf4"
|
||||||
bnb_config = BitsAndBytesConfig(
|
bnb_config = BitsAndBytesConfig(
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue