Remove zero3 context manager from LoRA (#11346)
This commit is contained in:
parent
f6cd628cd8
commit
67a1e05876
1 changed files with 19 additions and 42 deletions
|
|
@ -158,51 +158,28 @@ def train(
|
|||
# Check if parameter passed or if set within environ
|
||||
use_wandb = wandb_check(wandb_project, wandb_watch, wandb_log_model)
|
||||
|
||||
if deepspeed_zero3:
|
||||
deepspeed = deepspeed if deepspeed is not None else "./deepspeed_zero3_config.json"
|
||||
|
||||
if saved_low_bit_model is not None:
|
||||
# Load the low bit optimized model if provide the saved path
|
||||
if deepspeed_zero3:
|
||||
import deepspeed as ds
|
||||
with ds.zero.Init(config_dict_or_path=deepspeed):
|
||||
model = AutoModelForCausalLM.load_low_bit(
|
||||
saved_low_bit_model,
|
||||
optimize_model=False,
|
||||
torch_dtype=torch.bfloat16,
|
||||
modules_to_not_convert=["lm_head"],
|
||||
trust_remote_code=True,
|
||||
)
|
||||
else:
|
||||
model = AutoModelForCausalLM.load_low_bit(
|
||||
saved_low_bit_model,
|
||||
optimize_model=False,
|
||||
torch_dtype=torch.bfloat16,
|
||||
modules_to_not_convert=["lm_head"],
|
||||
trust_remote_code=True,
|
||||
)
|
||||
model = AutoModelForCausalLM.load_low_bit(
|
||||
saved_low_bit_model,
|
||||
optimize_model=False,
|
||||
torch_dtype=torch.bfloat16,
|
||||
modules_to_not_convert=["lm_head"],
|
||||
trust_remote_code=True,
|
||||
)
|
||||
else:
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
base_model,
|
||||
load_in_low_bit="bf16",
|
||||
optimize_model=False,
|
||||
torch_dtype=torch.bfloat16,
|
||||
modules_to_not_convert=["lm_head"],
|
||||
trust_remote_code=True,
|
||||
)
|
||||
|
||||
if deepspeed_zero3:
|
||||
deepspeed = deepspeed if deepspeed is not None else "./deepspeed_zero3_config.json"
|
||||
else:
|
||||
if deepspeed_zero3:
|
||||
import deepspeed as ds
|
||||
with ds.zero.Init(config_dict_or_path=deepspeed):
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
base_model,
|
||||
load_in_low_bit="bf16",
|
||||
optimize_model=False,
|
||||
torch_dtype=torch.bfloat16,
|
||||
modules_to_not_convert=["lm_head"],
|
||||
trust_remote_code=True,
|
||||
)
|
||||
else:
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
base_model,
|
||||
load_in_low_bit="bf16",
|
||||
optimize_model=False,
|
||||
torch_dtype=torch.bfloat16,
|
||||
modules_to_not_convert=["lm_head"],
|
||||
trust_remote_code=True,
|
||||
)
|
||||
if not deepspeed_zero3:
|
||||
print(f"Model loaded on rank {os.environ.get('LOCAL_RANK')}")
|
||||
model = model.to(f'xpu:{os.environ.get("LOCAL_RANK", 0)}')
|
||||
print(f"Model moved to rank {os.environ.get('LOCAL_RANK')}")
|
||||
|
|
|
|||
Loading…
Reference in a new issue