LLM: Update for cpu qlora mpirun (#9548)
This commit is contained in:
parent
5f5ca38b74
commit
b824754256
1 changed files with 15 additions and 0 deletions
|
|
@ -52,6 +52,21 @@ from bigdl.llm.transformers import AutoModelForCausalLM
|
|||
# import them from bigdl.llm.transformers.qlora to get a BigDL-LLM compatible Peft model
|
||||
from bigdl.llm.transformers.qlora import get_peft_model, prepare_model_for_kbit_training
|
||||
|
||||
def get_int_from_env(env_keys, default):
|
||||
"""Returns the first positive env value found in the `env_keys` list or the default."""
|
||||
for e in env_keys:
|
||||
val = int(os.environ.get(e, -1))
|
||||
if val >= 0:
|
||||
return val
|
||||
return default
|
||||
|
||||
local_rank = get_int_from_env(["LOCAL_RANK","MPI_LOCALRANKID"], "0")
|
||||
world_size = get_int_from_env(["WORLD_SIZE","PMI_SIZE"], "1")
|
||||
port = get_int_from_env(["MASTER_PORT"], 29500)
|
||||
os.environ["LOCAL_RANK"] = str(local_rank)
|
||||
os.environ["WORLD_SIZE"] = str(world_size)
|
||||
os.environ["RANK"] = str(local_rank)
|
||||
os.environ["MASTER_PORT"] = str(port)
|
||||
|
||||
def train(
|
||||
# model/data params
|
||||
|
|
|
|||
Loading…
Reference in a new issue