Remove zero3 context manager from LoRA (#11346)
This commit is contained in:
parent
f6cd628cd8
commit
67a1e05876
1 changed files with 19 additions and 42 deletions
|
|
@ -158,51 +158,28 @@ def train(
|
||||||
# Check if parameter passed or if set within environ
|
# Check if parameter passed or if set within environ
|
||||||
use_wandb = wandb_check(wandb_project, wandb_watch, wandb_log_model)
|
use_wandb = wandb_check(wandb_project, wandb_watch, wandb_log_model)
|
||||||
|
|
||||||
if deepspeed_zero3:
|
|
||||||
deepspeed = deepspeed if deepspeed is not None else "./deepspeed_zero3_config.json"
|
|
||||||
|
|
||||||
if saved_low_bit_model is not None:
|
if saved_low_bit_model is not None:
|
||||||
# Load the low bit optimized model if provide the saved path
|
# Load the low bit optimized model if provide the saved path
|
||||||
if deepspeed_zero3:
|
model = AutoModelForCausalLM.load_low_bit(
|
||||||
import deepspeed as ds
|
saved_low_bit_model,
|
||||||
with ds.zero.Init(config_dict_or_path=deepspeed):
|
optimize_model=False,
|
||||||
model = AutoModelForCausalLM.load_low_bit(
|
torch_dtype=torch.bfloat16,
|
||||||
saved_low_bit_model,
|
modules_to_not_convert=["lm_head"],
|
||||||
optimize_model=False,
|
trust_remote_code=True,
|
||||||
torch_dtype=torch.bfloat16,
|
)
|
||||||
modules_to_not_convert=["lm_head"],
|
else:
|
||||||
trust_remote_code=True,
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
)
|
base_model,
|
||||||
else:
|
load_in_low_bit="bf16",
|
||||||
model = AutoModelForCausalLM.load_low_bit(
|
optimize_model=False,
|
||||||
saved_low_bit_model,
|
torch_dtype=torch.bfloat16,
|
||||||
optimize_model=False,
|
modules_to_not_convert=["lm_head"],
|
||||||
torch_dtype=torch.bfloat16,
|
trust_remote_code=True,
|
||||||
modules_to_not_convert=["lm_head"],
|
)
|
||||||
trust_remote_code=True,
|
|
||||||
)
|
if deepspeed_zero3:
|
||||||
|
deepspeed = deepspeed if deepspeed is not None else "./deepspeed_zero3_config.json"
|
||||||
else:
|
else:
|
||||||
if deepspeed_zero3:
|
|
||||||
import deepspeed as ds
|
|
||||||
with ds.zero.Init(config_dict_or_path=deepspeed):
|
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
|
||||||
base_model,
|
|
||||||
load_in_low_bit="bf16",
|
|
||||||
optimize_model=False,
|
|
||||||
torch_dtype=torch.bfloat16,
|
|
||||||
modules_to_not_convert=["lm_head"],
|
|
||||||
trust_remote_code=True,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
|
||||||
base_model,
|
|
||||||
load_in_low_bit="bf16",
|
|
||||||
optimize_model=False,
|
|
||||||
torch_dtype=torch.bfloat16,
|
|
||||||
modules_to_not_convert=["lm_head"],
|
|
||||||
trust_remote_code=True,
|
|
||||||
)
|
|
||||||
if not deepspeed_zero3:
|
|
||||||
print(f"Model loaded on rank {os.environ.get('LOCAL_RANK')}")
|
print(f"Model loaded on rank {os.environ.get('LOCAL_RANK')}")
|
||||||
model = model.to(f'xpu:{os.environ.get("LOCAL_RANK", 0)}')
|
model = model.to(f'xpu:{os.environ.get("LOCAL_RANK", 0)}')
|
||||||
print(f"Model moved to rank {os.environ.get('LOCAL_RANK')}")
|
print(f"Model moved to rank {os.environ.get('LOCAL_RANK')}")
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue