diff --git a/python/llm/example/GPU/QLoRA-FineTuning/alpaca-qlora/alpaca_qlora_finetuning.py b/python/llm/example/GPU/QLoRA-FineTuning/alpaca-qlora/alpaca_qlora_finetuning.py index 4c9b0421..0a9c41b8 100644 --- a/python/llm/example/GPU/QLoRA-FineTuning/alpaca-qlora/alpaca_qlora_finetuning.py +++ b/python/llm/example/GPU/QLoRA-FineTuning/alpaca-qlora/alpaca_qlora_finetuning.py @@ -196,20 +196,28 @@ def train( else: # According to the QLoRA paper, using "nf4" could yield better model quality than "int4" # Default 4-bit format for qa-lora is sym_int4 - if training_mode == "qalora": - low_bit_format = "int4" - elif training_mode == "lora": - low_bit_format = "bf16" + if training_mode == "lora": + model = AutoModelForCausalLM.from_pretrained( + base_model, + load_in_low_bit="bf16", + optimize_model=False, + torch_dtype=torch.bfloat16, + modules_to_not_convert=["lm_head"], + ) else: - low_bit_format = "nf4" - bnb_config = BitsAndBytesConfig( - load_in_4bit=True, - bnb_4bit_use_double_quant=False, - bnb_4bit_quant_type=low_bit_format, - bnb_4bit_compute_dtype=torch.bfloat16 - ) - model = AutoModelForCausalLM.from_pretrained(base_model, - quantization_config=bnb_config, ) + # use bnb_config for qlora/qalora/relora, which use 4bit for base model + if training_mode == "qalora": + low_bit_format = "int4" + else: + low_bit_format = "nf4" + bnb_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_use_double_quant=False, + bnb_4bit_quant_type=low_bit_format, + bnb_4bit_compute_dtype=torch.bfloat16 + ) + model = AutoModelForCausalLM.from_pretrained(base_model, + quantization_config=bnb_config, ) # below is also supported # Load the base model from a directory or the HF Hub to 4-bit format