Remove zero3 context manager from LoRA (#11346)

This commit is contained in:
Heyang Sun 2024-06-18 17:24:43 +08:00 committed by GitHub
parent f6cd628cd8
commit 67a1e05876
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -158,51 +158,28 @@ def train(
# Check if parameter passed or if set within environ
use_wandb = wandb_check(wandb_project, wandb_watch, wandb_log_model)
if deepspeed_zero3:
deepspeed = deepspeed if deepspeed is not None else "./deepspeed_zero3_config.json"
if saved_low_bit_model is not None:
# Load the low bit optimized model if provide the saved path
if deepspeed_zero3:
import deepspeed as ds
with ds.zero.Init(config_dict_or_path=deepspeed):
model = AutoModelForCausalLM.load_low_bit(
saved_low_bit_model,
optimize_model=False,
torch_dtype=torch.bfloat16,
modules_to_not_convert=["lm_head"],
trust_remote_code=True,
)
else:
model = AutoModelForCausalLM.load_low_bit(
saved_low_bit_model,
optimize_model=False,
torch_dtype=torch.bfloat16,
modules_to_not_convert=["lm_head"],
trust_remote_code=True,
)
model = AutoModelForCausalLM.load_low_bit(
saved_low_bit_model,
optimize_model=False,
torch_dtype=torch.bfloat16,
modules_to_not_convert=["lm_head"],
trust_remote_code=True,
)
else:
model = AutoModelForCausalLM.from_pretrained(
base_model,
load_in_low_bit="bf16",
optimize_model=False,
torch_dtype=torch.bfloat16,
modules_to_not_convert=["lm_head"],
trust_remote_code=True,
)
if deepspeed_zero3:
deepspeed = deepspeed if deepspeed is not None else "./deepspeed_zero3_config.json"
else:
if deepspeed_zero3:
import deepspeed as ds
with ds.zero.Init(config_dict_or_path=deepspeed):
model = AutoModelForCausalLM.from_pretrained(
base_model,
load_in_low_bit="bf16",
optimize_model=False,
torch_dtype=torch.bfloat16,
modules_to_not_convert=["lm_head"],
trust_remote_code=True,
)
else:
model = AutoModelForCausalLM.from_pretrained(
base_model,
load_in_low_bit="bf16",
optimize_model=False,
torch_dtype=torch.bfloat16,
modules_to_not_convert=["lm_head"],
trust_remote_code=True,
)
if not deepspeed_zero3:
print(f"Model loaded on rank {os.environ.get('LOCAL_RANK')}")
model = model.to(f'xpu:{os.environ.get("LOCAL_RANK", 0)}')
print(f"Model moved to rank {os.environ.get('LOCAL_RANK')}")